diff --git a/migrations/drivers.go b/migrations/drivers.go index 5c6958a..486ea4d 100644 --- a/migrations/drivers.go +++ b/migrations/drivers.go @@ -1,110 +1,242 @@ /* * Copyright © 2019 Musing Studio LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package migrations import ( "fmt" ) // TODO: use now() from writefreely pkg func (db *datastore) now() string { - if db.driverName == driverSQLite { + switch db.driverName { + case driverSQLite: return "strftime('%Y-%m-%d %H:%M:%S','now')" + case driverMySQL: + return "NOW()" + case driverPostgres: + return "NOW()" } - return "NOW()" + + return "" // placeholder } func (db *datastore) typeInt() string { - if db.driverName == driverSQLite { + switch db.driverName { + case driverSQLite: return "INTEGER" + case driverMySQL: + return "INT" + case driverPostgres: + return "INT" } - return "INT" + + return "" // placeholder } func (db *datastore) typeSmallInt() string { - if db.driverName == driverSQLite { + switch db.driverName { + case driverSQLite: return "INTEGER" + case driverMySQL: + return "SMALLINT" + case driverPostgres: + return "SMALLINT" } - return "SMALLINT" + + return "" // placeholder } func (db *datastore) typeTinyInt() string { - if db.driverName == driverSQLite { + switch db.driverName { + case driverSQLite: return "INTEGER" + case driverMySQL: + return "TINYINT" + case driverPostgres: + return "SMALLINT" } - return "TINYINT" + + return "" // placeholder } func (db *datastore) typeText() string { return "TEXT" } func (db *datastore) typeChar(l int) string { - if db.driverName == driverSQLite { + switch db.driverName { + case driverSQLite: return "TEXT" + case driverMySQL: + return fmt.Sprintf("CHAR(%d)", l) + case driverPostgres: + return fmt.Sprintf("CHAR(%d)", l) } - return fmt.Sprintf("CHAR(%d)", l) + + return "" // placeholder } func (db *datastore) typeVarChar(l int) string { - if db.driverName == driverSQLite { + switch db.driverName { + case driverSQLite: return "TEXT" + case driverMySQL: + return fmt.Sprintf("VARCHAR(%d)", l) + case driverPostgres: + return fmt.Sprintf("VARCHAR(%d)", l) } - return fmt.Sprintf("VARCHAR(%d)", l) + + return "" // placeholder } func (db *datastore) typeVarBinary(l int) string { - if db.driverName == driverSQLite { + switch db.driverName { + case driverSQLite: return "BLOB" + case driverMySQL: + return fmt.Sprintf("VARBINARY(%d)", l) + case driverPostgres: + return "BYTEA" } - return fmt.Sprintf("VARBINARY(%d)", l) + + return "" // placeholder } func (db *datastore) typeBool() string { - if db.driverName == driverSQLite { + switch db.driverName { + case driverSQLite: return "INTEGER" + case driverMySQL: + return "TINYINT(1)" + case driverPostgres: + return "BOOLEAN" } - return "TINYINT(1)" + + return "" // placeholder } func (db *datastore) typeDateTime() string { - return "DATETIME" + switch db.driverName { + case driverSQLite: + return "DATETIME" + case driverMySQL: + return "DATETIME" + case driverPostgres: + return "TIMESTAMP" + } + + return "" // placeholder } func (db *datastore) typeIntPrimaryKey() string { - if db.driverName == driverSQLite { + switch db.driverName { + case driverSQLite: // From docs: "In SQLite, a column with type INTEGER PRIMARY KEY is an alias for the ROWID (except in WITHOUT // ROWID tables) which is always a 64-bit signed integer." return "INTEGER PRIMARY KEY" + case driverMySQL: + return "INT AUTO_INCREMENT PRIMARY KEY" + case driverPostgres: + return "SERIAL PRIMARY KEY" } - return "INT AUTO_INCREMENT PRIMARY KEY" + + return "" // placeholder } func (db *datastore) collateMultiByte() string { - if db.driverName == driverSQLite { + switch db.driverName { + case driverSQLite: + return "" + case driverMySQL: + return " COLLATE utf8_bin" + case driverPostgres: return "" } - return " COLLATE utf8_bin" + + return "" // placeholder } func (db *datastore) engine() string { - if db.driverName == driverSQLite { + switch db.driverName { + case driverSQLite: + return "" + case driverMySQL: + return " ENGINE = InnoDB" + case driverPostgres: return "" } - return " ENGINE = InnoDB" + + return "" // placeholder } func (db *datastore) after(colName string) string { - if db.driverName == driverSQLite { + switch db.driverName { + case driverSQLite: return "" + case driverMySQL: + return fmt.Sprintf(" AFTER %s", colName) + case driverPostgres: + return "" + } + + return "" // placeholder +} + +func (db *datastore) boolTrue() string { + switch db.driverName { + case driverSQLite: + return "1" + case driverMySQL: + return "1" + case driverPostgres: + return "TRUE" } - return " AFTER " + colName + + return "" // placeholder } + +func (db *datastore) boolFalse() string { + switch db.driverName { + case driverSQLite: + return "0" + case driverMySQL: + return "0" + case driverPostgres: + return "FALSE" + } + + return "" // placeholder +} + +func (db *datastore) QueryWrap(q string) string { + if db.driverName != driverPostgres { + return q + } + + output := "" + escape := false + ctr := 0 + + for i := range len(q) { + if q[i] == '\'' || q[i] == '`' { + escape = !escape + } + + if q[i] == '?' && !escape { + ctr += 1 + output += fmt.Sprintf("$%d", ctr) + } else { + output += string(q[i]) + } + } + + return output +} \ No newline at end of file diff --git a/migrations/migrations.go b/migrations/migrations.go index 6b5b094..0e99696 100644 --- a/migrations/migrations.go +++ b/migrations/migrations.go @@ -1,151 +1,160 @@ /* * Copyright © 2019 Musing Studio LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ // Package migrations contains database migrations for WriteFreely package migrations import ( "database/sql" + "fmt" "github.com/writeas/web-core/log" ) // TODO: refactor to use the datastore struct from writefreely pkg type datastore struct { *sql.DB driverName string } func NewDatastore(db *sql.DB, dn string) *datastore { return &datastore{db, dn} } // TODO: use these consts from writefreely pkg const ( - driverMySQL = "mysql" - driverSQLite = "sqlite3" + driverMySQL = "mysql" + driverPostgres = "postgres" + driverSQLite = "sqlite3" ) type Migration interface { Description() string Migrate(db *datastore) error } type migration struct { description string migrate func(db *datastore) error } func New(d string, fn func(db *datastore) error) Migration { return &migration{d, fn} } func (m *migration) Description() string { return m.description } func (m *migration) Migrate(db *datastore) error { return m.migrate(db) } var migrations = []Migration{ New("support user invites", supportUserInvites), // -> V1 (v0.8.0) New("support dynamic instance pages", supportInstancePages), // V1 -> V2 (v0.9.0) New("support users suspension", supportUserStatus), // V2 -> V3 (v0.11.0) New("support oauth", oauth), // V3 -> V4 New("support slack oauth", oauthSlack), // V4 -> v5 New("support ActivityPub mentions", supportActivityPubMentions), // V5 -> V6 New("support oauth attach", oauthAttach), // V6 -> V7 New("support oauth via invite", oauthInvites), // V7 -> V8 (v0.12.0) New("optimize drafts retrieval", optimizeDrafts), // V8 -> V9 New("support post signatures", supportPostSignatures), // V9 -> V10 (v0.13.0) New("Widen oauth_users.access_token", widenOauthAcceesToken), // V10 -> V11 New("support verifying fedi profile", fediverseVerifyProfile), // V11 -> V12 (v0.14.0) New("support newsletters", supportLetters), // V12 -> V13 New("support password resetting", supportPassReset), // V13 -> V14 New("speed up blog post retrieval", addPostRetrievalIndex), // V14 -> V15 New("support ActivityPub likes", supportRemoteLikes), // V15 -> V16 (v0.16.0) } // CurrentVer returns the current migration version the application is on func CurrentVer() int { return len(migrations) } func SetInitialMigrations(db *datastore) error { // Included schema files represent changes up to V1, so note that in the database - _, err := db.Exec("INSERT INTO appmigrations (version, migrated, result) VALUES (?, "+db.now()+", ?)", 1, "") + _, err := db.Exec(db.QueryWrap("INSERT INTO appmigrations (version, migrated, result) VALUES (?, "+db.now()+", ?)"), 1, "") if err != nil { return err } return nil } func Migrate(db *datastore) error { var version int var err error if db.tableExists("appmigrations") { err = db.QueryRow("SELECT MAX(version) FROM appmigrations").Scan(&version) if err != nil { return err } } else { log.Info("Initializing appmigrations table...") version = 0 _, err = db.Exec(`CREATE TABLE appmigrations ( version ` + db.typeInt() + ` NOT NULL, migrated ` + db.typeDateTime() + ` NOT NULL, result ` + db.typeText() + ` NOT NULL ) ` + db.engine() + `;`) if err != nil { return err } } if len(migrations[version:]) > 0 { for i, m := range migrations[version:] { curVer := version + i + 1 log.Info("Migrating to V%d: %s", curVer, m.Description()) err = m.Migrate(db) if err != nil { return err } // Update migrations table - _, err = db.Exec("INSERT INTO appmigrations (version, migrated, result) VALUES (?, "+db.now()+", ?)", curVer, "") + _, err = db.Exec(db.QueryWrap("INSERT INTO appmigrations (version, migrated, result) VALUES (?, "+db.now()+", ?)"), curVer, "") if err != nil { return err } } } else { log.Info("Database up-to-date. No migrations to run.") } return nil } func (db *datastore) tableExists(t string) bool { var dummy string var err error - if db.driverName == driverSQLite { - err = db.QueryRow("SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?", t).Scan(&dummy) - } else { - err = db.QueryRow("SHOW TABLES LIKE '" + t + "'").Scan(&dummy) + var q string + switch db.driverName { + case driverSQLite: + q = fmt.Sprintf("SELECT name FROM sqlite_master WHERE type = 'table' AND name = '%s'", t) + case driverMySQL: + q = fmt.Sprintf("SHOW TABLES LIKE '%s'", t) + case driverPostgres: + q = fmt.Sprintf("SELECT tablename FROM pg_tables WHERE schemaname = 'public' AND tablename = '%s'", t) } + + err = db.QueryRow(q).Scan(&dummy) + switch { case err == sql.ErrNoRows: return false case err != nil: - log.Error("Couldn't SHOW TABLES: %v", err) + log.Error("Couldn't SHOW TABLES: %v", err.Error()) return false } return true }