diff --git a/db/create.go b/db/create.go index c384778..648f93a 100644 --- a/db/create.go +++ b/db/create.go @@ -1,244 +1,266 @@ +/* + * Copyright © 2019-2020 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + package db import ( "fmt" "strings" ) type ColumnType int type OptionalInt struct { Set bool Value int } type OptionalString struct { Set bool Value string } type SQLBuilder interface { ToSQL() (string, error) } type Column struct { Dialect DialectType Name string Nullable bool Default OptionalString Type ColumnType Size OptionalInt PrimaryKey bool } type CreateTableSqlBuilder struct { Dialect DialectType Name string IfNotExists bool ColumnOrder []string Columns map[string]*Column Constraints []string } const ( ColumnTypeBool ColumnType = iota ColumnTypeSmallInt ColumnType = iota ColumnTypeInteger ColumnType = iota ColumnTypeChar ColumnType = iota ColumnTypeVarChar ColumnType = iota ColumnTypeText ColumnType = iota ColumnTypeDateTime ColumnType = iota ) var _ SQLBuilder = &CreateTableSqlBuilder{} var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0} var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""} func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) { if dialect != DialectMySQL && dialect != DialectSQLite { return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) } switch d { case ColumnTypeSmallInt: { if dialect == DialectSQLite { return "INTEGER", nil } mod := "" if size.Set { mod = fmt.Sprintf("(%d)", size.Value) } return "SMALLINT" + mod, nil } case ColumnTypeInteger: { if dialect == DialectSQLite { return "INTEGER", nil } mod := "" if size.Set { mod = fmt.Sprintf("(%d)", size.Value) } return "INT" + mod, nil } case ColumnTypeChar: { if dialect == DialectSQLite { return "TEXT", nil } mod := "" if size.Set { mod = fmt.Sprintf("(%d)", size.Value) } return "CHAR" + mod, nil } case ColumnTypeVarChar: { if dialect == DialectSQLite { return "TEXT", nil } mod := "" if size.Set { mod = fmt.Sprintf("(%d)", size.Value) } return "VARCHAR" + mod, nil } case ColumnTypeBool: { if dialect == DialectSQLite { return "INTEGER", nil } return "TINYINT(1)", nil } case ColumnTypeDateTime: return "DATETIME", nil case ColumnTypeText: return "TEXT", nil } return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) } func (c *Column) SetName(name string) *Column { c.Name = name return c } func (c *Column) SetNullable(nullable bool) *Column { c.Nullable = nullable return c } func (c *Column) SetPrimaryKey(pk bool) *Column { c.PrimaryKey = pk return c } func (c *Column) SetDefault(value string) *Column { c.Default = OptionalString{Set: true, Value: value} return c } +func (c *Column) SetDefaultCurrentTimestamp() *Column { + def := "NOW()" + if c.Dialect == DialectSQLite { + def = "CURRENT_TIMESTAMP" + } + c.Default = OptionalString{Set: true, Value: def} + return c +} + func (c *Column) SetType(t ColumnType) *Column { c.Type = t return c } func (c *Column) SetSize(size int) *Column { c.Size = OptionalInt{Set: true, Value: size} return c } func (c *Column) String() (string, error) { var str strings.Builder str.WriteString(c.Name) str.WriteString(" ") typeStr, err := c.Type.Format(c.Dialect, c.Size) if err != nil { return "", err } str.WriteString(typeStr) if !c.Nullable { str.WriteString(" NOT NULL") } if c.Default.Set { str.WriteString(" DEFAULT ") - str.WriteString(c.Default.Value) + val := c.Default.Value + if val == "" { + val = "''" + } + str.WriteString(val) } if c.PrimaryKey { str.WriteString(" PRIMARY KEY") } return str.String(), nil } func (b *CreateTableSqlBuilder) Column(column *Column) *CreateTableSqlBuilder { if b.Columns == nil { b.Columns = make(map[string]*Column) } b.Columns[column.Name] = column b.ColumnOrder = append(b.ColumnOrder, column.Name) return b } func (b *CreateTableSqlBuilder) UniqueConstraint(columns ...string) *CreateTableSqlBuilder { for _, column := range columns { if _, ok := b.Columns[column]; !ok { // This fails silently. return b } } b.Constraints = append(b.Constraints, fmt.Sprintf("UNIQUE(%s)", strings.Join(columns, ","))) return b } func (b *CreateTableSqlBuilder) SetIfNotExists(ine bool) *CreateTableSqlBuilder { b.IfNotExists = ine return b } func (b *CreateTableSqlBuilder) ToSQL() (string, error) { var str strings.Builder str.WriteString("CREATE TABLE ") if b.IfNotExists { str.WriteString("IF NOT EXISTS ") } str.WriteString(b.Name) var things []string for _, columnName := range b.ColumnOrder { column, ok := b.Columns[columnName] if !ok { return "", fmt.Errorf("column not found: %s", columnName) } columnStr, err := column.String() if err != nil { return "", err } things = append(things, columnStr) } for _, constraint := range b.Constraints { things = append(things, constraint) } if thingLen := len(things); thingLen > 0 { str.WriteString(" ( ") for i, thing := range things { str.WriteString(thing) if i < thingLen-1 { str.WriteString(", ") } } str.WriteString(" )") } return str.String(), nil } - diff --git a/migrations/v4.go b/migrations/v4.go index c075dd8..7d73f96 100644 --- a/migrations/v4.go +++ b/migrations/v4.go @@ -1,46 +1,54 @@ +/* + * Copyright © 2019-2020 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + package migrations import ( "context" "database/sql" wf_db "github.com/writeas/writefreely/db" ) func oauth(db *datastore) error { dialect := wf_db.DialectMySQL if db.driverName == driverSQLite { dialect = wf_db.DialectSQLite } return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { createTableUsersOauth, err := dialect. Table("oauth_users"). - SetIfNotExists(true). + SetIfNotExists(false). Column(dialect.Column("user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)). Column(dialect.Column("remote_user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)). - UniqueConstraint("user_id"). - UniqueConstraint("remote_user_id"). ToSQL() if err != nil { return err } createTableOauthClientState, err := dialect. Table("oauth_client_states"). - SetIfNotExists(true). + SetIfNotExists(false). Column(dialect.Column("state", wf_db.ColumnTypeVarChar, wf_db.OptionalInt{Set: true, Value: 255})). Column(dialect.Column("used", wf_db.ColumnTypeBool, wf_db.UnsetSize)). - Column(dialect.Column("created_at", wf_db.ColumnTypeDateTime, wf_db.UnsetSize).SetDefault("NOW()")). + Column(dialect.Column("created_at", wf_db.ColumnTypeDateTime, wf_db.UnsetSize).SetDefaultCurrentTimestamp()). UniqueConstraint("state"). ToSQL() if err != nil { return err } for _, table := range []string{createTableUsersOauth, createTableOauthClientState} { if _, err := tx.ExecContext(ctx, table); err != nil { return err } } return nil }) } diff --git a/migrations/v5.go b/migrations/v5.go index 94e3944..f93d067 100644 --- a/migrations/v5.go +++ b/migrations/v5.go @@ -1,67 +1,88 @@ +/* + * Copyright © 2019-2020 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + package migrations import ( "context" "database/sql" wf_db "github.com/writeas/writefreely/db" ) func oauthSlack(db *datastore) error { dialect := wf_db.DialectMySQL if db.driverName == driverSQLite { dialect = wf_db.DialectSQLite } return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { builders := []wf_db.SQLBuilder{ dialect. AlterTable("oauth_client_states"). AddColumn(dialect. Column( "provider", wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 24,})). + wf_db.OptionalInt{Set: true, Value: 24}).SetDefault("")), + dialect. + AlterTable("oauth_client_states"). AddColumn(dialect. Column( "client_id", wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 128,})), + wf_db.OptionalInt{Set: true, Value: 128}).SetDefault("")), dialect. AlterTable("oauth_users"). - ChangeColumn("remote_user_id", - dialect. - Column( - "remote_user_id", - wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 128,})). AddColumn(dialect. Column( "provider", wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 24,})). + wf_db.OptionalInt{Set: true, Value: 24}).SetDefault("")), + dialect. + AlterTable("oauth_users"). AddColumn(dialect. Column( "client_id", wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 128,})). + wf_db.OptionalInt{Set: true, Value: 128}).SetDefault("")), + dialect. + AlterTable("oauth_users"). AddColumn(dialect. Column( "access_token", wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 512,})), - dialect.DropIndex("remote_user_id", "oauth_users"), - dialect.DropIndex("user_id", "oauth_users"), - dialect.CreateUniqueIndex("oauth_users", "oauth_users", "user_id", "provider", "client_id"), + wf_db.OptionalInt{Set: true, Value: 512}).SetDefault("")), + dialect.CreateUniqueIndex("oauth_users_uk", "oauth_users", "user_id", "provider", "client_id"), + } + + if dialect != wf_db.DialectSQLite { + // This updates the length of the `remote_user_id` column. It isn't needed for SQLite databases. + builders = append(builders, dialect. + AlterTable("oauth_users"). + ChangeColumn("remote_user_id", + dialect. + Column( + "remote_user_id", + wf_db.ColumnTypeVarChar, + wf_db.OptionalInt{Set: true, Value: 128}))) } + for _, builder := range builders { query, err := builder.ToSQL() if err != nil { return err } if _, err := tx.ExecContext(ctx, query); err != nil { return err } } return nil }) }