Page MenuHomeMusing Studio

No OneTemporary

diff --git a/db/alter.go b/db/alter.go
index 0a4ffdd..6028700 100644
--- a/db/alter.go
+++ b/db/alter.go
@@ -1,52 +1,52 @@
package db
import (
"fmt"
"strings"
)
type AlterTableSqlBuilder struct {
Dialect DialectType
Name string
Changes []string
}
func (b *AlterTableSqlBuilder) AddColumn(col *Column) *AlterTableSqlBuilder {
- if colVal, err := col.String(); err == nil {
+ if colVal, err := col.CreateSQL(b.Dialect); err == nil {
b.Changes = append(b.Changes, fmt.Sprintf("ADD COLUMN %s", colVal))
}
return b
}
func (b *AlterTableSqlBuilder) ChangeColumn(name string, col *Column) *AlterTableSqlBuilder {
- if colVal, err := col.String(); err == nil {
- b.Changes = append(b.Changes, fmt.Sprintf("CHANGE COLUMN %s %s", name, colVal))
+ if colActions, err := col.AlterSQL(b.Dialect, name); err == nil {
+ b.Changes = append(b.Changes, colActions...)
}
return b
}
func (b *AlterTableSqlBuilder) AddUniqueConstraint(name string, columns ...string) *AlterTableSqlBuilder {
b.Changes = append(b.Changes, fmt.Sprintf("ADD CONSTRAINT %s UNIQUE (%s)", name, strings.Join(columns, ", ")))
return b
}
func (b *AlterTableSqlBuilder) ToSQL() (string, error) {
var str strings.Builder
str.WriteString("ALTER TABLE ")
str.WriteString(b.Name)
str.WriteString(" ")
if len(b.Changes) == 0 {
return "", fmt.Errorf("no changes provide for table: %s", b.Name)
}
changeCount := len(b.Changes)
for i, thing := range b.Changes {
str.WriteString(thing)
if i < changeCount-1 {
str.WriteString(", ")
}
}
return str.String(), nil
}
diff --git a/db/alter_test.go b/db/alter_test.go
index 4bd58ac..4d47821 100644
--- a/db/alter_test.go
+++ b/db/alter_test.go
@@ -1,56 +1,71 @@
package db
import "testing"
func TestAlterTableSqlBuilder_ToSQL(t *testing.T) {
type fields struct {
Dialect DialectType
Name string
Changes []string
}
tests := []struct {
name string
builder *AlterTableSqlBuilder
want string
wantErr bool
}{
{
name: "MySQL add int",
builder: DialectMySQL.
AlterTable("the_table").
- AddColumn(DialectMySQL.Column("the_col", ColumnTypeInteger, UnsetSize)),
- want: "ALTER TABLE the_table ADD COLUMN the_col INT NOT NULL",
+ AddColumn(NonNullableColumn("the_col", ColumnTypeInt{MaxBytes: 4})),
+ want: "ALTER TABLE the_table ADD COLUMN the_col INTEGER NOT NULL",
wantErr: false,
},
{
name: "MySQL add string",
builder: DialectMySQL.
AlterTable("the_table").
- AddColumn(DialectMySQL.Column("the_col", ColumnTypeVarChar, OptionalInt{true, 128})),
+ AddColumn(NonNullableColumn("the_col", ColumnTypeString{MaxChars: 128})),
want: "ALTER TABLE the_table ADD COLUMN the_col VARCHAR(128) NOT NULL",
wantErr: false,
},
-
{
name: "MySQL add int and string",
builder: DialectMySQL.
AlterTable("the_table").
- AddColumn(DialectMySQL.Column("first_col", ColumnTypeInteger, UnsetSize)).
- AddColumn(DialectMySQL.Column("second_col", ColumnTypeVarChar, OptionalInt{true, 128})),
+ AddColumn(NonNullableColumn("first_col", ColumnTypeInt{MaxBytes: 4})).
+ AddColumn(NonNullableColumn("second_col", ColumnTypeString{MaxChars: 128})),
want: "ALTER TABLE the_table ADD COLUMN first_col INT NOT NULL, ADD COLUMN second_col VARCHAR(128) NOT NULL",
wantErr: false,
},
+ {
+ name: "MySQL change to string",
+ builder: DialectMySQL.
+ AlterTable("the_table").
+ ChangeColumn("old_col", NonNullableColumn("new_col", ColumnTypeString{})),
+ want: "ALTER TABLE the_table RENAME COLUMN old_col TO new_col, MODIFY COLUMN new_col TEXT NOT NULL",
+ wantErr: false,
+ },
+ {
+ name: "PostgreSQL change to int",
+ builder: DialectMySQL.
+ AlterTable("the_table").
+ ChangeColumn("old_col", NullableColumn("new_col", ColumnTypeInt{MaxBytes: 4})),
+ want: "ALTER TABLE the_table RENAME COLUMN old_col TO new_col, ALTER COLUMN new_col TYPE INTEGER, ALTER COLUMN new_col DROP NOT NULL",
+ wantErr: false,
+ },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.builder.ToSQL()
if (err != nil) != tt.wantErr {
t.Errorf("ToSQL() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ToSQL() got = %v, want %v", got, tt.want)
}
})
}
}
diff --git a/db/builder.go b/db/builder.go
new file mode 100644
index 0000000..d0f4fe4
--- /dev/null
+++ b/db/builder.go
@@ -0,0 +1,15 @@
+/*
+ * Copyright © 2019-2022 A Bunch Tell LLC.
+ *
+ * This file is part of WriteFreely.
+ *
+ * WriteFreely is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, included
+ * in the LICENSE file in this source code package.
+ */
+
+package db
+
+type SQLBuilder interface {
+ ToSQL() (string, error)
+}
diff --git a/db/column.go b/db/column.go
new file mode 100644
index 0000000..5bf64e4
--- /dev/null
+++ b/db/column.go
@@ -0,0 +1,328 @@
+/*
+ * Copyright © 2019-2022 A Bunch Tell LLC.
+ *
+ * This file is part of WriteFreely.
+ *
+ * WriteFreely is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, included
+ * in the LICENSE file in this source code package.
+ */
+
+package db
+
+import (
+ "fmt"
+ "strings"
+)
+
+type Column struct {
+ Name string
+ Type ColumnType
+ Nullable bool
+ PrimaryKey bool
+}
+
+func NullableColumn(name string, ty ColumnType) *Column {
+ return &Column{
+ Name: name,
+ Type: ty,
+ Nullable: true,
+ PrimaryKey: false,
+ }
+}
+
+func NonNullableColumn(name string, ty ColumnType) *Column {
+ return &Column{
+ Name: name,
+ Type: ty,
+ Nullable: false,
+ PrimaryKey: false,
+ }
+}
+
+func PrimaryKeyColumn(name string, ty ColumnType) *Column {
+ return &Column{
+ Name: name,
+ Type: ty,
+ Nullable: false,
+ PrimaryKey: true,
+ }
+}
+
+type ColumnType interface {
+ Name(DialectType) (string, error)
+ Default(DialectType) (string, error)
+}
+
+type ColumnTypeInt struct {
+ IsSigned bool
+ MaxBytes int
+ MaxDigits int
+ HasDefault bool
+ DefaultVal int
+}
+
+type ColumnTypeString struct {
+ IsFixedLength bool
+ MaxChars int
+ HasDefault bool
+ DefaultVal string
+}
+
+type ColumnDefault int
+
+type ColumnTypeBool struct {
+ DefaultVal ColumnDefault
+}
+
+const (
+ NoDefault ColumnDefault = iota
+ DefaultFalse ColumnDefault = iota
+ DefaultTrue ColumnDefault = iota
+ DefaultNow ColumnDefault = iota
+)
+
+type ColumnTypeDateTime struct {
+ DefaultVal ColumnDefault
+}
+
+func (intCol ColumnTypeInt) Name(d DialectType) (string, error) {
+ switch d {
+ case DialectSQLite:
+ return "INTEGER", nil
+
+ case DialectMySQL, DialectPostgreSQL:
+ var colName string
+ switch intCol.MaxBytes {
+ case 1:
+ if d == DialectMySQL {
+ colName = "TINYINT"
+ } else {
+ colName = "SMALLINT"
+ }
+ case 2:
+ colName = "SMALLINT"
+ case 3:
+ if d == DialectMySQL {
+ colName = "MEDIUMINT"
+ } else {
+ colName = "INTEGER"
+ }
+ case 4:
+ colName = "INTEGER"
+ default:
+ colName = "BIGINT"
+ }
+ if d == DialectMySQL {
+ if intCol.MaxDigits > 0 {
+ colName = fmt.Sprintf("%s(%d)", colName, intCol.MaxDigits)
+ }
+ if !intCol.IsSigned {
+ colName += " UNSIGNED"
+ }
+ }
+ return colName, nil
+
+ default:
+ return "", fmt.Errorf("dialect %d does not support integer columns", d)
+ }
+}
+
+func (intCol ColumnTypeInt) Default(d DialectType) (string, error) {
+ if intCol.HasDefault {
+ return fmt.Sprintf("%d", intCol.DefaultVal), nil
+ }
+ return "", nil
+}
+
+func (strCol ColumnTypeString) Name(d DialectType) (string, error) {
+ switch d {
+ case DialectSQLite:
+ return "TEXT", nil
+
+ case DialectMySQL, DialectPostgreSQL:
+ if strCol.IsFixedLength {
+ if strCol.MaxChars > 0 {
+ return fmt.Sprintf("CHAR(%d)", strCol.MaxChars), nil
+ }
+ return "CHAR", nil
+ }
+
+ if strCol.MaxChars <= 0 {
+ return "TEXT", nil
+ }
+ if strCol.MaxChars < (1 << 16) {
+ return fmt.Sprintf("VARCHAR(%d)", strCol.MaxChars), nil
+ }
+ return "TEXT", nil
+
+ default:
+ return "", fmt.Errorf("dialect %d does not support string columns", d)
+ }
+}
+
+func (strCol ColumnTypeString) Default(d DialectType) (string, error) {
+ if strCol.HasDefault {
+ return EscapeSimple.SQLEscape(d, strCol.DefaultVal)
+ }
+ return "", nil
+}
+
+func (boolCol ColumnTypeBool) Name(d DialectType) (string, error) {
+ switch d {
+ case DialectSQLite:
+ return "INTEGER", nil
+ case DialectMySQL, DialectPostgreSQL:
+ return "BOOL", nil
+ default:
+ return "", fmt.Errorf("boolean column type not supported for dialect %d", d)
+ }
+}
+
+func (boolCol ColumnTypeBool) Default(d DialectType) (string, error) {
+ switch boolCol.DefaultVal {
+ case NoDefault:
+ return "", nil
+ case DefaultFalse:
+ return "0", nil
+ case DefaultTrue:
+ return "1", nil
+ default:
+ return "", fmt.Errorf("boolean columns cannot default to %d for dialect %d", boolCol.DefaultVal, d)
+ }
+}
+
+func (dateTimeCol ColumnTypeDateTime) Name(d DialectType) (string, error) {
+ switch d {
+ case DialectSQLite, DialectMySQL:
+ return "DATETIME", nil
+ case DialectPostgreSQL:
+ return "TIMESTAMP", nil
+ default:
+ return "", fmt.Errorf("datetime column type not supported for dialect %d", d)
+ }
+}
+
+func (dateTimeCol ColumnTypeDateTime) Default(d DialectType) (string, error) {
+ switch d {
+ case DialectSQLite, DialectMySQL:
+ switch dateTimeCol.DefaultVal {
+ case NoDefault:
+ return "", nil
+ case DefaultNow:
+ switch d {
+ case DialectSQLite, DialectPostgreSQL:
+ return "CURRENT_TIMESTAMP", nil
+ case DialectMySQL:
+ return "NOW()", nil
+ }
+ }
+ return "", fmt.Errorf("datetime columns cannot default to %d for dialect %d", dateTimeCol.DefaultVal, d)
+ default:
+ return "", fmt.Errorf("dialect %d does not support defaulted datetime columns", d)
+ }
+}
+
+func (c *Column) SetName(name string) *Column {
+ c.Name = name
+ return c
+}
+
+func (c *Column) SetNullable(nullable bool) *Column {
+ c.Nullable = nullable
+ return c
+}
+
+func (c *Column) SetPrimaryKey(pk bool) *Column {
+ c.PrimaryKey = pk
+ return c
+}
+
+func (c *Column) SetType(t ColumnType) *Column {
+ c.Type = t
+ return c
+}
+
+func (c *Column) AlterSQL(d DialectType, oldName string) ([]string, error) {
+ var actions []string = make([]string, 0)
+
+ switch d {
+ // MySQL does all modifications at once
+ case DialectMySQL:
+ sql, err := c.CreateSQL(d)
+ if err != nil {
+ return make([]string, 0), err
+ }
+ actions = append(actions, fmt.Sprintf("CHANGE COLUMN %s %s", oldName, sql))
+
+ // PostgreSQL does modifications piece by piece
+ case DialectPostgreSQL:
+ if oldName != c.Name {
+ actions = append(actions, fmt.Sprintf("RENAME COLUMN %s TO %s", oldName, c.Name))
+ }
+
+ typeStr, err := c.Type.Name(d)
+ if err != nil {
+ return make([]string, 0), err
+ }
+
+ actions = append(actions, fmt.Sprintf("ALTER COLUMN %s TYPE %s", c.Name, typeStr))
+ var nullAction string
+ if c.Nullable {
+ nullAction = "DROP"
+ } else {
+ nullAction = "SET"
+ }
+ actions = append(actions, fmt.Sprintf("ALTER COLUMN %s %s NOT NULL", c.Name, nullAction))
+
+ defaultStr, err := c.Type.Default(d)
+ if err != nil {
+ return make([]string, 0), err
+ }
+ if len(defaultStr) > 0 {
+ actions = append(actions, fmt.Sprintf("ALTER COLUMN %s SET DEFAULT %s", c.Name, defaultStr))
+ }
+
+ if c.PrimaryKey {
+ actions = append(actions, fmt.Sprintf("ADD PRIMARY KEY (%s)", c.Name))
+ }
+
+ default:
+ return make([]string, 0), fmt.Errorf("dialect %d doesn't support altering column data type", d)
+ }
+
+ return actions, nil
+}
+
+func (c *Column) CreateSQL(d DialectType) (string, error) {
+ var str strings.Builder
+
+ str.WriteString(c.Name)
+
+ str.WriteString(" ")
+ typeStr, err := c.Type.Name(d)
+ if err != nil {
+ return "", err
+ }
+
+ str.WriteString(typeStr)
+
+ if !c.Nullable {
+ str.WriteString(" NOT NULL")
+ }
+
+ defaultStr, err := c.Type.Default(d)
+ if err != nil {
+ return "", err
+ }
+ if len(defaultStr) > 0 {
+ str.WriteString(" DEFAULT ")
+ str.WriteString(defaultStr)
+ }
+
+ if c.PrimaryKey {
+ str.WriteString(" PRIMARY KEY")
+ }
+
+ return str.String(), nil
+}
diff --git a/db/column_test.go b/db/column_test.go
new file mode 100644
index 0000000..175ec3a
--- /dev/null
+++ b/db/column_test.go
@@ -0,0 +1,151 @@
+package db
+
+import (
+ "github.com/stretchr/testify/assert"
+ "testing"
+)
+
+func TestColumnType_Name(t *testing.T) {
+ tests := []struct {
+ name string
+ ty ColumnType
+ d DialectType
+ want string
+ wantErr bool
+ }{
+ {"SQLite bool", ColumnTypeBool{}, DialectSQLite, "INTEGER", false},
+ {"SQLite int", ColumnTypeInt{}, DialectSQLite, "INTEGER", false},
+ {"SQLite string ", ColumnTypeString{HasDefault: true, DefaultVal: "that's a default"}, DialectSQLite, "TEXT DEFAULT 'that''s a default'", false},
+ {"SQLite datetime", ColumnTypeDateTime{}, DialectSQLite, "DATETIME", false},
+
+ {"MySQL bool", ColumnTypeBool{}, DialectMySQL, "BOOL", false},
+ {"MySQL tiny int", ColumnTypeInt{MaxBytes: 1}, DialectMySQL, "TINYINT", false},
+ {"MySQL tiny int with digits", ColumnTypeInt{MaxBytes: 1, MaxDigits: 2}, DialectMySQL, "TINYINT(2)", false},
+ {"MySQL small int", ColumnTypeInt{MaxBytes: 2}, DialectMySQL, "SMALLINT", false},
+ {"MySQL small int with digits", ColumnTypeInt{MaxBytes: 2, MaxDigits: 3}, DialectMySQL, "SMALLINT(3)", false},
+ {"MySQL medium int", ColumnTypeInt{MaxBytes: 3}, DialectMySQL, "MEDIUMINT", false},
+ {"MySQL medium int with digits", ColumnTypeInt{MaxBytes: 3, MaxDigits: 6}, DialectMySQL, "MEDIUMINT(6)", false},
+ {"MySQL int", ColumnTypeInt{MaxBytes: 4}, DialectMySQL, "INTEGER", false},
+ {"MySQL int with digits", ColumnTypeInt{MaxBytes: 4, MaxDigits: 11}, DialectMySQL, "INTEGER(11)", false},
+ {"MySQL bigint", ColumnTypeInt{MaxBytes: 4}, DialectMySQL, "BIGINT", false},
+ {"MySQL bigint with digits", ColumnTypeInt{MaxBytes: 4, MaxDigits: 15}, DialectMySQL, "BIGINT(15)", false},
+ {"MySQL char", ColumnTypeString{IsFixedLength: true}, DialectMySQL, "CHAR", false},
+ {"MySQL char with length", ColumnTypeString{IsFixedLength: true, MaxChars: 4}, DialectMySQL, "CHAR(4)", false},
+ {"MySQL varchar with length", ColumnTypeString{MaxChars: 25}, DialectMySQL, "VARCHAR(25)", false},
+ {"MySQL text", ColumnTypeString{}, DialectMySQL, "TEXT", false},
+ {"MySQL datetime", ColumnTypeDateTime{}, DialectMySQL, "DATETIME", false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := tt.ty.Name(tt.d)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("Name() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if got != tt.want {
+ t.Errorf("Name() got = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestColumnType_Default(t *testing.T) {
+ tests := []struct {
+ name string
+ ty ColumnType
+ d DialectType
+ want string
+ wantErr bool
+ }{
+ {"SQLite bool none", ColumnTypeBool{}, DialectSQLite, "", false},
+ {"SQLite bool false", ColumnTypeBool{}, DialectSQLite, "0", false},
+ {"SQLite bool true", ColumnTypeBool{}, DialectSQLite, "1", false},
+ {"SQLite int none", ColumnTypeInt{}, DialectSQLite, "", false},
+ {"SQLite int empty", ColumnTypeInt{HasDefault: true}, DialectSQLite, "0", false},
+ {"SQLite int", ColumnTypeInt{HasDefault: true, DefaultVal: 10}, DialectSQLite, "10", false},
+ {"SQLite string none", ColumnTypeString{}, DialectSQLite, "", false},
+ {"SQLite string empty", ColumnTypeString{HasDefault: true}, DialectSQLite, "''", false},
+ {"SQLite string", ColumnTypeString{HasDefault: true, DefaultVal: "that's a default"}, DialectSQLite, "'that''s a default'", false},
+ {"MySQL string", ColumnTypeString{HasDefault: true, DefaultVal: "%that's a default%"}, DialectMySQL, "'%that\\'s a default%'", false},
+
+ {"SQLite datetime none", ColumnTypeDateTime{}, DialectSQLite, "", false},
+ {"SQLite datetime now", ColumnTypeDateTime{DefaultVal: DefaultNow}, DialectSQLite, "CURRENT_TIMESTAMP", false},
+ {"MySQL datetime now", ColumnTypeDateTime{DefaultVal: DefaultNow}, DialectMySQL, "NOW()", false},
+ {"PostgreSQL datetime now", ColumnTypeDateTime{DefaultVal: DefaultNow}, DialectPostgreSQL, "CURRENT_TIMESTAMP", false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := tt.ty.Default(tt.d)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("Default() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if got != tt.want {
+ t.Errorf("Default() got = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestColumn_CreateSQL(t *testing.T) {
+ type fields struct {
+ Dialect DialectType
+ Name string
+ Nullable bool
+ Type ColumnType
+ PrimaryKey bool
+ }
+ tests := []struct {
+ name string
+ fields fields
+ want string
+ wantErr bool
+ }{
+ {"SQLite bool", fields{DialectSQLite, "foo", false, ColumnTypeBool{}, false}, "foo INTEGER NOT NULL", false},
+ {"SQLite bool nullable", fields{DialectSQLite, "foo", true, ColumnTypeBool{}, false}, "foo INTEGER", false},
+ {"SQLite int", fields{DialectSQLite, "foo", false, ColumnTypeInt{}, true}, "foo INTEGER NOT NULL PRIMARY KEY", false},
+ {"SQLite int nullable", fields{DialectSQLite, "foo", true, ColumnTypeInt{}, false}, "foo INTEGER", false},
+ {"SQLite text", fields{DialectSQLite, "foo", false, ColumnTypeString{}, false}, "foo TEXT NOT NULL", false},
+ {"SQLite text nullable", fields{DialectSQLite, "foo", true, ColumnTypeString{}, false}, "foo TEXT", false},
+ {"SQLite datetime", fields{DialectSQLite, "foo", false, ColumnTypeDateTime{}, false}, "foo DATETIME NOT NULL", false},
+ {"SQLite datetime nullable", fields{DialectSQLite, "foo", true, ColumnTypeDateTime{}, false}, "foo DATETIME", false},
+
+ {"MySQL bool", fields{DialectMySQL, "foo", false, ColumnTypeBool{}, false}, "foo TINYINT(1) NOT NULL", false},
+ {"MySQL bool nullable", fields{DialectMySQL, "foo", true, ColumnTypeBool{}, false}, "foo TINYINT(1)", false},
+ {"MySQL tiny int", fields{DialectMySQL, "foo", false, ColumnTypeInt{MaxBytes: 1}, true}, "foo TINYINT NOT NULL PRIMARY KEY", false},
+ {"MySQL tiny int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{MaxBytes: 1}, false}, "foo TINYINT", false},
+ {"MySQL small int", fields{DialectMySQL, "foo", false, ColumnTypeInt{MaxBytes: 2}, true}, "foo SMALLINT NOT NULL PRIMARY KEY", false},
+ {"MySQL small int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{MaxBytes: 2}, false}, "foo SMALLINT", false},
+ {"MySQL int", fields{DialectMySQL, "foo", false, ColumnTypeInt{MaxBytes: 4}, true}, "foo INTEGER NOT NULL PRIMARY KEY", false},
+ {"MySQL int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{MaxBytes: 4}, false}, "foo INTEGER", false},
+ {"MySQL big int", fields{DialectMySQL, "foo", false, ColumnTypeInt{}, true}, "foo BIGINT NOT NULL PRIMARY KEY", false},
+ {"MySQL big int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{}, false}, "foo BIGINT", false},
+ {"MySQL char", fields{DialectMySQL, "foo", false, ColumnTypeString{IsFixedLength: true}, false}, "foo CHAR NOT NULL", false},
+ {"MySQL char nullable", fields{DialectMySQL, "foo", true, ColumnTypeString{IsFixedLength: true}, false}, "foo CHAR", false},
+ {"MySQL varchar", fields{DialectMySQL, "foo", false, ColumnTypeString{MaxChars: 255}, false}, "foo VARCHAR(255) NOT NULL", false},
+ {"MySQL varchar nullable", fields{DialectMySQL, "foo", true, ColumnTypeString{MaxChars: 255}, false}, "foo VARCHAR(255)", false},
+ {"MySQL text", fields{DialectMySQL, "foo", false, ColumnTypeString{}, false}, "foo TEXT NOT NULL", false},
+ {"MySQL text nullable", fields{DialectMySQL, "foo", true, ColumnTypeString{}, false}, "foo TEXT", false},
+ {"MySQL datetime", fields{DialectMySQL, "foo", false, ColumnTypeDateTime{}, false}, "foo DATETIME NOT NULL", false},
+ {"MySQL datetime nullable", fields{DialectMySQL, "foo", true, ColumnTypeDateTime{}, false}, "foo DATETIME", false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ c := &Column{
+ Name: tt.fields.Name,
+ Nullable: tt.fields.Nullable,
+ Type: tt.fields.Type,
+ PrimaryKey: tt.fields.PrimaryKey,
+ }
+ if got, err := c.CreateSQL(tt.fields.Dialect); got != tt.want {
+ if (err != nil) != tt.wantErr {
+ t.Errorf("String() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if got != tt.want {
+ t.Errorf("String() got = %v, want %v", got, tt.want)
+ }
+ }
+ })
+ }
+}
diff --git a/db/create.go b/db/create.go
index 8728d5d..9462bbe 100644
--- a/db/create.go
+++ b/db/create.go
@@ -1,266 +1,89 @@
/*
* Copyright © 2019-2020 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package db
import (
"fmt"
"strings"
)
-type ColumnType int
-
-type OptionalInt struct {
- Set bool
- Value int
-}
-
-type OptionalString struct {
- Set bool
- Value string
-}
-
-type SQLBuilder interface {
- ToSQL() (string, error)
-}
-
-type Column struct {
- Dialect DialectType
- Name string
- Nullable bool
- Default OptionalString
- Type ColumnType
- Size OptionalInt
- PrimaryKey bool
-}
-
type CreateTableSqlBuilder struct {
Dialect DialectType
Name string
IfNotExists bool
ColumnOrder []string
Columns map[string]*Column
Constraints []string
}
-const (
- ColumnTypeBool ColumnType = iota
- ColumnTypeSmallInt ColumnType = iota
- ColumnTypeInteger ColumnType = iota
- ColumnTypeChar ColumnType = iota
- ColumnTypeVarChar ColumnType = iota
- ColumnTypeText ColumnType = iota
- ColumnTypeDateTime ColumnType = iota
-)
-
-var _ SQLBuilder = &CreateTableSqlBuilder{}
-
-var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0}
-var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""}
-
-func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) {
- if dialect != DialectMySQL && dialect != DialectSQLite {
- return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
- }
- switch d {
- case ColumnTypeSmallInt:
- {
- if dialect == DialectSQLite {
- return "INTEGER", nil
- }
- mod := ""
- if size.Set {
- mod = fmt.Sprintf("(%d)", size.Value)
- }
- return "SMALLINT" + mod, nil
- }
- case ColumnTypeInteger:
- {
- if dialect == DialectSQLite {
- return "INTEGER", nil
- }
- mod := ""
- if size.Set {
- mod = fmt.Sprintf("(%d)", size.Value)
- }
- return "INT" + mod, nil
- }
- case ColumnTypeChar:
- {
- if dialect == DialectSQLite {
- return "TEXT", nil
- }
- mod := ""
- if size.Set {
- mod = fmt.Sprintf("(%d)", size.Value)
- }
- return "CHAR" + mod, nil
- }
- case ColumnTypeVarChar:
- {
- if dialect == DialectSQLite {
- return "TEXT", nil
- }
- mod := ""
- if size.Set {
- mod = fmt.Sprintf("(%d)", size.Value)
- }
- return "VARCHAR" + mod, nil
- }
- case ColumnTypeBool:
- {
- if dialect == DialectSQLite {
- return "INTEGER", nil
- }
- return "TINYINT(1)", nil
- }
- case ColumnTypeDateTime:
- return "DATETIME", nil
- case ColumnTypeText:
- return "TEXT", nil
- }
- return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
-}
-
-func (c *Column) SetName(name string) *Column {
- c.Name = name
- return c
-}
-
-func (c *Column) SetNullable(nullable bool) *Column {
- c.Nullable = nullable
- return c
-}
-
-func (c *Column) SetPrimaryKey(pk bool) *Column {
- c.PrimaryKey = pk
- return c
-}
-
-func (c *Column) SetDefault(value string) *Column {
- c.Default = OptionalString{Set: true, Value: value}
- return c
-}
-
-func (c *Column) SetDefaultCurrentTimestamp() *Column {
- def := "NOW()"
- if c.Dialect == DialectSQLite {
- def = "CURRENT_TIMESTAMP"
- }
- c.Default = OptionalString{Set: true, Value: def}
- return c
-}
-
-func (c *Column) SetType(t ColumnType) *Column {
- c.Type = t
- return c
-}
-
-func (c *Column) SetSize(size int) *Column {
- c.Size = OptionalInt{Set: true, Value: size}
- return c
-}
-
-func (c *Column) String() (string, error) {
- var str strings.Builder
-
- str.WriteString(c.Name)
-
- str.WriteString(" ")
- typeStr, err := c.Type.Format(c.Dialect, c.Size)
- if err != nil {
- return "", err
- }
-
- str.WriteString(typeStr)
-
- if !c.Nullable {
- str.WriteString(" NOT NULL")
- }
-
- if c.Default.Set {
- str.WriteString(" DEFAULT ")
- val := c.Default.Value
- if val == "" {
- val = "''"
- }
- str.WriteString(val)
- }
-
- if c.PrimaryKey {
- str.WriteString(" PRIMARY KEY")
- }
-
- return str.String(), nil
-}
-
func (b *CreateTableSqlBuilder) Column(column *Column) *CreateTableSqlBuilder {
if b.Columns == nil {
b.Columns = make(map[string]*Column)
}
b.Columns[column.Name] = column
b.ColumnOrder = append(b.ColumnOrder, column.Name)
return b
}
func (b *CreateTableSqlBuilder) UniqueConstraint(columns ...string) *CreateTableSqlBuilder {
for _, column := range columns {
if _, ok := b.Columns[column]; !ok {
// This fails silently.
return b
}
}
b.Constraints = append(b.Constraints, fmt.Sprintf("UNIQUE(%s)", strings.Join(columns, ",")))
return b
}
func (b *CreateTableSqlBuilder) SetIfNotExists(ine bool) *CreateTableSqlBuilder {
b.IfNotExists = ine
return b
}
func (b *CreateTableSqlBuilder) ToSQL() (string, error) {
var str strings.Builder
str.WriteString("CREATE TABLE ")
if b.IfNotExists {
str.WriteString("IF NOT EXISTS ")
}
str.WriteString(b.Name)
var things []string
for _, columnName := range b.ColumnOrder {
column, ok := b.Columns[columnName]
if !ok {
return "", fmt.Errorf("column not found: %s", columnName)
}
- columnStr, err := column.String()
+ columnStr, err := column.CreateSQL(b.Dialect)
if err != nil {
return "", err
}
things = append(things, columnStr)
}
for _, constraint := range b.Constraints {
things = append(things, constraint)
}
if thingLen := len(things); thingLen > 0 {
str.WriteString(" ( ")
for i, thing := range things {
str.WriteString(thing)
if i < thingLen-1 {
str.WriteString(", ")
}
}
str.WriteString(" )")
}
return str.String(), nil
}
diff --git a/db/create_test.go b/db/create_test.go
index 369d5c1..09efd18 100644
--- a/db/create_test.go
+++ b/db/create_test.go
@@ -1,146 +1,20 @@
package db
import (
"github.com/stretchr/testify/assert"
"testing"
)
-func TestDialect_Column(t *testing.T) {
- c1 := DialectSQLite.Column("foo", ColumnTypeBool, UnsetSize)
- assert.Equal(t, DialectSQLite, c1.Dialect)
- c2 := DialectMySQL.Column("foo", ColumnTypeBool, UnsetSize)
- assert.Equal(t, DialectMySQL, c2.Dialect)
-}
-
-func TestColumnType_Format(t *testing.T) {
- type args struct {
- dialect DialectType
- size OptionalInt
- }
- tests := []struct {
- name string
- d ColumnType
- args args
- want string
- wantErr bool
- }{
- {"Sqlite bool", ColumnTypeBool, args{dialect: DialectSQLite}, "INTEGER", false},
- {"Sqlite small int", ColumnTypeSmallInt, args{dialect: DialectSQLite}, "INTEGER", false},
- {"Sqlite int", ColumnTypeInteger, args{dialect: DialectSQLite}, "INTEGER", false},
- {"Sqlite char", ColumnTypeChar, args{dialect: DialectSQLite}, "TEXT", false},
- {"Sqlite varchar", ColumnTypeVarChar, args{dialect: DialectSQLite}, "TEXT", false},
- {"Sqlite text", ColumnTypeText, args{dialect: DialectSQLite}, "TEXT", false},
- {"Sqlite datetime", ColumnTypeDateTime, args{dialect: DialectSQLite}, "DATETIME", false},
-
- {"MySQL bool", ColumnTypeBool, args{dialect: DialectMySQL}, "TINYINT(1)", false},
- {"MySQL small int", ColumnTypeSmallInt, args{dialect: DialectMySQL}, "SMALLINT", false},
- {"MySQL small int with param", ColumnTypeSmallInt, args{dialect: DialectMySQL, size: OptionalInt{true, 3}}, "SMALLINT(3)", false},
- {"MySQL int", ColumnTypeInteger, args{dialect: DialectMySQL}, "INT", false},
- {"MySQL int with param", ColumnTypeInteger, args{dialect: DialectMySQL, size: OptionalInt{true, 11}}, "INT(11)", false},
- {"MySQL char", ColumnTypeChar, args{dialect: DialectMySQL}, "CHAR", false},
- {"MySQL char with param", ColumnTypeChar, args{dialect: DialectMySQL, size: OptionalInt{true, 4}}, "CHAR(4)", false},
- {"MySQL varchar", ColumnTypeVarChar, args{dialect: DialectMySQL}, "VARCHAR", false},
- {"MySQL varchar with param", ColumnTypeVarChar, args{dialect: DialectMySQL, size: OptionalInt{true, 25}}, "VARCHAR(25)", false},
- {"MySQL text", ColumnTypeText, args{dialect: DialectMySQL}, "TEXT", false},
- {"MySQL datetime", ColumnTypeDateTime, args{dialect: DialectMySQL}, "DATETIME", false},
-
- {"invalid column type", 10000, args{dialect: DialectMySQL}, "", true},
- {"invalid dialect", ColumnTypeBool, args{dialect: 10000}, "", true},
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := tt.d.Format(tt.args.dialect, tt.args.size)
- if (err != nil) != tt.wantErr {
- t.Errorf("Format() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if got != tt.want {
- t.Errorf("Format() got = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func TestColumn_Build(t *testing.T) {
- type fields struct {
- Dialect DialectType
- Name string
- Nullable bool
- Default OptionalString
- Type ColumnType
- Size OptionalInt
- PrimaryKey bool
- }
- tests := []struct {
- name string
- fields fields
- want string
- wantErr bool
- }{
- {"Sqlite bool", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo INTEGER NOT NULL", false},
- {"Sqlite bool nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo INTEGER", false},
- {"Sqlite small int", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeSmallInt, UnsetSize, true}, "foo INTEGER NOT NULL PRIMARY KEY", false},
- {"Sqlite small int nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeSmallInt, UnsetSize, false}, "foo INTEGER", false},
- {"Sqlite int", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INTEGER NOT NULL", false},
- {"Sqlite int nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INTEGER", false},
- {"Sqlite char", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo TEXT NOT NULL", false},
- {"Sqlite char nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo TEXT", false},
- {"Sqlite varchar", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo TEXT NOT NULL", false},
- {"Sqlite varchar nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo TEXT", false},
- {"Sqlite text", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT NOT NULL", false},
- {"Sqlite text nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT", false},
- {"Sqlite datetime", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME NOT NULL", false},
- {"Sqlite datetime nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME", false},
-
- {"MySQL bool", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo TINYINT(1) NOT NULL", false},
- {"MySQL bool nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo TINYINT(1)", false},
- {"MySQL small int", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeSmallInt, UnsetSize, true}, "foo SMALLINT NOT NULL PRIMARY KEY", false},
- {"MySQL small int nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeSmallInt, UnsetSize, false}, "foo SMALLINT", false},
- {"MySQL int", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INT NOT NULL", false},
- {"MySQL int nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INT", false},
- {"MySQL char", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo CHAR NOT NULL", false},
- {"MySQL char nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo CHAR", false},
- {"MySQL varchar", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo VARCHAR NOT NULL", false},
- {"MySQL varchar nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo VARCHAR", false},
- {"MySQL text", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT NOT NULL", false},
- {"MySQL text nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT", false},
- {"MySQL datetime", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME NOT NULL", false},
- {"MySQL datetime nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME", false},
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- c := &Column{
- Dialect: tt.fields.Dialect,
- Name: tt.fields.Name,
- Nullable: tt.fields.Nullable,
- Default: tt.fields.Default,
- Type: tt.fields.Type,
- Size: tt.fields.Size,
- PrimaryKey: tt.fields.PrimaryKey,
- }
- if got, err := c.String(); got != tt.want {
- if (err != nil) != tt.wantErr {
- t.Errorf("String() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if got != tt.want {
- t.Errorf("String() got = %v, want %v", got, tt.want)
- }
- }
- })
- }
-}
-
func TestCreateTableSqlBuilder_ToSQL(t *testing.T) {
sql, err := DialectMySQL.
Table("foo").
SetIfNotExists(true).
- Column(DialectMySQL.Column("bar", ColumnTypeInteger, UnsetSize).SetPrimaryKey(true)).
- Column(DialectMySQL.Column("baz", ColumnTypeText, UnsetSize)).
- Column(DialectMySQL.Column("qux", ColumnTypeDateTime, UnsetSize).SetDefault("NOW()")).
+ Column(PrimaryKeyColumn("bar", ColumnTypeInt{MaxBytes: 4})).
+ Column(NonNullableColumn("baz", ColumnTypeString{})).
+ Column(NonNullableColumn("qux", ColumnTypeDateTime{DefaultVal: DefaultNow})).
UniqueConstraint("bar").
UniqueConstraint("bar", "baz").
ToSQL()
assert.NoError(t, err)
assert.Equal(t, "CREATE TABLE IF NOT EXISTS foo ( bar INT NOT NULL PRIMARY KEY, baz TEXT NOT NULL, qux DATETIME NOT NULL DEFAULT NOW(), UNIQUE(bar), UNIQUE(bar,baz) )", sql)
}
diff --git a/db/dialect.go b/db/dialect.go
index 4251465..3e2b90b 100644
--- a/db/dialect.go
+++ b/db/dialect.go
@@ -1,76 +1,51 @@
package db
import "fmt"
type DialectType int
const (
- DialectSQLite DialectType = iota
- DialectMySQL DialectType = iota
+ DialectSQLite DialectType = iota
+ DialectMySQL DialectType = iota
+ DialectPostgreSQL DialectType = iota
)
-func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column {
+func (d DialectType) IsKnown() bool {
switch d {
- case DialectSQLite:
- return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size}
- case DialectMySQL:
- return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size}
+ case DialectSQLite, DialectMySQL, DialectPostgreSQL:
+ return true
default:
- panic(fmt.Sprintf("unexpected dialect: %d", d))
+ return false
}
}
-func (d DialectType) Table(name string) *CreateTableSqlBuilder {
- switch d {
- case DialectSQLite:
- return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name}
- case DialectMySQL:
- return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name}
- default:
+func (d DialectType) AssertKnown() {
+ if !d.IsKnown() {
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
+func (d DialectType) Table(name string) *CreateTableSqlBuilder {
+ d.AssertKnown()
+ return &CreateTableSqlBuilder{Dialect: d, Name: name}
+}
+
func (d DialectType) AlterTable(name string) *AlterTableSqlBuilder {
- switch d {
- case DialectSQLite:
- return &AlterTableSqlBuilder{Dialect: DialectSQLite, Name: name}
- case DialectMySQL:
- return &AlterTableSqlBuilder{Dialect: DialectMySQL, Name: name}
- default:
- panic(fmt.Sprintf("unexpected dialect: %d", d))
- }
+ d.AssertKnown()
+ return &AlterTableSqlBuilder{Dialect: d, Name: name}
}
func (d DialectType) CreateUniqueIndex(name, table string, columns ...string) *CreateIndexSqlBuilder {
- switch d {
- case DialectSQLite:
- return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: true, Columns: columns}
- case DialectMySQL:
- return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: true, Columns: columns}
- default:
- panic(fmt.Sprintf("unexpected dialect: %d", d))
- }
+ d.AssertKnown()
+ return &CreateIndexSqlBuilder{Dialect: d, Name: name, Table: table, Unique: true, Columns: columns}
}
func (d DialectType) CreateIndex(name, table string, columns ...string) *CreateIndexSqlBuilder {
- switch d {
- case DialectSQLite:
- return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: false, Columns: columns}
- case DialectMySQL:
- return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: false, Columns: columns}
- default:
- panic(fmt.Sprintf("unexpected dialect: %d", d))
- }
+ d.AssertKnown()
+ return &CreateIndexSqlBuilder{Dialect: d, Name: name, Table: table, Unique: false, Columns: columns}
}
func (d DialectType) DropIndex(name, table string) *DropIndexSqlBuilder {
- switch d {
- case DialectSQLite:
- return &DropIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table}
- case DialectMySQL:
- return &DropIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table}
- default:
- panic(fmt.Sprintf("unexpected dialect: %d", d))
- }
+ d.AssertKnown()
+ return &DropIndexSqlBuilder{Dialect: d, Name: name, Table: table}
}
diff --git a/db/escape.go b/db/escape.go
new file mode 100644
index 0000000..53b8ef3
--- /dev/null
+++ b/db/escape.go
@@ -0,0 +1,63 @@
+/*
+ * Copyright © 2019-2022 A Bunch Tell LLC.
+ *
+ * This file is part of WriteFreely.
+ *
+ * WriteFreely is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, included
+ * in the LICENSE file in this source code package.
+ */
+
+package db
+
+import (
+ "strings"
+)
+
+type EscapeContext int
+
+const (
+ EscapeSimple EscapeContext = iota
+)
+
+func (_ EscapeContext) SQLEscape(d DialectType, s string) (string, error) {
+ builder := strings.Builder{}
+ switch d {
+ case DialectSQLite:
+ builder.WriteRune('\'')
+ for _, c := range s {
+ if c == '\'' {
+ builder.WriteString("''")
+ } else {
+ builder.WriteRune(c)
+ }
+ }
+ builder.WriteRune('\'')
+ case DialectMySQL:
+ builder.WriteRune('\'')
+ for _, c := range s {
+ switch c {
+ case 0:
+ builder.WriteString("\\0")
+ case '\'':
+ builder.WriteString("\\'")
+ case '"':
+ builder.WriteString("\\\"")
+ case '\b':
+ builder.WriteString("\\b")
+ case '\n':
+ builder.WriteString("\\n")
+ case '\r':
+ builder.WriteString("\\r")
+ case '\t':
+ builder.WriteString("\\t")
+ case '\\':
+ builder.WriteString("\\\\")
+ default:
+ builder.WriteRune(c)
+ }
+ }
+ builder.WriteRune('\'')
+ }
+ return builder.String(), nil
+}
diff --git a/migrations/v4.go b/migrations/v4.go
index 4ae267d..6d0c9f2 100644
--- a/migrations/v4.go
+++ b/migrations/v4.go
@@ -1,54 +1,54 @@
/*
* Copyright © 2019-2021 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package migrations
import (
"context"
"database/sql"
wf_db "github.com/writefreely/writefreely/db"
)
func oauth(db *datastore) error {
dialect := wf_db.DialectMySQL
if db.driverName == driverSQLite {
dialect = wf_db.DialectSQLite
}
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
createTableUsersOauth, err := dialect.
Table("oauth_users").
SetIfNotExists(false).
- Column(dialect.Column("user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)).
- Column(dialect.Column("remote_user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)).
+ Column(wf_db.NonNullableColumn("user_id", wf_db.ColumnTypeInt{MaxBytes: 4})).
+ Column(wf_db.NonNullableColumn("remote_user_id", wf_db.ColumnTypeInt{MaxBytes: 4})).
ToSQL()
if err != nil {
return err
}
createTableOauthClientState, err := dialect.
Table("oauth_client_states").
SetIfNotExists(false).
- Column(dialect.Column("state", wf_db.ColumnTypeVarChar, wf_db.OptionalInt{Set: true, Value: 255})).
- Column(dialect.Column("used", wf_db.ColumnTypeBool, wf_db.UnsetSize)).
- Column(dialect.Column("created_at", wf_db.ColumnTypeDateTime, wf_db.UnsetSize).SetDefaultCurrentTimestamp()).
+ Column(wf_db.NonNullableColumn("state", wf_db.ColumnTypeString{MaxChars: 255})).
+ Column(wf_db.NonNullableColumn("used", wf_db.ColumnTypeBool{})).
+ Column(wf_db.NonNullableColumn("created_at", wf_db.ColumnTypeDateTime{DefaultVal: wf_db.DefaultNow})).
UniqueConstraint("state").
ToSQL()
if err != nil {
return err
}
for _, table := range []string{createTableUsersOauth, createTableOauthClientState} {
if _, err := tx.ExecContext(ctx, table); err != nil {
return err
}
}
return nil
})
}
diff --git a/migrations/v5.go b/migrations/v5.go
index db18fa1..4508b18 100644
--- a/migrations/v5.go
+++ b/migrations/v5.go
@@ -1,88 +1,105 @@
/*
* Copyright © 2019-2021 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package migrations
import (
"context"
"database/sql"
wf_db "github.com/writefreely/writefreely/db"
)
func oauthSlack(db *datastore) error {
dialect := wf_db.DialectMySQL
if db.driverName == driverSQLite {
dialect = wf_db.DialectSQLite
}
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
builders := []wf_db.SQLBuilder{
dialect.
AlterTable("oauth_client_states").
- AddColumn(dialect.
- Column(
+ AddColumn(wf_db.
+ NonNullableColumn(
"provider",
- wf_db.ColumnTypeVarChar,
- wf_db.OptionalInt{Set: true, Value: 24}).SetDefault("")),
+ wf_db.ColumnTypeString{
+ MaxChars: 24,
+ HasDefault: true,
+ DefaultVal: "",
+ })),
dialect.
AlterTable("oauth_client_states").
- AddColumn(dialect.
- Column(
+ AddColumn(wf_db.
+ NonNullableColumn(
"client_id",
- wf_db.ColumnTypeVarChar,
- wf_db.OptionalInt{Set: true, Value: 128}).SetDefault("")),
+ wf_db.ColumnTypeString{
+ MaxChars: 128,
+ HasDefault: true,
+ DefaultVal: "",
+ },
+ )),
dialect.
AlterTable("oauth_users").
- AddColumn(dialect.
- Column(
+ AddColumn(wf_db.
+ NonNullableColumn(
"provider",
- wf_db.ColumnTypeVarChar,
- wf_db.OptionalInt{Set: true, Value: 24}).SetDefault("")),
+ wf_db.ColumnTypeString{
+ MaxChars: 24,
+ HasDefault: true,
+ DefaultVal: "",
+ })),
dialect.
AlterTable("oauth_users").
- AddColumn(dialect.
- Column(
+ AddColumn(wf_db.
+ NonNullableColumn(
"client_id",
- wf_db.ColumnTypeVarChar,
- wf_db.OptionalInt{Set: true, Value: 128}).SetDefault("")),
+ wf_db.ColumnTypeString{
+ MaxChars: 128,
+ HasDefault: true,
+ DefaultVal: "",
+ })),
dialect.
AlterTable("oauth_users").
- AddColumn(dialect.
- Column(
+ AddColumn(wf_db.
+ NonNullableColumn(
"access_token",
- wf_db.ColumnTypeVarChar,
- wf_db.OptionalInt{Set: true, Value: 512}).SetDefault("")),
+ wf_db.ColumnTypeString{
+ MaxChars: 512,
+ HasDefault: true,
+ DefaultVal: "",
+ })),
dialect.CreateUniqueIndex("oauth_users_uk", "oauth_users", "user_id", "provider", "client_id"),
}
if dialect != wf_db.DialectSQLite {
// This updates the length of the `remote_user_id` column. It isn't needed for SQLite databases.
builders = append(builders, dialect.
AlterTable("oauth_users").
ChangeColumn("remote_user_id",
- dialect.
- Column(
+ wf_db.
+ NonNullableColumn(
"remote_user_id",
- wf_db.ColumnTypeVarChar,
- wf_db.OptionalInt{Set: true, Value: 128})))
+ wf_db.ColumnTypeString{
+ MaxChars: 128,
+ })))
}
for _, builder := range builders {
query, err := builder.ToSQL()
if err != nil {
return err
}
if _, err := tx.ExecContext(ctx, query); err != nil {
return err
}
}
return nil
})
}
diff --git a/migrations/v7.go b/migrations/v7.go
index 2056aa0..a9af405 100644
--- a/migrations/v7.go
+++ b/migrations/v7.go
@@ -1,46 +1,48 @@
/*
* Copyright © 2020-2021 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package migrations
import (
"context"
"database/sql"
wf_db "github.com/writefreely/writefreely/db"
)
func oauthAttach(db *datastore) error {
dialect := wf_db.DialectMySQL
if db.driverName == driverSQLite {
dialect = wf_db.DialectSQLite
}
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
builders := []wf_db.SQLBuilder{
dialect.
AlterTable("oauth_client_states").
- AddColumn(dialect.
- Column(
+ AddColumn(wf_db.
+ NullableColumn(
"attach_user_id",
- wf_db.ColumnTypeInteger,
- wf_db.OptionalInt{Set: true, Value: 24}).SetNullable(true)),
+ wf_db.ColumnTypeInt{
+ MaxBytes: 4,
+ MaxDigits: 24,
+ })),
}
for _, builder := range builders {
query, err := builder.ToSQL()
if err != nil {
return err
}
if _, err := tx.ExecContext(ctx, query); err != nil {
return err
}
}
return nil
})
}
diff --git a/migrations/v8.go b/migrations/v8.go
index 36001af..ded61c9 100644
--- a/migrations/v8.go
+++ b/migrations/v8.go
@@ -1,45 +1,45 @@
/*
* Copyright © 2020-2021 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package migrations
import (
"context"
"database/sql"
wf_db "github.com/writefreely/writefreely/db"
)
func oauthInvites(db *datastore) error {
dialect := wf_db.DialectMySQL
if db.driverName == driverSQLite {
dialect = wf_db.DialectSQLite
}
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
builders := []wf_db.SQLBuilder{
dialect.
AlterTable("oauth_client_states").
- AddColumn(dialect.Column("invite_code", wf_db.ColumnTypeChar, wf_db.OptionalInt{
- Set: true,
- Value: 6,
- }).SetNullable(true)),
+ AddColumn(wf_db.NullableColumn("invite_code", wf_db.ColumnTypeString{
+ IsFixedLength: true,
+ MaxChars: 6,
+ })),
}
for _, builder := range builders {
query, err := builder.ToSQL()
if err != nil {
return err
}
if _, err := tx.ExecContext(ctx, query); err != nil {
return err
}
}
return nil
})
}

File Metadata

Mime Type
text/x-diff
Expires
Sat, Feb 22, 7:13 AM (16 h, 9 m)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3155533

Event Timeline