diff --git a/db/alter.go b/db/alter.go index 0a4ffdd..14e92ac 100644 --- a/db/alter.go +++ b/db/alter.go @@ -1,52 +1,73 @@ package db import ( "fmt" "strings" ) +type alterTableQueryType int + +const ( + alter alterTableQueryType = iota + update alterTableQueryType = iota +) + +type change struct { + queryType alterTableQueryType + queryString string +} + type AlterTableSqlBuilder struct { Dialect DialectType Name string - Changes []string + Changes []change } func (b *AlterTableSqlBuilder) AddColumn(col *Column) *AlterTableSqlBuilder { if colVal, err := col.String(); err == nil { - b.Changes = append(b.Changes, fmt.Sprintf("ADD COLUMN %s", colVal)) + b.Changes = append(b.Changes, change{queryType: alter, queryString: fmt.Sprintf("ADD COLUMN %s", colVal)}) } return b } func (b *AlterTableSqlBuilder) ChangeColumn(name string, col *Column) *AlterTableSqlBuilder { if colVal, err := col.String(); err == nil { - b.Changes = append(b.Changes, fmt.Sprintf("CHANGE COLUMN %s %s", name, colVal)) + b.Changes = append(b.Changes, change{queryType: alter, queryString: fmt.Sprintf("RENAME COLUMN %s TO %s_old", name, name)}) + b.Changes = append(b.Changes, change{queryType: alter, queryString: fmt.Sprintf("ADD COLUMN %s", colVal)}) + b.Changes = append(b.Changes, change{queryType: update, queryString: fmt.Sprintf("SET %s = %s_old", col.Name, name)}) + b.Changes = append(b.Changes, change{queryType: alter, queryString: fmt.Sprintf("DROP COLUMN %s", name)}) } return b } func (b *AlterTableSqlBuilder) AddUniqueConstraint(name string, columns ...string) *AlterTableSqlBuilder { - b.Changes = append(b.Changes, fmt.Sprintf("ADD CONSTRAINT %s UNIQUE (%s)", name, strings.Join(columns, ", "))) + b.Changes = append(b.Changes, change{ + queryType: alter, + queryString: fmt.Sprintf("ADD CONSTRAINT %s UNIQUE (%s)", name, strings.Join(columns, ", ")), + }) return b } func (b *AlterTableSqlBuilder) ToSQL() (string, error) { - var str strings.Builder - - str.WriteString("ALTER TABLE ") - str.WriteString(b.Name) - str.WriteString(" ") + str := "" + changeCount := len(b.Changes) - if len(b.Changes) == 0 { + if changeCount == 0 { return "", fmt.Errorf("no changes provide for table: %s", b.Name) } - changeCount := len(b.Changes) - for i, thing := range b.Changes { - str.WriteString(thing) - if i < changeCount-1 { - str.WriteString(", ") + + for i, change := range b.Changes { + switch change.queryType { + case alter: + str += fmt.Sprintf("ALTER TABLE %s ", b.Name) + case update: + str += fmt.Sprintf("UPDATE %s ", b.Name) + } + str += change.queryString + if i < changeCount - 1 { + str += "; " } } - return str.String(), nil + return str, nil } diff --git a/db/create.go b/db/create.go index 1e9e679..870bfd7 100644 --- a/db/create.go +++ b/db/create.go @@ -1,263 +1,284 @@ /* * Copyright © 2019-2020 Musing Studio LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package db import ( "fmt" "strings" ) type ColumnType int type OptionalInt struct { Set bool Value int } type OptionalString struct { Set bool Value string } type SQLBuilder interface { ToSQL() (string, error) } type Column struct { Dialect DialectType Name string Nullable bool Default OptionalString Type ColumnType Size OptionalInt PrimaryKey bool } type CreateTableSqlBuilder struct { Dialect DialectType Name string IfNotExists bool ColumnOrder []string Columns map[string]*Column Constraints []string } const ( ColumnTypeBool ColumnType = iota ColumnTypeSmallInt ColumnType = iota ColumnTypeInteger ColumnType = iota ColumnTypeChar ColumnType = iota ColumnTypeVarChar ColumnType = iota ColumnTypeText ColumnType = iota ColumnTypeDateTime ColumnType = iota ) var _ SQLBuilder = &CreateTableSqlBuilder{} var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0} var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""} func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) { - if dialect != DialectMySQL && dialect != DialectSQLite { + if dialect != DialectMySQL && dialect != DialectPostgres && dialect != DialectSQLite { return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) } + + mod := "" + + if size.Set { + mod = fmt.Sprintf("(%d)", size.Value) + } + switch d { case ColumnTypeSmallInt: - { - if dialect == DialectSQLite { - return "INTEGER", nil - } - mod := "" - if size.Set { - mod = fmt.Sprintf("(%d)", size.Value) - } + switch dialect { + case DialectSQLite: + return "INTEGER", nil + case DialectMySQL: return "SMALLINT" + mod, nil + case DialectPostgres: + return "SMALLINT", nil } case ColumnTypeInteger: - { - if dialect == DialectSQLite { - return "INTEGER", nil - } - mod := "" - if size.Set { - mod = fmt.Sprintf("(%d)", size.Value) - } + switch dialect { + case DialectSQLite: + return "INTEGER", nil + case DialectMySQL: return "INT" + mod, nil + case DialectPostgres: + return "INT", nil } case ColumnTypeChar: - { - if dialect == DialectSQLite { - return "TEXT", nil - } - mod := "" - if size.Set { - mod = fmt.Sprintf("(%d)", size.Value) - } + switch dialect { + case DialectSQLite: + return "TEXT", nil + case DialectMySQL: + return "CHAR" + mod, nil + case DialectPostgres: return "CHAR" + mod, nil } case ColumnTypeVarChar: - { - if dialect == DialectSQLite { - return "TEXT", nil - } - mod := "" - if size.Set { - mod = fmt.Sprintf("(%d)", size.Value) - } + switch dialect { + case DialectSQLite: + return "TEXT", nil + case DialectMySQL: + return "VARCHAR" + mod, nil + case DialectPostgres: return "VARCHAR" + mod, nil } case ColumnTypeBool: - { - if dialect == DialectSQLite { - return "INTEGER", nil - } + switch dialect { + case DialectSQLite: + return "INTEGER", nil + case DialectMySQL: return "TINYINT(1)", nil + case DialectPostgres: + return "BOOLEAN", nil } case ColumnTypeDateTime: - return "DATETIME", nil + switch dialect { + case DialectSQLite: + return "DATETIME", nil + case DialectMySQL: + return "DATETIME", nil + case DialectPostgres: + return "TIMESTAMP", nil + } case ColumnTypeText: - return "TEXT", nil + switch dialect { + case DialectSQLite: + return "TEXT", nil + case DialectMySQL: + return "TEXT", nil + case DialectPostgres: + return "TEXT", nil + } } return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) } func (c *Column) SetName(name string) *Column { c.Name = name return c } func (c *Column) SetNullable(nullable bool) *Column { c.Nullable = nullable return c } func (c *Column) SetPrimaryKey(pk bool) *Column { c.PrimaryKey = pk return c } func (c *Column) SetDefault(value string) *Column { c.Default = OptionalString{Set: true, Value: value} return c } func (c *Column) SetDefaultCurrentTimestamp() *Column { - def := "NOW()" - if c.Dialect == DialectSQLite { + def := "" + + switch c.Dialect { + case DialectSQLite: def = "CURRENT_TIMESTAMP" + case DialectMySQL: + def = "NOW()" + case DialectPostgres: + def = "NOW()" } c.Default = OptionalString{Set: true, Value: def} return c } func (c *Column) SetType(t ColumnType) *Column { c.Type = t return c } func (c *Column) SetSize(size int) *Column { c.Size = OptionalInt{Set: true, Value: size} return c } func (c *Column) String() (string, error) { var str strings.Builder str.WriteString(c.Name) str.WriteString(" ") typeStr, err := c.Type.Format(c.Dialect, c.Size) if err != nil { return "", err } str.WriteString(typeStr) if !c.Nullable { str.WriteString(" NOT NULL") } if c.Default.Set { str.WriteString(" DEFAULT ") val := c.Default.Value if val == "" { val = "''" } str.WriteString(val) } if c.PrimaryKey { str.WriteString(" PRIMARY KEY") } return str.String(), nil } func (b *CreateTableSqlBuilder) Column(column *Column) *CreateTableSqlBuilder { if b.Columns == nil { b.Columns = make(map[string]*Column) } b.Columns[column.Name] = column b.ColumnOrder = append(b.ColumnOrder, column.Name) return b } func (b *CreateTableSqlBuilder) UniqueConstraint(columns ...string) *CreateTableSqlBuilder { for _, column := range columns { if _, ok := b.Columns[column]; !ok { // This fails silently. return b } } b.Constraints = append(b.Constraints, fmt.Sprintf("UNIQUE(%s)", strings.Join(columns, ","))) return b } func (b *CreateTableSqlBuilder) SetIfNotExists(ine bool) *CreateTableSqlBuilder { b.IfNotExists = ine return b } func (b *CreateTableSqlBuilder) ToSQL() (string, error) { var str strings.Builder str.WriteString("CREATE TABLE ") if b.IfNotExists { str.WriteString("IF NOT EXISTS ") } str.WriteString(b.Name) var things []string for _, columnName := range b.ColumnOrder { column, ok := b.Columns[columnName] if !ok { return "", fmt.Errorf("column not found: %s", columnName) } columnStr, err := column.String() if err != nil { return "", err } things = append(things, columnStr) } things = append(things, b.Constraints...) if thingLen := len(things); thingLen > 0 { str.WriteString(" ( ") for i, thing := range things { str.WriteString(thing) if i < thingLen-1 { str.WriteString(", ") } } str.WriteString(" )") } return str.String(), nil } diff --git a/db/dialect.go b/db/dialect.go index 4251465..5606ea0 100644 --- a/db/dialect.go +++ b/db/dialect.go @@ -1,76 +1,89 @@ package db import "fmt" type DialectType int const ( - DialectSQLite DialectType = iota - DialectMySQL DialectType = iota + DialectSQLite DialectType = iota + DialectMySQL DialectType = iota + DialectPostgres DialectType = iota ) func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column { switch d { case DialectSQLite: return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size} case DialectMySQL: return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size} + case DialectPostgres: + return &Column{Dialect: DialectPostgres, Name: name, Type: t, Size: size} default: panic(fmt.Sprintf("unexpected dialect: %d", d)) } } func (d DialectType) Table(name string) *CreateTableSqlBuilder { switch d { case DialectSQLite: return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name} case DialectMySQL: return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name} + case DialectPostgres: + return &CreateTableSqlBuilder{Dialect: DialectPostgres, Name: name} default: panic(fmt.Sprintf("unexpected dialect: %d", d)) } } func (d DialectType) AlterTable(name string) *AlterTableSqlBuilder { switch d { case DialectSQLite: return &AlterTableSqlBuilder{Dialect: DialectSQLite, Name: name} case DialectMySQL: return &AlterTableSqlBuilder{Dialect: DialectMySQL, Name: name} + case DialectPostgres: + return &AlterTableSqlBuilder{Dialect: DialectPostgres, Name: name} default: panic(fmt.Sprintf("unexpected dialect: %d", d)) } } func (d DialectType) CreateUniqueIndex(name, table string, columns ...string) *CreateIndexSqlBuilder { switch d { case DialectSQLite: return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: true, Columns: columns} case DialectMySQL: return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: true, Columns: columns} + case DialectPostgres: + return &CreateIndexSqlBuilder{Dialect: DialectPostgres, Name: name, Table: table, Unique: true, Columns: columns} default: panic(fmt.Sprintf("unexpected dialect: %d", d)) } } func (d DialectType) CreateIndex(name, table string, columns ...string) *CreateIndexSqlBuilder { switch d { case DialectSQLite: return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: false, Columns: columns} case DialectMySQL: return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: false, Columns: columns} + case DialectPostgres: + return &CreateIndexSqlBuilder{Dialect: DialectPostgres, Name: name, Table: table, Unique: false, Columns: columns} default: panic(fmt.Sprintf("unexpected dialect: %d", d)) } } func (d DialectType) DropIndex(name, table string) *DropIndexSqlBuilder { switch d { case DialectSQLite: return &DropIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table} case DialectMySQL: return &DropIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table} + case DialectPostgres: + return &DropIndexSqlBuilder{Dialect: DialectPostgres, Name: name, Table: table} default: panic(fmt.Sprintf("unexpected dialect: %d", d)) } }