diff --git a/db/alter.go b/db/alter.go index 0564d3e..6028700 100644 --- a/db/alter.go +++ b/db/alter.go @@ -1,52 +1,52 @@ package db import ( "fmt" "strings" ) type AlterTableSqlBuilder struct { Dialect DialectType Name string Changes []string } func (b *AlterTableSqlBuilder) AddColumn(col *Column) *AlterTableSqlBuilder { - if colVal, err := col.ToSQL(b.Dialect); err == nil { + if colVal, err := col.CreateSQL(b.Dialect); err == nil { b.Changes = append(b.Changes, fmt.Sprintf("ADD COLUMN %s", colVal)) } return b } func (b *AlterTableSqlBuilder) ChangeColumn(name string, col *Column) *AlterTableSqlBuilder { - if colVal, err := col.ToSQL(b.Dialect); err == nil { - b.Changes = append(b.Changes, fmt.Sprintf("CHANGE COLUMN %s %s", name, colVal)) + if colActions, err := col.AlterSQL(b.Dialect, name); err == nil { + b.Changes = append(b.Changes, colActions...) } return b } func (b *AlterTableSqlBuilder) AddUniqueConstraint(name string, columns ...string) *AlterTableSqlBuilder { b.Changes = append(b.Changes, fmt.Sprintf("ADD CONSTRAINT %s UNIQUE (%s)", name, strings.Join(columns, ", "))) return b } func (b *AlterTableSqlBuilder) ToSQL() (string, error) { var str strings.Builder str.WriteString("ALTER TABLE ") str.WriteString(b.Name) str.WriteString(" ") if len(b.Changes) == 0 { return "", fmt.Errorf("no changes provide for table: %s", b.Name) } changeCount := len(b.Changes) for i, thing := range b.Changes { str.WriteString(thing) if i < changeCount-1 { str.WriteString(", ") } } return str.String(), nil } diff --git a/db/alter_test.go b/db/alter_test.go index 4bd58ac..4d47821 100644 --- a/db/alter_test.go +++ b/db/alter_test.go @@ -1,56 +1,71 @@ package db import "testing" func TestAlterTableSqlBuilder_ToSQL(t *testing.T) { type fields struct { Dialect DialectType Name string Changes []string } tests := []struct { name string builder *AlterTableSqlBuilder want string wantErr bool }{ { name: "MySQL add int", builder: DialectMySQL. AlterTable("the_table"). - AddColumn(DialectMySQL.Column("the_col", ColumnTypeInteger, UnsetSize)), - want: "ALTER TABLE the_table ADD COLUMN the_col INT NOT NULL", + AddColumn(NonNullableColumn("the_col", ColumnTypeInt{MaxBytes: 4})), + want: "ALTER TABLE the_table ADD COLUMN the_col INTEGER NOT NULL", wantErr: false, }, { name: "MySQL add string", builder: DialectMySQL. AlterTable("the_table"). - AddColumn(DialectMySQL.Column("the_col", ColumnTypeVarChar, OptionalInt{true, 128})), + AddColumn(NonNullableColumn("the_col", ColumnTypeString{MaxChars: 128})), want: "ALTER TABLE the_table ADD COLUMN the_col VARCHAR(128) NOT NULL", wantErr: false, }, - { name: "MySQL add int and string", builder: DialectMySQL. AlterTable("the_table"). - AddColumn(DialectMySQL.Column("first_col", ColumnTypeInteger, UnsetSize)). - AddColumn(DialectMySQL.Column("second_col", ColumnTypeVarChar, OptionalInt{true, 128})), + AddColumn(NonNullableColumn("first_col", ColumnTypeInt{MaxBytes: 4})). + AddColumn(NonNullableColumn("second_col", ColumnTypeString{MaxChars: 128})), want: "ALTER TABLE the_table ADD COLUMN first_col INT NOT NULL, ADD COLUMN second_col VARCHAR(128) NOT NULL", wantErr: false, }, + { + name: "MySQL change to string", + builder: DialectMySQL. + AlterTable("the_table"). + ChangeColumn("old_col", NonNullableColumn("new_col", ColumnTypeString{})), + want: "ALTER TABLE the_table RENAME COLUMN old_col TO new_col, MODIFY COLUMN new_col TEXT NOT NULL", + wantErr: false, + }, + { + name: "PostgreSQL change to int", + builder: DialectMySQL. + AlterTable("the_table"). + ChangeColumn("old_col", NullableColumn("new_col", ColumnTypeInt{MaxBytes: 4})), + want: "ALTER TABLE the_table RENAME COLUMN old_col TO new_col, ALTER COLUMN new_col TYPE INTEGER, ALTER COLUMN new_col DROP NOT NULL", + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.builder.ToSQL() if (err != nil) != tt.wantErr { t.Errorf("ToSQL() error = %v, wantErr %v", err, tt.wantErr) return } if got != tt.want { t.Errorf("ToSQL() got = %v, want %v", got, tt.want) } }) } } diff --git a/db/column.go b/db/column.go index 4e0bb6d..5bf64e4 100644 --- a/db/column.go +++ b/db/column.go @@ -1,280 +1,328 @@ /* * Copyright © 2019-2022 A Bunch Tell LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package db import ( "fmt" "strings" ) type Column struct { Name string Type ColumnType Nullable bool PrimaryKey bool } func NullableColumn(name string, ty ColumnType) *Column { return &Column{ Name: name, Type: ty, Nullable: true, PrimaryKey: false, } } func NonNullableColumn(name string, ty ColumnType) *Column { return &Column{ Name: name, Type: ty, Nullable: false, PrimaryKey: false, } } func PrimaryKeyColumn(name string, ty ColumnType) *Column { return &Column{ Name: name, Type: ty, Nullable: false, PrimaryKey: true, } } type ColumnType interface { Name(DialectType) (string, error) Default(DialectType) (string, error) } type ColumnTypeInt struct { IsSigned bool MaxBytes int MaxDigits int HasDefault bool DefaultVal int } type ColumnTypeString struct { IsFixedLength bool MaxChars int HasDefault bool DefaultVal string } type ColumnDefault int type ColumnTypeBool struct { DefaultVal ColumnDefault } const ( NoDefault ColumnDefault = iota DefaultFalse ColumnDefault = iota DefaultTrue ColumnDefault = iota DefaultNow ColumnDefault = iota ) type ColumnTypeDateTime struct { DefaultVal ColumnDefault } func (intCol ColumnTypeInt) Name(d DialectType) (string, error) { switch d { case DialectSQLite: return "INTEGER", nil - case DialectMySQL: + case DialectMySQL, DialectPostgreSQL: var colName string switch intCol.MaxBytes { case 1: - colName = "TINYINT" + if d == DialectMySQL { + colName = "TINYINT" + } else { + colName = "SMALLINT" + } case 2: colName = "SMALLINT" case 3: - colName = "MEDIUMINT" + if d == DialectMySQL { + colName = "MEDIUMINT" + } else { + colName = "INTEGER" + } case 4: - colName = "INT" + colName = "INTEGER" default: colName = "BIGINT" } - if intCol.MaxDigits > 0 { - colName = fmt.Sprintf("%s(%d)", colName, intCol.MaxDigits) - } - if !intCol.IsSigned { - colName += " UNSIGNED" + if d == DialectMySQL { + if intCol.MaxDigits > 0 { + colName = fmt.Sprintf("%s(%d)", colName, intCol.MaxDigits) + } + if !intCol.IsSigned { + colName += " UNSIGNED" + } } return colName, nil default: return "", fmt.Errorf("dialect %d does not support integer columns", d) } } func (intCol ColumnTypeInt) Default(d DialectType) (string, error) { - switch d { - case DialectSQLite, DialectMySQL: - if intCol.HasDefault { - return fmt.Sprintf("%d", intCol.DefaultVal), nil - } - return "", nil - default: - return "", fmt.Errorf("dialect %d does not support defaulted integer columns", d) + if intCol.HasDefault { + return fmt.Sprintf("%d", intCol.DefaultVal), nil } + return "", nil } func (strCol ColumnTypeString) Name(d DialectType) (string, error) { switch d { case DialectSQLite: return "TEXT", nil - case DialectMySQL: + case DialectMySQL, DialectPostgreSQL: if strCol.IsFixedLength { if strCol.MaxChars > 0 { return fmt.Sprintf("CHAR(%d)", strCol.MaxChars), nil } return "CHAR", nil } if strCol.MaxChars <= 0 { return "TEXT", nil } if strCol.MaxChars < (1 << 16) { return fmt.Sprintf("VARCHAR(%d)", strCol.MaxChars), nil } return "TEXT", nil default: return "", fmt.Errorf("dialect %d does not support string columns", d) } } func (strCol ColumnTypeString) Default(d DialectType) (string, error) { - switch d { - case DialectSQLite, DialectMySQL: - if strCol.HasDefault { - return EscapeSimple.SQLEscape(d, strCol.DefaultVal) - } - return "", nil - default: - return "", fmt.Errorf("dialect %d does not support defaulted string columns", d) + if strCol.HasDefault { + return EscapeSimple.SQLEscape(d, strCol.DefaultVal) } + return "", nil } func (boolCol ColumnTypeBool) Name(d DialectType) (string, error) { switch d { case DialectSQLite: return "INTEGER", nil - case DialectMySQL: + case DialectMySQL, DialectPostgreSQL: return "BOOL", nil default: return "", fmt.Errorf("boolean column type not supported for dialect %d", d) } } func (boolCol ColumnTypeBool) Default(d DialectType) (string, error) { - switch d { - case DialectSQLite, DialectMySQL: - switch boolCol.DefaultVal { - case NoDefault: - return "", nil - case DefaultFalse: - return "0", nil - case DefaultTrue: - return "1", nil - default: - return "", fmt.Errorf("boolean columns cannot default to %d for dialect %d", boolCol.DefaultVal, d) - } + switch boolCol.DefaultVal { + case NoDefault: + return "", nil + case DefaultFalse: + return "0", nil + case DefaultTrue: + return "1", nil default: - return "", fmt.Errorf("dialect %d does not support defaulted boolean columns", d) + return "", fmt.Errorf("boolean columns cannot default to %d for dialect %d", boolCol.DefaultVal, d) } } func (dateTimeCol ColumnTypeDateTime) Name(d DialectType) (string, error) { switch d { case DialectSQLite, DialectMySQL: return "DATETIME", nil + case DialectPostgreSQL: + return "TIMESTAMP", nil default: return "", fmt.Errorf("datetime column type not supported for dialect %d", d) } } func (dateTimeCol ColumnTypeDateTime) Default(d DialectType) (string, error) { switch d { case DialectSQLite, DialectMySQL: switch dateTimeCol.DefaultVal { case NoDefault: return "", nil case DefaultNow: switch d { - case DialectSQLite: + case DialectSQLite, DialectPostgreSQL: return "CURRENT_TIMESTAMP", nil case DialectMySQL: return "NOW()", nil } } return "", fmt.Errorf("datetime columns cannot default to %d for dialect %d", dateTimeCol.DefaultVal, d) default: return "", fmt.Errorf("dialect %d does not support defaulted datetime columns", d) } } func (c *Column) SetName(name string) *Column { c.Name = name return c } func (c *Column) SetNullable(nullable bool) *Column { c.Nullable = nullable return c } func (c *Column) SetPrimaryKey(pk bool) *Column { c.PrimaryKey = pk return c } func (c *Column) SetType(t ColumnType) *Column { c.Type = t return c } -func (c *Column) ToSQL(d DialectType) (string, error) { +func (c *Column) AlterSQL(d DialectType, oldName string) ([]string, error) { + var actions []string = make([]string, 0) + + switch d { + // MySQL does all modifications at once + case DialectMySQL: + sql, err := c.CreateSQL(d) + if err != nil { + return make([]string, 0), err + } + actions = append(actions, fmt.Sprintf("CHANGE COLUMN %s %s", oldName, sql)) + + // PostgreSQL does modifications piece by piece + case DialectPostgreSQL: + if oldName != c.Name { + actions = append(actions, fmt.Sprintf("RENAME COLUMN %s TO %s", oldName, c.Name)) + } + + typeStr, err := c.Type.Name(d) + if err != nil { + return make([]string, 0), err + } + + actions = append(actions, fmt.Sprintf("ALTER COLUMN %s TYPE %s", c.Name, typeStr)) + var nullAction string + if c.Nullable { + nullAction = "DROP" + } else { + nullAction = "SET" + } + actions = append(actions, fmt.Sprintf("ALTER COLUMN %s %s NOT NULL", c.Name, nullAction)) + + defaultStr, err := c.Type.Default(d) + if err != nil { + return make([]string, 0), err + } + if len(defaultStr) > 0 { + actions = append(actions, fmt.Sprintf("ALTER COLUMN %s SET DEFAULT %s", c.Name, defaultStr)) + } + + if c.PrimaryKey { + actions = append(actions, fmt.Sprintf("ADD PRIMARY KEY (%s)", c.Name)) + } + + default: + return make([]string, 0), fmt.Errorf("dialect %d doesn't support altering column data type", d) + } + + return actions, nil +} + +func (c *Column) CreateSQL(d DialectType) (string, error) { var str strings.Builder str.WriteString(c.Name) str.WriteString(" ") typeStr, err := c.Type.Name(d) if err != nil { return "", err } str.WriteString(typeStr) if !c.Nullable { str.WriteString(" NOT NULL") } defaultStr, err := c.Type.Default(d) if err != nil { return "", err } if len(defaultStr) > 0 { str.WriteString(" DEFAULT ") str.WriteString(defaultStr) } if c.PrimaryKey { str.WriteString(" PRIMARY KEY") } return str.String(), nil } diff --git a/db/column_test.go b/db/column_test.go new file mode 100644 index 0000000..175ec3a --- /dev/null +++ b/db/column_test.go @@ -0,0 +1,151 @@ +package db + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestColumnType_Name(t *testing.T) { + tests := []struct { + name string + ty ColumnType + d DialectType + want string + wantErr bool + }{ + {"SQLite bool", ColumnTypeBool{}, DialectSQLite, "INTEGER", false}, + {"SQLite int", ColumnTypeInt{}, DialectSQLite, "INTEGER", false}, + {"SQLite string ", ColumnTypeString{HasDefault: true, DefaultVal: "that's a default"}, DialectSQLite, "TEXT DEFAULT 'that''s a default'", false}, + {"SQLite datetime", ColumnTypeDateTime{}, DialectSQLite, "DATETIME", false}, + + {"MySQL bool", ColumnTypeBool{}, DialectMySQL, "BOOL", false}, + {"MySQL tiny int", ColumnTypeInt{MaxBytes: 1}, DialectMySQL, "TINYINT", false}, + {"MySQL tiny int with digits", ColumnTypeInt{MaxBytes: 1, MaxDigits: 2}, DialectMySQL, "TINYINT(2)", false}, + {"MySQL small int", ColumnTypeInt{MaxBytes: 2}, DialectMySQL, "SMALLINT", false}, + {"MySQL small int with digits", ColumnTypeInt{MaxBytes: 2, MaxDigits: 3}, DialectMySQL, "SMALLINT(3)", false}, + {"MySQL medium int", ColumnTypeInt{MaxBytes: 3}, DialectMySQL, "MEDIUMINT", false}, + {"MySQL medium int with digits", ColumnTypeInt{MaxBytes: 3, MaxDigits: 6}, DialectMySQL, "MEDIUMINT(6)", false}, + {"MySQL int", ColumnTypeInt{MaxBytes: 4}, DialectMySQL, "INTEGER", false}, + {"MySQL int with digits", ColumnTypeInt{MaxBytes: 4, MaxDigits: 11}, DialectMySQL, "INTEGER(11)", false}, + {"MySQL bigint", ColumnTypeInt{MaxBytes: 4}, DialectMySQL, "BIGINT", false}, + {"MySQL bigint with digits", ColumnTypeInt{MaxBytes: 4, MaxDigits: 15}, DialectMySQL, "BIGINT(15)", false}, + {"MySQL char", ColumnTypeString{IsFixedLength: true}, DialectMySQL, "CHAR", false}, + {"MySQL char with length", ColumnTypeString{IsFixedLength: true, MaxChars: 4}, DialectMySQL, "CHAR(4)", false}, + {"MySQL varchar with length", ColumnTypeString{MaxChars: 25}, DialectMySQL, "VARCHAR(25)", false}, + {"MySQL text", ColumnTypeString{}, DialectMySQL, "TEXT", false}, + {"MySQL datetime", ColumnTypeDateTime{}, DialectMySQL, "DATETIME", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.ty.Name(tt.d) + if (err != nil) != tt.wantErr { + t.Errorf("Name() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Name() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestColumnType_Default(t *testing.T) { + tests := []struct { + name string + ty ColumnType + d DialectType + want string + wantErr bool + }{ + {"SQLite bool none", ColumnTypeBool{}, DialectSQLite, "", false}, + {"SQLite bool false", ColumnTypeBool{}, DialectSQLite, "0", false}, + {"SQLite bool true", ColumnTypeBool{}, DialectSQLite, "1", false}, + {"SQLite int none", ColumnTypeInt{}, DialectSQLite, "", false}, + {"SQLite int empty", ColumnTypeInt{HasDefault: true}, DialectSQLite, "0", false}, + {"SQLite int", ColumnTypeInt{HasDefault: true, DefaultVal: 10}, DialectSQLite, "10", false}, + {"SQLite string none", ColumnTypeString{}, DialectSQLite, "", false}, + {"SQLite string empty", ColumnTypeString{HasDefault: true}, DialectSQLite, "''", false}, + {"SQLite string", ColumnTypeString{HasDefault: true, DefaultVal: "that's a default"}, DialectSQLite, "'that''s a default'", false}, + {"MySQL string", ColumnTypeString{HasDefault: true, DefaultVal: "%that's a default%"}, DialectMySQL, "'%that\\'s a default%'", false}, + + {"SQLite datetime none", ColumnTypeDateTime{}, DialectSQLite, "", false}, + {"SQLite datetime now", ColumnTypeDateTime{DefaultVal: DefaultNow}, DialectSQLite, "CURRENT_TIMESTAMP", false}, + {"MySQL datetime now", ColumnTypeDateTime{DefaultVal: DefaultNow}, DialectMySQL, "NOW()", false}, + {"PostgreSQL datetime now", ColumnTypeDateTime{DefaultVal: DefaultNow}, DialectPostgreSQL, "CURRENT_TIMESTAMP", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.ty.Default(tt.d) + if (err != nil) != tt.wantErr { + t.Errorf("Default() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Default() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestColumn_CreateSQL(t *testing.T) { + type fields struct { + Dialect DialectType + Name string + Nullable bool + Type ColumnType + PrimaryKey bool + } + tests := []struct { + name string + fields fields + want string + wantErr bool + }{ + {"SQLite bool", fields{DialectSQLite, "foo", false, ColumnTypeBool{}, false}, "foo INTEGER NOT NULL", false}, + {"SQLite bool nullable", fields{DialectSQLite, "foo", true, ColumnTypeBool{}, false}, "foo INTEGER", false}, + {"SQLite int", fields{DialectSQLite, "foo", false, ColumnTypeInt{}, true}, "foo INTEGER NOT NULL PRIMARY KEY", false}, + {"SQLite int nullable", fields{DialectSQLite, "foo", true, ColumnTypeInt{}, false}, "foo INTEGER", false}, + {"SQLite text", fields{DialectSQLite, "foo", false, ColumnTypeString{}, false}, "foo TEXT NOT NULL", false}, + {"SQLite text nullable", fields{DialectSQLite, "foo", true, ColumnTypeString{}, false}, "foo TEXT", false}, + {"SQLite datetime", fields{DialectSQLite, "foo", false, ColumnTypeDateTime{}, false}, "foo DATETIME NOT NULL", false}, + {"SQLite datetime nullable", fields{DialectSQLite, "foo", true, ColumnTypeDateTime{}, false}, "foo DATETIME", false}, + + {"MySQL bool", fields{DialectMySQL, "foo", false, ColumnTypeBool{}, false}, "foo TINYINT(1) NOT NULL", false}, + {"MySQL bool nullable", fields{DialectMySQL, "foo", true, ColumnTypeBool{}, false}, "foo TINYINT(1)", false}, + {"MySQL tiny int", fields{DialectMySQL, "foo", false, ColumnTypeInt{MaxBytes: 1}, true}, "foo TINYINT NOT NULL PRIMARY KEY", false}, + {"MySQL tiny int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{MaxBytes: 1}, false}, "foo TINYINT", false}, + {"MySQL small int", fields{DialectMySQL, "foo", false, ColumnTypeInt{MaxBytes: 2}, true}, "foo SMALLINT NOT NULL PRIMARY KEY", false}, + {"MySQL small int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{MaxBytes: 2}, false}, "foo SMALLINT", false}, + {"MySQL int", fields{DialectMySQL, "foo", false, ColumnTypeInt{MaxBytes: 4}, true}, "foo INTEGER NOT NULL PRIMARY KEY", false}, + {"MySQL int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{MaxBytes: 4}, false}, "foo INTEGER", false}, + {"MySQL big int", fields{DialectMySQL, "foo", false, ColumnTypeInt{}, true}, "foo BIGINT NOT NULL PRIMARY KEY", false}, + {"MySQL big int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{}, false}, "foo BIGINT", false}, + {"MySQL char", fields{DialectMySQL, "foo", false, ColumnTypeString{IsFixedLength: true}, false}, "foo CHAR NOT NULL", false}, + {"MySQL char nullable", fields{DialectMySQL, "foo", true, ColumnTypeString{IsFixedLength: true}, false}, "foo CHAR", false}, + {"MySQL varchar", fields{DialectMySQL, "foo", false, ColumnTypeString{MaxChars: 255}, false}, "foo VARCHAR(255) NOT NULL", false}, + {"MySQL varchar nullable", fields{DialectMySQL, "foo", true, ColumnTypeString{MaxChars: 255}, false}, "foo VARCHAR(255)", false}, + {"MySQL text", fields{DialectMySQL, "foo", false, ColumnTypeString{}, false}, "foo TEXT NOT NULL", false}, + {"MySQL text nullable", fields{DialectMySQL, "foo", true, ColumnTypeString{}, false}, "foo TEXT", false}, + {"MySQL datetime", fields{DialectMySQL, "foo", false, ColumnTypeDateTime{}, false}, "foo DATETIME NOT NULL", false}, + {"MySQL datetime nullable", fields{DialectMySQL, "foo", true, ColumnTypeDateTime{}, false}, "foo DATETIME", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Column{ + Name: tt.fields.Name, + Nullable: tt.fields.Nullable, + Type: tt.fields.Type, + PrimaryKey: tt.fields.PrimaryKey, + } + if got, err := c.CreateSQL(tt.fields.Dialect); got != tt.want { + if (err != nil) != tt.wantErr { + t.Errorf("String() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("String() got = %v, want %v", got, tt.want) + } + } + }) + } +} diff --git a/db/create.go b/db/create.go index a9fad98..1bef61e 100644 --- a/db/create.go +++ b/db/create.go @@ -1,89 +1,89 @@ /* * Copyright © 2019-2020 A Bunch Tell LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package db import ( "fmt" "strings" ) type CreateTableSqlBuilder struct { Dialect DialectType Name string IfNotExists bool ColumnOrder []string Columns map[string]*Column Constraints []string } func (b *CreateTableSqlBuilder) Column(column *Column) *CreateTableSqlBuilder { if b.Columns == nil { b.Columns = make(map[string]*Column) } b.Columns[column.Name] = column b.ColumnOrder = append(b.ColumnOrder, column.Name) return b } func (b *CreateTableSqlBuilder) UniqueConstraint(columns ...string) *CreateTableSqlBuilder { for _, column := range columns { if _, ok := b.Columns[column]; !ok { // This fails silently. return b } } b.Constraints = append(b.Constraints, fmt.Sprintf("UNIQUE(%s)", strings.Join(columns, ","))) return b } func (b *CreateTableSqlBuilder) SetIfNotExists(ine bool) *CreateTableSqlBuilder { b.IfNotExists = ine return b } func (b *CreateTableSqlBuilder) ToSQL() (string, error) { var str strings.Builder str.WriteString("CREATE TABLE ") if b.IfNotExists { str.WriteString("IF NOT EXISTS ") } str.WriteString(b.Name) var things []string for _, columnName := range b.ColumnOrder { column, ok := b.Columns[columnName] if !ok { return "", fmt.Errorf("column not found: %s", columnName) } - columnStr, err := column.ToSQL(b.Dialect) + columnStr, err := column.CreateSQL(b.Dialect) if err != nil { return "", err } things = append(things, columnStr) } for _, constraint := range b.Constraints { things = append(things, constraint) } if thingLen := len(things); thingLen > 0 { str.WriteString(" ( ") for i, thing := range things { str.WriteString(thing) if i < thingLen-1 { str.WriteString(", ") } } str.WriteString(" )") } return str.String(), nil } diff --git a/db/create_test.go b/db/create_test.go index 369d5c1..09efd18 100644 --- a/db/create_test.go +++ b/db/create_test.go @@ -1,146 +1,20 @@ package db import ( "github.com/stretchr/testify/assert" "testing" ) -func TestDialect_Column(t *testing.T) { - c1 := DialectSQLite.Column("foo", ColumnTypeBool, UnsetSize) - assert.Equal(t, DialectSQLite, c1.Dialect) - c2 := DialectMySQL.Column("foo", ColumnTypeBool, UnsetSize) - assert.Equal(t, DialectMySQL, c2.Dialect) -} - -func TestColumnType_Format(t *testing.T) { - type args struct { - dialect DialectType - size OptionalInt - } - tests := []struct { - name string - d ColumnType - args args - want string - wantErr bool - }{ - {"Sqlite bool", ColumnTypeBool, args{dialect: DialectSQLite}, "INTEGER", false}, - {"Sqlite small int", ColumnTypeSmallInt, args{dialect: DialectSQLite}, "INTEGER", false}, - {"Sqlite int", ColumnTypeInteger, args{dialect: DialectSQLite}, "INTEGER", false}, - {"Sqlite char", ColumnTypeChar, args{dialect: DialectSQLite}, "TEXT", false}, - {"Sqlite varchar", ColumnTypeVarChar, args{dialect: DialectSQLite}, "TEXT", false}, - {"Sqlite text", ColumnTypeText, args{dialect: DialectSQLite}, "TEXT", false}, - {"Sqlite datetime", ColumnTypeDateTime, args{dialect: DialectSQLite}, "DATETIME", false}, - - {"MySQL bool", ColumnTypeBool, args{dialect: DialectMySQL}, "TINYINT(1)", false}, - {"MySQL small int", ColumnTypeSmallInt, args{dialect: DialectMySQL}, "SMALLINT", false}, - {"MySQL small int with param", ColumnTypeSmallInt, args{dialect: DialectMySQL, size: OptionalInt{true, 3}}, "SMALLINT(3)", false}, - {"MySQL int", ColumnTypeInteger, args{dialect: DialectMySQL}, "INT", false}, - {"MySQL int with param", ColumnTypeInteger, args{dialect: DialectMySQL, size: OptionalInt{true, 11}}, "INT(11)", false}, - {"MySQL char", ColumnTypeChar, args{dialect: DialectMySQL}, "CHAR", false}, - {"MySQL char with param", ColumnTypeChar, args{dialect: DialectMySQL, size: OptionalInt{true, 4}}, "CHAR(4)", false}, - {"MySQL varchar", ColumnTypeVarChar, args{dialect: DialectMySQL}, "VARCHAR", false}, - {"MySQL varchar with param", ColumnTypeVarChar, args{dialect: DialectMySQL, size: OptionalInt{true, 25}}, "VARCHAR(25)", false}, - {"MySQL text", ColumnTypeText, args{dialect: DialectMySQL}, "TEXT", false}, - {"MySQL datetime", ColumnTypeDateTime, args{dialect: DialectMySQL}, "DATETIME", false}, - - {"invalid column type", 10000, args{dialect: DialectMySQL}, "", true}, - {"invalid dialect", ColumnTypeBool, args{dialect: 10000}, "", true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.d.Format(tt.args.dialect, tt.args.size) - if (err != nil) != tt.wantErr { - t.Errorf("Format() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("Format() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestColumn_Build(t *testing.T) { - type fields struct { - Dialect DialectType - Name string - Nullable bool - Default OptionalString - Type ColumnType - Size OptionalInt - PrimaryKey bool - } - tests := []struct { - name string - fields fields - want string - wantErr bool - }{ - {"Sqlite bool", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo INTEGER NOT NULL", false}, - {"Sqlite bool nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo INTEGER", false}, - {"Sqlite small int", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeSmallInt, UnsetSize, true}, "foo INTEGER NOT NULL PRIMARY KEY", false}, - {"Sqlite small int nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeSmallInt, UnsetSize, false}, "foo INTEGER", false}, - {"Sqlite int", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INTEGER NOT NULL", false}, - {"Sqlite int nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INTEGER", false}, - {"Sqlite char", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo TEXT NOT NULL", false}, - {"Sqlite char nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo TEXT", false}, - {"Sqlite varchar", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo TEXT NOT NULL", false}, - {"Sqlite varchar nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo TEXT", false}, - {"Sqlite text", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT NOT NULL", false}, - {"Sqlite text nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT", false}, - {"Sqlite datetime", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME NOT NULL", false}, - {"Sqlite datetime nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME", false}, - - {"MySQL bool", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo TINYINT(1) NOT NULL", false}, - {"MySQL bool nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo TINYINT(1)", false}, - {"MySQL small int", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeSmallInt, UnsetSize, true}, "foo SMALLINT NOT NULL PRIMARY KEY", false}, - {"MySQL small int nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeSmallInt, UnsetSize, false}, "foo SMALLINT", false}, - {"MySQL int", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INT NOT NULL", false}, - {"MySQL int nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INT", false}, - {"MySQL char", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo CHAR NOT NULL", false}, - {"MySQL char nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo CHAR", false}, - {"MySQL varchar", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo VARCHAR NOT NULL", false}, - {"MySQL varchar nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo VARCHAR", false}, - {"MySQL text", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT NOT NULL", false}, - {"MySQL text nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT", false}, - {"MySQL datetime", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME NOT NULL", false}, - {"MySQL datetime nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME", false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - c := &Column{ - Dialect: tt.fields.Dialect, - Name: tt.fields.Name, - Nullable: tt.fields.Nullable, - Default: tt.fields.Default, - Type: tt.fields.Type, - Size: tt.fields.Size, - PrimaryKey: tt.fields.PrimaryKey, - } - if got, err := c.String(); got != tt.want { - if (err != nil) != tt.wantErr { - t.Errorf("String() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("String() got = %v, want %v", got, tt.want) - } - } - }) - } -} - func TestCreateTableSqlBuilder_ToSQL(t *testing.T) { sql, err := DialectMySQL. Table("foo"). SetIfNotExists(true). - Column(DialectMySQL.Column("bar", ColumnTypeInteger, UnsetSize).SetPrimaryKey(true)). - Column(DialectMySQL.Column("baz", ColumnTypeText, UnsetSize)). - Column(DialectMySQL.Column("qux", ColumnTypeDateTime, UnsetSize).SetDefault("NOW()")). + Column(PrimaryKeyColumn("bar", ColumnTypeInt{MaxBytes: 4})). + Column(NonNullableColumn("baz", ColumnTypeString{})). + Column(NonNullableColumn("qux", ColumnTypeDateTime{DefaultVal: DefaultNow})). UniqueConstraint("bar"). UniqueConstraint("bar", "baz"). ToSQL() assert.NoError(t, err) assert.Equal(t, "CREATE TABLE IF NOT EXISTS foo ( bar INT NOT NULL PRIMARY KEY, baz TEXT NOT NULL, qux DATETIME NOT NULL DEFAULT NOW(), UNIQUE(bar), UNIQUE(bar,baz) )", sql) } diff --git a/db/dialect.go b/db/dialect.go index ee1eb0f..3e2b90b 100644 --- a/db/dialect.go +++ b/db/dialect.go @@ -1,50 +1,51 @@ package db import "fmt" type DialectType int const ( - DialectSQLite DialectType = iota - DialectMySQL DialectType = iota + DialectSQLite DialectType = iota + DialectMySQL DialectType = iota + DialectPostgreSQL DialectType = iota ) func (d DialectType) IsKnown() bool { switch d { - case DialectSQLite, DialectMySQL: + case DialectSQLite, DialectMySQL, DialectPostgreSQL: return true default: return false } } func (d DialectType) AssertKnown() { if !d.IsKnown() { panic(fmt.Sprintf("unexpected dialect: %d", d)) } } func (d DialectType) Table(name string) *CreateTableSqlBuilder { d.AssertKnown() return &CreateTableSqlBuilder{Dialect: d, Name: name} } func (d DialectType) AlterTable(name string) *AlterTableSqlBuilder { d.AssertKnown() return &AlterTableSqlBuilder{Dialect: d, Name: name} } func (d DialectType) CreateUniqueIndex(name, table string, columns ...string) *CreateIndexSqlBuilder { d.AssertKnown() return &CreateIndexSqlBuilder{Dialect: d, Name: name, Table: table, Unique: true, Columns: columns} } func (d DialectType) CreateIndex(name, table string, columns ...string) *CreateIndexSqlBuilder { d.AssertKnown() return &CreateIndexSqlBuilder{Dialect: d, Name: name, Table: table, Unique: false, Columns: columns} } func (d DialectType) DropIndex(name, table string) *DropIndexSqlBuilder { d.AssertKnown() return &DropIndexSqlBuilder{Dialect: d, Name: name, Table: table} }