Page MenuHomeMusing Studio

No OneTemporary

diff --git a/db/alter.go b/db/alter.go
index 0a4ffdd..14e92ac 100644
--- a/db/alter.go
+++ b/db/alter.go
@@ -1,52 +1,73 @@
package db
import (
"fmt"
"strings"
)
+type alterTableQueryType int
+
+const (
+ alter alterTableQueryType = iota
+ update alterTableQueryType = iota
+)
+
+type change struct {
+ queryType alterTableQueryType
+ queryString string
+}
+
type AlterTableSqlBuilder struct {
Dialect DialectType
Name string
- Changes []string
+ Changes []change
}
func (b *AlterTableSqlBuilder) AddColumn(col *Column) *AlterTableSqlBuilder {
if colVal, err := col.String(); err == nil {
- b.Changes = append(b.Changes, fmt.Sprintf("ADD COLUMN %s", colVal))
+ b.Changes = append(b.Changes, change{queryType: alter, queryString: fmt.Sprintf("ADD COLUMN %s", colVal)})
}
return b
}
func (b *AlterTableSqlBuilder) ChangeColumn(name string, col *Column) *AlterTableSqlBuilder {
if colVal, err := col.String(); err == nil {
- b.Changes = append(b.Changes, fmt.Sprintf("CHANGE COLUMN %s %s", name, colVal))
+ b.Changes = append(b.Changes, change{queryType: alter, queryString: fmt.Sprintf("RENAME COLUMN %s TO %s_old", name, name)})
+ b.Changes = append(b.Changes, change{queryType: alter, queryString: fmt.Sprintf("ADD COLUMN %s", colVal)})
+ b.Changes = append(b.Changes, change{queryType: update, queryString: fmt.Sprintf("SET %s = %s_old", col.Name, name)})
+ b.Changes = append(b.Changes, change{queryType: alter, queryString: fmt.Sprintf("DROP COLUMN %s", name)})
}
return b
}
func (b *AlterTableSqlBuilder) AddUniqueConstraint(name string, columns ...string) *AlterTableSqlBuilder {
- b.Changes = append(b.Changes, fmt.Sprintf("ADD CONSTRAINT %s UNIQUE (%s)", name, strings.Join(columns, ", ")))
+ b.Changes = append(b.Changes, change{
+ queryType: alter,
+ queryString: fmt.Sprintf("ADD CONSTRAINT %s UNIQUE (%s)", name, strings.Join(columns, ", ")),
+ })
return b
}
func (b *AlterTableSqlBuilder) ToSQL() (string, error) {
- var str strings.Builder
-
- str.WriteString("ALTER TABLE ")
- str.WriteString(b.Name)
- str.WriteString(" ")
+ str := ""
+ changeCount := len(b.Changes)
- if len(b.Changes) == 0 {
+ if changeCount == 0 {
return "", fmt.Errorf("no changes provide for table: %s", b.Name)
}
- changeCount := len(b.Changes)
- for i, thing := range b.Changes {
- str.WriteString(thing)
- if i < changeCount-1 {
- str.WriteString(", ")
+
+ for i, change := range b.Changes {
+ switch change.queryType {
+ case alter:
+ str += fmt.Sprintf("ALTER TABLE %s ", b.Name)
+ case update:
+ str += fmt.Sprintf("UPDATE %s ", b.Name)
+ }
+ str += change.queryString
+ if i < changeCount - 1 {
+ str += "; "
}
}
- return str.String(), nil
+ return str, nil
}
diff --git a/db/create.go b/db/create.go
index 1e9e679..870bfd7 100644
--- a/db/create.go
+++ b/db/create.go
@@ -1,263 +1,284 @@
/*
* Copyright © 2019-2020 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package db
import (
"fmt"
"strings"
)
type ColumnType int
type OptionalInt struct {
Set bool
Value int
}
type OptionalString struct {
Set bool
Value string
}
type SQLBuilder interface {
ToSQL() (string, error)
}
type Column struct {
Dialect DialectType
Name string
Nullable bool
Default OptionalString
Type ColumnType
Size OptionalInt
PrimaryKey bool
}
type CreateTableSqlBuilder struct {
Dialect DialectType
Name string
IfNotExists bool
ColumnOrder []string
Columns map[string]*Column
Constraints []string
}
const (
ColumnTypeBool ColumnType = iota
ColumnTypeSmallInt ColumnType = iota
ColumnTypeInteger ColumnType = iota
ColumnTypeChar ColumnType = iota
ColumnTypeVarChar ColumnType = iota
ColumnTypeText ColumnType = iota
ColumnTypeDateTime ColumnType = iota
)
var _ SQLBuilder = &CreateTableSqlBuilder{}
var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0}
var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""}
func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) {
- if dialect != DialectMySQL && dialect != DialectSQLite {
+ if dialect != DialectMySQL && dialect != DialectPostgres && dialect != DialectSQLite {
return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
}
+
+ mod := ""
+
+ if size.Set {
+ mod = fmt.Sprintf("(%d)", size.Value)
+ }
+
switch d {
case ColumnTypeSmallInt:
- {
- if dialect == DialectSQLite {
- return "INTEGER", nil
- }
- mod := ""
- if size.Set {
- mod = fmt.Sprintf("(%d)", size.Value)
- }
+ switch dialect {
+ case DialectSQLite:
+ return "INTEGER", nil
+ case DialectMySQL:
return "SMALLINT" + mod, nil
+ case DialectPostgres:
+ return "SMALLINT", nil
}
case ColumnTypeInteger:
- {
- if dialect == DialectSQLite {
- return "INTEGER", nil
- }
- mod := ""
- if size.Set {
- mod = fmt.Sprintf("(%d)", size.Value)
- }
+ switch dialect {
+ case DialectSQLite:
+ return "INTEGER", nil
+ case DialectMySQL:
return "INT" + mod, nil
+ case DialectPostgres:
+ return "INT", nil
}
case ColumnTypeChar:
- {
- if dialect == DialectSQLite {
- return "TEXT", nil
- }
- mod := ""
- if size.Set {
- mod = fmt.Sprintf("(%d)", size.Value)
- }
+ switch dialect {
+ case DialectSQLite:
+ return "TEXT", nil
+ case DialectMySQL:
+ return "CHAR" + mod, nil
+ case DialectPostgres:
return "CHAR" + mod, nil
}
case ColumnTypeVarChar:
- {
- if dialect == DialectSQLite {
- return "TEXT", nil
- }
- mod := ""
- if size.Set {
- mod = fmt.Sprintf("(%d)", size.Value)
- }
+ switch dialect {
+ case DialectSQLite:
+ return "TEXT", nil
+ case DialectMySQL:
+ return "VARCHAR" + mod, nil
+ case DialectPostgres:
return "VARCHAR" + mod, nil
}
case ColumnTypeBool:
- {
- if dialect == DialectSQLite {
- return "INTEGER", nil
- }
+ switch dialect {
+ case DialectSQLite:
+ return "INTEGER", nil
+ case DialectMySQL:
return "TINYINT(1)", nil
+ case DialectPostgres:
+ return "BOOLEAN", nil
}
case ColumnTypeDateTime:
- return "DATETIME", nil
+ switch dialect {
+ case DialectSQLite:
+ return "DATETIME", nil
+ case DialectMySQL:
+ return "DATETIME", nil
+ case DialectPostgres:
+ return "TIMESTAMP", nil
+ }
case ColumnTypeText:
- return "TEXT", nil
+ switch dialect {
+ case DialectSQLite:
+ return "TEXT", nil
+ case DialectMySQL:
+ return "TEXT", nil
+ case DialectPostgres:
+ return "TEXT", nil
+ }
}
return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
}
func (c *Column) SetName(name string) *Column {
c.Name = name
return c
}
func (c *Column) SetNullable(nullable bool) *Column {
c.Nullable = nullable
return c
}
func (c *Column) SetPrimaryKey(pk bool) *Column {
c.PrimaryKey = pk
return c
}
func (c *Column) SetDefault(value string) *Column {
c.Default = OptionalString{Set: true, Value: value}
return c
}
func (c *Column) SetDefaultCurrentTimestamp() *Column {
- def := "NOW()"
- if c.Dialect == DialectSQLite {
+ def := ""
+
+ switch c.Dialect {
+ case DialectSQLite:
def = "CURRENT_TIMESTAMP"
+ case DialectMySQL:
+ def = "NOW()"
+ case DialectPostgres:
+ def = "NOW()"
}
c.Default = OptionalString{Set: true, Value: def}
return c
}
func (c *Column) SetType(t ColumnType) *Column {
c.Type = t
return c
}
func (c *Column) SetSize(size int) *Column {
c.Size = OptionalInt{Set: true, Value: size}
return c
}
func (c *Column) String() (string, error) {
var str strings.Builder
str.WriteString(c.Name)
str.WriteString(" ")
typeStr, err := c.Type.Format(c.Dialect, c.Size)
if err != nil {
return "", err
}
str.WriteString(typeStr)
if !c.Nullable {
str.WriteString(" NOT NULL")
}
if c.Default.Set {
str.WriteString(" DEFAULT ")
val := c.Default.Value
if val == "" {
val = "''"
}
str.WriteString(val)
}
if c.PrimaryKey {
str.WriteString(" PRIMARY KEY")
}
return str.String(), nil
}
func (b *CreateTableSqlBuilder) Column(column *Column) *CreateTableSqlBuilder {
if b.Columns == nil {
b.Columns = make(map[string]*Column)
}
b.Columns[column.Name] = column
b.ColumnOrder = append(b.ColumnOrder, column.Name)
return b
}
func (b *CreateTableSqlBuilder) UniqueConstraint(columns ...string) *CreateTableSqlBuilder {
for _, column := range columns {
if _, ok := b.Columns[column]; !ok {
// This fails silently.
return b
}
}
b.Constraints = append(b.Constraints, fmt.Sprintf("UNIQUE(%s)", strings.Join(columns, ",")))
return b
}
func (b *CreateTableSqlBuilder) SetIfNotExists(ine bool) *CreateTableSqlBuilder {
b.IfNotExists = ine
return b
}
func (b *CreateTableSqlBuilder) ToSQL() (string, error) {
var str strings.Builder
str.WriteString("CREATE TABLE ")
if b.IfNotExists {
str.WriteString("IF NOT EXISTS ")
}
str.WriteString(b.Name)
var things []string
for _, columnName := range b.ColumnOrder {
column, ok := b.Columns[columnName]
if !ok {
return "", fmt.Errorf("column not found: %s", columnName)
}
columnStr, err := column.String()
if err != nil {
return "", err
}
things = append(things, columnStr)
}
things = append(things, b.Constraints...)
if thingLen := len(things); thingLen > 0 {
str.WriteString(" ( ")
for i, thing := range things {
str.WriteString(thing)
if i < thingLen-1 {
str.WriteString(", ")
}
}
str.WriteString(" )")
}
return str.String(), nil
}
diff --git a/db/dialect.go b/db/dialect.go
index 4251465..5606ea0 100644
--- a/db/dialect.go
+++ b/db/dialect.go
@@ -1,76 +1,89 @@
package db
import "fmt"
type DialectType int
const (
- DialectSQLite DialectType = iota
- DialectMySQL DialectType = iota
+ DialectSQLite DialectType = iota
+ DialectMySQL DialectType = iota
+ DialectPostgres DialectType = iota
)
func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column {
switch d {
case DialectSQLite:
return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size}
case DialectMySQL:
return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size}
+ case DialectPostgres:
+ return &Column{Dialect: DialectPostgres, Name: name, Type: t, Size: size}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d DialectType) Table(name string) *CreateTableSqlBuilder {
switch d {
case DialectSQLite:
return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name}
case DialectMySQL:
return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name}
+ case DialectPostgres:
+ return &CreateTableSqlBuilder{Dialect: DialectPostgres, Name: name}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d DialectType) AlterTable(name string) *AlterTableSqlBuilder {
switch d {
case DialectSQLite:
return &AlterTableSqlBuilder{Dialect: DialectSQLite, Name: name}
case DialectMySQL:
return &AlterTableSqlBuilder{Dialect: DialectMySQL, Name: name}
+ case DialectPostgres:
+ return &AlterTableSqlBuilder{Dialect: DialectPostgres, Name: name}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d DialectType) CreateUniqueIndex(name, table string, columns ...string) *CreateIndexSqlBuilder {
switch d {
case DialectSQLite:
return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: true, Columns: columns}
case DialectMySQL:
return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: true, Columns: columns}
+ case DialectPostgres:
+ return &CreateIndexSqlBuilder{Dialect: DialectPostgres, Name: name, Table: table, Unique: true, Columns: columns}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d DialectType) CreateIndex(name, table string, columns ...string) *CreateIndexSqlBuilder {
switch d {
case DialectSQLite:
return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: false, Columns: columns}
case DialectMySQL:
return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: false, Columns: columns}
+ case DialectPostgres:
+ return &CreateIndexSqlBuilder{Dialect: DialectPostgres, Name: name, Table: table, Unique: false, Columns: columns}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d DialectType) DropIndex(name, table string) *DropIndexSqlBuilder {
switch d {
case DialectSQLite:
return &DropIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table}
case DialectMySQL:
return &DropIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table}
+ case DialectPostgres:
+ return &DropIndexSqlBuilder{Dialect: DialectPostgres, Name: name, Table: table}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}

File Metadata

Mime Type
text/x-diff
Expires
Thu, Oct 30, 2:14 AM (10 h, 44 m)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3464871

Event Timeline