Page Menu
Home
Musing Studio
Search
Configure Global Search
Log In
Files
F10472704
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
47 KB
Subscribers
None
View Options
diff --git a/db/alter.go b/db/alter.go
index 0a4ffdd..6028700 100644
--- a/db/alter.go
+++ b/db/alter.go
@@ -1,52 +1,52 @@
package db
import (
"fmt"
"strings"
)
type AlterTableSqlBuilder struct {
Dialect DialectType
Name string
Changes []string
}
func (b *AlterTableSqlBuilder) AddColumn(col *Column) *AlterTableSqlBuilder {
- if colVal, err := col.String(); err == nil {
+ if colVal, err := col.CreateSQL(b.Dialect); err == nil {
b.Changes = append(b.Changes, fmt.Sprintf("ADD COLUMN %s", colVal))
}
return b
}
func (b *AlterTableSqlBuilder) ChangeColumn(name string, col *Column) *AlterTableSqlBuilder {
- if colVal, err := col.String(); err == nil {
- b.Changes = append(b.Changes, fmt.Sprintf("CHANGE COLUMN %s %s", name, colVal))
+ if colActions, err := col.AlterSQL(b.Dialect, name); err == nil {
+ b.Changes = append(b.Changes, colActions...)
}
return b
}
func (b *AlterTableSqlBuilder) AddUniqueConstraint(name string, columns ...string) *AlterTableSqlBuilder {
b.Changes = append(b.Changes, fmt.Sprintf("ADD CONSTRAINT %s UNIQUE (%s)", name, strings.Join(columns, ", ")))
return b
}
func (b *AlterTableSqlBuilder) ToSQL() (string, error) {
var str strings.Builder
str.WriteString("ALTER TABLE ")
str.WriteString(b.Name)
str.WriteString(" ")
if len(b.Changes) == 0 {
return "", fmt.Errorf("no changes provide for table: %s", b.Name)
}
changeCount := len(b.Changes)
for i, thing := range b.Changes {
str.WriteString(thing)
if i < changeCount-1 {
str.WriteString(", ")
}
}
return str.String(), nil
}
diff --git a/db/alter_test.go b/db/alter_test.go
index 4bd58ac..4d47821 100644
--- a/db/alter_test.go
+++ b/db/alter_test.go
@@ -1,56 +1,71 @@
package db
import "testing"
func TestAlterTableSqlBuilder_ToSQL(t *testing.T) {
type fields struct {
Dialect DialectType
Name string
Changes []string
}
tests := []struct {
name string
builder *AlterTableSqlBuilder
want string
wantErr bool
}{
{
name: "MySQL add int",
builder: DialectMySQL.
AlterTable("the_table").
- AddColumn(DialectMySQL.Column("the_col", ColumnTypeInteger, UnsetSize)),
- want: "ALTER TABLE the_table ADD COLUMN the_col INT NOT NULL",
+ AddColumn(NonNullableColumn("the_col", ColumnTypeInt{MaxBytes: 4})),
+ want: "ALTER TABLE the_table ADD COLUMN the_col INTEGER NOT NULL",
wantErr: false,
},
{
name: "MySQL add string",
builder: DialectMySQL.
AlterTable("the_table").
- AddColumn(DialectMySQL.Column("the_col", ColumnTypeVarChar, OptionalInt{true, 128})),
+ AddColumn(NonNullableColumn("the_col", ColumnTypeString{MaxChars: 128})),
want: "ALTER TABLE the_table ADD COLUMN the_col VARCHAR(128) NOT NULL",
wantErr: false,
},
-
{
name: "MySQL add int and string",
builder: DialectMySQL.
AlterTable("the_table").
- AddColumn(DialectMySQL.Column("first_col", ColumnTypeInteger, UnsetSize)).
- AddColumn(DialectMySQL.Column("second_col", ColumnTypeVarChar, OptionalInt{true, 128})),
+ AddColumn(NonNullableColumn("first_col", ColumnTypeInt{MaxBytes: 4})).
+ AddColumn(NonNullableColumn("second_col", ColumnTypeString{MaxChars: 128})),
want: "ALTER TABLE the_table ADD COLUMN first_col INT NOT NULL, ADD COLUMN second_col VARCHAR(128) NOT NULL",
wantErr: false,
},
+ {
+ name: "MySQL change to string",
+ builder: DialectMySQL.
+ AlterTable("the_table").
+ ChangeColumn("old_col", NonNullableColumn("new_col", ColumnTypeString{})),
+ want: "ALTER TABLE the_table RENAME COLUMN old_col TO new_col, MODIFY COLUMN new_col TEXT NOT NULL",
+ wantErr: false,
+ },
+ {
+ name: "PostgreSQL change to int",
+ builder: DialectMySQL.
+ AlterTable("the_table").
+ ChangeColumn("old_col", NullableColumn("new_col", ColumnTypeInt{MaxBytes: 4})),
+ want: "ALTER TABLE the_table RENAME COLUMN old_col TO new_col, ALTER COLUMN new_col TYPE INTEGER, ALTER COLUMN new_col DROP NOT NULL",
+ wantErr: false,
+ },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.builder.ToSQL()
if (err != nil) != tt.wantErr {
t.Errorf("ToSQL() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ToSQL() got = %v, want %v", got, tt.want)
}
})
}
}
diff --git a/db/builder.go b/db/builder.go
new file mode 100644
index 0000000..d0f4fe4
--- /dev/null
+++ b/db/builder.go
@@ -0,0 +1,15 @@
+/*
+ * Copyright © 2019-2022 A Bunch Tell LLC.
+ *
+ * This file is part of WriteFreely.
+ *
+ * WriteFreely is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, included
+ * in the LICENSE file in this source code package.
+ */
+
+package db
+
+type SQLBuilder interface {
+ ToSQL() (string, error)
+}
diff --git a/db/column.go b/db/column.go
new file mode 100644
index 0000000..5bf64e4
--- /dev/null
+++ b/db/column.go
@@ -0,0 +1,328 @@
+/*
+ * Copyright © 2019-2022 A Bunch Tell LLC.
+ *
+ * This file is part of WriteFreely.
+ *
+ * WriteFreely is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, included
+ * in the LICENSE file in this source code package.
+ */
+
+package db
+
+import (
+ "fmt"
+ "strings"
+)
+
+type Column struct {
+ Name string
+ Type ColumnType
+ Nullable bool
+ PrimaryKey bool
+}
+
+func NullableColumn(name string, ty ColumnType) *Column {
+ return &Column{
+ Name: name,
+ Type: ty,
+ Nullable: true,
+ PrimaryKey: false,
+ }
+}
+
+func NonNullableColumn(name string, ty ColumnType) *Column {
+ return &Column{
+ Name: name,
+ Type: ty,
+ Nullable: false,
+ PrimaryKey: false,
+ }
+}
+
+func PrimaryKeyColumn(name string, ty ColumnType) *Column {
+ return &Column{
+ Name: name,
+ Type: ty,
+ Nullable: false,
+ PrimaryKey: true,
+ }
+}
+
+type ColumnType interface {
+ Name(DialectType) (string, error)
+ Default(DialectType) (string, error)
+}
+
+type ColumnTypeInt struct {
+ IsSigned bool
+ MaxBytes int
+ MaxDigits int
+ HasDefault bool
+ DefaultVal int
+}
+
+type ColumnTypeString struct {
+ IsFixedLength bool
+ MaxChars int
+ HasDefault bool
+ DefaultVal string
+}
+
+type ColumnDefault int
+
+type ColumnTypeBool struct {
+ DefaultVal ColumnDefault
+}
+
+const (
+ NoDefault ColumnDefault = iota
+ DefaultFalse ColumnDefault = iota
+ DefaultTrue ColumnDefault = iota
+ DefaultNow ColumnDefault = iota
+)
+
+type ColumnTypeDateTime struct {
+ DefaultVal ColumnDefault
+}
+
+func (intCol ColumnTypeInt) Name(d DialectType) (string, error) {
+ switch d {
+ case DialectSQLite:
+ return "INTEGER", nil
+
+ case DialectMySQL, DialectPostgreSQL:
+ var colName string
+ switch intCol.MaxBytes {
+ case 1:
+ if d == DialectMySQL {
+ colName = "TINYINT"
+ } else {
+ colName = "SMALLINT"
+ }
+ case 2:
+ colName = "SMALLINT"
+ case 3:
+ if d == DialectMySQL {
+ colName = "MEDIUMINT"
+ } else {
+ colName = "INTEGER"
+ }
+ case 4:
+ colName = "INTEGER"
+ default:
+ colName = "BIGINT"
+ }
+ if d == DialectMySQL {
+ if intCol.MaxDigits > 0 {
+ colName = fmt.Sprintf("%s(%d)", colName, intCol.MaxDigits)
+ }
+ if !intCol.IsSigned {
+ colName += " UNSIGNED"
+ }
+ }
+ return colName, nil
+
+ default:
+ return "", fmt.Errorf("dialect %d does not support integer columns", d)
+ }
+}
+
+func (intCol ColumnTypeInt) Default(d DialectType) (string, error) {
+ if intCol.HasDefault {
+ return fmt.Sprintf("%d", intCol.DefaultVal), nil
+ }
+ return "", nil
+}
+
+func (strCol ColumnTypeString) Name(d DialectType) (string, error) {
+ switch d {
+ case DialectSQLite:
+ return "TEXT", nil
+
+ case DialectMySQL, DialectPostgreSQL:
+ if strCol.IsFixedLength {
+ if strCol.MaxChars > 0 {
+ return fmt.Sprintf("CHAR(%d)", strCol.MaxChars), nil
+ }
+ return "CHAR", nil
+ }
+
+ if strCol.MaxChars <= 0 {
+ return "TEXT", nil
+ }
+ if strCol.MaxChars < (1 << 16) {
+ return fmt.Sprintf("VARCHAR(%d)", strCol.MaxChars), nil
+ }
+ return "TEXT", nil
+
+ default:
+ return "", fmt.Errorf("dialect %d does not support string columns", d)
+ }
+}
+
+func (strCol ColumnTypeString) Default(d DialectType) (string, error) {
+ if strCol.HasDefault {
+ return EscapeSimple.SQLEscape(d, strCol.DefaultVal)
+ }
+ return "", nil
+}
+
+func (boolCol ColumnTypeBool) Name(d DialectType) (string, error) {
+ switch d {
+ case DialectSQLite:
+ return "INTEGER", nil
+ case DialectMySQL, DialectPostgreSQL:
+ return "BOOL", nil
+ default:
+ return "", fmt.Errorf("boolean column type not supported for dialect %d", d)
+ }
+}
+
+func (boolCol ColumnTypeBool) Default(d DialectType) (string, error) {
+ switch boolCol.DefaultVal {
+ case NoDefault:
+ return "", nil
+ case DefaultFalse:
+ return "0", nil
+ case DefaultTrue:
+ return "1", nil
+ default:
+ return "", fmt.Errorf("boolean columns cannot default to %d for dialect %d", boolCol.DefaultVal, d)
+ }
+}
+
+func (dateTimeCol ColumnTypeDateTime) Name(d DialectType) (string, error) {
+ switch d {
+ case DialectSQLite, DialectMySQL:
+ return "DATETIME", nil
+ case DialectPostgreSQL:
+ return "TIMESTAMP", nil
+ default:
+ return "", fmt.Errorf("datetime column type not supported for dialect %d", d)
+ }
+}
+
+func (dateTimeCol ColumnTypeDateTime) Default(d DialectType) (string, error) {
+ switch d {
+ case DialectSQLite, DialectMySQL:
+ switch dateTimeCol.DefaultVal {
+ case NoDefault:
+ return "", nil
+ case DefaultNow:
+ switch d {
+ case DialectSQLite, DialectPostgreSQL:
+ return "CURRENT_TIMESTAMP", nil
+ case DialectMySQL:
+ return "NOW()", nil
+ }
+ }
+ return "", fmt.Errorf("datetime columns cannot default to %d for dialect %d", dateTimeCol.DefaultVal, d)
+ default:
+ return "", fmt.Errorf("dialect %d does not support defaulted datetime columns", d)
+ }
+}
+
+func (c *Column) SetName(name string) *Column {
+ c.Name = name
+ return c
+}
+
+func (c *Column) SetNullable(nullable bool) *Column {
+ c.Nullable = nullable
+ return c
+}
+
+func (c *Column) SetPrimaryKey(pk bool) *Column {
+ c.PrimaryKey = pk
+ return c
+}
+
+func (c *Column) SetType(t ColumnType) *Column {
+ c.Type = t
+ return c
+}
+
+func (c *Column) AlterSQL(d DialectType, oldName string) ([]string, error) {
+ var actions []string = make([]string, 0)
+
+ switch d {
+ // MySQL does all modifications at once
+ case DialectMySQL:
+ sql, err := c.CreateSQL(d)
+ if err != nil {
+ return make([]string, 0), err
+ }
+ actions = append(actions, fmt.Sprintf("CHANGE COLUMN %s %s", oldName, sql))
+
+ // PostgreSQL does modifications piece by piece
+ case DialectPostgreSQL:
+ if oldName != c.Name {
+ actions = append(actions, fmt.Sprintf("RENAME COLUMN %s TO %s", oldName, c.Name))
+ }
+
+ typeStr, err := c.Type.Name(d)
+ if err != nil {
+ return make([]string, 0), err
+ }
+
+ actions = append(actions, fmt.Sprintf("ALTER COLUMN %s TYPE %s", c.Name, typeStr))
+ var nullAction string
+ if c.Nullable {
+ nullAction = "DROP"
+ } else {
+ nullAction = "SET"
+ }
+ actions = append(actions, fmt.Sprintf("ALTER COLUMN %s %s NOT NULL", c.Name, nullAction))
+
+ defaultStr, err := c.Type.Default(d)
+ if err != nil {
+ return make([]string, 0), err
+ }
+ if len(defaultStr) > 0 {
+ actions = append(actions, fmt.Sprintf("ALTER COLUMN %s SET DEFAULT %s", c.Name, defaultStr))
+ }
+
+ if c.PrimaryKey {
+ actions = append(actions, fmt.Sprintf("ADD PRIMARY KEY (%s)", c.Name))
+ }
+
+ default:
+ return make([]string, 0), fmt.Errorf("dialect %d doesn't support altering column data type", d)
+ }
+
+ return actions, nil
+}
+
+func (c *Column) CreateSQL(d DialectType) (string, error) {
+ var str strings.Builder
+
+ str.WriteString(c.Name)
+
+ str.WriteString(" ")
+ typeStr, err := c.Type.Name(d)
+ if err != nil {
+ return "", err
+ }
+
+ str.WriteString(typeStr)
+
+ if !c.Nullable {
+ str.WriteString(" NOT NULL")
+ }
+
+ defaultStr, err := c.Type.Default(d)
+ if err != nil {
+ return "", err
+ }
+ if len(defaultStr) > 0 {
+ str.WriteString(" DEFAULT ")
+ str.WriteString(defaultStr)
+ }
+
+ if c.PrimaryKey {
+ str.WriteString(" PRIMARY KEY")
+ }
+
+ return str.String(), nil
+}
diff --git a/db/column_test.go b/db/column_test.go
new file mode 100644
index 0000000..175ec3a
--- /dev/null
+++ b/db/column_test.go
@@ -0,0 +1,151 @@
+package db
+
+import (
+ "github.com/stretchr/testify/assert"
+ "testing"
+)
+
+func TestColumnType_Name(t *testing.T) {
+ tests := []struct {
+ name string
+ ty ColumnType
+ d DialectType
+ want string
+ wantErr bool
+ }{
+ {"SQLite bool", ColumnTypeBool{}, DialectSQLite, "INTEGER", false},
+ {"SQLite int", ColumnTypeInt{}, DialectSQLite, "INTEGER", false},
+ {"SQLite string ", ColumnTypeString{HasDefault: true, DefaultVal: "that's a default"}, DialectSQLite, "TEXT DEFAULT 'that''s a default'", false},
+ {"SQLite datetime", ColumnTypeDateTime{}, DialectSQLite, "DATETIME", false},
+
+ {"MySQL bool", ColumnTypeBool{}, DialectMySQL, "BOOL", false},
+ {"MySQL tiny int", ColumnTypeInt{MaxBytes: 1}, DialectMySQL, "TINYINT", false},
+ {"MySQL tiny int with digits", ColumnTypeInt{MaxBytes: 1, MaxDigits: 2}, DialectMySQL, "TINYINT(2)", false},
+ {"MySQL small int", ColumnTypeInt{MaxBytes: 2}, DialectMySQL, "SMALLINT", false},
+ {"MySQL small int with digits", ColumnTypeInt{MaxBytes: 2, MaxDigits: 3}, DialectMySQL, "SMALLINT(3)", false},
+ {"MySQL medium int", ColumnTypeInt{MaxBytes: 3}, DialectMySQL, "MEDIUMINT", false},
+ {"MySQL medium int with digits", ColumnTypeInt{MaxBytes: 3, MaxDigits: 6}, DialectMySQL, "MEDIUMINT(6)", false},
+ {"MySQL int", ColumnTypeInt{MaxBytes: 4}, DialectMySQL, "INTEGER", false},
+ {"MySQL int with digits", ColumnTypeInt{MaxBytes: 4, MaxDigits: 11}, DialectMySQL, "INTEGER(11)", false},
+ {"MySQL bigint", ColumnTypeInt{MaxBytes: 4}, DialectMySQL, "BIGINT", false},
+ {"MySQL bigint with digits", ColumnTypeInt{MaxBytes: 4, MaxDigits: 15}, DialectMySQL, "BIGINT(15)", false},
+ {"MySQL char", ColumnTypeString{IsFixedLength: true}, DialectMySQL, "CHAR", false},
+ {"MySQL char with length", ColumnTypeString{IsFixedLength: true, MaxChars: 4}, DialectMySQL, "CHAR(4)", false},
+ {"MySQL varchar with length", ColumnTypeString{MaxChars: 25}, DialectMySQL, "VARCHAR(25)", false},
+ {"MySQL text", ColumnTypeString{}, DialectMySQL, "TEXT", false},
+ {"MySQL datetime", ColumnTypeDateTime{}, DialectMySQL, "DATETIME", false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := tt.ty.Name(tt.d)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("Name() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if got != tt.want {
+ t.Errorf("Name() got = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestColumnType_Default(t *testing.T) {
+ tests := []struct {
+ name string
+ ty ColumnType
+ d DialectType
+ want string
+ wantErr bool
+ }{
+ {"SQLite bool none", ColumnTypeBool{}, DialectSQLite, "", false},
+ {"SQLite bool false", ColumnTypeBool{}, DialectSQLite, "0", false},
+ {"SQLite bool true", ColumnTypeBool{}, DialectSQLite, "1", false},
+ {"SQLite int none", ColumnTypeInt{}, DialectSQLite, "", false},
+ {"SQLite int empty", ColumnTypeInt{HasDefault: true}, DialectSQLite, "0", false},
+ {"SQLite int", ColumnTypeInt{HasDefault: true, DefaultVal: 10}, DialectSQLite, "10", false},
+ {"SQLite string none", ColumnTypeString{}, DialectSQLite, "", false},
+ {"SQLite string empty", ColumnTypeString{HasDefault: true}, DialectSQLite, "''", false},
+ {"SQLite string", ColumnTypeString{HasDefault: true, DefaultVal: "that's a default"}, DialectSQLite, "'that''s a default'", false},
+ {"MySQL string", ColumnTypeString{HasDefault: true, DefaultVal: "%that's a default%"}, DialectMySQL, "'%that\\'s a default%'", false},
+
+ {"SQLite datetime none", ColumnTypeDateTime{}, DialectSQLite, "", false},
+ {"SQLite datetime now", ColumnTypeDateTime{DefaultVal: DefaultNow}, DialectSQLite, "CURRENT_TIMESTAMP", false},
+ {"MySQL datetime now", ColumnTypeDateTime{DefaultVal: DefaultNow}, DialectMySQL, "NOW()", false},
+ {"PostgreSQL datetime now", ColumnTypeDateTime{DefaultVal: DefaultNow}, DialectPostgreSQL, "CURRENT_TIMESTAMP", false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := tt.ty.Default(tt.d)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("Default() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if got != tt.want {
+ t.Errorf("Default() got = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestColumn_CreateSQL(t *testing.T) {
+ type fields struct {
+ Dialect DialectType
+ Name string
+ Nullable bool
+ Type ColumnType
+ PrimaryKey bool
+ }
+ tests := []struct {
+ name string
+ fields fields
+ want string
+ wantErr bool
+ }{
+ {"SQLite bool", fields{DialectSQLite, "foo", false, ColumnTypeBool{}, false}, "foo INTEGER NOT NULL", false},
+ {"SQLite bool nullable", fields{DialectSQLite, "foo", true, ColumnTypeBool{}, false}, "foo INTEGER", false},
+ {"SQLite int", fields{DialectSQLite, "foo", false, ColumnTypeInt{}, true}, "foo INTEGER NOT NULL PRIMARY KEY", false},
+ {"SQLite int nullable", fields{DialectSQLite, "foo", true, ColumnTypeInt{}, false}, "foo INTEGER", false},
+ {"SQLite text", fields{DialectSQLite, "foo", false, ColumnTypeString{}, false}, "foo TEXT NOT NULL", false},
+ {"SQLite text nullable", fields{DialectSQLite, "foo", true, ColumnTypeString{}, false}, "foo TEXT", false},
+ {"SQLite datetime", fields{DialectSQLite, "foo", false, ColumnTypeDateTime{}, false}, "foo DATETIME NOT NULL", false},
+ {"SQLite datetime nullable", fields{DialectSQLite, "foo", true, ColumnTypeDateTime{}, false}, "foo DATETIME", false},
+
+ {"MySQL bool", fields{DialectMySQL, "foo", false, ColumnTypeBool{}, false}, "foo TINYINT(1) NOT NULL", false},
+ {"MySQL bool nullable", fields{DialectMySQL, "foo", true, ColumnTypeBool{}, false}, "foo TINYINT(1)", false},
+ {"MySQL tiny int", fields{DialectMySQL, "foo", false, ColumnTypeInt{MaxBytes: 1}, true}, "foo TINYINT NOT NULL PRIMARY KEY", false},
+ {"MySQL tiny int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{MaxBytes: 1}, false}, "foo TINYINT", false},
+ {"MySQL small int", fields{DialectMySQL, "foo", false, ColumnTypeInt{MaxBytes: 2}, true}, "foo SMALLINT NOT NULL PRIMARY KEY", false},
+ {"MySQL small int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{MaxBytes: 2}, false}, "foo SMALLINT", false},
+ {"MySQL int", fields{DialectMySQL, "foo", false, ColumnTypeInt{MaxBytes: 4}, true}, "foo INTEGER NOT NULL PRIMARY KEY", false},
+ {"MySQL int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{MaxBytes: 4}, false}, "foo INTEGER", false},
+ {"MySQL big int", fields{DialectMySQL, "foo", false, ColumnTypeInt{}, true}, "foo BIGINT NOT NULL PRIMARY KEY", false},
+ {"MySQL big int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{}, false}, "foo BIGINT", false},
+ {"MySQL char", fields{DialectMySQL, "foo", false, ColumnTypeString{IsFixedLength: true}, false}, "foo CHAR NOT NULL", false},
+ {"MySQL char nullable", fields{DialectMySQL, "foo", true, ColumnTypeString{IsFixedLength: true}, false}, "foo CHAR", false},
+ {"MySQL varchar", fields{DialectMySQL, "foo", false, ColumnTypeString{MaxChars: 255}, false}, "foo VARCHAR(255) NOT NULL", false},
+ {"MySQL varchar nullable", fields{DialectMySQL, "foo", true, ColumnTypeString{MaxChars: 255}, false}, "foo VARCHAR(255)", false},
+ {"MySQL text", fields{DialectMySQL, "foo", false, ColumnTypeString{}, false}, "foo TEXT NOT NULL", false},
+ {"MySQL text nullable", fields{DialectMySQL, "foo", true, ColumnTypeString{}, false}, "foo TEXT", false},
+ {"MySQL datetime", fields{DialectMySQL, "foo", false, ColumnTypeDateTime{}, false}, "foo DATETIME NOT NULL", false},
+ {"MySQL datetime nullable", fields{DialectMySQL, "foo", true, ColumnTypeDateTime{}, false}, "foo DATETIME", false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ c := &Column{
+ Name: tt.fields.Name,
+ Nullable: tt.fields.Nullable,
+ Type: tt.fields.Type,
+ PrimaryKey: tt.fields.PrimaryKey,
+ }
+ if got, err := c.CreateSQL(tt.fields.Dialect); got != tt.want {
+ if (err != nil) != tt.wantErr {
+ t.Errorf("String() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if got != tt.want {
+ t.Errorf("String() got = %v, want %v", got, tt.want)
+ }
+ }
+ })
+ }
+}
diff --git a/db/create.go b/db/create.go
index 8728d5d..9462bbe 100644
--- a/db/create.go
+++ b/db/create.go
@@ -1,266 +1,89 @@
/*
* Copyright © 2019-2020 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package db
import (
"fmt"
"strings"
)
-type ColumnType int
-
-type OptionalInt struct {
- Set bool
- Value int
-}
-
-type OptionalString struct {
- Set bool
- Value string
-}
-
-type SQLBuilder interface {
- ToSQL() (string, error)
-}
-
-type Column struct {
- Dialect DialectType
- Name string
- Nullable bool
- Default OptionalString
- Type ColumnType
- Size OptionalInt
- PrimaryKey bool
-}
-
type CreateTableSqlBuilder struct {
Dialect DialectType
Name string
IfNotExists bool
ColumnOrder []string
Columns map[string]*Column
Constraints []string
}
-const (
- ColumnTypeBool ColumnType = iota
- ColumnTypeSmallInt ColumnType = iota
- ColumnTypeInteger ColumnType = iota
- ColumnTypeChar ColumnType = iota
- ColumnTypeVarChar ColumnType = iota
- ColumnTypeText ColumnType = iota
- ColumnTypeDateTime ColumnType = iota
-)
-
-var _ SQLBuilder = &CreateTableSqlBuilder{}
-
-var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0}
-var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""}
-
-func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) {
- if dialect != DialectMySQL && dialect != DialectSQLite {
- return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
- }
- switch d {
- case ColumnTypeSmallInt:
- {
- if dialect == DialectSQLite {
- return "INTEGER", nil
- }
- mod := ""
- if size.Set {
- mod = fmt.Sprintf("(%d)", size.Value)
- }
- return "SMALLINT" + mod, nil
- }
- case ColumnTypeInteger:
- {
- if dialect == DialectSQLite {
- return "INTEGER", nil
- }
- mod := ""
- if size.Set {
- mod = fmt.Sprintf("(%d)", size.Value)
- }
- return "INT" + mod, nil
- }
- case ColumnTypeChar:
- {
- if dialect == DialectSQLite {
- return "TEXT", nil
- }
- mod := ""
- if size.Set {
- mod = fmt.Sprintf("(%d)", size.Value)
- }
- return "CHAR" + mod, nil
- }
- case ColumnTypeVarChar:
- {
- if dialect == DialectSQLite {
- return "TEXT", nil
- }
- mod := ""
- if size.Set {
- mod = fmt.Sprintf("(%d)", size.Value)
- }
- return "VARCHAR" + mod, nil
- }
- case ColumnTypeBool:
- {
- if dialect == DialectSQLite {
- return "INTEGER", nil
- }
- return "TINYINT(1)", nil
- }
- case ColumnTypeDateTime:
- return "DATETIME", nil
- case ColumnTypeText:
- return "TEXT", nil
- }
- return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
-}
-
-func (c *Column) SetName(name string) *Column {
- c.Name = name
- return c
-}
-
-func (c *Column) SetNullable(nullable bool) *Column {
- c.Nullable = nullable
- return c
-}
-
-func (c *Column) SetPrimaryKey(pk bool) *Column {
- c.PrimaryKey = pk
- return c
-}
-
-func (c *Column) SetDefault(value string) *Column {
- c.Default = OptionalString{Set: true, Value: value}
- return c
-}
-
-func (c *Column) SetDefaultCurrentTimestamp() *Column {
- def := "NOW()"
- if c.Dialect == DialectSQLite {
- def = "CURRENT_TIMESTAMP"
- }
- c.Default = OptionalString{Set: true, Value: def}
- return c
-}
-
-func (c *Column) SetType(t ColumnType) *Column {
- c.Type = t
- return c
-}
-
-func (c *Column) SetSize(size int) *Column {
- c.Size = OptionalInt{Set: true, Value: size}
- return c
-}
-
-func (c *Column) String() (string, error) {
- var str strings.Builder
-
- str.WriteString(c.Name)
-
- str.WriteString(" ")
- typeStr, err := c.Type.Format(c.Dialect, c.Size)
- if err != nil {
- return "", err
- }
-
- str.WriteString(typeStr)
-
- if !c.Nullable {
- str.WriteString(" NOT NULL")
- }
-
- if c.Default.Set {
- str.WriteString(" DEFAULT ")
- val := c.Default.Value
- if val == "" {
- val = "''"
- }
- str.WriteString(val)
- }
-
- if c.PrimaryKey {
- str.WriteString(" PRIMARY KEY")
- }
-
- return str.String(), nil
-}
-
func (b *CreateTableSqlBuilder) Column(column *Column) *CreateTableSqlBuilder {
if b.Columns == nil {
b.Columns = make(map[string]*Column)
}
b.Columns[column.Name] = column
b.ColumnOrder = append(b.ColumnOrder, column.Name)
return b
}
func (b *CreateTableSqlBuilder) UniqueConstraint(columns ...string) *CreateTableSqlBuilder {
for _, column := range columns {
if _, ok := b.Columns[column]; !ok {
// This fails silently.
return b
}
}
b.Constraints = append(b.Constraints, fmt.Sprintf("UNIQUE(%s)", strings.Join(columns, ",")))
return b
}
func (b *CreateTableSqlBuilder) SetIfNotExists(ine bool) *CreateTableSqlBuilder {
b.IfNotExists = ine
return b
}
func (b *CreateTableSqlBuilder) ToSQL() (string, error) {
var str strings.Builder
str.WriteString("CREATE TABLE ")
if b.IfNotExists {
str.WriteString("IF NOT EXISTS ")
}
str.WriteString(b.Name)
var things []string
for _, columnName := range b.ColumnOrder {
column, ok := b.Columns[columnName]
if !ok {
return "", fmt.Errorf("column not found: %s", columnName)
}
- columnStr, err := column.String()
+ columnStr, err := column.CreateSQL(b.Dialect)
if err != nil {
return "", err
}
things = append(things, columnStr)
}
for _, constraint := range b.Constraints {
things = append(things, constraint)
}
if thingLen := len(things); thingLen > 0 {
str.WriteString(" ( ")
for i, thing := range things {
str.WriteString(thing)
if i < thingLen-1 {
str.WriteString(", ")
}
}
str.WriteString(" )")
}
return str.String(), nil
}
diff --git a/db/create_test.go b/db/create_test.go
index 369d5c1..09efd18 100644
--- a/db/create_test.go
+++ b/db/create_test.go
@@ -1,146 +1,20 @@
package db
import (
"github.com/stretchr/testify/assert"
"testing"
)
-func TestDialect_Column(t *testing.T) {
- c1 := DialectSQLite.Column("foo", ColumnTypeBool, UnsetSize)
- assert.Equal(t, DialectSQLite, c1.Dialect)
- c2 := DialectMySQL.Column("foo", ColumnTypeBool, UnsetSize)
- assert.Equal(t, DialectMySQL, c2.Dialect)
-}
-
-func TestColumnType_Format(t *testing.T) {
- type args struct {
- dialect DialectType
- size OptionalInt
- }
- tests := []struct {
- name string
- d ColumnType
- args args
- want string
- wantErr bool
- }{
- {"Sqlite bool", ColumnTypeBool, args{dialect: DialectSQLite}, "INTEGER", false},
- {"Sqlite small int", ColumnTypeSmallInt, args{dialect: DialectSQLite}, "INTEGER", false},
- {"Sqlite int", ColumnTypeInteger, args{dialect: DialectSQLite}, "INTEGER", false},
- {"Sqlite char", ColumnTypeChar, args{dialect: DialectSQLite}, "TEXT", false},
- {"Sqlite varchar", ColumnTypeVarChar, args{dialect: DialectSQLite}, "TEXT", false},
- {"Sqlite text", ColumnTypeText, args{dialect: DialectSQLite}, "TEXT", false},
- {"Sqlite datetime", ColumnTypeDateTime, args{dialect: DialectSQLite}, "DATETIME", false},
-
- {"MySQL bool", ColumnTypeBool, args{dialect: DialectMySQL}, "TINYINT(1)", false},
- {"MySQL small int", ColumnTypeSmallInt, args{dialect: DialectMySQL}, "SMALLINT", false},
- {"MySQL small int with param", ColumnTypeSmallInt, args{dialect: DialectMySQL, size: OptionalInt{true, 3}}, "SMALLINT(3)", false},
- {"MySQL int", ColumnTypeInteger, args{dialect: DialectMySQL}, "INT", false},
- {"MySQL int with param", ColumnTypeInteger, args{dialect: DialectMySQL, size: OptionalInt{true, 11}}, "INT(11)", false},
- {"MySQL char", ColumnTypeChar, args{dialect: DialectMySQL}, "CHAR", false},
- {"MySQL char with param", ColumnTypeChar, args{dialect: DialectMySQL, size: OptionalInt{true, 4}}, "CHAR(4)", false},
- {"MySQL varchar", ColumnTypeVarChar, args{dialect: DialectMySQL}, "VARCHAR", false},
- {"MySQL varchar with param", ColumnTypeVarChar, args{dialect: DialectMySQL, size: OptionalInt{true, 25}}, "VARCHAR(25)", false},
- {"MySQL text", ColumnTypeText, args{dialect: DialectMySQL}, "TEXT", false},
- {"MySQL datetime", ColumnTypeDateTime, args{dialect: DialectMySQL}, "DATETIME", false},
-
- {"invalid column type", 10000, args{dialect: DialectMySQL}, "", true},
- {"invalid dialect", ColumnTypeBool, args{dialect: 10000}, "", true},
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := tt.d.Format(tt.args.dialect, tt.args.size)
- if (err != nil) != tt.wantErr {
- t.Errorf("Format() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if got != tt.want {
- t.Errorf("Format() got = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func TestColumn_Build(t *testing.T) {
- type fields struct {
- Dialect DialectType
- Name string
- Nullable bool
- Default OptionalString
- Type ColumnType
- Size OptionalInt
- PrimaryKey bool
- }
- tests := []struct {
- name string
- fields fields
- want string
- wantErr bool
- }{
- {"Sqlite bool", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo INTEGER NOT NULL", false},
- {"Sqlite bool nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo INTEGER", false},
- {"Sqlite small int", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeSmallInt, UnsetSize, true}, "foo INTEGER NOT NULL PRIMARY KEY", false},
- {"Sqlite small int nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeSmallInt, UnsetSize, false}, "foo INTEGER", false},
- {"Sqlite int", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INTEGER NOT NULL", false},
- {"Sqlite int nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INTEGER", false},
- {"Sqlite char", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo TEXT NOT NULL", false},
- {"Sqlite char nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo TEXT", false},
- {"Sqlite varchar", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo TEXT NOT NULL", false},
- {"Sqlite varchar nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo TEXT", false},
- {"Sqlite text", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT NOT NULL", false},
- {"Sqlite text nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT", false},
- {"Sqlite datetime", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME NOT NULL", false},
- {"Sqlite datetime nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME", false},
-
- {"MySQL bool", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo TINYINT(1) NOT NULL", false},
- {"MySQL bool nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo TINYINT(1)", false},
- {"MySQL small int", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeSmallInt, UnsetSize, true}, "foo SMALLINT NOT NULL PRIMARY KEY", false},
- {"MySQL small int nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeSmallInt, UnsetSize, false}, "foo SMALLINT", false},
- {"MySQL int", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INT NOT NULL", false},
- {"MySQL int nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INT", false},
- {"MySQL char", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo CHAR NOT NULL", false},
- {"MySQL char nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo CHAR", false},
- {"MySQL varchar", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo VARCHAR NOT NULL", false},
- {"MySQL varchar nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo VARCHAR", false},
- {"MySQL text", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT NOT NULL", false},
- {"MySQL text nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT", false},
- {"MySQL datetime", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME NOT NULL", false},
- {"MySQL datetime nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME", false},
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- c := &Column{
- Dialect: tt.fields.Dialect,
- Name: tt.fields.Name,
- Nullable: tt.fields.Nullable,
- Default: tt.fields.Default,
- Type: tt.fields.Type,
- Size: tt.fields.Size,
- PrimaryKey: tt.fields.PrimaryKey,
- }
- if got, err := c.String(); got != tt.want {
- if (err != nil) != tt.wantErr {
- t.Errorf("String() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if got != tt.want {
- t.Errorf("String() got = %v, want %v", got, tt.want)
- }
- }
- })
- }
-}
-
func TestCreateTableSqlBuilder_ToSQL(t *testing.T) {
sql, err := DialectMySQL.
Table("foo").
SetIfNotExists(true).
- Column(DialectMySQL.Column("bar", ColumnTypeInteger, UnsetSize).SetPrimaryKey(true)).
- Column(DialectMySQL.Column("baz", ColumnTypeText, UnsetSize)).
- Column(DialectMySQL.Column("qux", ColumnTypeDateTime, UnsetSize).SetDefault("NOW()")).
+ Column(PrimaryKeyColumn("bar", ColumnTypeInt{MaxBytes: 4})).
+ Column(NonNullableColumn("baz", ColumnTypeString{})).
+ Column(NonNullableColumn("qux", ColumnTypeDateTime{DefaultVal: DefaultNow})).
UniqueConstraint("bar").
UniqueConstraint("bar", "baz").
ToSQL()
assert.NoError(t, err)
assert.Equal(t, "CREATE TABLE IF NOT EXISTS foo ( bar INT NOT NULL PRIMARY KEY, baz TEXT NOT NULL, qux DATETIME NOT NULL DEFAULT NOW(), UNIQUE(bar), UNIQUE(bar,baz) )", sql)
}
diff --git a/db/dialect.go b/db/dialect.go
index 4251465..3e2b90b 100644
--- a/db/dialect.go
+++ b/db/dialect.go
@@ -1,76 +1,51 @@
package db
import "fmt"
type DialectType int
const (
- DialectSQLite DialectType = iota
- DialectMySQL DialectType = iota
+ DialectSQLite DialectType = iota
+ DialectMySQL DialectType = iota
+ DialectPostgreSQL DialectType = iota
)
-func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column {
+func (d DialectType) IsKnown() bool {
switch d {
- case DialectSQLite:
- return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size}
- case DialectMySQL:
- return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size}
+ case DialectSQLite, DialectMySQL, DialectPostgreSQL:
+ return true
default:
- panic(fmt.Sprintf("unexpected dialect: %d", d))
+ return false
}
}
-func (d DialectType) Table(name string) *CreateTableSqlBuilder {
- switch d {
- case DialectSQLite:
- return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name}
- case DialectMySQL:
- return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name}
- default:
+func (d DialectType) AssertKnown() {
+ if !d.IsKnown() {
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
+func (d DialectType) Table(name string) *CreateTableSqlBuilder {
+ d.AssertKnown()
+ return &CreateTableSqlBuilder{Dialect: d, Name: name}
+}
+
func (d DialectType) AlterTable(name string) *AlterTableSqlBuilder {
- switch d {
- case DialectSQLite:
- return &AlterTableSqlBuilder{Dialect: DialectSQLite, Name: name}
- case DialectMySQL:
- return &AlterTableSqlBuilder{Dialect: DialectMySQL, Name: name}
- default:
- panic(fmt.Sprintf("unexpected dialect: %d", d))
- }
+ d.AssertKnown()
+ return &AlterTableSqlBuilder{Dialect: d, Name: name}
}
func (d DialectType) CreateUniqueIndex(name, table string, columns ...string) *CreateIndexSqlBuilder {
- switch d {
- case DialectSQLite:
- return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: true, Columns: columns}
- case DialectMySQL:
- return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: true, Columns: columns}
- default:
- panic(fmt.Sprintf("unexpected dialect: %d", d))
- }
+ d.AssertKnown()
+ return &CreateIndexSqlBuilder{Dialect: d, Name: name, Table: table, Unique: true, Columns: columns}
}
func (d DialectType) CreateIndex(name, table string, columns ...string) *CreateIndexSqlBuilder {
- switch d {
- case DialectSQLite:
- return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: false, Columns: columns}
- case DialectMySQL:
- return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: false, Columns: columns}
- default:
- panic(fmt.Sprintf("unexpected dialect: %d", d))
- }
+ d.AssertKnown()
+ return &CreateIndexSqlBuilder{Dialect: d, Name: name, Table: table, Unique: false, Columns: columns}
}
func (d DialectType) DropIndex(name, table string) *DropIndexSqlBuilder {
- switch d {
- case DialectSQLite:
- return &DropIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table}
- case DialectMySQL:
- return &DropIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table}
- default:
- panic(fmt.Sprintf("unexpected dialect: %d", d))
- }
+ d.AssertKnown()
+ return &DropIndexSqlBuilder{Dialect: d, Name: name, Table: table}
}
diff --git a/db/escape.go b/db/escape.go
new file mode 100644
index 0000000..53b8ef3
--- /dev/null
+++ b/db/escape.go
@@ -0,0 +1,63 @@
+/*
+ * Copyright © 2019-2022 A Bunch Tell LLC.
+ *
+ * This file is part of WriteFreely.
+ *
+ * WriteFreely is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, included
+ * in the LICENSE file in this source code package.
+ */
+
+package db
+
+import (
+ "strings"
+)
+
+type EscapeContext int
+
+const (
+ EscapeSimple EscapeContext = iota
+)
+
+func (_ EscapeContext) SQLEscape(d DialectType, s string) (string, error) {
+ builder := strings.Builder{}
+ switch d {
+ case DialectSQLite:
+ builder.WriteRune('\'')
+ for _, c := range s {
+ if c == '\'' {
+ builder.WriteString("''")
+ } else {
+ builder.WriteRune(c)
+ }
+ }
+ builder.WriteRune('\'')
+ case DialectMySQL:
+ builder.WriteRune('\'')
+ for _, c := range s {
+ switch c {
+ case 0:
+ builder.WriteString("\\0")
+ case '\'':
+ builder.WriteString("\\'")
+ case '"':
+ builder.WriteString("\\\"")
+ case '\b':
+ builder.WriteString("\\b")
+ case '\n':
+ builder.WriteString("\\n")
+ case '\r':
+ builder.WriteString("\\r")
+ case '\t':
+ builder.WriteString("\\t")
+ case '\\':
+ builder.WriteString("\\\\")
+ default:
+ builder.WriteRune(c)
+ }
+ }
+ builder.WriteRune('\'')
+ }
+ return builder.String(), nil
+}
diff --git a/migrations/v4.go b/migrations/v4.go
index 4ae267d..6d0c9f2 100644
--- a/migrations/v4.go
+++ b/migrations/v4.go
@@ -1,54 +1,54 @@
/*
* Copyright © 2019-2021 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package migrations
import (
"context"
"database/sql"
wf_db "github.com/writefreely/writefreely/db"
)
func oauth(db *datastore) error {
dialect := wf_db.DialectMySQL
if db.driverName == driverSQLite {
dialect = wf_db.DialectSQLite
}
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
createTableUsersOauth, err := dialect.
Table("oauth_users").
SetIfNotExists(false).
- Column(dialect.Column("user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)).
- Column(dialect.Column("remote_user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)).
+ Column(wf_db.NonNullableColumn("user_id", wf_db.ColumnTypeInt{MaxBytes: 4})).
+ Column(wf_db.NonNullableColumn("remote_user_id", wf_db.ColumnTypeInt{MaxBytes: 4})).
ToSQL()
if err != nil {
return err
}
createTableOauthClientState, err := dialect.
Table("oauth_client_states").
SetIfNotExists(false).
- Column(dialect.Column("state", wf_db.ColumnTypeVarChar, wf_db.OptionalInt{Set: true, Value: 255})).
- Column(dialect.Column("used", wf_db.ColumnTypeBool, wf_db.UnsetSize)).
- Column(dialect.Column("created_at", wf_db.ColumnTypeDateTime, wf_db.UnsetSize).SetDefaultCurrentTimestamp()).
+ Column(wf_db.NonNullableColumn("state", wf_db.ColumnTypeString{MaxChars: 255})).
+ Column(wf_db.NonNullableColumn("used", wf_db.ColumnTypeBool{})).
+ Column(wf_db.NonNullableColumn("created_at", wf_db.ColumnTypeDateTime{DefaultVal: wf_db.DefaultNow})).
UniqueConstraint("state").
ToSQL()
if err != nil {
return err
}
for _, table := range []string{createTableUsersOauth, createTableOauthClientState} {
if _, err := tx.ExecContext(ctx, table); err != nil {
return err
}
}
return nil
})
}
diff --git a/migrations/v5.go b/migrations/v5.go
index db18fa1..4508b18 100644
--- a/migrations/v5.go
+++ b/migrations/v5.go
@@ -1,88 +1,105 @@
/*
* Copyright © 2019-2021 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package migrations
import (
"context"
"database/sql"
wf_db "github.com/writefreely/writefreely/db"
)
func oauthSlack(db *datastore) error {
dialect := wf_db.DialectMySQL
if db.driverName == driverSQLite {
dialect = wf_db.DialectSQLite
}
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
builders := []wf_db.SQLBuilder{
dialect.
AlterTable("oauth_client_states").
- AddColumn(dialect.
- Column(
+ AddColumn(wf_db.
+ NonNullableColumn(
"provider",
- wf_db.ColumnTypeVarChar,
- wf_db.OptionalInt{Set: true, Value: 24}).SetDefault("")),
+ wf_db.ColumnTypeString{
+ MaxChars: 24,
+ HasDefault: true,
+ DefaultVal: "",
+ })),
dialect.
AlterTable("oauth_client_states").
- AddColumn(dialect.
- Column(
+ AddColumn(wf_db.
+ NonNullableColumn(
"client_id",
- wf_db.ColumnTypeVarChar,
- wf_db.OptionalInt{Set: true, Value: 128}).SetDefault("")),
+ wf_db.ColumnTypeString{
+ MaxChars: 128,
+ HasDefault: true,
+ DefaultVal: "",
+ },
+ )),
dialect.
AlterTable("oauth_users").
- AddColumn(dialect.
- Column(
+ AddColumn(wf_db.
+ NonNullableColumn(
"provider",
- wf_db.ColumnTypeVarChar,
- wf_db.OptionalInt{Set: true, Value: 24}).SetDefault("")),
+ wf_db.ColumnTypeString{
+ MaxChars: 24,
+ HasDefault: true,
+ DefaultVal: "",
+ })),
dialect.
AlterTable("oauth_users").
- AddColumn(dialect.
- Column(
+ AddColumn(wf_db.
+ NonNullableColumn(
"client_id",
- wf_db.ColumnTypeVarChar,
- wf_db.OptionalInt{Set: true, Value: 128}).SetDefault("")),
+ wf_db.ColumnTypeString{
+ MaxChars: 128,
+ HasDefault: true,
+ DefaultVal: "",
+ })),
dialect.
AlterTable("oauth_users").
- AddColumn(dialect.
- Column(
+ AddColumn(wf_db.
+ NonNullableColumn(
"access_token",
- wf_db.ColumnTypeVarChar,
- wf_db.OptionalInt{Set: true, Value: 512}).SetDefault("")),
+ wf_db.ColumnTypeString{
+ MaxChars: 512,
+ HasDefault: true,
+ DefaultVal: "",
+ })),
dialect.CreateUniqueIndex("oauth_users_uk", "oauth_users", "user_id", "provider", "client_id"),
}
if dialect != wf_db.DialectSQLite {
// This updates the length of the `remote_user_id` column. It isn't needed for SQLite databases.
builders = append(builders, dialect.
AlterTable("oauth_users").
ChangeColumn("remote_user_id",
- dialect.
- Column(
+ wf_db.
+ NonNullableColumn(
"remote_user_id",
- wf_db.ColumnTypeVarChar,
- wf_db.OptionalInt{Set: true, Value: 128})))
+ wf_db.ColumnTypeString{
+ MaxChars: 128,
+ })))
}
for _, builder := range builders {
query, err := builder.ToSQL()
if err != nil {
return err
}
if _, err := tx.ExecContext(ctx, query); err != nil {
return err
}
}
return nil
})
}
diff --git a/migrations/v7.go b/migrations/v7.go
index 2056aa0..a9af405 100644
--- a/migrations/v7.go
+++ b/migrations/v7.go
@@ -1,46 +1,48 @@
/*
* Copyright © 2020-2021 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package migrations
import (
"context"
"database/sql"
wf_db "github.com/writefreely/writefreely/db"
)
func oauthAttach(db *datastore) error {
dialect := wf_db.DialectMySQL
if db.driverName == driverSQLite {
dialect = wf_db.DialectSQLite
}
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
builders := []wf_db.SQLBuilder{
dialect.
AlterTable("oauth_client_states").
- AddColumn(dialect.
- Column(
+ AddColumn(wf_db.
+ NullableColumn(
"attach_user_id",
- wf_db.ColumnTypeInteger,
- wf_db.OptionalInt{Set: true, Value: 24}).SetNullable(true)),
+ wf_db.ColumnTypeInt{
+ MaxBytes: 4,
+ MaxDigits: 24,
+ })),
}
for _, builder := range builders {
query, err := builder.ToSQL()
if err != nil {
return err
}
if _, err := tx.ExecContext(ctx, query); err != nil {
return err
}
}
return nil
})
}
diff --git a/migrations/v8.go b/migrations/v8.go
index 36001af..ded61c9 100644
--- a/migrations/v8.go
+++ b/migrations/v8.go
@@ -1,45 +1,45 @@
/*
* Copyright © 2020-2021 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package migrations
import (
"context"
"database/sql"
wf_db "github.com/writefreely/writefreely/db"
)
func oauthInvites(db *datastore) error {
dialect := wf_db.DialectMySQL
if db.driverName == driverSQLite {
dialect = wf_db.DialectSQLite
}
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
builders := []wf_db.SQLBuilder{
dialect.
AlterTable("oauth_client_states").
- AddColumn(dialect.Column("invite_code", wf_db.ColumnTypeChar, wf_db.OptionalInt{
- Set: true,
- Value: 6,
- }).SetNullable(true)),
+ AddColumn(wf_db.NullableColumn("invite_code", wf_db.ColumnTypeString{
+ IsFixedLength: true,
+ MaxChars: 6,
+ })),
}
for _, builder := range builders {
query, err := builder.ToSQL()
if err != nil {
return err
}
if _, err := tx.ExecContext(ctx, query); err != nil {
return err
}
}
return nil
})
}
File Metadata
Details
Attached
Mime Type
text/x-diff
Expires
Sat, Feb 22, 7:13 AM (16 h, 9 m)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3155533
Attached To
rWF WriteFreely
Event Timeline
Log In to Comment