Initial support for multiple schemas

This commit is contained in:
Dan Sosedoff
2016-01-12 21:33:44 -06:00
parent 9c7eaf63d5
commit 9ffa05affb
11 changed files with 410 additions and 137 deletions

View File

@@ -3,6 +3,7 @@ package client
import (
"fmt"
"reflect"
"strings"
_ "github.com/lib/pq"
@@ -28,6 +29,14 @@ type RowsOptions struct {
SortOrder string // Sort direction (ASC, DESC)
}
func getSchemaAndTable(str string) (string, string) {
chunks := strings.Split(str, ".")
if len(chunks) == 1 {
return "public", chunks[0]
}
return chunks[0], chunks[1]
}
func New() (*Client, error) {
str, err := connection.BuildString(command.Opts)
@@ -89,16 +98,18 @@ func (client *Client) Schemas() ([]string, error) {
return client.fetchRows(statements.PG_SCHEMAS)
}
func (client *Client) Tables() ([]string, error) {
return client.fetchRows(statements.PG_TABLES)
func (client *Client) Objects() (*Result, error) {
return client.query(statements.PG_OBJECTS)
}
func (client *Client) Table(table string) (*Result, error) {
return client.query(statements.PG_TABLE_SCHEMA, table)
schema, table := getSchemaAndTable(table)
return client.query(statements.PG_TABLE_SCHEMA, schema, table)
}
func (client *Client) TableRows(table string, opts RowsOptions) (*Result, error) {
sql := fmt.Sprintf(`SELECT * FROM "%s"`, table)
schema, table := getSchemaAndTable(table)
sql := fmt.Sprintf(`SELECT * FROM "%s"."%s"`, schema, table)
if opts.Where != "" {
sql += fmt.Sprintf(" WHERE %s", opts.Where)
@@ -124,7 +135,8 @@ func (client *Client) TableRows(table string, opts RowsOptions) (*Result, error)
}
func (client *Client) TableRowsCount(table string, opts RowsOptions) (*Result, error) {
sql := fmt.Sprintf(`SELECT COUNT(1) FROM "%s"`, table)
schema, table := getSchemaAndTable(table)
sql := fmt.Sprintf(`SELECT COUNT(1) FROM "%s"."%s"`, schema, table)
if opts.Where != "" {
sql += fmt.Sprintf(" WHERE %s", opts.Where)
@@ -138,7 +150,8 @@ func (client *Client) TableInfo(table string) (*Result, error) {
}
func (client *Client) TableIndexes(table string) (*Result, error) {
res, err := client.query(statements.PG_TABLE_INDEXES, table)
schema, table := getSchemaAndTable(table)
res, err := client.query(statements.PG_TABLE_INDEXES, schema, table)
if err != nil {
return nil, err
@@ -148,7 +161,8 @@ func (client *Client) TableIndexes(table string) (*Result, error) {
}
func (client *Client) TableConstraints(table string) (*Result, error) {
res, err := client.query(statements.PG_TABLE_CONSTRAINTS, table)
schema, table := getSchemaAndTable(table)
res, err := client.query(statements.PG_TABLE_CONSTRAINTS, schema, table)
if err != nil {
return nil, err

View File

@@ -15,6 +15,14 @@ var (
testCommands map[string]string
)
func mapKeys(data map[string]*Objects) []string {
result := []string{}
for k, _ := range data {
result = append(result, k)
}
return result
}
func setupCommands() {
testCommands = map[string]string{
"createdb": "createdb",
@@ -112,10 +120,11 @@ func test_Databases(t *testing.T) {
assert.Contains(t, res, "postgres")
}
func test_Tables(t *testing.T) {
res, err := testClient.Tables()
func test_Objects(t *testing.T) {
res, err := testClient.Objects()
objects := ObjectsFromResult(res)
expected := []string{
tables := []string{
"alternate_stock",
"authors",
"book_backup",
@@ -132,19 +141,21 @@ func test_Tables(t *testing.T) {
"my_list",
"numeric_values",
"publishers",
"recent_shipments",
"schedules",
"shipments",
"states",
"stock",
"stock_backup",
"stock_view",
"subjects",
"text_sorting",
}
assert.Equal(t, nil, err)
assert.Equal(t, expected, res)
assert.Equal(t, []string{"schema", "name", "type", "owner"}, res.Columns)
assert.Equal(t, []string{"public"}, mapKeys(objects))
assert.Equal(t, tables, objects["public"].Tables)
assert.Equal(t, []string{"recent_shipments", "stock_view"}, objects["public"].Views)
assert.Equal(t, []string{"author_ids", "book_ids", "shipments_ship_id_seq", "subject_ids"}, objects["public"].Sequences)
}
func test_Table(t *testing.T) {
@@ -284,7 +295,7 @@ func TestAll(t *testing.T) {
test_Test(t)
test_Info(t)
test_Databases(t)
test_Tables(t)
test_Objects(t)
test_Table(t)
test_TableRows(t)
test_TableInfo(t)

View File

@@ -24,6 +24,12 @@ type Result struct {
Rows []Row `json:"rows"`
}
type Objects struct {
Tables []string `json:"tables"`
Views []string `json:"views"`
Sequences []string `json:"sequences"`
}
// Due to big int number limitations in javascript, numbers should be encoded
// as strings so they could be properly loaded on the frontend.
func (res *Result) PrepareBigints() {
@@ -98,3 +104,32 @@ func (res *Result) JSON() []byte {
data, _ := json.Marshal(res.Format())
return data
}
func ObjectsFromResult(res *Result) map[string]*Objects {
objects := map[string]*Objects{}
for _, row := range res.Rows {
schema := row[0].(string)
name := row[1].(string)
object_type := row[2].(string)
if objects[schema] == nil {
objects[schema] = &Objects{
Tables: []string{},
Views: []string{},
Sequences: []string{},
}
}
switch object_type {
case "table":
objects[schema].Tables = append(objects[schema].Tables, name)
case "view":
objects[schema].Views = append(objects[schema].Views, name)
case "sequence":
objects[schema].Sequences = append(objects[schema].Sequences, name)
}
}
return objects
}