diff --git a/db/alter.go b/db/alter.go index 0a4ffdd..0564d3e 100644 --- a/db/alter.go +++ b/db/alter.go @@ -1,52 +1,52 @@ package db import ( "fmt" "strings" ) type AlterTableSqlBuilder struct { Dialect DialectType Name string Changes []string } func (b *AlterTableSqlBuilder) AddColumn(col *Column) *AlterTableSqlBuilder { - if colVal, err := col.String(); err == nil { + if colVal, err := col.ToSQL(b.Dialect); err == nil { b.Changes = append(b.Changes, fmt.Sprintf("ADD COLUMN %s", colVal)) } return b } func (b *AlterTableSqlBuilder) ChangeColumn(name string, col *Column) *AlterTableSqlBuilder { - if colVal, err := col.String(); err == nil { + if colVal, err := col.ToSQL(b.Dialect); err == nil { b.Changes = append(b.Changes, fmt.Sprintf("CHANGE COLUMN %s %s", name, colVal)) } return b } func (b *AlterTableSqlBuilder) AddUniqueConstraint(name string, columns ...string) *AlterTableSqlBuilder { b.Changes = append(b.Changes, fmt.Sprintf("ADD CONSTRAINT %s UNIQUE (%s)", name, strings.Join(columns, ", "))) return b } func (b *AlterTableSqlBuilder) ToSQL() (string, error) { var str strings.Builder str.WriteString("ALTER TABLE ") str.WriteString(b.Name) str.WriteString(" ") if len(b.Changes) == 0 { return "", fmt.Errorf("no changes provide for table: %s", b.Name) } changeCount := len(b.Changes) for i, thing := range b.Changes { str.WriteString(thing) if i < changeCount-1 { str.WriteString(", ") } } return str.String(), nil } diff --git a/db/builder.go b/db/builder.go new file mode 100644 index 0000000..d0f4fe4 --- /dev/null +++ b/db/builder.go @@ -0,0 +1,15 @@ +/* + * Copyright © 2019-2022 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + +package db + +type SQLBuilder interface { + ToSQL() (string, error) +} diff --git a/db/column.go b/db/column.go new file mode 100644 index 0000000..4e0bb6d --- /dev/null +++ b/db/column.go @@ -0,0 +1,280 @@ +/* + * Copyright © 2019-2022 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + +package db + +import ( + "fmt" + "strings" +) + +type Column struct { + Name string + Type ColumnType + Nullable bool + PrimaryKey bool +} + +func NullableColumn(name string, ty ColumnType) *Column { + return &Column{ + Name: name, + Type: ty, + Nullable: true, + PrimaryKey: false, + } +} + +func NonNullableColumn(name string, ty ColumnType) *Column { + return &Column{ + Name: name, + Type: ty, + Nullable: false, + PrimaryKey: false, + } +} + +func PrimaryKeyColumn(name string, ty ColumnType) *Column { + return &Column{ + Name: name, + Type: ty, + Nullable: false, + PrimaryKey: true, + } +} + +type ColumnType interface { + Name(DialectType) (string, error) + Default(DialectType) (string, error) +} + +type ColumnTypeInt struct { + IsSigned bool + MaxBytes int + MaxDigits int + HasDefault bool + DefaultVal int +} + +type ColumnTypeString struct { + IsFixedLength bool + MaxChars int + HasDefault bool + DefaultVal string +} + +type ColumnDefault int + +type ColumnTypeBool struct { + DefaultVal ColumnDefault +} + +const ( + NoDefault ColumnDefault = iota + DefaultFalse ColumnDefault = iota + DefaultTrue ColumnDefault = iota + DefaultNow ColumnDefault = iota +) + +type ColumnTypeDateTime struct { + DefaultVal ColumnDefault +} + +func (intCol ColumnTypeInt) Name(d DialectType) (string, error) { + switch d { + case DialectSQLite: + return "INTEGER", nil + + case DialectMySQL: + var colName string + switch intCol.MaxBytes { + case 1: + colName = "TINYINT" + case 2: + colName = "SMALLINT" + case 3: + colName = "MEDIUMINT" + case 4: + colName = "INT" + default: + colName = "BIGINT" + } + if intCol.MaxDigits > 0 { + colName = fmt.Sprintf("%s(%d)", colName, intCol.MaxDigits) + } + if !intCol.IsSigned { + colName += " UNSIGNED" + } + return colName, nil + + default: + return "", fmt.Errorf("dialect %d does not support integer columns", d) + } +} + +func (intCol ColumnTypeInt) Default(d DialectType) (string, error) { + switch d { + case DialectSQLite, DialectMySQL: + if intCol.HasDefault { + return fmt.Sprintf("%d", intCol.DefaultVal), nil + } + return "", nil + default: + return "", fmt.Errorf("dialect %d does not support defaulted integer columns", d) + } +} + +func (strCol ColumnTypeString) Name(d DialectType) (string, error) { + switch d { + case DialectSQLite: + return "TEXT", nil + + case DialectMySQL: + if strCol.IsFixedLength { + if strCol.MaxChars > 0 { + return fmt.Sprintf("CHAR(%d)", strCol.MaxChars), nil + } + return "CHAR", nil + } + + if strCol.MaxChars <= 0 { + return "TEXT", nil + } + if strCol.MaxChars < (1 << 16) { + return fmt.Sprintf("VARCHAR(%d)", strCol.MaxChars), nil + } + return "TEXT", nil + + default: + return "", fmt.Errorf("dialect %d does not support string columns", d) + } +} + +func (strCol ColumnTypeString) Default(d DialectType) (string, error) { + switch d { + case DialectSQLite, DialectMySQL: + if strCol.HasDefault { + return EscapeSimple.SQLEscape(d, strCol.DefaultVal) + } + return "", nil + default: + return "", fmt.Errorf("dialect %d does not support defaulted string columns", d) + } +} + +func (boolCol ColumnTypeBool) Name(d DialectType) (string, error) { + switch d { + case DialectSQLite: + return "INTEGER", nil + case DialectMySQL: + return "BOOL", nil + default: + return "", fmt.Errorf("boolean column type not supported for dialect %d", d) + } +} + +func (boolCol ColumnTypeBool) Default(d DialectType) (string, error) { + switch d { + case DialectSQLite, DialectMySQL: + switch boolCol.DefaultVal { + case NoDefault: + return "", nil + case DefaultFalse: + return "0", nil + case DefaultTrue: + return "1", nil + default: + return "", fmt.Errorf("boolean columns cannot default to %d for dialect %d", boolCol.DefaultVal, d) + } + default: + return "", fmt.Errorf("dialect %d does not support defaulted boolean columns", d) + } +} + +func (dateTimeCol ColumnTypeDateTime) Name(d DialectType) (string, error) { + switch d { + case DialectSQLite, DialectMySQL: + return "DATETIME", nil + default: + return "", fmt.Errorf("datetime column type not supported for dialect %d", d) + } +} + +func (dateTimeCol ColumnTypeDateTime) Default(d DialectType) (string, error) { + switch d { + case DialectSQLite, DialectMySQL: + switch dateTimeCol.DefaultVal { + case NoDefault: + return "", nil + case DefaultNow: + switch d { + case DialectSQLite: + return "CURRENT_TIMESTAMP", nil + case DialectMySQL: + return "NOW()", nil + } + } + return "", fmt.Errorf("datetime columns cannot default to %d for dialect %d", dateTimeCol.DefaultVal, d) + default: + return "", fmt.Errorf("dialect %d does not support defaulted datetime columns", d) + } +} + +func (c *Column) SetName(name string) *Column { + c.Name = name + return c +} + +func (c *Column) SetNullable(nullable bool) *Column { + c.Nullable = nullable + return c +} + +func (c *Column) SetPrimaryKey(pk bool) *Column { + c.PrimaryKey = pk + return c +} + +func (c *Column) SetType(t ColumnType) *Column { + c.Type = t + return c +} + +func (c *Column) ToSQL(d DialectType) (string, error) { + var str strings.Builder + + str.WriteString(c.Name) + + str.WriteString(" ") + typeStr, err := c.Type.Name(d) + if err != nil { + return "", err + } + + str.WriteString(typeStr) + + if !c.Nullable { + str.WriteString(" NOT NULL") + } + + defaultStr, err := c.Type.Default(d) + if err != nil { + return "", err + } + if len(defaultStr) > 0 { + str.WriteString(" DEFAULT ") + str.WriteString(defaultStr) + } + + if c.PrimaryKey { + str.WriteString(" PRIMARY KEY") + } + + return str.String(), nil +} diff --git a/db/create.go b/db/create.go index 648f93a..a9fad98 100644 --- a/db/create.go +++ b/db/create.go @@ -1,266 +1,89 @@ /* * Copyright © 2019-2020 A Bunch Tell LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package db import ( "fmt" "strings" ) -type ColumnType int - -type OptionalInt struct { - Set bool - Value int -} - -type OptionalString struct { - Set bool - Value string -} - -type SQLBuilder interface { - ToSQL() (string, error) -} - -type Column struct { - Dialect DialectType - Name string - Nullable bool - Default OptionalString - Type ColumnType - Size OptionalInt - PrimaryKey bool -} - type CreateTableSqlBuilder struct { Dialect DialectType Name string IfNotExists bool ColumnOrder []string Columns map[string]*Column Constraints []string } -const ( - ColumnTypeBool ColumnType = iota - ColumnTypeSmallInt ColumnType = iota - ColumnTypeInteger ColumnType = iota - ColumnTypeChar ColumnType = iota - ColumnTypeVarChar ColumnType = iota - ColumnTypeText ColumnType = iota - ColumnTypeDateTime ColumnType = iota -) - -var _ SQLBuilder = &CreateTableSqlBuilder{} - -var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0} -var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""} - -func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) { - if dialect != DialectMySQL && dialect != DialectSQLite { - return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) - } - switch d { - case ColumnTypeSmallInt: - { - if dialect == DialectSQLite { - return "INTEGER", nil - } - mod := "" - if size.Set { - mod = fmt.Sprintf("(%d)", size.Value) - } - return "SMALLINT" + mod, nil - } - case ColumnTypeInteger: - { - if dialect == DialectSQLite { - return "INTEGER", nil - } - mod := "" - if size.Set { - mod = fmt.Sprintf("(%d)", size.Value) - } - return "INT" + mod, nil - } - case ColumnTypeChar: - { - if dialect == DialectSQLite { - return "TEXT", nil - } - mod := "" - if size.Set { - mod = fmt.Sprintf("(%d)", size.Value) - } - return "CHAR" + mod, nil - } - case ColumnTypeVarChar: - { - if dialect == DialectSQLite { - return "TEXT", nil - } - mod := "" - if size.Set { - mod = fmt.Sprintf("(%d)", size.Value) - } - return "VARCHAR" + mod, nil - } - case ColumnTypeBool: - { - if dialect == DialectSQLite { - return "INTEGER", nil - } - return "TINYINT(1)", nil - } - case ColumnTypeDateTime: - return "DATETIME", nil - case ColumnTypeText: - return "TEXT", nil - } - return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) -} - -func (c *Column) SetName(name string) *Column { - c.Name = name - return c -} - -func (c *Column) SetNullable(nullable bool) *Column { - c.Nullable = nullable - return c -} - -func (c *Column) SetPrimaryKey(pk bool) *Column { - c.PrimaryKey = pk - return c -} - -func (c *Column) SetDefault(value string) *Column { - c.Default = OptionalString{Set: true, Value: value} - return c -} - -func (c *Column) SetDefaultCurrentTimestamp() *Column { - def := "NOW()" - if c.Dialect == DialectSQLite { - def = "CURRENT_TIMESTAMP" - } - c.Default = OptionalString{Set: true, Value: def} - return c -} - -func (c *Column) SetType(t ColumnType) *Column { - c.Type = t - return c -} - -func (c *Column) SetSize(size int) *Column { - c.Size = OptionalInt{Set: true, Value: size} - return c -} - -func (c *Column) String() (string, error) { - var str strings.Builder - - str.WriteString(c.Name) - - str.WriteString(" ") - typeStr, err := c.Type.Format(c.Dialect, c.Size) - if err != nil { - return "", err - } - - str.WriteString(typeStr) - - if !c.Nullable { - str.WriteString(" NOT NULL") - } - - if c.Default.Set { - str.WriteString(" DEFAULT ") - val := c.Default.Value - if val == "" { - val = "''" - } - str.WriteString(val) - } - - if c.PrimaryKey { - str.WriteString(" PRIMARY KEY") - } - - return str.String(), nil -} - func (b *CreateTableSqlBuilder) Column(column *Column) *CreateTableSqlBuilder { if b.Columns == nil { b.Columns = make(map[string]*Column) } b.Columns[column.Name] = column b.ColumnOrder = append(b.ColumnOrder, column.Name) return b } func (b *CreateTableSqlBuilder) UniqueConstraint(columns ...string) *CreateTableSqlBuilder { for _, column := range columns { if _, ok := b.Columns[column]; !ok { // This fails silently. return b } } b.Constraints = append(b.Constraints, fmt.Sprintf("UNIQUE(%s)", strings.Join(columns, ","))) return b } func (b *CreateTableSqlBuilder) SetIfNotExists(ine bool) *CreateTableSqlBuilder { b.IfNotExists = ine return b } func (b *CreateTableSqlBuilder) ToSQL() (string, error) { var str strings.Builder str.WriteString("CREATE TABLE ") if b.IfNotExists { str.WriteString("IF NOT EXISTS ") } str.WriteString(b.Name) var things []string for _, columnName := range b.ColumnOrder { column, ok := b.Columns[columnName] if !ok { return "", fmt.Errorf("column not found: %s", columnName) } - columnStr, err := column.String() + columnStr, err := column.ToSQL(b.Dialect) if err != nil { return "", err } things = append(things, columnStr) } for _, constraint := range b.Constraints { things = append(things, constraint) } if thingLen := len(things); thingLen > 0 { str.WriteString(" ( ") for i, thing := range things { str.WriteString(thing) if i < thingLen-1 { str.WriteString(", ") } } str.WriteString(" )") } return str.String(), nil } diff --git a/db/dialect.go b/db/dialect.go index 4251465..ee1eb0f 100644 --- a/db/dialect.go +++ b/db/dialect.go @@ -1,76 +1,50 @@ package db import "fmt" type DialectType int const ( DialectSQLite DialectType = iota DialectMySQL DialectType = iota ) -func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column { +func (d DialectType) IsKnown() bool { switch d { - case DialectSQLite: - return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size} - case DialectMySQL: - return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size} + case DialectSQLite, DialectMySQL: + return true default: - panic(fmt.Sprintf("unexpected dialect: %d", d)) + return false } } -func (d DialectType) Table(name string) *CreateTableSqlBuilder { - switch d { - case DialectSQLite: - return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name} - case DialectMySQL: - return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name} - default: +func (d DialectType) AssertKnown() { + if !d.IsKnown() { panic(fmt.Sprintf("unexpected dialect: %d", d)) } } +func (d DialectType) Table(name string) *CreateTableSqlBuilder { + d.AssertKnown() + return &CreateTableSqlBuilder{Dialect: d, Name: name} +} + func (d DialectType) AlterTable(name string) *AlterTableSqlBuilder { - switch d { - case DialectSQLite: - return &AlterTableSqlBuilder{Dialect: DialectSQLite, Name: name} - case DialectMySQL: - return &AlterTableSqlBuilder{Dialect: DialectMySQL, Name: name} - default: - panic(fmt.Sprintf("unexpected dialect: %d", d)) - } + d.AssertKnown() + return &AlterTableSqlBuilder{Dialect: d, Name: name} } func (d DialectType) CreateUniqueIndex(name, table string, columns ...string) *CreateIndexSqlBuilder { - switch d { - case DialectSQLite: - return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: true, Columns: columns} - case DialectMySQL: - return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: true, Columns: columns} - default: - panic(fmt.Sprintf("unexpected dialect: %d", d)) - } + d.AssertKnown() + return &CreateIndexSqlBuilder{Dialect: d, Name: name, Table: table, Unique: true, Columns: columns} } func (d DialectType) CreateIndex(name, table string, columns ...string) *CreateIndexSqlBuilder { - switch d { - case DialectSQLite: - return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: false, Columns: columns} - case DialectMySQL: - return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: false, Columns: columns} - default: - panic(fmt.Sprintf("unexpected dialect: %d", d)) - } + d.AssertKnown() + return &CreateIndexSqlBuilder{Dialect: d, Name: name, Table: table, Unique: false, Columns: columns} } func (d DialectType) DropIndex(name, table string) *DropIndexSqlBuilder { - switch d { - case DialectSQLite: - return &DropIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table} - case DialectMySQL: - return &DropIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table} - default: - panic(fmt.Sprintf("unexpected dialect: %d", d)) - } + d.AssertKnown() + return &DropIndexSqlBuilder{Dialect: d, Name: name, Table: table} } diff --git a/db/escape.go b/db/escape.go new file mode 100644 index 0000000..53b8ef3 --- /dev/null +++ b/db/escape.go @@ -0,0 +1,63 @@ +/* + * Copyright © 2019-2022 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + +package db + +import ( + "strings" +) + +type EscapeContext int + +const ( + EscapeSimple EscapeContext = iota +) + +func (_ EscapeContext) SQLEscape(d DialectType, s string) (string, error) { + builder := strings.Builder{} + switch d { + case DialectSQLite: + builder.WriteRune('\'') + for _, c := range s { + if c == '\'' { + builder.WriteString("''") + } else { + builder.WriteRune(c) + } + } + builder.WriteRune('\'') + case DialectMySQL: + builder.WriteRune('\'') + for _, c := range s { + switch c { + case 0: + builder.WriteString("\\0") + case '\'': + builder.WriteString("\\'") + case '"': + builder.WriteString("\\\"") + case '\b': + builder.WriteString("\\b") + case '\n': + builder.WriteString("\\n") + case '\r': + builder.WriteString("\\r") + case '\t': + builder.WriteString("\\t") + case '\\': + builder.WriteString("\\\\") + default: + builder.WriteRune(c) + } + } + builder.WriteRune('\'') + } + return builder.String(), nil +} diff --git a/migrations/v4.go b/migrations/v4.go index c69dce1..25533dd 100644 --- a/migrations/v4.go +++ b/migrations/v4.go @@ -1,54 +1,54 @@ /* * Copyright © 2019-2021 A Bunch Tell LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package migrations import ( "context" "database/sql" wf_db "github.com/writefreely/writefreely/db" ) func oauth(db *datastore) error { dialect := wf_db.DialectMySQL if db.driverName == driverSQLite { dialect = wf_db.DialectSQLite } return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { createTableUsersOauth, err := dialect. Table("oauth_users"). SetIfNotExists(false). - Column(dialect.Column("user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)). - Column(dialect.Column("remote_user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)). + Column(wf_db.NonNullableColumn("user_id", wf_db.ColumnTypeInt{MaxBytes: 4})). + Column(wf_db.NonNullableColumn("remote_user_id", wf_db.ColumnTypeInt{MaxBytes: 4})). ToSQL() if err != nil { return err } createTableOauthClientState, err := dialect. Table("oauth_client_states"). SetIfNotExists(false). - Column(dialect.Column("state", wf_db.ColumnTypeVarChar, wf_db.OptionalInt{Set: true, Value: 255})). - Column(dialect.Column("used", wf_db.ColumnTypeBool, wf_db.UnsetSize)). - Column(dialect.Column("created_at", wf_db.ColumnTypeDateTime, wf_db.UnsetSize).SetDefaultCurrentTimestamp()). + Column(wf_db.NonNullableColumn("state", wf_db.ColumnTypeString{MaxChars: 255})). + Column(wf_db.NonNullableColumn("used", wf_db.ColumnTypeBool{})). + Column(wf_db.NonNullableColumn("created_at", wf_db.ColumnTypeDateTime{DefaultVal: wf_db.DefaultNow})). UniqueConstraint("state"). ToSQL() if err != nil { return err } for _, table := range []string{createTableUsersOauth, createTableOauthClientState} { if _, err := tx.ExecContext(ctx, table); err != nil { return err } } return nil }) } diff --git a/migrations/v5.go b/migrations/v5.go index 1fe3e30..01ad2a7 100644 --- a/migrations/v5.go +++ b/migrations/v5.go @@ -1,88 +1,105 @@ /* * Copyright © 2019-2021 A Bunch Tell LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package migrations import ( "context" "database/sql" wf_db "github.com/writefreely/writefreely/db" ) func oauthSlack(db *datastore) error { dialect := wf_db.DialectMySQL if db.driverName == driverSQLite { dialect = wf_db.DialectSQLite } return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { builders := []wf_db.SQLBuilder{ dialect. AlterTable("oauth_client_states"). - AddColumn(dialect. - Column( + AddColumn(wf_db. + NonNullableColumn( "provider", - wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 24}).SetDefault("")), + wf_db.ColumnTypeString{ + MaxChars: 24, + HasDefault: true, + DefaultVal: "", + })), dialect. AlterTable("oauth_client_states"). - AddColumn(dialect. - Column( + AddColumn(wf_db. + NonNullableColumn( "client_id", - wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 128}).SetDefault("")), + wf_db.ColumnTypeString{ + MaxChars: 128, + HasDefault: true, + DefaultVal: "", + }, + )), dialect. AlterTable("oauth_users"). - AddColumn(dialect. - Column( + AddColumn(wf_db. + NonNullableColumn( "provider", - wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 24}).SetDefault("")), + wf_db.ColumnTypeString{ + MaxChars: 24, + HasDefault: true, + DefaultVal: "", + })), dialect. AlterTable("oauth_users"). - AddColumn(dialect. - Column( + AddColumn(wf_db. + NonNullableColumn( "client_id", - wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 128}).SetDefault("")), + wf_db.ColumnTypeString{ + MaxChars: 128, + HasDefault: true, + DefaultVal: "", + })), dialect. AlterTable("oauth_users"). - AddColumn(dialect. - Column( + AddColumn(wf_db. + NonNullableColumn( "access_token", - wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 512}).SetDefault("")), + wf_db.ColumnTypeString{ + MaxChars: 512, + HasDefault: true, + DefaultVal: "", + })), dialect.CreateUniqueIndex("oauth_users_uk", "oauth_users", "user_id", "provider", "client_id"), } if dialect != wf_db.DialectSQLite { // This updates the length of the `remote_user_id` column. It isn't needed for SQLite databases. builders = append(builders, dialect. AlterTable("oauth_users"). ChangeColumn("remote_user_id", - dialect. - Column( + wf_db. + NonNullableColumn( "remote_user_id", - wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 128}))) + wf_db.ColumnTypeString{ + MaxChars: 128, + }))) } for _, builder := range builders { query, err := builder.ToSQL() if err != nil { return err } if _, err := tx.ExecContext(ctx, query); err != nil { return err } } return nil }) } diff --git a/migrations/v7.go b/migrations/v7.go index 5737b21..7eb8910 100644 --- a/migrations/v7.go +++ b/migrations/v7.go @@ -1,46 +1,48 @@ /* * Copyright © 2020-2021 A Bunch Tell LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package migrations import ( "context" "database/sql" wf_db "github.com/writefreely/writefreely/db" ) func oauthAttach(db *datastore) error { dialect := wf_db.DialectMySQL if db.driverName == driverSQLite { dialect = wf_db.DialectSQLite } return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { builders := []wf_db.SQLBuilder{ dialect. AlterTable("oauth_client_states"). - AddColumn(dialect. - Column( + AddColumn(wf_db. + NullableColumn( "attach_user_id", - wf_db.ColumnTypeInteger, - wf_db.OptionalInt{Set: true, Value: 24}).SetNullable(true)), + wf_db.ColumnTypeInt{ + MaxBytes: 4, + MaxDigits: 24, + })), } for _, builder := range builders { query, err := builder.ToSQL() if err != nil { return err } if _, err := tx.ExecContext(ctx, query); err != nil { return err } } return nil }) } diff --git a/migrations/v8.go b/migrations/v8.go index 28af523..00a95ca 100644 --- a/migrations/v8.go +++ b/migrations/v8.go @@ -1,45 +1,45 @@ /* * Copyright © 2020-2021 A Bunch Tell LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package migrations import ( "context" "database/sql" wf_db "github.com/writefreely/writefreely/db" ) func oauthInvites(db *datastore) error { dialect := wf_db.DialectMySQL if db.driverName == driverSQLite { dialect = wf_db.DialectSQLite } return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { builders := []wf_db.SQLBuilder{ dialect. AlterTable("oauth_client_states"). - AddColumn(dialect.Column("invite_code", wf_db.ColumnTypeChar, wf_db.OptionalInt{ - Set: true, - Value: 6, - }).SetNullable(true)), + AddColumn(wf_db.NullableColumn("invite_code", wf_db.ColumnTypeString{ + IsFixedLength: true, + MaxChars: 6, + })), } for _, builder := range builders { query, err := builder.ToSQL() if err != nil { return err } if _, err := tx.ExecContext(ctx, query); err != nil { return err } } return nil }) }