diff --git a/db/create.go b/db/create.go index 870bfd7..c98370f 100644 --- a/db/create.go +++ b/db/create.go @@ -1,284 +1,276 @@ /* * Copyright © 2019-2020 Musing Studio LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package db import ( "fmt" "strings" ) type ColumnType int type OptionalInt struct { Set bool Value int } type OptionalString struct { Set bool Value string } type SQLBuilder interface { ToSQL() (string, error) } type Column struct { Dialect DialectType Name string Nullable bool Default OptionalString Type ColumnType Size OptionalInt PrimaryKey bool } type CreateTableSqlBuilder struct { Dialect DialectType Name string IfNotExists bool ColumnOrder []string Columns map[string]*Column Constraints []string } const ( ColumnTypeBool ColumnType = iota ColumnTypeSmallInt ColumnType = iota ColumnTypeInteger ColumnType = iota ColumnTypeChar ColumnType = iota ColumnTypeVarChar ColumnType = iota ColumnTypeText ColumnType = iota ColumnTypeDateTime ColumnType = iota ) var _ SQLBuilder = &CreateTableSqlBuilder{} var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0} var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""} func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) { if dialect != DialectMySQL && dialect != DialectPostgres && dialect != DialectSQLite { return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) } mod := "" if size.Set { mod = fmt.Sprintf("(%d)", size.Value) } switch d { case ColumnTypeSmallInt: switch dialect { case DialectSQLite: return "INTEGER", nil case DialectMySQL: return "SMALLINT" + mod, nil case DialectPostgres: return "SMALLINT", nil } case ColumnTypeInteger: switch dialect { - case DialectSQLite: + case DialectSQLite, DialectPostgres: return "INTEGER", nil case DialectMySQL: return "INT" + mod, nil - case DialectPostgres: - return "INT", nil } case ColumnTypeChar: switch dialect { case DialectSQLite: return "TEXT", nil case DialectMySQL: return "CHAR" + mod, nil case DialectPostgres: return "CHAR" + mod, nil } case ColumnTypeVarChar: switch dialect { case DialectSQLite: return "TEXT", nil case DialectMySQL: return "VARCHAR" + mod, nil case DialectPostgres: return "VARCHAR" + mod, nil } case ColumnTypeBool: switch dialect { case DialectSQLite: return "INTEGER", nil case DialectMySQL: return "TINYINT(1)", nil case DialectPostgres: return "BOOLEAN", nil } case ColumnTypeDateTime: switch dialect { case DialectSQLite: return "DATETIME", nil case DialectMySQL: return "DATETIME", nil case DialectPostgres: return "TIMESTAMP", nil } case ColumnTypeText: switch dialect { - case DialectSQLite: - return "TEXT", nil - case DialectMySQL: - return "TEXT", nil - case DialectPostgres: + case DialectSQLite, DialectMySQL, DialectPostgres: return "TEXT", nil } } return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) } func (c *Column) SetName(name string) *Column { c.Name = name return c } func (c *Column) SetNullable(nullable bool) *Column { c.Nullable = nullable return c } func (c *Column) SetPrimaryKey(pk bool) *Column { c.PrimaryKey = pk return c } func (c *Column) SetDefault(value string) *Column { c.Default = OptionalString{Set: true, Value: value} return c } func (c *Column) SetDefaultCurrentTimestamp() *Column { def := "" switch c.Dialect { - case DialectSQLite: + case DialectSQLite, DialectPostgres: def = "CURRENT_TIMESTAMP" case DialectMySQL: def = "NOW()" - case DialectPostgres: - def = "NOW()" } c.Default = OptionalString{Set: true, Value: def} return c } func (c *Column) SetType(t ColumnType) *Column { c.Type = t return c } func (c *Column) SetSize(size int) *Column { c.Size = OptionalInt{Set: true, Value: size} return c } func (c *Column) String() (string, error) { var str strings.Builder str.WriteString(c.Name) str.WriteString(" ") typeStr, err := c.Type.Format(c.Dialect, c.Size) if err != nil { return "", err } str.WriteString(typeStr) if !c.Nullable { str.WriteString(" NOT NULL") } if c.Default.Set { str.WriteString(" DEFAULT ") val := c.Default.Value if val == "" { val = "''" } str.WriteString(val) } if c.PrimaryKey { str.WriteString(" PRIMARY KEY") } return str.String(), nil } func (b *CreateTableSqlBuilder) Column(column *Column) *CreateTableSqlBuilder { if b.Columns == nil { b.Columns = make(map[string]*Column) } b.Columns[column.Name] = column b.ColumnOrder = append(b.ColumnOrder, column.Name) return b } func (b *CreateTableSqlBuilder) UniqueConstraint(columns ...string) *CreateTableSqlBuilder { for _, column := range columns { if _, ok := b.Columns[column]; !ok { // This fails silently. return b } } b.Constraints = append(b.Constraints, fmt.Sprintf("UNIQUE(%s)", strings.Join(columns, ","))) return b } func (b *CreateTableSqlBuilder) SetIfNotExists(ine bool) *CreateTableSqlBuilder { b.IfNotExists = ine return b } func (b *CreateTableSqlBuilder) ToSQL() (string, error) { var str strings.Builder str.WriteString("CREATE TABLE ") if b.IfNotExists { str.WriteString("IF NOT EXISTS ") } str.WriteString(b.Name) var things []string for _, columnName := range b.ColumnOrder { column, ok := b.Columns[columnName] if !ok { return "", fmt.Errorf("column not found: %s", columnName) } columnStr, err := column.String() if err != nil { return "", err } things = append(things, columnStr) } things = append(things, b.Constraints...) if thingLen := len(things); thingLen > 0 { str.WriteString(" ( ") for i, thing := range things { str.WriteString(thing) if i < thingLen-1 { str.WriteString(", ") } } str.WriteString(" )") } return str.String(), nil } diff --git a/migrations/drivers.go b/migrations/drivers.go index be99439..539e7f1 100644 --- a/migrations/drivers.go +++ b/migrations/drivers.go @@ -1,253 +1,245 @@ /* * Copyright © 2019 Musing Studio LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package migrations import ( "fmt" ) // TODO: use now() from writefreely pkg func (db *datastore) now() string { switch db.driverName { case driverSQLite: return "strftime('%Y-%m-%d %H:%M:%S','now')" - case driverMySQL: - return "NOW()" - case driverPostgres: + case driverMySQL, driverPostgres: return "NOW()" } return "" // placeholder } func (db *datastore) typeInt() string { switch db.driverName { - case driverSQLite: + case driverSQLite, driverMySQL, driverPostgres: return "INTEGER" - case driverMySQL: - return "INT" - case driverPostgres: - return "INT" } return "" // placeholder } func (db *datastore) typeSmallInt() string { switch db.driverName { case driverSQLite: return "INTEGER" - case driverMySQL: - return "SMALLINT" - case driverPostgres: + case driverMySQL, driverPostgres: return "SMALLINT" } return "" // placeholder } func (db *datastore) typeTinyInt() string { switch db.driverName { case driverSQLite: return "INTEGER" case driverMySQL: return "TINYINT" case driverPostgres: return "SMALLINT" } return "" // placeholder } func (db *datastore) typeText() string { return "TEXT" } func (db *datastore) typeChar(l int) string { switch db.driverName { case driverSQLite: return "TEXT" case driverMySQL: return fmt.Sprintf("CHAR(%d)", l) case driverPostgres: return fmt.Sprintf("CHAR(%d)", l) } return "" // placeholder } func (db *datastore) typeVarChar(l int) string { switch db.driverName { case driverSQLite: return "TEXT" case driverMySQL: return fmt.Sprintf("VARCHAR(%d)", l) case driverPostgres: return fmt.Sprintf("VARCHAR(%d)", l) } return "" // placeholder } func (db *datastore) typeVarBinary(l int) string { switch db.driverName { case driverSQLite: return "BLOB" case driverMySQL: return fmt.Sprintf("VARBINARY(%d)", l) case driverPostgres: return "BYTEA" } return "" // placeholder } func (db *datastore) typeBool() string { switch db.driverName { case driverSQLite: return "INTEGER" case driverMySQL: return "TINYINT(1)" case driverPostgres: return "BOOLEAN" } return "" // placeholder } func (db *datastore) typeDateTime() string { switch db.driverName { case driverSQLite: return "DATETIME" case driverMySQL: return "DATETIME" case driverPostgres: return "TIMESTAMP" } return "" // placeholder } func (db *datastore) typeIntPrimaryKey() string { switch db.driverName { case driverSQLite: // From docs: "In SQLite, a column with type INTEGER PRIMARY KEY is an alias for the ROWID (except in WITHOUT // ROWID tables) which is always a 64-bit signed integer." return "INTEGER PRIMARY KEY" case driverMySQL: return "INT AUTO_INCREMENT PRIMARY KEY" case driverPostgres: return "INTEGER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY" } return "" // placeholder } func (db *datastore) collateMultiByte() string { switch db.driverName { case driverSQLite: return "" case driverMySQL: return " COLLATE utf8_bin" case driverPostgres: return "" } return "" // placeholder } func (db *datastore) engine() string { switch db.driverName { case driverSQLite: return "" case driverMySQL: return " ENGINE = InnoDB" case driverPostgres: return "" } return "" // placeholder } func (db *datastore) after(colName string) string { switch db.driverName { case driverSQLite: return "" case driverMySQL: return fmt.Sprintf(" AFTER %s", colName) case driverPostgres: return "" } return "" // placeholder } func (db *datastore) boolTrue() string { switch db.driverName { case driverSQLite: return "1" case driverMySQL: return "1" case driverPostgres: return "TRUE" } return "" // placeholder } func (db *datastore) boolFalse() string { switch db.driverName { case driverSQLite: return "0" case driverMySQL: return "0" case driverPostgres: return "FALSE" } return "" // placeholder } func (db *datastore) limit(offset int, size int) string { switch db.driverName { case driverSQLite, driverMySQL: return fmt.Sprintf(" LIMIT %d, %d", offset, size) case driverPostgres: return fmt.Sprintf(" LIMIT %d OFFSET %d", size, offset) } return "" // placeholder } func (db *datastore) QueryWrap(q string) string { if db.driverName != driverPostgres { return q } output := "" escape := false ctr := 0 for i := range len(q) { if q[i] == '\'' || q[i] == '`' { escape = !escape } if q[i] == '?' && !escape { ctr += 1 output += fmt.Sprintf("$%d", ctr) } else { output += string(q[i]) } } return output } \ No newline at end of file