diff --git a/Makefile b/Makefile index 757bcfd..782e680 100644 --- a/Makefile +++ b/Makefile @@ -1,149 +1,159 @@ GITREV=`git describe | cut -c 2-` LDFLAGS=-ldflags="-X 'github.com/writeas/writefreely.softwareVer=$(GITREV)'" GOCMD=go GOINSTALL=$(GOCMD) install $(LDFLAGS) GOBUILD=$(GOCMD) build $(LDFLAGS) GOTEST=$(GOCMD) test $(LDFLAGS) GOGET=$(GOCMD) get BINARY_NAME=writefreely BUILDPATH=build/$(BINARY_NAME) DOCKERCMD=docker IMAGE_NAME=writeas/writefreely TMPBIN=./tmp all : build ci: ci-assets deps cd cmd/writefreely; $(GOBUILD) -v build: assets deps cd cmd/writefreely; $(GOBUILD) -v -tags='sqlite' build-no-sqlite: assets-no-sqlite deps-no-sqlite cd cmd/writefreely; $(GOBUILD) -v -o $(BINARY_NAME) build-linux: deps @hash xgo > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ $(GOGET) -u github.com/karalabe/xgo; \ fi xgo --targets=linux/amd64, -dest build/ $(LDFLAGS) -tags='sqlite' -out writefreely ./cmd/writefreely build-windows: deps @hash xgo > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ $(GOGET) -u github.com/karalabe/xgo; \ fi xgo --targets=windows/amd64, -dest build/ $(LDFLAGS) -tags='sqlite' -out writefreely ./cmd/writefreely build-darwin: deps @hash xgo > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ $(GOGET) -u github.com/karalabe/xgo; \ fi xgo --targets=darwin/amd64, -dest build/ $(LDFLAGS) -tags='sqlite' -out writefreely ./cmd/writefreely build-arm7: deps @hash xgo > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ $(GOGET) -u github.com/karalabe/xgo; \ fi xgo --targets=linux/arm-7, -dest build/ $(LDFLAGS) -tags='sqlite' -out writefreely ./cmd/writefreely +build-arm64: deps + @hash xgo > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ + $(GOGET) -u github.com/karalabe/xgo; \ + fi + xgo --targets=linux/arm64, -dest build/ $(LDFLAGS) -tags='sqlite' -out writefreely ./cmd/writefreely + build-docker : $(DOCKERCMD) build -t $(IMAGE_NAME):latest -t $(IMAGE_NAME):$(GITREV) . test: $(GOTEST) -v ./... run: dev-assets $(GOINSTALL) -tags='sqlite' ./... $(BINARY_NAME) --debug deps : $(GOGET) -tags='sqlite' -d -v ./... deps-no-sqlite: $(GOGET) -d -v ./... install : build cmd/writefreely/$(BINARY_NAME) --config cmd/writefreely/$(BINARY_NAME) --gen-keys cmd/writefreely/$(BINARY_NAME) --init-db cd less/; $(MAKE) install $(MFLAGS) release : clean ui assets mkdir -p $(BUILDPATH) cp -r templates $(BUILDPATH) cp -r pages $(BUILDPATH) cp -r static $(BUILDPATH) mkdir $(BUILDPATH)/keys $(MAKE) build-linux mv build/$(BINARY_NAME)-linux-amd64 $(BUILDPATH)/$(BINARY_NAME) tar -cvzf $(BINARY_NAME)_$(GITREV)_linux_amd64.tar.gz -C build $(BINARY_NAME) rm $(BUILDPATH)/$(BINARY_NAME) $(MAKE) build-arm7 mv build/$(BINARY_NAME)-linux-arm-7 $(BUILDPATH)/$(BINARY_NAME) tar -cvzf $(BINARY_NAME)_$(GITREV)_linux_arm7.tar.gz -C build $(BINARY_NAME) rm $(BUILDPATH)/$(BINARY_NAME) + $(MAKE) build-arm64 + mv build/$(BINARY_NAME)-linux-arm64 $(BUILDPATH)/$(BINARY_NAME) + tar -cvzf $(BINARY_NAME)_$(GITREV)_linux_arm64.tar.gz -C build $(BINARY_NAME) + rm $(BUILDPATH)/$(BINARY_NAME) $(MAKE) build-darwin mv build/$(BINARY_NAME)-darwin-10.6-amd64 $(BUILDPATH)/$(BINARY_NAME) tar -cvzf $(BINARY_NAME)_$(GITREV)_macos_amd64.tar.gz -C build $(BINARY_NAME) rm $(BUILDPATH)/$(BINARY_NAME) $(MAKE) build-windows mv build/$(BINARY_NAME)-windows-4.0-amd64.exe $(BUILDPATH)/$(BINARY_NAME).exe cd build; zip -r ../$(BINARY_NAME)_$(GITREV)_windows_amd64.zip ./$(BINARY_NAME) rm $(BUILDPATH)/$(BINARY_NAME) $(MAKE) build-docker $(MAKE) release-docker # This assumes you're on linux/amd64 release-linux : clean ui mkdir -p $(BUILDPATH) cp -r templates $(BUILDPATH) cp -r pages $(BUILDPATH) cp -r static $(BUILDPATH) mkdir $(BUILDPATH)/keys $(MAKE) build-no-sqlite mv cmd/writefreely/$(BINARY_NAME) $(BUILDPATH)/$(BINARY_NAME) tar -cvzf $(BINARY_NAME)_$(GITREV)_linux_amd64.tar.gz -C build $(BINARY_NAME) release-docker : $(DOCKERCMD) push $(IMAGE_NAME) ui : force_look cd less/; $(MAKE) $(MFLAGS) assets : generate go-bindata -pkg writefreely -ignore=\\.gitignore -tags="!wflib" schema.sql sqlite.sql assets-no-sqlite: generate go-bindata -pkg writefreely -ignore=\\.gitignore -tags="!wflib" schema.sql dev-assets : generate go-bindata -pkg writefreely -ignore=\\.gitignore -debug -tags="!wflib" schema.sql sqlite.sql lib-assets : generate go-bindata -pkg writefreely -ignore=\\.gitignore -o bindata-lib.go -tags="wflib" schema.sql generate : @hash go-bindata > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ $(GOGET) -u github.com/jteeuwen/go-bindata/go-bindata; \ fi $(TMPBIN): mkdir -p $(TMPBIN) $(TMPBIN)/go-bindata: deps $(TMPBIN) $(GOBUILD) -o $(TMPBIN)/go-bindata github.com/jteeuwen/go-bindata/go-bindata $(TMPBIN)/xgo: deps $(TMPBIN) $(GOBUILD) -o $(TMPBIN)/xgo github.com/karalabe/xgo ci-assets : $(TMPBIN)/go-bindata $(TMPBIN)/go-bindata -pkg writefreely -ignore=\\.gitignore -tags="!wflib" schema.sql sqlite.sql clean : -rm -rf build -rm -rf tmp cd less/; $(MAKE) clean $(MFLAGS) force_look : true diff --git a/config/config.go b/config/config.go index 4b9586e..996c1df 100644 --- a/config/config.go +++ b/config/config.go @@ -1,199 +1,206 @@ /* * Copyright © 2018-2019 A Bunch Tell LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ // Package config holds and assists in the configuration of a writefreely instance. package config import ( "gopkg.in/ini.v1" "strings" ) const ( // FileName is the default configuration file name FileName = "config.ini" UserNormal UserType = "user" UserAdmin = "admin" ) type ( UserType string // ServerCfg holds values that affect how the HTTP server runs ServerCfg struct { HiddenHost string `ini:"hidden_host"` Port int `ini:"port"` Bind string `ini:"bind"` TLSCertPath string `ini:"tls_cert_path"` TLSKeyPath string `ini:"tls_key_path"` Autocert bool `ini:"autocert"` TemplatesParentDir string `ini:"templates_parent_dir"` StaticParentDir string `ini:"static_parent_dir"` PagesParentDir string `ini:"pages_parent_dir"` KeysParentDir string `ini:"keys_parent_dir"` Dev bool `ini:"-"` } // DatabaseCfg holds values that determine how the application connects to a datastore DatabaseCfg struct { Type string `ini:"type"` FileName string `ini:"filename"` User string `ini:"username"` Password string `ini:"password"` Database string `ini:"database"` Host string `ini:"host"` Port int `ini:"port"` } + WriteAsOauthCfg struct { + ClientID string `ini:"client_id"` + ClientSecret string `ini:"client_secret"` + AuthLocation string `ini:"auth_location"` + TokenLocation string `ini:"token_location"` + InspectLocation string `ini:"inspect_location"` + } + + SlackOauthCfg struct { + ClientID string `ini:"client_id"` + ClientSecret string `ini:"client_secret"` + TeamID string `ini:"team_id"` + } + // AppCfg holds values that affect how the application functions AppCfg struct { SiteName string `ini:"site_name"` SiteDesc string `ini:"site_description"` Host string `ini:"host"` // Site appearance Theme string `ini:"theme"` Editor string `ini:"editor"` JSDisabled bool `ini:"disable_js"` WebFonts bool `ini:"webfonts"` Landing string `ini:"landing"` SimpleNav bool `ini:"simple_nav"` WFModesty bool `ini:"wf_modesty"` // Site functionality Chorus bool `ini:"chorus"` DisableDrafts bool `ini:"disable_drafts"` // Users SingleUser bool `ini:"single_user"` OpenRegistration bool `ini:"open_registration"` MinUsernameLen int `ini:"min_username_len"` MaxBlogs int `ini:"max_blogs"` // Federation Federation bool `ini:"federation"` PublicStats bool `ini:"public_stats"` // Access Private bool `ini:"private"` // Additional functions LocalTimeline bool `ini:"local_timeline"` UserInvites string `ini:"user_invites"` - // OAuth - EnableOAuth bool `ini:"enable_oauth"` - OAuthProviderAuthLocation string `ini:"oauth_auth_location"` - OAuthProviderTokenLocation string `ini:"oauth_token_location"` - OAuthProviderInspectLocation string `ini:"oauth_inspect_location"` - OAuthClientCallbackLocation string `ini:"oauth_callback_location"` - OAuthClientID string `ini:"oauth_client_id"` - OAuthClientSecret string `ini:"oauth_client_secret"` - // Defaults DefaultVisibility string `ini:"default_visibility"` } // Config holds the complete configuration for running a writefreely instance Config struct { - Server ServerCfg `ini:"server"` - Database DatabaseCfg `ini:"database"` - App AppCfg `ini:"app"` + Server ServerCfg `ini:"server"` + Database DatabaseCfg `ini:"database"` + App AppCfg `ini:"app"` + SlackOauth SlackOauthCfg `ini:"oauth.slack"` + WriteAsOauth WriteAsOauthCfg `ini:"oauth.writeas"` } ) // New creates a new Config with sane defaults func New() *Config { c := &Config{ Server: ServerCfg{ Port: 8080, Bind: "localhost", /* IPV6 support when not using localhost? */ }, App: AppCfg{ Host: "http://localhost:8080", Theme: "write", WebFonts: true, SingleUser: true, MinUsernameLen: 3, MaxBlogs: 1, Federation: true, PublicStats: true, }, } c.UseMySQL(true) return c } // UseMySQL resets the Config's Database to use default values for a MySQL setup. func (cfg *Config) UseMySQL(fresh bool) { cfg.Database.Type = "mysql" if fresh { cfg.Database.Host = "localhost" cfg.Database.Port = 3306 } } // UseSQLite resets the Config's Database to use default values for a SQLite setup. func (cfg *Config) UseSQLite(fresh bool) { cfg.Database.Type = "sqlite3" if fresh { cfg.Database.FileName = "writefreely.db" } } // IsSecureStandalone returns whether or not the application is running as a // standalone server with TLS enabled. func (cfg *Config) IsSecureStandalone() bool { return cfg.Server.Port == 443 && cfg.Server.TLSCertPath != "" && cfg.Server.TLSKeyPath != "" } func (ac *AppCfg) LandingPath() string { if !strings.HasPrefix(ac.Landing, "/") { return "/" + ac.Landing } return ac.Landing } // Load reads the given configuration file, then parses and returns it as a Config. func Load(fname string) (*Config, error) { if fname == "" { fname = FileName } cfg, err := ini.Load(fname) if err != nil { return nil, err } // Parse INI file uc := &Config{} err = cfg.MapTo(uc) if err != nil { return nil, err } return uc, nil } // Save writes the given Config to the given file. func Save(uc *Config, fname string) error { cfg := ini.Empty() err := ini.ReflectFrom(cfg, uc) if err != nil { return err } if fname == "" { fname = FileName } return cfg.SaveTo(fname) } diff --git a/config/funcs.go b/config/funcs.go index a9c82ce..9678df0 100644 --- a/config/funcs.go +++ b/config/funcs.go @@ -1,27 +1,42 @@ /* * Copyright © 2018 A Bunch Tell LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package config import ( + "net/http" "strings" + "time" ) // FriendlyHost returns the app's Host sans any schema func (ac AppCfg) FriendlyHost() string { return ac.Host[strings.Index(ac.Host, "://")+len("://"):] } func (ac AppCfg) CanCreateBlogs(currentlyUsed uint64) bool { if ac.MaxBlogs <= 0 { return true } return int(currentlyUsed) < ac.MaxBlogs } + +// OrDefaultString returns input or a default value if input is empty. +func OrDefaultString(input, defaultValue string) string { + if len(input) == 0 { + return defaultValue + } + return input +} + +// DefaultHTTPClient returns a sane default HTTP client. +func DefaultHTTPClient() *http.Client { + return &http.Client{Timeout: 10 * time.Second} +} diff --git a/database.go b/database.go index 56035dd..ef52d84 100644 --- a/database.go +++ b/database.go @@ -1,2541 +1,2557 @@ /* * Copyright © 2018 A Bunch Tell LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package writefreely import ( "context" "database/sql" "fmt" + wf_db "github.com/writeas/writefreely/db" "net/http" "strings" "time" "github.com/guregu/null" "github.com/guregu/null/zero" uuid "github.com/nu7hatch/gouuid" "github.com/writeas/impart" "github.com/writeas/nerds/store" "github.com/writeas/web-core/activitypub" "github.com/writeas/web-core/auth" "github.com/writeas/web-core/data" "github.com/writeas/web-core/id" "github.com/writeas/web-core/log" "github.com/writeas/web-core/query" "github.com/writeas/writefreely/author" "github.com/writeas/writefreely/config" "github.com/writeas/writefreely/key" ) const ( mySQLErrDuplicateKey = 1062 driverMySQL = "mysql" driverSQLite = "sqlite3" ) var ( SQLiteEnabled bool ) type writestore interface { CreateUser(*config.Config, *User, string) error UpdateUserEmail(keys *key.Keychain, userID int64, email string) error UpdateEncryptedUserEmail(int64, []byte) error GetUserByID(int64) (*User, error) GetUserForAuth(string) (*User, error) GetUserForAuthByID(int64) (*User, error) GetUserNameFromToken(string) (string, error) GetUserDataFromToken(string) (int64, string, error) GetAPIUser(header string) (*User, error) GetUserID(accessToken string) int64 GetUserIDPrivilege(accessToken string) (userID int64, sudo bool) DeleteToken(accessToken []byte) error FetchLastAccessToken(userID int64) string GetAccessToken(userID int64) (string, error) GetTemporaryAccessToken(userID int64, validSecs int) (string, error) GetTemporaryOneTimeAccessToken(userID int64, validSecs int, oneTime bool) (string, error) DeleteAccount(userID int64) (l *string, err error) ChangeSettings(app *App, u *User, s *userSettings) error ChangePassphrase(userID int64, sudo bool, curPass string, hashedPass []byte) error GetCollections(u *User, hostName string) (*[]Collection, error) GetPublishableCollections(u *User, hostName string) (*[]Collection, error) GetMeStats(u *User) userMeStats GetTotalCollections() (int64, error) GetTotalPosts() (int64, error) GetTopPosts(u *User, alias string) (*[]PublicPost, error) GetAnonymousPosts(u *User) (*[]PublicPost, error) GetUserPosts(u *User) (*[]PublicPost, error) CreateOwnedPost(post *SubmittedPost, accessToken, collAlias, hostName string) (*PublicPost, error) CreatePost(userID, collID int64, post *SubmittedPost) (*Post, error) UpdateOwnedPost(post *AuthenticatedPost, userID int64) error GetEditablePost(id, editToken string) (*PublicPost, error) PostIDExists(id string) bool GetPost(id string, collectionID int64) (*PublicPost, error) GetOwnedPost(id string, ownerID int64) (*PublicPost, error) GetPostProperty(id string, collectionID int64, property string) (interface{}, error) CreateCollectionFromToken(*config.Config, string, string, string) (*Collection, error) CreateCollection(*config.Config, string, string, int64) (*Collection, error) GetCollectionBy(condition string, value interface{}) (*Collection, error) GetCollection(alias string) (*Collection, error) GetCollectionForPad(alias string) (*Collection, error) GetCollectionByID(id int64) (*Collection, error) UpdateCollection(c *SubmittedCollection, alias string) error DeleteCollection(alias string, userID int64) error UpdatePostPinState(pinned bool, postID string, collID, ownerID, pos int64) error GetLastPinnedPostPos(collID int64) int64 GetPinnedPosts(coll *CollectionObj, includeFuture bool) (*[]PublicPost, error) RemoveCollectionRedirect(t *sql.Tx, alias string) error GetCollectionRedirect(alias string) (new string) IsCollectionAttributeOn(id int64, attr string) bool CollectionHasAttribute(id int64, attr string) bool CanCollect(cpr *ClaimPostRequest, userID int64) bool AttemptClaim(p *ClaimPostRequest, query string, params []interface{}, slugIdx int) (sql.Result, error) DispersePosts(userID int64, postIDs []string) (*[]ClaimPostResult, error) ClaimPosts(cfg *config.Config, userID int64, collAlias string, posts *[]ClaimPostRequest) (*[]ClaimPostResult, error) GetPostsCount(c *CollectionObj, includeFuture bool) GetPosts(cfg *config.Config, c *Collection, page int, includeFuture, forceRecentFirst, includePinned bool) (*[]PublicPost, error) GetPostsTagged(cfg *config.Config, c *Collection, tag string, page int, includeFuture bool) (*[]PublicPost, error) GetAPFollowers(c *Collection) (*[]RemoteUser, error) GetAPActorKeys(collectionID int64) ([]byte, []byte) CreateUserInvite(id string, userID int64, maxUses int, expires *time.Time) error GetUserInvites(userID int64) (*[]Invite, error) GetUserInvite(id string) (*Invite, error) GetUsersInvitedCount(id string) int64 CreateInvitedUser(inviteID string, userID int64) error GetDynamicContent(id string) (*instanceContent, error) UpdateDynamicContent(id, title, content, contentType string) error GetAllUsers(page uint) (*[]User, error) GetAllUsersCount() int64 GetUserLastPostTime(id int64) (*time.Time, error) GetCollectionLastPostTime(id int64) (*time.Time, error) - GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error) - RecordRemoteUserID(ctx context.Context, localUserID, remoteUserID int64) error - ValidateOAuthState(ctx context.Context, state string) error - GenerateOAuthState(ctx context.Context) (string, error) + GetIDForRemoteUser(context.Context, string, string, string) (int64, error) + RecordRemoteUserID(context.Context, int64, string, string, string, string) error + ValidateOAuthState(context.Context, string) (string, string, error) + GenerateOAuthState(context.Context, string, string) (string, error) DatabaseInitialized() bool } type datastore struct { *sql.DB driverName string } +var _ writestore = &datastore{} + func (db *datastore) now() string { if db.driverName == driverSQLite { return "strftime('%Y-%m-%d %H:%M:%S','now')" } return "NOW()" } func (db *datastore) clip(field string, l int) string { if db.driverName == driverSQLite { return fmt.Sprintf("SUBSTR(%s, 0, %d)", field, l) } return fmt.Sprintf("LEFT(%s, %d)", field, l) } func (db *datastore) upsert(indexedCols ...string) string { if db.driverName == driverSQLite { // NOTE: SQLite UPSERT syntax only works in v3.24.0 (2018-06-04) or later // Leaving this for whenever we can upgrade and include it in our binary cc := strings.Join(indexedCols, ", ") return "ON CONFLICT(" + cc + ") DO UPDATE SET" } return "ON DUPLICATE KEY UPDATE" } func (db *datastore) dateSub(l int, unit string) string { if db.driverName == driverSQLite { return fmt.Sprintf("DATETIME('now', '-%d %s')", l, unit) } return fmt.Sprintf("DATE_SUB(NOW(), INTERVAL %d %s)", l, unit) } func (db *datastore) CreateUser(cfg *config.Config, u *User, collectionTitle string) error { if db.PostIDExists(u.Username) { return impart.HTTPError{http.StatusConflict, "Invalid collection name."} } // New users get a `users` and `collections` row. t, err := db.Begin() if err != nil { return err } // 1. Add to `users` table // NOTE: Assumes User's Password is already hashed! res, err := t.Exec("INSERT INTO users (username, password, email) VALUES (?, ?, ?)", u.Username, u.HashedPass, u.Email) if err != nil { t.Rollback() if db.isDuplicateKeyErr(err) { return impart.HTTPError{http.StatusConflict, "Username is already taken."} } log.Error("Rolling back users INSERT: %v\n", err) return err } u.ID, err = res.LastInsertId() if err != nil { t.Rollback() log.Error("Rolling back after LastInsertId: %v\n", err) return err } // 2. Create user's Collection if collectionTitle == "" { collectionTitle = u.Username } res, err = t.Exec("INSERT INTO collections (alias, title, description, privacy, owner_id, view_count) VALUES (?, ?, ?, ?, ?, ?)", u.Username, collectionTitle, "", defaultVisibility(cfg), u.ID, 0) if err != nil { t.Rollback() if db.isDuplicateKeyErr(err) { return impart.HTTPError{http.StatusConflict, "Username is already taken."} } log.Error("Rolling back collections INSERT: %v\n", err) return err } db.RemoveCollectionRedirect(t, u.Username) err = t.Commit() if err != nil { t.Rollback() log.Error("Rolling back after Commit(): %v\n", err) return err } return nil } // FIXME: We're returning errors inconsistently in this file. Do we use Errorf // for returned value, or impart? func (db *datastore) UpdateUserEmail(keys *key.Keychain, userID int64, email string) error { encEmail, err := data.Encrypt(keys.EmailKey, email) if err != nil { return fmt.Errorf("Couldn't encrypt email %s: %s\n", email, err) } return db.UpdateEncryptedUserEmail(userID, encEmail) } func (db *datastore) UpdateEncryptedUserEmail(userID int64, encEmail []byte) error { _, err := db.Exec("UPDATE users SET email = ? WHERE id = ?", encEmail, userID) if err != nil { return fmt.Errorf("Unable to update user email: %s", err) } return nil } func (db *datastore) CreateCollectionFromToken(cfg *config.Config, alias, title, accessToken string) (*Collection, error) { userID := db.GetUserID(accessToken) if userID == -1 { return nil, ErrBadAccessToken } return db.CreateCollection(cfg, alias, title, userID) } func (db *datastore) GetUserCollectionCount(userID int64) (uint64, error) { var collCount uint64 err := db.QueryRow("SELECT COUNT(*) FROM collections WHERE owner_id = ?", userID).Scan(&collCount) switch { case err == sql.ErrNoRows: return 0, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user from database."} case err != nil: log.Error("Couldn't get collections count for user %d: %v", userID, err) return 0, err } return collCount, nil } func (db *datastore) CreateCollection(cfg *config.Config, alias, title string, userID int64) (*Collection, error) { if db.PostIDExists(alias) { return nil, impart.HTTPError{http.StatusConflict, "Invalid collection name."} } // All good, so create new collection res, err := db.Exec("INSERT INTO collections (alias, title, description, privacy, owner_id, view_count) VALUES (?, ?, ?, ?, ?, ?)", alias, title, "", defaultVisibility(cfg), userID, 0) if err != nil { if db.isDuplicateKeyErr(err) { return nil, impart.HTTPError{http.StatusConflict, "Collection already exists."} } log.Error("Couldn't add to collections: %v\n", err) return nil, err } c := &Collection{ Alias: alias, Title: title, OwnerID: userID, PublicOwner: false, Public: defaultVisibility(cfg) == CollPublic, } c.ID, err = res.LastInsertId() if err != nil { log.Error("Couldn't get collection LastInsertId: %v\n", err) } return c, nil } func (db *datastore) GetUserByID(id int64) (*User, error) { u := &User{ID: id} err := db.QueryRow("SELECT username, password, email, created, status FROM users WHERE id = ?", id).Scan(&u.Username, &u.HashedPass, &u.Email, &u.Created, &u.Status) switch { case err == sql.ErrNoRows: return nil, ErrUserNotFound case err != nil: log.Error("Couldn't SELECT user password: %v", err) return nil, err } return u, nil } // IsUserSuspended returns true if the user account associated with id is // currently suspended. func (db *datastore) IsUserSuspended(id int64) (bool, error) { u := &User{ID: id} err := db.QueryRow("SELECT status FROM users WHERE id = ?", id).Scan(&u.Status) switch { case err == sql.ErrNoRows: return false, fmt.Errorf("is user suspended: %v", ErrUserNotFound) case err != nil: log.Error("Couldn't SELECT user password: %v", err) return false, fmt.Errorf("is user suspended: %v", err) } return u.IsSilenced(), nil } // DoesUserNeedAuth returns true if the user hasn't provided any methods for // authenticating with the account, such a passphrase or email address. // Any errors are reported to admin and silently quashed, returning false as the // result. func (db *datastore) DoesUserNeedAuth(id int64) bool { var pass, email []byte // Find out if user has an email set first err := db.QueryRow("SELECT password, email FROM users WHERE id = ?", id).Scan(&pass, &email) switch { case err == sql.ErrNoRows: // ERROR. Don't give false positives on needing auth methods return false case err != nil: // ERROR. Don't give false positives on needing auth methods log.Error("Couldn't SELECT user %d from users: %v", id, err) return false } // User doesn't need auth if there's an email return len(email) == 0 && len(pass) == 0 } func (db *datastore) IsUserPassSet(id int64) (bool, error) { var pass []byte err := db.QueryRow("SELECT password FROM users WHERE id = ?", id).Scan(&pass) switch { case err == sql.ErrNoRows: return false, nil case err != nil: log.Error("Couldn't SELECT user %d from users: %v", id, err) return false, err } return len(pass) > 0, nil } func (db *datastore) GetUserForAuth(username string) (*User, error) { u := &User{Username: username} err := db.QueryRow("SELECT id, password, email, created, status FROM users WHERE username = ?", username).Scan(&u.ID, &u.HashedPass, &u.Email, &u.Created, &u.Status) switch { case err == sql.ErrNoRows: // Check if they've entered the wrong, unnormalized username username = getSlug(username, "") if username != u.Username { err = db.QueryRow("SELECT id FROM users WHERE username = ? LIMIT 1", username).Scan(&u.ID) if err == nil { return db.GetUserForAuth(username) } } return nil, ErrUserNotFound case err != nil: log.Error("Couldn't SELECT user password: %v", err) return nil, err } return u, nil } func (db *datastore) GetUserForAuthByID(userID int64) (*User, error) { u := &User{ID: userID} err := db.QueryRow("SELECT id, password, email, created, status FROM users WHERE id = ?", u.ID).Scan(&u.ID, &u.HashedPass, &u.Email, &u.Created, &u.Status) switch { case err == sql.ErrNoRows: return nil, ErrUserNotFound case err != nil: log.Error("Couldn't SELECT userForAuthByID: %v", err) return nil, err } return u, nil } func (db *datastore) GetUserNameFromToken(accessToken string) (string, error) { t := auth.GetToken(accessToken) if len(t) == 0 { return "", ErrNoAccessToken } var oneTime bool var username string err := db.QueryRow("SELECT username, one_time FROM accesstokens LEFT JOIN users ON user_id = id WHERE token LIKE ? AND (expires IS NULL OR expires > "+db.now()+")", t).Scan(&username, &oneTime) switch { case err == sql.ErrNoRows: return "", ErrBadAccessToken case err != nil: return "", ErrInternalGeneral } // Delete token if it was one-time if oneTime { db.DeleteToken(t[:]) } return username, nil } func (db *datastore) GetUserDataFromToken(accessToken string) (int64, string, error) { t := auth.GetToken(accessToken) if len(t) == 0 { return 0, "", ErrNoAccessToken } var userID int64 var oneTime bool var username string err := db.QueryRow("SELECT user_id, username, one_time FROM accesstokens LEFT JOIN users ON user_id = id WHERE token LIKE ? AND (expires IS NULL OR expires > "+db.now()+")", t).Scan(&userID, &username, &oneTime) switch { case err == sql.ErrNoRows: return 0, "", ErrBadAccessToken case err != nil: return 0, "", ErrInternalGeneral } // Delete token if it was one-time if oneTime { db.DeleteToken(t[:]) } return userID, username, nil } func (db *datastore) GetAPIUser(header string) (*User, error) { uID := db.GetUserID(header) if uID == -1 { return nil, fmt.Errorf(ErrUserNotFound.Error()) } return db.GetUserByID(uID) } // GetUserID takes a hexadecimal accessToken, parses it into its binary // representation, and gets any user ID associated with the token. If no user // is associated, -1 is returned. func (db *datastore) GetUserID(accessToken string) int64 { i, _ := db.GetUserIDPrivilege(accessToken) return i } func (db *datastore) GetUserIDPrivilege(accessToken string) (userID int64, sudo bool) { t := auth.GetToken(accessToken) if len(t) == 0 { return -1, false } var oneTime bool err := db.QueryRow("SELECT user_id, sudo, one_time FROM accesstokens WHERE token LIKE ? AND (expires IS NULL OR expires > "+db.now()+")", t).Scan(&userID, &sudo, &oneTime) switch { case err == sql.ErrNoRows: return -1, false case err != nil: return -1, false } // Delete token if it was one-time if oneTime { db.DeleteToken(t[:]) } return } func (db *datastore) DeleteToken(accessToken []byte) error { res, err := db.Exec("DELETE FROM accesstokens WHERE token LIKE ?", accessToken) if err != nil { return err } rowsAffected, _ := res.RowsAffected() if rowsAffected == 0 { return impart.HTTPError{http.StatusNotFound, "Token is invalid or doesn't exist"} } return nil } // FetchLastAccessToken creates a new non-expiring, valid access token for the given // userID. func (db *datastore) FetchLastAccessToken(userID int64) string { var t []byte err := db.QueryRow("SELECT token FROM accesstokens WHERE user_id = ? AND (expires IS NULL OR expires > "+db.now()+") ORDER BY created DESC LIMIT 1", userID).Scan(&t) switch { case err == sql.ErrNoRows: return "" case err != nil: log.Error("Failed selecting from accesstoken: %v", err) return "" } u, err := uuid.Parse(t) if err != nil { return "" } return u.String() } // GetAccessToken creates a new non-expiring, valid access token for the given // userID. func (db *datastore) GetAccessToken(userID int64) (string, error) { return db.GetTemporaryOneTimeAccessToken(userID, 0, false) } // GetTemporaryAccessToken creates a new valid access token for the given // userID that remains valid for the given time in seconds. If validSecs is 0, // the access token doesn't automatically expire. func (db *datastore) GetTemporaryAccessToken(userID int64, validSecs int) (string, error) { return db.GetTemporaryOneTimeAccessToken(userID, validSecs, false) } // GetTemporaryOneTimeAccessToken creates a new valid access token for the given // userID that remains valid for the given time in seconds and can only be used // once if oneTime is true. If validSecs is 0, the access token doesn't // automatically expire. func (db *datastore) GetTemporaryOneTimeAccessToken(userID int64, validSecs int, oneTime bool) (string, error) { u, err := uuid.NewV4() if err != nil { log.Error("Unable to generate token: %v", err) return "", err } // Insert UUID to `accesstokens` binTok := u[:] expirationVal := "NULL" if validSecs > 0 { expirationVal = fmt.Sprintf("DATE_ADD("+db.now()+", INTERVAL %d SECOND)", validSecs) } _, err = db.Exec("INSERT INTO accesstokens (token, user_id, one_time, expires) VALUES (?, ?, ?, "+expirationVal+")", string(binTok), userID, oneTime) if err != nil { log.Error("Couldn't INSERT accesstoken: %v", err) return "", err } return u.String(), nil } func (db *datastore) CreateOwnedPost(post *SubmittedPost, accessToken, collAlias, hostName string) (*PublicPost, error) { var userID, collID int64 = -1, -1 var coll *Collection var err error if accessToken != "" { userID = db.GetUserID(accessToken) if userID == -1 { return nil, ErrBadAccessToken } if collAlias != "" { coll, err = db.GetCollection(collAlias) if err != nil { return nil, err } coll.hostName = hostName if coll.OwnerID != userID { return nil, ErrForbiddenCollection } collID = coll.ID } } rp := &PublicPost{} rp.Post, err = db.CreatePost(userID, collID, post) if err != nil { return rp, err } if coll != nil { coll.ForPublic() rp.Collection = &CollectionObj{Collection: *coll} } return rp, nil } func (db *datastore) CreatePost(userID, collID int64, post *SubmittedPost) (*Post, error) { idLen := postIDLen friendlyID := store.GenerateFriendlyRandomString(idLen) // Handle appearance / font face appearance := post.Font if !post.isFontValid() { appearance = "norm" } var err error ownerID := sql.NullInt64{ Valid: false, } ownerCollID := sql.NullInt64{ Valid: false, } slug := sql.NullString{"", false} // If an alias was supplied, we'll add this to the collection as well. if userID > 0 { ownerID.Int64 = userID ownerID.Valid = true if collID > 0 { ownerCollID.Int64 = collID ownerCollID.Valid = true var slugVal string if post.Title != nil && *post.Title != "" { slugVal = getSlug(*post.Title, post.Language.String) if slugVal == "" { slugVal = getSlug(*post.Content, post.Language.String) } } else { slugVal = getSlug(*post.Content, post.Language.String) } if slugVal == "" { slugVal = friendlyID } slug = sql.NullString{slugVal, true} } } created := time.Now() if db.driverName == driverSQLite { // SQLite stores datetimes in UTC, so convert time.Now() to it here created = created.UTC() } if post.Created != nil { created, err = time.Parse("2006-01-02T15:04:05Z", *post.Created) if err != nil { log.Error("Unable to parse Created time '%s': %v", *post.Created, err) created = time.Now() if db.driverName == driverSQLite { // SQLite stores datetimes in UTC, so convert time.Now() to it here created = created.UTC() } } } stmt, err := db.Prepare("INSERT INTO posts (id, slug, title, content, text_appearance, language, rtl, privacy, owner_id, collection_id, created, updated, view_count) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, " + db.now() + ", ?)") if err != nil { return nil, err } defer stmt.Close() _, err = stmt.Exec(friendlyID, slug, post.Title, post.Content, appearance, post.Language, post.IsRTL, 0, ownerID, ownerCollID, created, 0) if err != nil { if db.isDuplicateKeyErr(err) { // Duplicate entry error; try a new slug // TODO: make this a little more robust slug = sql.NullString{id.GenSafeUniqueSlug(slug.String), true} _, err = stmt.Exec(friendlyID, slug, post.Title, post.Content, appearance, post.Language, post.IsRTL, 0, ownerID, ownerCollID, created, 0) if err != nil { return nil, handleFailedPostInsert(fmt.Errorf("Retried slug generation, still failed: %v", err)) } } else { return nil, handleFailedPostInsert(err) } } // TODO: return Created field in proper format return &Post{ ID: friendlyID, Slug: null.NewString(slug.String, slug.Valid), Font: appearance, Language: zero.NewString(post.Language.String, post.Language.Valid), RTL: zero.NewBool(post.IsRTL.Bool, post.IsRTL.Valid), OwnerID: null.NewInt(userID, true), CollectionID: null.NewInt(userID, true), Created: created.Truncate(time.Second).UTC(), Updated: time.Now().Truncate(time.Second).UTC(), Title: zero.NewString(*(post.Title), true), Content: *(post.Content), }, nil } // UpdateOwnedPost updates an existing post with only the given fields in the // supplied AuthenticatedPost. func (db *datastore) UpdateOwnedPost(post *AuthenticatedPost, userID int64) error { params := []interface{}{} var queryUpdates, sep, authCondition string if post.Slug != nil && *post.Slug != "" { queryUpdates += sep + "slug = ?" sep = ", " params = append(params, getSlug(*post.Slug, "")) } if post.Content != nil { queryUpdates += sep + "content = ?" sep = ", " params = append(params, post.Content) } if post.Title != nil { queryUpdates += sep + "title = ?" sep = ", " params = append(params, post.Title) } if post.Language.Valid { queryUpdates += sep + "language = ?" sep = ", " params = append(params, post.Language.String) } if post.IsRTL.Valid { queryUpdates += sep + "rtl = ?" sep = ", " params = append(params, post.IsRTL.Bool) } if post.Font != "" { queryUpdates += sep + "text_appearance = ?" sep = ", " params = append(params, post.Font) } if post.Created != nil { createTime, err := time.Parse(postMetaDateFormat, *post.Created) if err != nil { log.Error("Unable to parse Created date: %v", err) return fmt.Errorf("That's the incorrect format for Created date.") } queryUpdates += sep + "created = ?" sep = ", " params = append(params, createTime) } // WHERE parameters... // id = ? params = append(params, post.ID) // AND owner_id = ? authCondition = "(owner_id = ?)" params = append(params, userID) if queryUpdates == "" { return ErrPostNoUpdatableVals } queryUpdates += sep + "updated = " + db.now() res, err := db.Exec("UPDATE posts SET "+queryUpdates+" WHERE id = ? AND "+authCondition, params...) if err != nil { log.Error("Unable to update owned post: %v", err) return err } rowsAffected, _ := res.RowsAffected() if rowsAffected == 0 { // Show the correct error message if nothing was updated var dummy int err := db.QueryRow("SELECT 1 FROM posts WHERE id = ? AND "+authCondition, post.ID, params[len(params)-1]).Scan(&dummy) switch { case err == sql.ErrNoRows: return ErrUnauthorizedEditPost case err != nil: log.Error("Failed selecting from posts: %v", err) } return nil } return nil } func (db *datastore) GetCollectionBy(condition string, value interface{}) (*Collection, error) { c := &Collection{} // FIXME: change Collection to reflect database values. Add helper functions to get actual values var styleSheet, script, format zero.String row := db.QueryRow("SELECT id, alias, title, description, style_sheet, script, format, owner_id, privacy, view_count FROM collections WHERE "+condition, value) err := row.Scan(&c.ID, &c.Alias, &c.Title, &c.Description, &styleSheet, &script, &format, &c.OwnerID, &c.Visibility, &c.Views) switch { case err == sql.ErrNoRows: return nil, impart.HTTPError{http.StatusNotFound, "Collection doesn't exist."} case err != nil: log.Error("Failed selecting from collections: %v", err) return nil, err } c.StyleSheet = styleSheet.String c.Script = script.String c.Format = format.String c.Public = c.IsPublic() c.db = db return c, nil } func (db *datastore) GetCollection(alias string) (*Collection, error) { return db.GetCollectionBy("alias = ?", alias) } func (db *datastore) GetCollectionForPad(alias string) (*Collection, error) { c := &Collection{Alias: alias} row := db.QueryRow("SELECT id, alias, title, description, privacy FROM collections WHERE alias = ?", alias) err := row.Scan(&c.ID, &c.Alias, &c.Title, &c.Description, &c.Visibility) switch { case err == sql.ErrNoRows: return c, impart.HTTPError{http.StatusNotFound, "Collection doesn't exist."} case err != nil: log.Error("Failed selecting from collections: %v", err) return c, ErrInternalGeneral } c.Public = c.IsPublic() return c, nil } func (db *datastore) GetCollectionByID(id int64) (*Collection, error) { return db.GetCollectionBy("id = ?", id) } func (db *datastore) GetCollectionFromDomain(host string) (*Collection, error) { return db.GetCollectionBy("host = ?", host) } func (db *datastore) UpdateCollection(c *SubmittedCollection, alias string) error { q := query.NewUpdate(). SetStringPtr(c.Title, "title"). SetStringPtr(c.Description, "description"). SetNullString(c.StyleSheet, "style_sheet"). SetNullString(c.Script, "script") if c.Format != nil { cf := &CollectionFormat{Format: c.Format.String} if cf.Valid() { q.SetNullString(c.Format, "format") } } var updatePass bool if c.Visibility != nil && (collVisibility(*c.Visibility)&CollProtected == 0 || c.Pass != "") { q.SetIntPtr(c.Visibility, "privacy") if c.Pass != "" { updatePass = true } } // WHERE values q.Where("alias = ? AND owner_id = ?", alias, c.OwnerID) if q.Updates == "" { return ErrPostNoUpdatableVals } // Find any current domain var collID int64 var rowsAffected int64 var changed bool var res sql.Result err := db.QueryRow("SELECT id FROM collections WHERE alias = ?", alias).Scan(&collID) if err != nil { log.Error("Failed selecting from collections: %v. Some things won't work.", err) } // Update MathJax value if c.MathJax { if db.driverName == driverSQLite { _, err = db.Exec("INSERT OR REPLACE INTO collectionattributes (collection_id, attribute, value) VALUES (?, ?, ?)", collID, "render_mathjax", "1") } else { _, err = db.Exec("INSERT INTO collectionattributes (collection_id, attribute, value) VALUES (?, ?, ?) "+db.upsert("collection_id", "attribute")+" value = ?", collID, "render_mathjax", "1", "1") } if err != nil { log.Error("Unable to insert render_mathjax value: %v", err) return err } } else { _, err = db.Exec("DELETE FROM collectionattributes WHERE collection_id = ? AND attribute = ?", collID, "render_mathjax") if err != nil { log.Error("Unable to delete render_mathjax value: %v", err) return err } } // Update rest of the collection data res, err = db.Exec("UPDATE collections SET "+q.Updates+" WHERE "+q.Conditions, q.Params...) if err != nil { log.Error("Unable to update collection: %v", err) return err } rowsAffected, _ = res.RowsAffected() if !changed || rowsAffected == 0 { // Show the correct error message if nothing was updated var dummy int err := db.QueryRow("SELECT 1 FROM collections WHERE alias = ? AND owner_id = ?", alias, c.OwnerID).Scan(&dummy) switch { case err == sql.ErrNoRows: return ErrUnauthorizedEditPost case err != nil: log.Error("Failed selecting from collections: %v", err) } if !updatePass { return nil } } if updatePass { hashedPass, err := auth.HashPass([]byte(c.Pass)) if err != nil { log.Error("Unable to create hash: %s", err) return impart.HTTPError{http.StatusInternalServerError, "Could not create password hash."} } if db.driverName == driverSQLite { _, err = db.Exec("INSERT OR REPLACE INTO collectionpasswords (collection_id, password) VALUES ((SELECT id FROM collections WHERE alias = ?), ?)", alias, hashedPass) } else { _, err = db.Exec("INSERT INTO collectionpasswords (collection_id, password) VALUES ((SELECT id FROM collections WHERE alias = ?), ?) "+db.upsert("collection_id")+" password = ?", alias, hashedPass, hashedPass) } if err != nil { return err } } return nil } const postCols = "id, slug, text_appearance, language, rtl, privacy, owner_id, collection_id, pinned_position, created, updated, view_count, title, content" // getEditablePost returns a PublicPost with the given ID only if the given // edit token is valid for the post. func (db *datastore) GetEditablePost(id, editToken string) (*PublicPost, error) { // FIXME: code duplicated from getPost() // TODO: add slight logic difference to getPost / one func var ownerName sql.NullString p := &Post{} row := db.QueryRow("SELECT "+postCols+", (SELECT username FROM users WHERE users.id = posts.owner_id) AS username FROM posts WHERE id = ? LIMIT 1", id) err := row.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content, &ownerName) switch { case err == sql.ErrNoRows: return nil, ErrPostNotFound case err != nil: log.Error("Failed selecting from collections: %v", err) return nil, err } if p.Content == "" && p.Title.String == "" { return nil, ErrPostUnpublished } res := p.processPost() if ownerName.Valid { res.Owner = &PublicUser{Username: ownerName.String} } return &res, nil } func (db *datastore) PostIDExists(id string) bool { var dummy bool err := db.QueryRow("SELECT 1 FROM posts WHERE id = ?", id).Scan(&dummy) return err == nil && dummy } // GetPost gets a public-facing post object from the database. If collectionID // is > 0, the post will be retrieved by slug and collection ID, rather than // post ID. // TODO: break this into two functions: // - GetPost(id string) // - GetCollectionPost(slug string, collectionID int64) func (db *datastore) GetPost(id string, collectionID int64) (*PublicPost, error) { var ownerName sql.NullString p := &Post{} var row *sql.Row var where string params := []interface{}{id} if collectionID > 0 { where = "slug = ? AND collection_id = ?" params = append(params, collectionID) } else { where = "id = ?" } row = db.QueryRow("SELECT "+postCols+", (SELECT username FROM users WHERE users.id = posts.owner_id) AS username FROM posts WHERE "+where+" LIMIT 1", params...) err := row.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content, &ownerName) switch { case err == sql.ErrNoRows: if collectionID > 0 { return nil, ErrCollectionPageNotFound } return nil, ErrPostNotFound case err != nil: log.Error("Failed selecting from collections: %v", err) return nil, err } if p.Content == "" && p.Title.String == "" { return nil, ErrPostUnpublished } res := p.processPost() if ownerName.Valid { res.Owner = &PublicUser{Username: ownerName.String} } return &res, nil } // TODO: don't duplicate getPost() functionality func (db *datastore) GetOwnedPost(id string, ownerID int64) (*PublicPost, error) { p := &Post{} var row *sql.Row where := "id = ? AND owner_id = ?" params := []interface{}{id, ownerID} row = db.QueryRow("SELECT "+postCols+" FROM posts WHERE "+where+" LIMIT 1", params...) err := row.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content) switch { case err == sql.ErrNoRows: return nil, ErrPostNotFound case err != nil: log.Error("Failed selecting from collections: %v", err) return nil, err } if p.Content == "" && p.Title.String == "" { return nil, ErrPostUnpublished } res := p.processPost() return &res, nil } func (db *datastore) GetPostProperty(id string, collectionID int64, property string) (interface{}, error) { propSelects := map[string]string{ "views": "view_count AS views", } selectQuery, ok := propSelects[property] if !ok { return nil, impart.HTTPError{http.StatusBadRequest, fmt.Sprintf("Invalid property: %s.", property)} } var res interface{} var row *sql.Row if collectionID != 0 { row = db.QueryRow("SELECT "+selectQuery+" FROM posts WHERE slug = ? AND collection_id = ? LIMIT 1", id, collectionID) } else { row = db.QueryRow("SELECT "+selectQuery+" FROM posts WHERE id = ? LIMIT 1", id) } err := row.Scan(&res) switch { case err == sql.ErrNoRows: return nil, impart.HTTPError{http.StatusNotFound, "Post not found."} case err != nil: log.Error("Failed selecting post: %v", err) return nil, err } return res, nil } // GetPostsCount modifies the CollectionObj to include the correct number of // standard (non-pinned) posts. It will return future posts if `includeFuture` // is true. func (db *datastore) GetPostsCount(c *CollectionObj, includeFuture bool) { var count int64 timeCondition := "" if !includeFuture { timeCondition = "AND created <= " + db.now() } err := db.QueryRow("SELECT COUNT(*) FROM posts WHERE collection_id = ? AND pinned_position IS NULL "+timeCondition, c.ID).Scan(&count) switch { case err == sql.ErrNoRows: c.TotalPosts = 0 case err != nil: log.Error("Failed selecting from collections: %v", err) c.TotalPosts = 0 } c.TotalPosts = int(count) } // GetPosts retrieves all posts for the given Collection. // It will return future posts if `includeFuture` is true. // It will include only standard (non-pinned) posts unless `includePinned` is true. // TODO: change includeFuture to isOwner, since that's how it's used func (db *datastore) GetPosts(cfg *config.Config, c *Collection, page int, includeFuture, forceRecentFirst, includePinned bool) (*[]PublicPost, error) { collID := c.ID cf := c.NewFormat() order := "DESC" if cf.Ascending() && !forceRecentFirst { order = "ASC" } pagePosts := cf.PostsPerPage() start := page*pagePosts - pagePosts if page == 0 { start = 0 pagePosts = 1000 } limitStr := "" if page > 0 { limitStr = fmt.Sprintf(" LIMIT %d, %d", start, pagePosts) } timeCondition := "" if !includeFuture { timeCondition = "AND created <= " + db.now() } pinnedCondition := "" if !includePinned { pinnedCondition = "AND pinned_position IS NULL" } rows, err := db.Query("SELECT "+postCols+" FROM posts WHERE collection_id = ? "+pinnedCondition+" "+timeCondition+" ORDER BY created "+order+limitStr, collID) if err != nil { log.Error("Failed selecting from posts: %v", err) return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve collection posts."} } defer rows.Close() // TODO: extract this common row scanning logic for queries using `postCols` posts := []PublicPost{} for rows.Next() { p := &Post{} err = rows.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content) if err != nil { log.Error("Failed scanning row: %v", err) break } p.extractData() p.formatContent(cfg, c, includeFuture) posts = append(posts, p.processPost()) } err = rows.Err() if err != nil { log.Error("Error after Next() on rows: %v", err) } return &posts, nil } // GetPostsTagged retrieves all posts on the given Collection that contain the // given tag. // It will return future posts if `includeFuture` is true. // TODO: change includeFuture to isOwner, since that's how it's used func (db *datastore) GetPostsTagged(cfg *config.Config, c *Collection, tag string, page int, includeFuture bool) (*[]PublicPost, error) { collID := c.ID cf := c.NewFormat() order := "DESC" if cf.Ascending() { order = "ASC" } pagePosts := cf.PostsPerPage() start := page*pagePosts - pagePosts if page == 0 { start = 0 pagePosts = 1000 } limitStr := "" if page > 0 { limitStr = fmt.Sprintf(" LIMIT %d, %d", start, pagePosts) } timeCondition := "" if !includeFuture { timeCondition = "AND created <= " + db.now() } var rows *sql.Rows var err error if db.driverName == driverSQLite { rows, err = db.Query("SELECT "+postCols+" FROM posts WHERE collection_id = ? AND LOWER(content) regexp ? "+timeCondition+" ORDER BY created "+order+limitStr, collID, `.*#`+strings.ToLower(tag)+`\b.*`) } else { rows, err = db.Query("SELECT "+postCols+" FROM posts WHERE collection_id = ? AND LOWER(content) RLIKE ? "+timeCondition+" ORDER BY created "+order+limitStr, collID, "#"+strings.ToLower(tag)+"[[:>:]]") } if err != nil { log.Error("Failed selecting from posts: %v", err) return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve collection posts."} } defer rows.Close() // TODO: extract this common row scanning logic for queries using `postCols` posts := []PublicPost{} for rows.Next() { p := &Post{} err = rows.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content) if err != nil { log.Error("Failed scanning row: %v", err) break } p.extractData() p.formatContent(cfg, c, includeFuture) posts = append(posts, p.processPost()) } err = rows.Err() if err != nil { log.Error("Error after Next() on rows: %v", err) } return &posts, nil } func (db *datastore) GetAPFollowers(c *Collection) (*[]RemoteUser, error) { rows, err := db.Query("SELECT actor_id, inbox, shared_inbox FROM remotefollows f INNER JOIN remoteusers u ON f.remote_user_id = u.id WHERE collection_id = ?", c.ID) if err != nil { log.Error("Failed selecting from followers: %v", err) return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve followers."} } defer rows.Close() followers := []RemoteUser{} for rows.Next() { f := RemoteUser{} err = rows.Scan(&f.ActorID, &f.Inbox, &f.SharedInbox) followers = append(followers, f) } return &followers, nil } // CanCollect returns whether or not the given user can add the given post to a // collection. This is true when a post is already owned by the user. // NOTE: this is currently only used to potentially add owned posts to a // collection. This has the SIDE EFFECT of also generating a slug for the post. // FIXME: make this side effect more explicit (or extract it) func (db *datastore) CanCollect(cpr *ClaimPostRequest, userID int64) bool { var title, content string var lang sql.NullString err := db.QueryRow("SELECT title, content, language FROM posts WHERE id = ? AND owner_id = ?", cpr.ID, userID).Scan(&title, &content, &lang) switch { case err == sql.ErrNoRows: return false case err != nil: log.Error("Failed on post CanCollect(%s, %d): %v", cpr.ID, userID, err) return false } // Since we have the post content and the post is collectable, generate the // post's slug now. cpr.Slug = getSlugFromPost(title, content, lang.String) return true } func (db *datastore) AttemptClaim(p *ClaimPostRequest, query string, params []interface{}, slugIdx int) (sql.Result, error) { qRes, err := db.Exec(query, params...) if err != nil { if db.isDuplicateKeyErr(err) && slugIdx > -1 { s := id.GenSafeUniqueSlug(p.Slug) if s == p.Slug { // Sanity check to prevent infinite recursion return qRes, fmt.Errorf("GenSafeUniqueSlug generated nothing unique: %s", s) } p.Slug = s params[slugIdx] = p.Slug return db.AttemptClaim(p, query, params, slugIdx) } return qRes, fmt.Errorf("attemptClaim: %s", err) } return qRes, nil } func (db *datastore) DispersePosts(userID int64, postIDs []string) (*[]ClaimPostResult, error) { postClaimReqs := map[string]bool{} res := []ClaimPostResult{} for i := range postIDs { postID := postIDs[i] r := ClaimPostResult{Code: 0, ErrorMessage: ""} // Perform post validation if postID == "" { r.ErrorMessage = "Missing post ID. " } if _, ok := postClaimReqs[postID]; ok { r.Code = 429 r.ErrorMessage = "You've already tried anonymizing this post." r.ID = postID res = append(res, r) continue } postClaimReqs[postID] = true var err error // Get full post information to return var fullPost *PublicPost fullPost, err = db.GetPost(postID, 0) if err != nil { if err, ok := err.(impart.HTTPError); ok { r.Code = err.Status r.ErrorMessage = err.Message r.ID = postID res = append(res, r) continue } else { log.Error("Error getting post in dispersePosts: %v", err) } } if fullPost.OwnerID.Int64 != userID { r.Code = http.StatusConflict r.ErrorMessage = "Post is already owned by someone else." r.ID = postID res = append(res, r) continue } var qRes sql.Result var query string var params []interface{} // Do AND owner_id = ? for sanity. // This should've been caught and returned with a good error message // just above. query = "UPDATE posts SET collection_id = NULL WHERE id = ? AND owner_id = ?" params = []interface{}{postID, userID} qRes, err = db.Exec(query, params...) if err != nil { r.Code = http.StatusInternalServerError r.ErrorMessage = "A glitch happened on our end." r.ID = postID res = append(res, r) log.Error("dispersePosts (post %s): %v", postID, err) continue } // Post was successfully dispersed r.Code = http.StatusOK r.Post = fullPost rowsAffected, _ := qRes.RowsAffected() if rowsAffected == 0 { // This was already claimed, but return 200 r.Code = http.StatusOK } res = append(res, r) } return &res, nil } func (db *datastore) ClaimPosts(cfg *config.Config, userID int64, collAlias string, posts *[]ClaimPostRequest) (*[]ClaimPostResult, error) { postClaimReqs := map[string]bool{} res := []ClaimPostResult{} postCollAlias := collAlias for i := range *posts { p := (*posts)[i] if &p == nil { continue } r := ClaimPostResult{Code: 0, ErrorMessage: ""} // Perform post validation if p.ID == "" { r.ErrorMessage = "Missing post ID `id`. " } if _, ok := postClaimReqs[p.ID]; ok { r.Code = 429 r.ErrorMessage = "You've already tried claiming this post." r.ID = p.ID res = append(res, r) continue } postClaimReqs[p.ID] = true canCollect := db.CanCollect(&p, userID) if !canCollect && p.Token == "" { // TODO: ensure post isn't owned by anyone else when a valid modify // token is given. r.ErrorMessage += "Missing post Edit Token `token`." } if r.ErrorMessage != "" { // Post validate failed r.Code = http.StatusBadRequest r.ID = p.ID res = append(res, r) continue } var err error var qRes sql.Result var query string var params []interface{} var slugIdx int = -1 var coll *Collection if collAlias == "" { // Posts are being claimed at /posts/claim, not // /collections/{alias}/collect, so use given individual collection // to associate post with. postCollAlias = p.CollectionAlias } if postCollAlias != "" { // Associate this post with a collection if p.CreateCollection { // This is a new collection // TODO: consider removing this. This seriously complicates this // method and adds another (unnecessary?) logic path. coll, err = db.CreateCollection(cfg, postCollAlias, "", userID) if err != nil { if err, ok := err.(impart.HTTPError); ok { r.Code = err.Status r.ErrorMessage = err.Message } else { r.Code = http.StatusInternalServerError r.ErrorMessage = "Unknown error occurred creating collection" } r.ID = p.ID res = append(res, r) continue } } else { // Attempt to add to existing collection coll, err = db.GetCollection(postCollAlias) if err != nil { if err, ok := err.(impart.HTTPError); ok { if err.Status == http.StatusNotFound { // Show obfuscated "forbidden" response, as if attempting to add to an // unowned blog. r.Code = ErrForbiddenCollection.Status r.ErrorMessage = ErrForbiddenCollection.Message } else { r.Code = err.Status r.ErrorMessage = err.Message } } else { r.Code = http.StatusInternalServerError r.ErrorMessage = "Unknown error occurred claiming post with collection" } r.ID = p.ID res = append(res, r) continue } if coll.OwnerID != userID { r.Code = ErrForbiddenCollection.Status r.ErrorMessage = ErrForbiddenCollection.Message r.ID = p.ID res = append(res, r) continue } } if p.Slug == "" { p.Slug = p.ID } if canCollect { // User already owns this post, so just add it to the given // collection. query = "UPDATE posts SET collection_id = ?, slug = ? WHERE id = ? AND owner_id = ?" params = []interface{}{coll.ID, p.Slug, p.ID, userID} slugIdx = 1 } else { query = "UPDATE posts SET owner_id = ?, collection_id = ?, slug = ? WHERE id = ? AND modify_token = ? AND owner_id IS NULL" params = []interface{}{userID, coll.ID, p.Slug, p.ID, p.Token} slugIdx = 2 } } else { query = "UPDATE posts SET owner_id = ? WHERE id = ? AND modify_token = ? AND owner_id IS NULL" params = []interface{}{userID, p.ID, p.Token} } qRes, err = db.AttemptClaim(&p, query, params, slugIdx) if err != nil { r.Code = http.StatusInternalServerError r.ErrorMessage = "An unknown error occurred." r.ID = p.ID res = append(res, r) log.Error("claimPosts (post %s): %v", p.ID, err) continue } // Get full post information to return var fullPost *PublicPost if p.Token != "" { fullPost, err = db.GetEditablePost(p.ID, p.Token) } else { fullPost, err = db.GetPost(p.ID, 0) } if err != nil { if err, ok := err.(impart.HTTPError); ok { r.Code = err.Status r.ErrorMessage = err.Message r.ID = p.ID res = append(res, r) continue } } if fullPost.OwnerID.Int64 != userID { r.Code = http.StatusConflict r.ErrorMessage = "Post is already owned by someone else." r.ID = p.ID res = append(res, r) continue } // Post was successfully claimed r.Code = http.StatusOK r.Post = fullPost if coll != nil { r.Post.Collection = &CollectionObj{Collection: *coll} } rowsAffected, _ := qRes.RowsAffected() if rowsAffected == 0 { // This was already claimed, but return 200 r.Code = http.StatusOK } res = append(res, r) } return &res, nil } func (db *datastore) UpdatePostPinState(pinned bool, postID string, collID, ownerID, pos int64) error { if pos <= 0 || pos > 20 { pos = db.GetLastPinnedPostPos(collID) + 1 if pos == -1 { pos = 1 } } var err error if pinned { _, err = db.Exec("UPDATE posts SET pinned_position = ? WHERE id = ?", pos, postID) } else { _, err = db.Exec("UPDATE posts SET pinned_position = NULL WHERE id = ?", postID) } if err != nil { log.Error("Unable to update pinned post: %v", err) return err } return nil } func (db *datastore) GetLastPinnedPostPos(collID int64) int64 { var lastPos sql.NullInt64 err := db.QueryRow("SELECT MAX(pinned_position) FROM posts WHERE collection_id = ? AND pinned_position IS NOT NULL", collID).Scan(&lastPos) switch { case err == sql.ErrNoRows: return -1 case err != nil: log.Error("Failed selecting from posts: %v", err) return -1 } if !lastPos.Valid { return -1 } return lastPos.Int64 } func (db *datastore) GetPinnedPosts(coll *CollectionObj, includeFuture bool) (*[]PublicPost, error) { // FIXME: sqlite-backed instances don't include ellipsis on truncated titles timeCondition := "" if !includeFuture { timeCondition = "AND created <= " + db.now() } rows, err := db.Query("SELECT id, slug, title, "+db.clip("content", 80)+", pinned_position FROM posts WHERE collection_id = ? AND pinned_position IS NOT NULL "+timeCondition+" ORDER BY pinned_position ASC", coll.ID) if err != nil { log.Error("Failed selecting pinned posts: %v", err) return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve pinned posts."} } defer rows.Close() posts := []PublicPost{} for rows.Next() { p := &Post{} err = rows.Scan(&p.ID, &p.Slug, &p.Title, &p.Content, &p.PinnedPosition) if err != nil { log.Error("Failed scanning row: %v", err) break } p.extractData() pp := p.processPost() pp.Collection = coll posts = append(posts, pp) } return &posts, nil } func (db *datastore) GetCollections(u *User, hostName string) (*[]Collection, error) { rows, err := db.Query("SELECT id, alias, title, description, privacy, view_count FROM collections WHERE owner_id = ? ORDER BY id ASC", u.ID) if err != nil { log.Error("Failed selecting from collections: %v", err) return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user collections."} } defer rows.Close() colls := []Collection{} for rows.Next() { c := Collection{} err = rows.Scan(&c.ID, &c.Alias, &c.Title, &c.Description, &c.Visibility, &c.Views) if err != nil { log.Error("Failed scanning row: %v", err) break } c.hostName = hostName c.URL = c.CanonicalURL() c.Public = c.IsPublic() colls = append(colls, c) } err = rows.Err() if err != nil { log.Error("Error after Next() on rows: %v", err) } return &colls, nil } func (db *datastore) GetPublishableCollections(u *User, hostName string) (*[]Collection, error) { c, err := db.GetCollections(u, hostName) if err != nil { return nil, err } if len(*c) == 0 { return nil, impart.HTTPError{http.StatusInternalServerError, "You don't seem to have any blogs; they might've moved to another account. Try logging out and logging into your other account."} } return c, nil } func (db *datastore) GetMeStats(u *User) userMeStats { s := userMeStats{} // User counts colls, _ := db.GetUserCollectionCount(u.ID) s.TotalCollections = colls var articles, collPosts uint64 err := db.QueryRow("SELECT COUNT(*) FROM posts WHERE owner_id = ? AND collection_id IS NULL", u.ID).Scan(&articles) if err != nil && err != sql.ErrNoRows { log.Error("Couldn't get articles count for user %d: %v", u.ID, err) } s.TotalArticles = articles err = db.QueryRow("SELECT COUNT(*) FROM posts WHERE owner_id = ? AND collection_id IS NOT NULL", u.ID).Scan(&collPosts) if err != nil && err != sql.ErrNoRows { log.Error("Couldn't get coll posts count for user %d: %v", u.ID, err) } s.CollectionPosts = collPosts return s } func (db *datastore) GetTotalCollections() (collCount int64, err error) { err = db.QueryRow(` SELECT COUNT(*) FROM collections c LEFT JOIN users u ON u.id = c.owner_id WHERE u.status = 0`).Scan(&collCount) if err != nil { log.Error("Unable to fetch collections count: %v", err) } return } func (db *datastore) GetTotalPosts() (postCount int64, err error) { err = db.QueryRow(` SELECT COUNT(*) FROM posts p LEFT JOIN users u ON u.id = p.owner_id WHERE u.status = 0`).Scan(&postCount) if err != nil { log.Error("Unable to fetch posts count: %v", err) } return } func (db *datastore) GetTopPosts(u *User, alias string) (*[]PublicPost, error) { params := []interface{}{u.ID} where := "" if alias != "" { where = " AND alias = ?" params = append(params, alias) } rows, err := db.Query("SELECT p.id, p.slug, p.view_count, p.title, c.alias, c.title, c.description, c.view_count FROM posts p LEFT JOIN collections c ON p.collection_id = c.id WHERE p.owner_id = ?"+where+" ORDER BY p.view_count DESC, created DESC LIMIT 25", params...) if err != nil { log.Error("Failed selecting from posts: %v", err) return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user top posts."} } defer rows.Close() posts := []PublicPost{} var gotErr bool for rows.Next() { p := Post{} c := Collection{} var alias, title, description sql.NullString var views sql.NullInt64 err = rows.Scan(&p.ID, &p.Slug, &p.ViewCount, &p.Title, &alias, &title, &description, &views) if err != nil { log.Error("Failed scanning User.getPosts() row: %v", err) gotErr = true break } p.extractData() pubPost := p.processPost() if alias.Valid && alias.String != "" { c.Alias = alias.String c.Title = title.String c.Description = description.String c.Views = views.Int64 pubPost.Collection = &CollectionObj{Collection: c} } posts = append(posts, pubPost) } err = rows.Err() if err != nil { log.Error("Error after Next() on rows: %v", err) } if gotErr && len(posts) == 0 { // There were a lot of errors return nil, impart.HTTPError{http.StatusInternalServerError, "Unable to get data."} } return &posts, nil } func (db *datastore) GetAnonymousPosts(u *User) (*[]PublicPost, error) { rows, err := db.Query("SELECT id, view_count, title, created, updated, content FROM posts WHERE owner_id = ? AND collection_id IS NULL ORDER BY created DESC", u.ID) if err != nil { log.Error("Failed selecting from posts: %v", err) return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user anonymous posts."} } defer rows.Close() posts := []PublicPost{} for rows.Next() { p := Post{} err = rows.Scan(&p.ID, &p.ViewCount, &p.Title, &p.Created, &p.Updated, &p.Content) if err != nil { log.Error("Failed scanning row: %v", err) break } p.extractData() posts = append(posts, p.processPost()) } err = rows.Err() if err != nil { log.Error("Error after Next() on rows: %v", err) } return &posts, nil } func (db *datastore) GetUserPosts(u *User) (*[]PublicPost, error) { rows, err := db.Query("SELECT p.id, p.slug, p.view_count, p.title, p.created, p.updated, p.content, p.text_appearance, p.language, p.rtl, c.alias, c.title, c.description, c.view_count FROM posts p LEFT JOIN collections c ON collection_id = c.id WHERE p.owner_id = ? ORDER BY created ASC", u.ID) if err != nil { log.Error("Failed selecting from posts: %v", err) return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user posts."} } defer rows.Close() posts := []PublicPost{} var gotErr bool for rows.Next() { p := Post{} c := Collection{} var alias, title, description sql.NullString var views sql.NullInt64 err = rows.Scan(&p.ID, &p.Slug, &p.ViewCount, &p.Title, &p.Created, &p.Updated, &p.Content, &p.Font, &p.Language, &p.RTL, &alias, &title, &description, &views) if err != nil { log.Error("Failed scanning User.getPosts() row: %v", err) gotErr = true break } p.extractData() pubPost := p.processPost() if alias.Valid && alias.String != "" { c.Alias = alias.String c.Title = title.String c.Description = description.String c.Views = views.Int64 pubPost.Collection = &CollectionObj{Collection: c} } posts = append(posts, pubPost) } err = rows.Err() if err != nil { log.Error("Error after Next() on rows: %v", err) } if gotErr && len(posts) == 0 { // There were a lot of errors return nil, impart.HTTPError{http.StatusInternalServerError, "Unable to get data."} } return &posts, nil } func (db *datastore) GetUserPostsCount(userID int64) int64 { var count int64 err := db.QueryRow("SELECT COUNT(*) FROM posts WHERE owner_id = ?", userID).Scan(&count) switch { case err == sql.ErrNoRows: return 0 case err != nil: log.Error("Failed selecting posts count for user %d: %v", userID, err) return 0 } return count } // ChangeSettings takes a User and applies the changes in the given // userSettings, MODIFYING THE USER with successful changes. func (db *datastore) ChangeSettings(app *App, u *User, s *userSettings) error { var errPass error q := query.NewUpdate() // Update email if given if s.Email != "" { encEmail, err := data.Encrypt(app.keys.EmailKey, s.Email) if err != nil { log.Error("Couldn't encrypt email %s: %s\n", s.Email, err) return impart.HTTPError{http.StatusInternalServerError, "Unable to encrypt email address."} } q.SetBytes(encEmail, "email") // Update the email if something goes awry updating the password defer func() { if errPass != nil { db.UpdateEncryptedUserEmail(u.ID, encEmail) } }() u.Email = zero.StringFrom(s.Email) } // Update username if given var newUsername string if s.Username != "" { var ie *impart.HTTPError newUsername, ie = getValidUsername(app, s.Username, u.Username) if ie != nil { // Username is invalid return *ie } if !author.IsValidUsername(app.cfg, newUsername) { // Ensure the username is syntactically correct. return impart.HTTPError{http.StatusPreconditionFailed, "Username isn't valid."} } t, err := db.Begin() if err != nil { log.Error("Couldn't start username change transaction: %v", err) return err } _, err = t.Exec("UPDATE users SET username = ? WHERE id = ?", newUsername, u.ID) if err != nil { t.Rollback() if db.isDuplicateKeyErr(err) { return impart.HTTPError{http.StatusConflict, "Username is already taken."} } log.Error("Unable to update users table: %v", err) return ErrInternalGeneral } _, err = t.Exec("UPDATE collections SET alias = ? WHERE alias = ? AND owner_id = ?", newUsername, u.Username, u.ID) if err != nil { t.Rollback() if db.isDuplicateKeyErr(err) { return impart.HTTPError{http.StatusConflict, "Username is already taken."} } log.Error("Unable to update collection: %v", err) return ErrInternalGeneral } // Keep track of name changes for redirection db.RemoveCollectionRedirect(t, newUsername) _, err = t.Exec("UPDATE collectionredirects SET new_alias = ? WHERE new_alias = ?", newUsername, u.Username) if err != nil { log.Error("Unable to update collectionredirects: %v", err) } _, err = t.Exec("INSERT INTO collectionredirects (prev_alias, new_alias) VALUES (?, ?)", u.Username, newUsername) if err != nil { log.Error("Unable to add new collectionredirect: %v", err) } err = t.Commit() if err != nil { t.Rollback() log.Error("Rolling back after Commit(): %v\n", err) return err } u.Username = newUsername } // Update passphrase if given if s.NewPass != "" { // Check if user has already set a password var err error u.HasPass, err = db.IsUserPassSet(u.ID) if err != nil { errPass = impart.HTTPError{http.StatusInternalServerError, "Unable to retrieve user data."} return errPass } if u.HasPass { // Check if currently-set password is correct hashedPass := u.HashedPass if len(hashedPass) == 0 { authUser, err := db.GetUserForAuthByID(u.ID) if err != nil { errPass = err return errPass } hashedPass = authUser.HashedPass } if !auth.Authenticated(hashedPass, []byte(s.OldPass)) { errPass = impart.HTTPError{http.StatusUnauthorized, "Incorrect password."} return errPass } } hashedPass, err := auth.HashPass([]byte(s.NewPass)) if err != nil { errPass = impart.HTTPError{http.StatusInternalServerError, "Could not create password hash."} return errPass } q.SetBytes(hashedPass, "password") } // WHERE values q.Append(u.ID) if q.Updates == "" { if s.Username == "" { return ErrPostNoUpdatableVals } // Nothing to update except username. That was successful, so return now. return nil } res, err := db.Exec("UPDATE users SET "+q.Updates+" WHERE id = ?", q.Params...) if err != nil { log.Error("Unable to update collection: %v", err) return err } rowsAffected, _ := res.RowsAffected() if rowsAffected == 0 { // Show the correct error message if nothing was updated var dummy int err := db.QueryRow("SELECT 1 FROM users WHERE id = ?", u.ID).Scan(&dummy) switch { case err == sql.ErrNoRows: return ErrUnauthorizedGeneral case err != nil: log.Error("Failed selecting from users: %v", err) } return nil } if s.NewPass != "" && !u.HasPass { u.HasPass = true } return nil } func (db *datastore) ChangePassphrase(userID int64, sudo bool, curPass string, hashedPass []byte) error { var dbPass []byte err := db.QueryRow("SELECT password FROM users WHERE id = ?", userID).Scan(&dbPass) switch { case err == sql.ErrNoRows: return ErrUserNotFound case err != nil: log.Error("Couldn't SELECT user password for change: %v", err) return err } if !sudo && !auth.Authenticated(dbPass, []byte(curPass)) { return impart.HTTPError{http.StatusUnauthorized, "Incorrect password."} } _, err = db.Exec("UPDATE users SET password = ? WHERE id = ?", hashedPass, userID) if err != nil { log.Error("Could not update passphrase: %v", err) return err } return nil } func (db *datastore) RemoveCollectionRedirect(t *sql.Tx, alias string) error { _, err := t.Exec("DELETE FROM collectionredirects WHERE prev_alias = ?", alias) if err != nil { log.Error("Unable to delete from collectionredirects: %v", err) return err } return nil } func (db *datastore) GetCollectionRedirect(alias string) (new string) { row := db.QueryRow("SELECT new_alias FROM collectionredirects WHERE prev_alias = ?", alias) err := row.Scan(&new) if err != nil && err != sql.ErrNoRows { log.Error("Failed selecting from collectionredirects: %v", err) } return } func (db *datastore) DeleteCollection(alias string, userID int64) error { c := &Collection{Alias: alias} var username string row := db.QueryRow("SELECT username FROM users WHERE id = ?", userID) err := row.Scan(&username) if err != nil { return err } // Ensure user isn't deleting their main blog if alias == username { return impart.HTTPError{http.StatusForbidden, "You cannot currently delete your primary blog."} } row = db.QueryRow("SELECT id FROM collections WHERE alias = ? AND owner_id = ?", alias, userID) err = row.Scan(&c.ID) switch { case err == sql.ErrNoRows: return impart.HTTPError{http.StatusNotFound, "Collection doesn't exist or you're not allowed to delete it."} case err != nil: log.Error("Failed selecting from collections: %v", err) return ErrInternalGeneral } t, err := db.Begin() if err != nil { return err } // Float all collection's posts _, err = t.Exec("UPDATE posts SET collection_id = NULL WHERE collection_id = ? AND owner_id = ?", c.ID, userID) if err != nil { t.Rollback() return err } // Remove redirects to or from this collection _, err = t.Exec("DELETE FROM collectionredirects WHERE prev_alias = ? OR new_alias = ?", alias, alias) if err != nil { t.Rollback() return err } // Remove any optional collection password _, err = t.Exec("DELETE FROM collectionpasswords WHERE collection_id = ?", c.ID) if err != nil { t.Rollback() return err } // Finally, delete collection itself _, err = t.Exec("DELETE FROM collections WHERE id = ?", c.ID) if err != nil { t.Rollback() return err } err = t.Commit() if err != nil { t.Rollback() return err } return nil } func (db *datastore) IsCollectionAttributeOn(id int64, attr string) bool { var v string err := db.QueryRow("SELECT value FROM collectionattributes WHERE collection_id = ? AND attribute = ?", id, attr).Scan(&v) switch { case err == sql.ErrNoRows: return false case err != nil: log.Error("Couldn't SELECT value in isCollectionAttributeOn for attribute '%s': %v", attr, err) return false } return v == "1" } func (db *datastore) CollectionHasAttribute(id int64, attr string) bool { var dummy string err := db.QueryRow("SELECT value FROM collectionattributes WHERE collection_id = ? AND attribute = ?", id, attr).Scan(&dummy) switch { case err == sql.ErrNoRows: return false case err != nil: log.Error("Couldn't SELECT value in collectionHasAttribute for attribute '%s': %v", attr, err) return false } return true } func (db *datastore) DeleteAccount(userID int64) (l *string, err error) { debug := "" l = &debug t, err := db.Begin() if err != nil { stringLogln(l, "Unable to begin: %v", err) return } // Get all collections rows, err := db.Query("SELECT id, alias FROM collections WHERE owner_id = ?", userID) if err != nil { t.Rollback() stringLogln(l, "Unable to get collections: %v", err) return } defer rows.Close() colls := []Collection{} var c Collection for rows.Next() { err = rows.Scan(&c.ID, &c.Alias) if err != nil { t.Rollback() stringLogln(l, "Unable to scan collection cols: %v", err) return } colls = append(colls, c) } var res sql.Result for _, c := range colls { // TODO: user deleteCollection() func // Delete tokens res, err = t.Exec("DELETE FROM collectionattributes WHERE collection_id = ?", c.ID) if err != nil { t.Rollback() stringLogln(l, "Unable to delete attributes on %s: %v", c.Alias, err) return } rs, _ := res.RowsAffected() stringLogln(l, "Deleted %d for %s from collectionattributes", rs, c.Alias) // Remove any optional collection password res, err = t.Exec("DELETE FROM collectionpasswords WHERE collection_id = ?", c.ID) if err != nil { t.Rollback() stringLogln(l, "Unable to delete passwords on %s: %v", c.Alias, err) return } rs, _ = res.RowsAffected() stringLogln(l, "Deleted %d for %s from collectionpasswords", rs, c.Alias) // Remove redirects to this collection res, err = t.Exec("DELETE FROM collectionredirects WHERE new_alias = ?", c.Alias) if err != nil { t.Rollback() stringLogln(l, "Unable to delete redirects on %s: %v", c.Alias, err) return } rs, _ = res.RowsAffected() stringLogln(l, "Deleted %d for %s from collectionredirects", rs, c.Alias) } // Delete collections res, err = t.Exec("DELETE FROM collections WHERE owner_id = ?", userID) if err != nil { t.Rollback() stringLogln(l, "Unable to delete collections: %v", err) return } rs, _ := res.RowsAffected() stringLogln(l, "Deleted %d from collections", rs) // Delete tokens res, err = t.Exec("DELETE FROM accesstokens WHERE user_id = ?", userID) if err != nil { t.Rollback() stringLogln(l, "Unable to delete access tokens: %v", err) return } rs, _ = res.RowsAffected() stringLogln(l, "Deleted %d from accesstokens", rs) // Delete posts res, err = t.Exec("DELETE FROM posts WHERE owner_id = ?", userID) if err != nil { t.Rollback() stringLogln(l, "Unable to delete posts: %v", err) return } rs, _ = res.RowsAffected() stringLogln(l, "Deleted %d from posts", rs) res, err = t.Exec("DELETE FROM userattributes WHERE user_id = ?", userID) if err != nil { t.Rollback() stringLogln(l, "Unable to delete attributes: %v", err) return } rs, _ = res.RowsAffected() stringLogln(l, "Deleted %d from userattributes", rs) res, err = t.Exec("DELETE FROM users WHERE id = ?", userID) if err != nil { t.Rollback() stringLogln(l, "Unable to delete user: %v", err) return } rs, _ = res.RowsAffected() stringLogln(l, "Deleted %d from users", rs) err = t.Commit() if err != nil { t.Rollback() stringLogln(l, "Unable to commit: %v", err) return } return } func (db *datastore) GetAPActorKeys(collectionID int64) ([]byte, []byte) { var pub, priv []byte err := db.QueryRow("SELECT public_key, private_key FROM collectionkeys WHERE collection_id = ?", collectionID).Scan(&pub, &priv) switch { case err == sql.ErrNoRows: // Generate keys pub, priv = activitypub.GenerateKeys() _, err = db.Exec("INSERT INTO collectionkeys (collection_id, public_key, private_key) VALUES (?, ?, ?)", collectionID, pub, priv) if err != nil { log.Error("Unable to INSERT new activitypub keypair: %v", err) return nil, nil } case err != nil: log.Error("Couldn't SELECT collectionkeys: %v", err) return nil, nil } return pub, priv } func (db *datastore) CreateUserInvite(id string, userID int64, maxUses int, expires *time.Time) error { _, err := db.Exec("INSERT INTO userinvites (id, owner_id, max_uses, created, expires, inactive) VALUES (?, ?, ?, "+db.now()+", ?, 0)", id, userID, maxUses, expires) return err } func (db *datastore) GetUserInvites(userID int64) (*[]Invite, error) { rows, err := db.Query("SELECT id, max_uses, created, expires, inactive FROM userinvites WHERE owner_id = ? ORDER BY created DESC", userID) if err != nil { log.Error("Failed selecting from userinvites: %v", err) return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user invites."} } defer rows.Close() is := []Invite{} for rows.Next() { i := Invite{} err = rows.Scan(&i.ID, &i.MaxUses, &i.Created, &i.Expires, &i.Inactive) is = append(is, i) } return &is, nil } func (db *datastore) GetUserInvite(id string) (*Invite, error) { var i Invite err := db.QueryRow("SELECT id, max_uses, created, expires, inactive FROM userinvites WHERE id = ?", id).Scan(&i.ID, &i.MaxUses, &i.Created, &i.Expires, &i.Inactive) switch { case err == sql.ErrNoRows: return nil, impart.HTTPError{http.StatusNotFound, "Invite doesn't exist."} case err != nil: log.Error("Failed selecting invite: %v", err) return nil, err } return &i, nil } // IsUsersInvite returns true if the user with ID created the invite with code // and an error other than sql no rows, if any. Will return false in the event // of an error. func (db *datastore) IsUsersInvite(code string, userID int64) (bool, error) { var id string err := db.QueryRow("SELECT id FROM userinvites WHERE id = ? AND owner_id = ?", code, userID).Scan(&id) if err != nil && err != sql.ErrNoRows { log.Error("Failed selecting invite: %v", err) return false, err } return id != "", nil } func (db *datastore) GetUsersInvitedCount(id string) int64 { var count int64 err := db.QueryRow("SELECT COUNT(*) FROM usersinvited WHERE invite_id = ?", id).Scan(&count) switch { case err == sql.ErrNoRows: return 0 case err != nil: log.Error("Failed selecting users invited count: %v", err) return 0 } return count } func (db *datastore) CreateInvitedUser(inviteID string, userID int64) error { _, err := db.Exec("INSERT INTO usersinvited (invite_id, user_id) VALUES (?, ?)", inviteID, userID) return err } func (db *datastore) GetInstancePages() ([]*instanceContent, error) { return db.GetAllDynamicContent("page") } func (db *datastore) GetAllDynamicContent(t string) ([]*instanceContent, error) { where := "" params := []interface{}{} if t != "" { where = " WHERE content_type = ?" params = append(params, t) } rows, err := db.Query("SELECT id, title, content, updated, content_type FROM appcontent"+where, params...) if err != nil { log.Error("Failed selecting from appcontent: %v", err) return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve instance pages."} } defer rows.Close() pages := []*instanceContent{} for rows.Next() { c := &instanceContent{} err = rows.Scan(&c.ID, &c.Title, &c.Content, &c.Updated, &c.Type) if err != nil { log.Error("Failed scanning row: %v", err) break } pages = append(pages, c) } err = rows.Err() if err != nil { log.Error("Error after Next() on rows: %v", err) } return pages, nil } func (db *datastore) GetDynamicContent(id string) (*instanceContent, error) { c := &instanceContent{ ID: id, } err := db.QueryRow("SELECT title, content, updated, content_type FROM appcontent WHERE id = ?", id).Scan(&c.Title, &c.Content, &c.Updated, &c.Type) switch { case err == sql.ErrNoRows: return nil, nil case err != nil: log.Error("Couldn't SELECT FROM appcontent for id '%s': %v", id, err) return nil, err } return c, nil } func (db *datastore) UpdateDynamicContent(id, title, content, contentType string) error { var err error if db.driverName == driverSQLite { _, err = db.Exec("INSERT OR REPLACE INTO appcontent (id, title, content, updated, content_type) VALUES (?, ?, ?, "+db.now()+", ?)", id, title, content, contentType) } else { _, err = db.Exec("INSERT INTO appcontent (id, title, content, updated, content_type) VALUES (?, ?, ?, "+db.now()+", ?) "+db.upsert("id")+" title = ?, content = ?, updated = "+db.now(), id, title, content, contentType, title, content) } if err != nil { log.Error("Unable to INSERT appcontent for '%s': %v", id, err) } return err } func (db *datastore) GetAllUsers(page uint) (*[]User, error) { limitStr := fmt.Sprintf("0, %d", adminUsersPerPage) if page > 1 { limitStr = fmt.Sprintf("%d, %d", (page-1)*adminUsersPerPage, adminUsersPerPage) } rows, err := db.Query("SELECT id, username, created, status FROM users ORDER BY created DESC LIMIT " + limitStr) if err != nil { log.Error("Failed selecting from users: %v", err) return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve all users."} } defer rows.Close() users := []User{} for rows.Next() { u := User{} err = rows.Scan(&u.ID, &u.Username, &u.Created, &u.Status) if err != nil { log.Error("Failed scanning GetAllUsers() row: %v", err) break } users = append(users, u) } return &users, nil } func (db *datastore) GetAllUsersCount() int64 { var count int64 err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count) switch { case err == sql.ErrNoRows: return 0 case err != nil: log.Error("Failed selecting all users count: %v", err) return 0 } return count } func (db *datastore) GetUserLastPostTime(id int64) (*time.Time, error) { var t time.Time err := db.QueryRow("SELECT created FROM posts WHERE owner_id = ? ORDER BY created DESC LIMIT 1", id).Scan(&t) switch { case err == sql.ErrNoRows: return nil, nil case err != nil: log.Error("Failed selecting last post time from posts: %v", err) return nil, err } return &t, nil } // SetUserStatus changes a user's status in the database. see Users.UserStatus func (db *datastore) SetUserStatus(id int64, status UserStatus) error { _, err := db.Exec("UPDATE users SET status = ? WHERE id = ?", status, id) if err != nil { return fmt.Errorf("failed to update user status: %v", err) } return nil } func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) { var t time.Time err := db.QueryRow("SELECT created FROM posts WHERE collection_id = ? ORDER BY created DESC LIMIT 1", id).Scan(&t) switch { case err == sql.ErrNoRows: return nil, nil case err != nil: log.Error("Failed selecting last post time from posts: %v", err) return nil, err } return &t, nil } -func (db *datastore) GenerateOAuthState(ctx context.Context) (string, error) { +func (db *datastore) GenerateOAuthState(ctx context.Context, provider, clientID string) (string, error) { state := store.Generate62RandomString(24) - _, err := db.ExecContext(ctx, "INSERT INTO oauth_client_state (state, used, created_at) VALUES (?, FALSE, NOW())", state) + _, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at) VALUES (?, ?, ?, FALSE, NOW())", state, provider, clientID) if err != nil { return "", fmt.Errorf("unable to record oauth client state: %w", err) } return state, nil } -func (db *datastore) ValidateOAuthState(ctx context.Context, state string) error { - res, err := db.ExecContext(ctx, "UPDATE oauth_client_state SET used = TRUE WHERE state = ?", state) - if err != nil { - return err - } - rowsAffected, err := res.RowsAffected() +func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) { + var provider string + var clientID string + err := wf_db.RunTransactionWithOptions(ctx, db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { + err := tx.QueryRow("SELECT provider, client_id FROM oauth_client_states WHERE state = ? AND used = FALSE", state).Scan(&provider, &clientID) + if err != nil { + return err + } + + res, err := tx.ExecContext(ctx, "UPDATE oauth_client_states SET used = TRUE WHERE state = ?", state) + if err != nil { + return err + } + rowsAffected, err := res.RowsAffected() + if err != nil { + return err + } + if rowsAffected != 1 { + return fmt.Errorf("state not found") + } + return nil + }) if err != nil { - return err + return "", "", nil } - if rowsAffected != 1 { - return fmt.Errorf("state not found") - } - return nil + return provider, clientID, nil } -func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID, remoteUserID int64) error { +func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error { var err error if db.driverName == driverSQLite { - _, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO users_oauth (user_id, remote_user_id) VALUES (?, ?)", localUserID, remoteUserID) + _, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO oauth_users (user_id, remote_user_id, provider, client_id, access_token) VALUES (?, ?, ?, ?, ?)", localUserID, remoteUserID, provider, clientID, accessToken) } else { - _, err = db.ExecContext(ctx, "INSERT INTO users_oauth (user_id, remote_user_id) VALUES (?, ?) "+db.upsert("user_id")+" user_id = ?", localUserID, remoteUserID, localUserID) + _, err = db.ExecContext(ctx, "INSERT INTO oauth_users (user_id, remote_user_id, provider, client_id, access_token) VALUES (?, ?, ?, ?, ?) "+db.upsert("user")+" access_token = ?", localUserID, remoteUserID, provider, clientID, accessToken, accessToken) } if err != nil { - log.Error("Unable to INSERT users_oauth for '%d': %v", localUserID, err) + log.Error("Unable to INSERT oauth_users for '%d': %v", localUserID, err) } return err } // GetIDForRemoteUser returns a user ID associated with a remote user ID. -func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error) { +func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) { var userID int64 = -1 err := db. - QueryRowContext(ctx, "SELECT user_id FROM users_oauth WHERE remote_user_id = ?", remoteUserID). + QueryRowContext(ctx, "SELECT user_id FROM oauth_users WHERE remote_user_id = ? AND provider = ? AND client_id = ?", remoteUserID, provider, clientID). Scan(&userID) // Not finding a record is OK. if err != nil && err != sql.ErrNoRows { return -1, err } return userID, nil } // DatabaseInitialized returns whether or not the current datastore has been // initialized with the correct schema. // Currently, it checks to see if the `users` table exists. func (db *datastore) DatabaseInitialized() bool { var dummy string var err error if db.driverName == driverSQLite { err = db.QueryRow("SELECT name FROM sqlite_master WHERE type = 'table' AND name = 'users'").Scan(&dummy) } else { err = db.QueryRow("SHOW TABLES LIKE 'users'").Scan(&dummy) } switch { case err == sql.ErrNoRows: return false case err != nil: log.Error("Couldn't SHOW TABLES: %v", err) return false } return true } func stringLogln(log *string, s string, v ...interface{}) { *log += fmt.Sprintf(s+"\n", v...) } func handleFailedPostInsert(err error) error { log.Error("Couldn't insert into posts: %v", err) return err } diff --git a/database_test.go b/database_test.go index 4a45dd0..c4c586a 100644 --- a/database_test.go +++ b/database_test.go @@ -1,43 +1,50 @@ package writefreely import ( "context" "database/sql" "github.com/stretchr/testify/assert" "testing" ) func TestOAuthDatastore(t *testing.T) { if !runMySQLTests() { t.Skip("skipping mysql tests") } withTestDB(t, func(db *sql.DB) { ctx := context.Background() ds := &datastore{ DB: db, driverName: "", } - state, err := ds.GenerateOAuthState(ctx) + state, err := ds.GenerateOAuthState(ctx, "test", "development") assert.NoError(t, err) assert.Len(t, state, 24) - countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_state` WHERE `state` = ? AND `used` = false", state) + countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false", state) - err = ds.ValidateOAuthState(ctx, state) + _, _, err = ds.ValidateOAuthState(ctx, state) assert.NoError(t, err) - countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_state` WHERE `state` = ? AND `used` = true", state) + countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true", state) var localUserID int64 = 99 - var remoteUserID int64 = 100 - err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID) + var remoteUserID = "100" + err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID, "test", "test", "access_token_a") assert.NoError(t, err) - countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `users_oauth` WHERE `user_id` = ? AND `remote_user_id` = ?", localUserID, remoteUserID) + countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users` WHERE `user_id` = ? AND `remote_user_id` = ? AND access_token = 'access_token_a'", localUserID, remoteUserID) - foundUserID, err := ds.GetIDForRemoteUser(ctx, remoteUserID) + err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID, "test", "test", "access_token_b") + assert.NoError(t, err) + + countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users` WHERE `user_id` = ? AND `remote_user_id` = ? AND access_token = 'access_token_b'", localUserID, remoteUserID) + + countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users`") + + foundUserID, err := ds.GetIDForRemoteUser(ctx, remoteUserID, "test", "test") assert.NoError(t, err) assert.Equal(t, localUserID, foundUserID) }) } diff --git a/db/alter.go b/db/alter.go new file mode 100644 index 0000000..0a4ffdd --- /dev/null +++ b/db/alter.go @@ -0,0 +1,52 @@ +package db + +import ( + "fmt" + "strings" +) + +type AlterTableSqlBuilder struct { + Dialect DialectType + Name string + Changes []string +} + +func (b *AlterTableSqlBuilder) AddColumn(col *Column) *AlterTableSqlBuilder { + if colVal, err := col.String(); err == nil { + b.Changes = append(b.Changes, fmt.Sprintf("ADD COLUMN %s", colVal)) + } + return b +} + +func (b *AlterTableSqlBuilder) ChangeColumn(name string, col *Column) *AlterTableSqlBuilder { + if colVal, err := col.String(); err == nil { + b.Changes = append(b.Changes, fmt.Sprintf("CHANGE COLUMN %s %s", name, colVal)) + } + return b +} + +func (b *AlterTableSqlBuilder) AddUniqueConstraint(name string, columns ...string) *AlterTableSqlBuilder { + b.Changes = append(b.Changes, fmt.Sprintf("ADD CONSTRAINT %s UNIQUE (%s)", name, strings.Join(columns, ", "))) + return b +} + +func (b *AlterTableSqlBuilder) ToSQL() (string, error) { + var str strings.Builder + + str.WriteString("ALTER TABLE ") + str.WriteString(b.Name) + str.WriteString(" ") + + if len(b.Changes) == 0 { + return "", fmt.Errorf("no changes provide for table: %s", b.Name) + } + changeCount := len(b.Changes) + for i, thing := range b.Changes { + str.WriteString(thing) + if i < changeCount-1 { + str.WriteString(", ") + } + } + + return str.String(), nil +} diff --git a/db/alter_test.go b/db/alter_test.go new file mode 100644 index 0000000..4bd58ac --- /dev/null +++ b/db/alter_test.go @@ -0,0 +1,56 @@ +package db + +import "testing" + +func TestAlterTableSqlBuilder_ToSQL(t *testing.T) { + type fields struct { + Dialect DialectType + Name string + Changes []string + } + tests := []struct { + name string + builder *AlterTableSqlBuilder + want string + wantErr bool + }{ + { + name: "MySQL add int", + builder: DialectMySQL. + AlterTable("the_table"). + AddColumn(DialectMySQL.Column("the_col", ColumnTypeInteger, UnsetSize)), + want: "ALTER TABLE the_table ADD COLUMN the_col INT NOT NULL", + wantErr: false, + }, + { + name: "MySQL add string", + builder: DialectMySQL. + AlterTable("the_table"). + AddColumn(DialectMySQL.Column("the_col", ColumnTypeVarChar, OptionalInt{true, 128})), + want: "ALTER TABLE the_table ADD COLUMN the_col VARCHAR(128) NOT NULL", + wantErr: false, + }, + + { + name: "MySQL add int and string", + builder: DialectMySQL. + AlterTable("the_table"). + AddColumn(DialectMySQL.Column("first_col", ColumnTypeInteger, UnsetSize)). + AddColumn(DialectMySQL.Column("second_col", ColumnTypeVarChar, OptionalInt{true, 128})), + want: "ALTER TABLE the_table ADD COLUMN first_col INT NOT NULL, ADD COLUMN second_col VARCHAR(128) NOT NULL", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.builder.ToSQL() + if (err != nil) != tt.wantErr { + t.Errorf("ToSQL() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("ToSQL() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/db/create.go b/db/create.go index 5e4f34e..c384778 100644 --- a/db/create.go +++ b/db/create.go @@ -1,271 +1,244 @@ package db import ( "fmt" "strings" ) -type DialectType int type ColumnType int type OptionalInt struct { Set bool Value int } type OptionalString struct { Set bool Value string } type SQLBuilder interface { ToSQL() (string, error) } type Column struct { Dialect DialectType Name string Nullable bool Default OptionalString Type ColumnType Size OptionalInt PrimaryKey bool } type CreateTableSqlBuilder struct { Dialect DialectType Name string IfNotExists bool ColumnOrder []string Columns map[string]*Column Constraints []string } -const ( - DialectSQLite DialectType = iota - DialectMySQL DialectType = iota -) - const ( ColumnTypeBool ColumnType = iota ColumnTypeSmallInt ColumnType = iota ColumnTypeInteger ColumnType = iota ColumnTypeChar ColumnType = iota ColumnTypeVarChar ColumnType = iota ColumnTypeText ColumnType = iota ColumnTypeDateTime ColumnType = iota ) var _ SQLBuilder = &CreateTableSqlBuilder{} var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0} var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""} -func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column { - switch d { - case DialectSQLite: - return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size} - case DialectMySQL: - return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size} - default: - panic(fmt.Sprintf("unexpected dialect: %d", d)) - } -} - -func (d DialectType) Table(name string) *CreateTableSqlBuilder { - switch d { - case DialectSQLite: - return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name} - case DialectMySQL: - return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name} - default: - panic(fmt.Sprintf("unexpected dialect: %d", d)) - } -} - func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) { if dialect != DialectMySQL && dialect != DialectSQLite { return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) } switch d { case ColumnTypeSmallInt: { if dialect == DialectSQLite { return "INTEGER", nil } mod := "" if size.Set { mod = fmt.Sprintf("(%d)", size.Value) } return "SMALLINT" + mod, nil } case ColumnTypeInteger: { if dialect == DialectSQLite { return "INTEGER", nil } mod := "" if size.Set { mod = fmt.Sprintf("(%d)", size.Value) } return "INT" + mod, nil } case ColumnTypeChar: { if dialect == DialectSQLite { return "TEXT", nil } mod := "" if size.Set { mod = fmt.Sprintf("(%d)", size.Value) } return "CHAR" + mod, nil } case ColumnTypeVarChar: { if dialect == DialectSQLite { return "TEXT", nil } mod := "" if size.Set { mod = fmt.Sprintf("(%d)", size.Value) } return "VARCHAR" + mod, nil } case ColumnTypeBool: { if dialect == DialectSQLite { return "INTEGER", nil } return "TINYINT(1)", nil } case ColumnTypeDateTime: return "DATETIME", nil case ColumnTypeText: return "TEXT", nil } return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) } func (c *Column) SetName(name string) *Column { c.Name = name return c } func (c *Column) SetNullable(nullable bool) *Column { c.Nullable = nullable return c } func (c *Column) SetPrimaryKey(pk bool) *Column { c.PrimaryKey = pk return c } func (c *Column) SetDefault(value string) *Column { c.Default = OptionalString{Set: true, Value: value} return c } func (c *Column) SetType(t ColumnType) *Column { c.Type = t return c } func (c *Column) SetSize(size int) *Column { c.Size = OptionalInt{Set: true, Value: size} return c } func (c *Column) String() (string, error) { var str strings.Builder str.WriteString(c.Name) str.WriteString(" ") typeStr, err := c.Type.Format(c.Dialect, c.Size) if err != nil { return "", err } str.WriteString(typeStr) if !c.Nullable { str.WriteString(" NOT NULL") } if c.Default.Set { str.WriteString(" DEFAULT ") str.WriteString(c.Default.Value) } if c.PrimaryKey { str.WriteString(" PRIMARY KEY") } return str.String(), nil } func (b *CreateTableSqlBuilder) Column(column *Column) *CreateTableSqlBuilder { if b.Columns == nil { b.Columns = make(map[string]*Column) } b.Columns[column.Name] = column b.ColumnOrder = append(b.ColumnOrder, column.Name) return b } func (b *CreateTableSqlBuilder) UniqueConstraint(columns ...string) *CreateTableSqlBuilder { for _, column := range columns { if _, ok := b.Columns[column]; !ok { // This fails silently. return b } } b.Constraints = append(b.Constraints, fmt.Sprintf("UNIQUE(%s)", strings.Join(columns, ","))) return b } func (b *CreateTableSqlBuilder) SetIfNotExists(ine bool) *CreateTableSqlBuilder { b.IfNotExists = ine return b } func (b *CreateTableSqlBuilder) ToSQL() (string, error) { var str strings.Builder str.WriteString("CREATE TABLE ") if b.IfNotExists { str.WriteString("IF NOT EXISTS ") } str.WriteString(b.Name) var things []string for _, columnName := range b.ColumnOrder { column, ok := b.Columns[columnName] if !ok { return "", fmt.Errorf("column not found: %s", columnName) } columnStr, err := column.String() if err != nil { return "", err } things = append(things, columnStr) } for _, constraint := range b.Constraints { things = append(things, constraint) } if thingLen := len(things); thingLen > 0 { str.WriteString(" ( ") for i, thing := range things { str.WriteString(thing) if i < thingLen-1 { str.WriteString(", ") } } str.WriteString(" )") } return str.String(), nil } + diff --git a/db/dialect.go b/db/dialect.go new file mode 100644 index 0000000..4251465 --- /dev/null +++ b/db/dialect.go @@ -0,0 +1,76 @@ +package db + +import "fmt" + +type DialectType int + +const ( + DialectSQLite DialectType = iota + DialectMySQL DialectType = iota +) + +func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column { + switch d { + case DialectSQLite: + return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size} + case DialectMySQL: + return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size} + default: + panic(fmt.Sprintf("unexpected dialect: %d", d)) + } +} + +func (d DialectType) Table(name string) *CreateTableSqlBuilder { + switch d { + case DialectSQLite: + return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name} + case DialectMySQL: + return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name} + default: + panic(fmt.Sprintf("unexpected dialect: %d", d)) + } +} + +func (d DialectType) AlterTable(name string) *AlterTableSqlBuilder { + switch d { + case DialectSQLite: + return &AlterTableSqlBuilder{Dialect: DialectSQLite, Name: name} + case DialectMySQL: + return &AlterTableSqlBuilder{Dialect: DialectMySQL, Name: name} + default: + panic(fmt.Sprintf("unexpected dialect: %d", d)) + } +} + +func (d DialectType) CreateUniqueIndex(name, table string, columns ...string) *CreateIndexSqlBuilder { + switch d { + case DialectSQLite: + return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: true, Columns: columns} + case DialectMySQL: + return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: true, Columns: columns} + default: + panic(fmt.Sprintf("unexpected dialect: %d", d)) + } +} + +func (d DialectType) CreateIndex(name, table string, columns ...string) *CreateIndexSqlBuilder { + switch d { + case DialectSQLite: + return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: false, Columns: columns} + case DialectMySQL: + return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: false, Columns: columns} + default: + panic(fmt.Sprintf("unexpected dialect: %d", d)) + } +} + +func (d DialectType) DropIndex(name, table string) *DropIndexSqlBuilder { + switch d { + case DialectSQLite: + return &DropIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table} + case DialectMySQL: + return &DropIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table} + default: + panic(fmt.Sprintf("unexpected dialect: %d", d)) + } +} diff --git a/db/index.go b/db/index.go new file mode 100644 index 0000000..8180224 --- /dev/null +++ b/db/index.go @@ -0,0 +1,53 @@ +package db + +import ( + "fmt" + "strings" +) + +type CreateIndexSqlBuilder struct { + Dialect DialectType + Name string + Table string + Unique bool + Columns []string +} + +type DropIndexSqlBuilder struct { + Dialect DialectType + Name string + Table string +} + +func (b *CreateIndexSqlBuilder) ToSQL() (string, error) { + var str strings.Builder + + str.WriteString("CREATE ") + if b.Unique { + str.WriteString("UNIQUE ") + } + str.WriteString("INDEX ") + str.WriteString(b.Name) + str.WriteString(" on ") + str.WriteString(b.Table) + + if len(b.Columns) == 0 { + return "", fmt.Errorf("columns provided for this index: %s", b.Name) + } + + str.WriteString(" (") + columnCount := len(b.Columns) + for i, thing := range b.Columns { + str.WriteString(thing) + if i < columnCount-1 { + str.WriteString(", ") + } + } + str.WriteString(")") + + return str.String(), nil +} + +func (b *DropIndexSqlBuilder) ToSQL() (string, error) { + return fmt.Sprintf("DROP INDEX %s on %s", b.Name, b.Table), nil +} diff --git a/db/raw.go b/db/raw.go new file mode 100644 index 0000000..d0301c8 --- /dev/null +++ b/db/raw.go @@ -0,0 +1,9 @@ +package db + +type RawSqlBuilder struct { + Query string +} + +func (b *RawSqlBuilder) ToSQL() (string, error) { + return b.Query, nil +} diff --git a/go.mod b/go.mod index 358173b..5c5e936 100644 --- a/go.mod +++ b/go.mod @@ -1,59 +1,60 @@ module github.com/writeas/writefreely require ( github.com/BurntSushi/toml v0.3.1 // indirect github.com/alecthomas/gometalinter v3.0.0+incompatible // indirect github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf // indirect github.com/captncraig/cors v0.0.0-20180620154129-376d45073b49 // indirect github.com/clbanning/mxj v1.8.4 // indirect github.com/dustin/go-humanize v1.0.0 github.com/fatih/color v1.7.0 github.com/go-sql-driver/mysql v1.4.1 github.com/go-test/deep v1.0.1 // indirect github.com/golang/lint v0.0.0-20181217174547-8f45f776aaf1 // indirect github.com/gopherjs/gopherjs v0.0.0-20181103185306-d547d1d9531e // indirect github.com/gorilla/feeds v1.1.0 github.com/gorilla/mux v1.7.0 github.com/gorilla/schema v1.0.2 github.com/gorilla/sessions v1.1.3 github.com/guregu/null v3.4.0+incompatible github.com/ikeikeikeike/go-sitemap-generator/v2 v2.0.2 github.com/jtolds/gls v4.2.1+incompatible // indirect + github.com/kr/pretty v0.1.0 github.com/kylemcc/twitter-text-go v0.0.0-20180726194232-7f582f6736ec github.com/lunixbochs/vtclean v1.0.0 // indirect github.com/manifoldco/promptui v0.3.2 github.com/mattn/go-colorable v0.1.0 // indirect github.com/mattn/go-sqlite3 v1.10.0 github.com/microcosm-cc/bluemonday v1.0.2 github.com/mitchellh/go-wordwrap v1.0.0 github.com/nicksnyder/go-i18n v1.10.0 // indirect github.com/nu7hatch/gouuid v0.0.0-20131221200532-179d4d0c4d8d github.com/pelletier/go-toml v1.2.0 // indirect github.com/pkg/errors v0.8.1 github.com/rainycape/unidecode v0.0.0-20150907023854-cb7f23ec59be // indirect github.com/smartystreets/assertions v0.0.0-20190116191733-b6c0e53d7304 // indirect github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c // indirect github.com/stretchr/testify v1.3.0 github.com/writeas/activity v0.1.2 github.com/writeas/go-strip-markdown v2.0.1+incompatible github.com/writeas/go-webfinger v0.0.0-20190106002315-85cf805c86d2 github.com/writeas/httpsig v1.0.0 github.com/writeas/impart v1.1.1-0.20191230230525-d3c45ced010d github.com/writeas/monday v0.0.0-20181024183321-54a7dd579219 github.com/writeas/nerds v1.0.0 github.com/writeas/saturday v1.7.1 github.com/writeas/slug v1.2.0 github.com/writeas/web-core v1.2.0 github.com/writefreely/go-nodeinfo v1.2.0 golang.org/x/crypto v0.0.0-20190208162236-193df9c0f06f golang.org/x/lint v0.0.0-20181217174547-8f45f776aaf1 // indirect golang.org/x/net v0.0.0-20190206173232-65e2d4e15006 // indirect golang.org/x/sys v0.0.0-20190209173611-3b5209105503 // indirect golang.org/x/tools v0.0.0-20190208222737-3744606dbb67 google.golang.org/appengine v1.4.0 // indirect gopkg.in/alecthomas/kingpin.v3-unstable v3.0.0-20180810215634-df19058c872c // indirect gopkg.in/ini.v1 v1.41.0 gopkg.in/yaml.v2 v2.2.2 // indirect ) go 1.13 diff --git a/go.sum b/go.sum index 38e804c..0c9d05d 100644 --- a/go.sum +++ b/go.sum @@ -1,177 +1,178 @@ github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/alecthomas/gometalinter v2.0.11+incompatible/go.mod h1:qfIpQGGz3d+NmgyPBqv+LSh50emm1pt72EtcX2vKYQk= github.com/alecthomas/gometalinter v3.0.0+incompatible h1:e9Zfvfytsw/e6Kd/PYd75wggK+/kX5Xn8IYDUKyc5fU= github.com/alecthomas/gometalinter v3.0.0+incompatible/go.mod h1:qfIpQGGz3d+NmgyPBqv+LSh50emm1pt72EtcX2vKYQk= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf h1:qet1QNfXsQxTZqLG4oE62mJzwPIB8+Tee4RNCL9ulrY= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/beevik/etree v1.1.0 h1:T0xke/WvNtMoCqgzPhkX2r4rjY3GDZFi+FjpRZY2Jbs= github.com/beevik/etree v1.1.0/go.mod h1:r8Aw8JqVegEf0w2fDnATrX9VpkMcyFeM0FhwO62wh+A= github.com/captncraig/cors v0.0.0-20180620154129-376d45073b49 h1:jWNY1NDg6a/c8RSXkai7IX6UOhir0LD39I4Dukg+4Ks= github.com/captncraig/cors v0.0.0-20180620154129-376d45073b49/go.mod h1:EIlIeMufZ8nqdUhnesledB15xLRl4wIJUppwDLPrdrQ= github.com/chzyer/logex v1.1.10 h1:Swpa1K6QvQznwJRcfTfQJmTE72DqScAa40E+fbHEXEE= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e h1:fY5BOSpyZCqRo5OhCuC+XN+r/bBCmeuuJtjz+bCNIf8= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1 h1:q763qf9huN11kDQavWsoZXJNW3xEE4JJyHa5Q25/sd8= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/clbanning/mxj v1.8.3/go.mod h1:BVjHeAH+rl9rs6f+QIpeRl0tfu10SXn1pUSa5PVGJng= github.com/clbanning/mxj v1.8.4 h1:HuhwZtbyvyOw+3Z1AowPkU87JkJUSv751ELWaiTpj8I= github.com/clbanning/mxj v1.8.4/go.mod h1:BVjHeAH+rl9rs6f+QIpeRl0tfu10SXn1pUSa5PVGJng= github.com/client9/misspell v0.3.4 h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo= github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= github.com/go-fed/httpsig v0.1.0/go.mod h1:T56HUNYZUQ1AGUzhAYPugZfp36sKApVnGBgKlIY+aIE= github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= github.com/go-test/deep v1.0.1 h1:UQhStjbkDClarlmv0am7OXXO4/GaPdCGiUiMTvi28sg= github.com/go-test/deep v1.0.1/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= github.com/golang/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/lint v0.0.0-20181217174547-8f45f776aaf1 h1:6DVPu65tee05kY0/rciBQ47ue+AnuY8KTayV6VHikIo= github.com/golang/lint v0.0.0-20181217174547-8f45f776aaf1/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/shlex v0.0.0-20181106134648-c34317bd91bf h1:7+FW5aGwISbqUtkfmIpZJGRgNFg2ioYPvFaUxdqpDsg= github.com/google/shlex v0.0.0-20181106134648-c34317bd91bf/go.mod h1:RpwtwJQFrIEPstU94h88MWPXP2ektJZ8cZ0YntAmXiE= github.com/gopherjs/gopherjs v0.0.0-20181103185306-d547d1d9531e h1:JKmoR8x90Iww1ks85zJ1lfDGgIiMDuIptTOhJq+zKyg= github.com/gopherjs/gopherjs v0.0.0-20181103185306-d547d1d9531e/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gordonklaus/ineffassign v0.0.0-20180909121442-1003c8bd00dc h1:cJlkeAx1QYgO5N80aF5xRGstVsRQwgLR7uA2FnP1ZjY= github.com/gordonklaus/ineffassign v0.0.0-20180909121442-1003c8bd00dc/go.mod h1:cuNKsD1zp2v6XfE/orVX2QE1LC+i254ceGcVeDT3pTU= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/feeds v1.1.0 h1:pcgLJhbdYgaUESnj3AmXPcB7cS3vy63+jC/TI14AGXk= github.com/gorilla/feeds v1.1.0/go.mod h1:Nk0jZrvPFZX1OBe5NPiddPw7CfwF6Q9eqzaBbaightA= github.com/gorilla/mux v1.7.0 h1:tOSd0UKHQd6urX6ApfOn4XdBMY6Sh1MfxV3kmaazO+U= github.com/gorilla/mux v1.7.0/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/schema v1.0.2 h1:sAgNfOcNYvdDSrzGHVy9nzCQahG+qmsg+nE8dK85QRA= github.com/gorilla/schema v1.0.2/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorilla/sessions v1.1.3 h1:uXoZdcdA5XdXF3QzuSlheVRUvjl+1rKY7zBXL68L9RU= github.com/gorilla/sessions v1.1.3/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w= github.com/guregu/null v3.4.0+incompatible h1:a4mw37gBO7ypcBlTJeZGuMpSxxFTV9qFfFKgWxQSGaM= github.com/guregu/null v3.4.0+incompatible/go.mod h1:ePGpQaN9cw0tj45IR5E5ehMvsFlLlQZAkkOXZurJ3NM= github.com/ikeikeikeike/go-sitemap-generator/v2 v2.0.2 h1:wIdDEle9HEy7vBPjC6oKz6ejs3Ut+jmsYvuOoAW2pSM= github.com/ikeikeikeike/go-sitemap-generator/v2 v2.0.2/go.mod h1:WtaVKD9TeruTED9ydiaOJU08qGoEPP/LyzTKiD3jEsw= github.com/jtolds/gls v4.2.1+incompatible h1:fSuqC+Gmlu6l/ZYAoZzx2pyucC8Xza35fpRVWLVmUEE= github.com/jtolds/gls v4.2.1+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a h1:FaWFmfWdAUKbSCtOU2QjDaorUexogfaMgbipgYATUMU= github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a/go.mod h1:UJSiEoRfvx3hP73CvoARgeLjaIOjybY9vj8PUPPFGeU= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kylemcc/twitter-text-go v0.0.0-20180726194232-7f582f6736ec h1:ZXWuspqypleMuJy4bzYEqlMhJnGAYpLrWe5p7W3CdvI= github.com/kylemcc/twitter-text-go v0.0.0-20180726194232-7f582f6736ec/go.mod h1:voECJzdraJmolzPBgL9Z7ANwXf4oMXaTCsIkdiPpR/g= github.com/lunixbochs/vtclean v0.0.0-20180621232353-2d01aacdc34a h1:weJVJJRzAJBFRlAiJQROKQs8oC9vOxvm4rZmBBk0ONw= github.com/lunixbochs/vtclean v0.0.0-20180621232353-2d01aacdc34a/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/lunixbochs/vtclean v1.0.0 h1:xu2sLAri4lGiovBDQKxl5mrXyESr3gUr5m5SM5+LVb8= github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/manifoldco/promptui v0.3.2 h1:rir7oByTERac6jhpHUPErHuopoRDvO3jxS+FdadEns8= github.com/manifoldco/promptui v0.3.2/go.mod h1:8JU+igZ+eeiiRku4T5BjtKh2ms8sziGpSYl1gN8Bazw= github.com/mattn/go-colorable v0.0.9 h1:UVL0vNpWh04HeJXV0KLcaT7r06gOH2l4OW6ddYRUIY4= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.0 h1:v2XXALHHh6zHfYTJ+cSkwtyffnaOyR1MXaA91mTrb8o= github.com/mattn/go-colorable v0.1.0/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-isatty v0.0.4 h1:bnP0vzxcAdeI1zdubAl5PjU6zsERjGZb7raWodagDYs= github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o= github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/microcosm-cc/bluemonday v1.0.2 h1:5lPfLTTAvAbtS0VqT+94yOtFnGfUWYyx0+iToC3Os3s= github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc= github.com/mitchellh/go-wordwrap v1.0.0 h1:6GlHJ/LTGMrIJbwgdqdl2eEH8o+Exx/0m8ir9Gns0u4= github.com/mitchellh/go-wordwrap v1.0.0/go.mod h1:ZXFpozHsX6DPmq2I0TCekCxypsnAUbP2oI0UX1GXzOo= github.com/nicksnyder/go-i18n v1.10.0 h1:5AzlPKvXBH4qBzmZ09Ua9Gipyruv6uApMcrNZdo96+Q= github.com/nicksnyder/go-i18n v1.10.0/go.mod h1:HrK7VCrbOvQoUAQ7Vpy7i87N7JZZZ7R2xBGjv0j365Q= github.com/nu7hatch/gouuid v0.0.0-20131221200532-179d4d0c4d8d h1:VhgPp6v9qf9Agr/56bj7Y/xa04UccTW04VP0Qed4vnQ= github.com/nu7hatch/gouuid v0.0.0-20131221200532-179d4d0c4d8d/go.mod h1:YUTz3bUH2ZwIWBy3CJBeOBEugqcmXREj14T+iG/4k4U= github.com/pelletier/go-toml v1.2.0 h1:T5zMGML61Wp+FlcbWjRDT7yAxhJNAiPPLOFECq181zc= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rainycape/unidecode v0.0.0-20150907023854-cb7f23ec59be h1:ta7tUOvsPHVHGom5hKW5VXNc2xZIkfCKP8iaqOyYtUQ= github.com/rainycape/unidecode v0.0.0-20150907023854-cb7f23ec59be/go.mod h1:MIDFMn7db1kT65GmV94GzpX9Qdi7N/pQlwb+AN8wh+Q= github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/smartystreets/assertions v0.0.0-20190116191733-b6c0e53d7304 h1:Jpy1PXuP99tXNrhbq2BaPz9B+jNAvH1JPQQpG/9GCXY= github.com/smartystreets/assertions v0.0.0-20190116191733-b6c0e53d7304/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c h1:Ho+uVpkel/udgjbwB5Lktg9BtvJSh2DT0Hi6LPSyI2w= github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/tsenart/deadcode v0.0.0-20160724212837-210d2dc333e9 h1:vY5WqiEon0ZSTGM3ayVVi+twaHKHDFUVloaQ/wug9/c= github.com/tsenart/deadcode v0.0.0-20160724212837-210d2dc333e9/go.mod h1:q+QjxYvZ+fpjMXqs+XEriussHjSYqeXVnAdSV1tkMYk= github.com/writeas/activity v0.1.2 h1:Y12B5lIrabfqKE7e7HFCWiXrlfXljr9tlkFm2mp7DgY= github.com/writeas/activity v0.1.2/go.mod h1:mYYgiewmEM+8tlifirK/vl6tmB2EbjYaxwb+ndUw5T0= github.com/writeas/go-strip-markdown v2.0.1+incompatible h1:IIqxTM5Jr7RzhigcL6FkrCNfXkvbR+Nbu1ls48pXYcw= github.com/writeas/go-strip-markdown v2.0.1+incompatible/go.mod h1:Rsyu10ZhbEK9pXdk8V6MVnZmTzRG0alMNLMwa0J01fE= github.com/writeas/go-webfinger v0.0.0-20190106002315-85cf805c86d2 h1:DUsp4OhdfI+e6iUqcPQlwx8QYXuUDsToTz/x82D3Zuo= github.com/writeas/go-webfinger v0.0.0-20190106002315-85cf805c86d2/go.mod h1:w2VxyRO/J5vfNjJHYVubsjUGHd3RLDoVciz0DE3ApOc= github.com/writeas/httpsig v1.0.0 h1:peIAoIA3DmlP8IG8tMNZqI4YD1uEnWBmkcC9OFPjt3A= github.com/writeas/httpsig v1.0.0/go.mod h1:7ClMGSrSVXJbmiLa17bZ1LrG1oibGZmUMlh3402flPY= github.com/writeas/impart v1.1.0 h1:nPnoO211VscNkp/gnzir5UwCDEvdHThL5uELU60NFSE= github.com/writeas/impart v1.1.0/go.mod h1:g0MpxdnTOHHrl+Ca/2oMXUHJ0PcRAEWtkCzYCJUXC9Y= github.com/writeas/impart v1.1.1-0.20191230230525-d3c45ced010d h1:PK7DOj3JE6MGf647esPrKzXEHFjGWX2hl22uX79ixaE= github.com/writeas/impart v1.1.1-0.20191230230525-d3c45ced010d/go.mod h1:g0MpxdnTOHHrl+Ca/2oMXUHJ0PcRAEWtkCzYCJUXC9Y= github.com/writeas/monday v0.0.0-20181024183321-54a7dd579219 h1:baEp0631C8sT2r/hqwypIw2snCFZa6h7U6TojoLHu/c= github.com/writeas/monday v0.0.0-20181024183321-54a7dd579219/go.mod h1:NyM35ayknT7lzO6O/1JpfgGyv+0W9Z9q7aE0J8bXxfQ= github.com/writeas/nerds v1.0.0 h1:ZzRcCN+Sr3MWID7o/x1cr1ZbLvdpej9Y1/Ho+JKlqxo= github.com/writeas/nerds v1.0.0/go.mod h1:Gn2bHy1EwRcpXeB7ZhVmuUwiweK0e+JllNf66gvNLdU= github.com/writeas/openssl-go v1.0.0 h1:YXM1tDXeYOlTyJjoMlYLQH1xOloUimSR1WMF8kjFc5o= github.com/writeas/openssl-go v1.0.0/go.mod h1:WsKeK5jYl0B5y8ggOmtVjbmb+3rEGqSD25TppjJnETA= github.com/writeas/saturday v1.6.0/go.mod h1:ETE1EK6ogxptJpAgUbcJD0prAtX48bSloie80+tvnzQ= github.com/writeas/saturday v1.7.1 h1:lYo1EH6CYyrFObQoA9RNWHVlpZA5iYL5Opxo7PYAnZE= github.com/writeas/saturday v1.7.1/go.mod h1:ETE1EK6ogxptJpAgUbcJD0prAtX48bSloie80+tvnzQ= github.com/writeas/slug v1.2.0 h1:EMQ+cwLiOcA6EtFwUgyw3Ge18x9uflUnOnR6bp/J+/g= github.com/writeas/slug v1.2.0/go.mod h1:RE8shOqQP3YhsfsQe0L3RnuejfQ4Mk+JjY5YJQFubfQ= github.com/writeas/web-core v1.0.0 h1:5VKkCakQgdKZcbfVKJXtRpc5VHrkflusCl/KRCPzpQ0= github.com/writeas/web-core v1.0.0/go.mod h1:Si3chV7VWgY8CsV+3gRolMXSO2Vx1ZFAQ/mkrpvmyEE= github.com/writeas/web-core v1.2.0 h1:CYqvBd+byi1cK4mCr1NZ6CjILuMOFmiFecv+OACcmG0= github.com/writeas/web-core v1.2.0/go.mod h1:vTYajviuNBAxjctPp2NUYdgjofywVkxUGpeaERF3SfI= github.com/writefreely/go-nodeinfo v1.2.0 h1:La+YbTCvmpTwFhBSlebWDDL81N88Qf/SCAvRLR7F8ss= github.com/writefreely/go-nodeinfo v1.2.0/go.mod h1:UTvE78KpcjYOlRHupZIiSEFcXHioTXuacCbHU+CAcPg= golang.org/x/crypto v0.0.0-20180527072434-ab813273cd59 h1:hk3yo72LXLapY9EXVttc3Z1rLOxT9IuAPPX3GpY2+jo= golang.org/x/crypto v0.0.0-20180527072434-ab813273cd59/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190131182504-b8fe1690c613/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190208162236-193df9c0f06f h1:ETU2VEl7TnT5bl7IvuKEzTDpplg5wzGYsOCAPhdoEIg= golang.org/x/crypto v0.0.0-20190208162236-193df9c0f06f/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181217174547-8f45f776aaf1 h1:rJm0LuqUjoDhSk2zO9ISMSToQxGz7Os2jRiOL8AWu4c= golang.org/x/lint v0.0.0-20181217174547-8f45f776aaf1/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181220203305-927f97764cc3 h1:eH6Eip3UpmR+yM/qI9Ijluzb1bNv/cAU/n+6l8tRSis= golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190206173232-65e2d4e15006 h1:bfLnR+k0tq5Lqt6dflRLcZiz6UaXCMt3vhYJ1l4FQ80= golang.org/x/net v0.0.0-20190206173232-65e2d4e15006/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/sys v0.0.0-20180525142821-c11f84a56e43/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190209173611-3b5209105503 h1:5SvYFrOM3W8Mexn9/oA44Ji7vhXAZQ9hiP+1Q/DMrWg= golang.org/x/sys v0.0.0-20190209173611-3b5209105503/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20181122213734-04b5d21e00f1/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190208222737-3744606dbb67 h1:bPP/rGuN1LUM0eaEwo6vnP6OfIWJzJBulzGUiKLjjSY= golang.org/x/tools v0.0.0-20190208222737-3744606dbb67/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= gopkg.in/alecthomas/kingpin.v3-unstable v3.0.0-20180810215634-df19058c872c h1:vTxShRUnK60yd8DZU+f95p1zSLj814+5CuEh7NjF2/Y= gopkg.in/alecthomas/kingpin.v3-unstable v3.0.0-20180810215634-df19058c872c/go.mod h1:3HH7i1SgMqlzxCcBmUHW657sD4Kvv9sC3HpL3YukzwA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/ini.v1 v1.41.0 h1:Ka3ViY6gNYSKiVy71zXBEqKplnV35ImDLVG+8uoIklE= gopkg.in/ini.v1 v1.41.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0 h1:POO/ycCATvegFmVuPpQzZFJ+pGZeX22Ufu6fibxDVjU= gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0/go.mod h1:WDnlLJ4WF5VGsH/HVa3CI79GS0ol3YnhVnKP89i0kNg= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/migrations/migrations.go b/migrations/migrations.go index b9ead23..917d912 100644 --- a/migrations/migrations.go +++ b/migrations/migrations.go @@ -1,136 +1,137 @@ /* * Copyright © 2019 A Bunch Tell LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ // Package migrations contains database migrations for WriteFreely package migrations import ( "database/sql" "github.com/writeas/web-core/log" ) // TODO: refactor to use the datastore struct from writefreely pkg type datastore struct { *sql.DB driverName string } func NewDatastore(db *sql.DB, dn string) *datastore { return &datastore{db, dn} } // TODO: use these consts from writefreely pkg const ( driverMySQL = "mysql" driverSQLite = "sqlite3" ) type Migration interface { Description() string Migrate(db *datastore) error } type migration struct { description string migrate func(db *datastore) error } func New(d string, fn func(db *datastore) error) Migration { return &migration{d, fn} } func (m *migration) Description() string { return m.description } func (m *migration) Migrate(db *datastore) error { return m.migrate(db) } var migrations = []Migration{ New("support user invites", supportUserInvites), // -> V1 (v0.8.0) New("support dynamic instance pages", supportInstancePages), // V1 -> V2 (v0.9.0) New("support users suspension", supportUserStatus), // V2 -> V3 (v0.11.0) - New("support oauth", oauth), // V3 -> V4 + New("support oauth", oauth), // V3 -> V4 + New("support slack oauth", oauthSlack), // V4 -> v5 } // CurrentVer returns the current migration version the application is on func CurrentVer() int { return len(migrations) } func SetInitialMigrations(db *datastore) error { // Included schema files represent changes up to V1, so note that in the database _, err := db.Exec("INSERT INTO appmigrations (version, migrated, result) VALUES (?, "+db.now()+", ?)", 1, "") if err != nil { return err } return nil } func Migrate(db *datastore) error { var version int var err error if db.tableExists("appmigrations") { err = db.QueryRow("SELECT MAX(version) FROM appmigrations").Scan(&version) } else { log.Info("Initializing appmigrations table...") version = 0 _, err = db.Exec(`CREATE TABLE appmigrations ( version ` + db.typeInt() + ` NOT NULL, migrated ` + db.typeDateTime() + ` NOT NULL, result ` + db.typeText() + ` NOT NULL ) ` + db.engine() + `;`) if err != nil { return err } } if len(migrations[version:]) > 0 { for i, m := range migrations[version:] { curVer := version + i + 1 log.Info("Migrating to V%d: %s", curVer, m.Description()) err = m.Migrate(db) if err != nil { return err } // Update migrations table _, err = db.Exec("INSERT INTO appmigrations (version, migrated, result) VALUES (?, "+db.now()+", ?)", curVer, "") if err != nil { return err } } } else { log.Info("Database up-to-date. No migrations to run.") } return nil } func (db *datastore) tableExists(t string) bool { var dummy string var err error if db.driverName == driverSQLite { err = db.QueryRow("SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?", t).Scan(&dummy) } else { err = db.QueryRow("SHOW TABLES LIKE '" + t + "'").Scan(&dummy) } switch { case err == sql.ErrNoRows: return false case err != nil: log.Error("Couldn't SHOW TABLES: %v", err) return false } return true } diff --git a/migrations/v4.go b/migrations/v4.go index c123f54..c075dd8 100644 --- a/migrations/v4.go +++ b/migrations/v4.go @@ -1,46 +1,46 @@ package migrations import ( "context" "database/sql" wf_db "github.com/writeas/writefreely/db" ) func oauth(db *datastore) error { dialect := wf_db.DialectMySQL if db.driverName == driverSQLite { dialect = wf_db.DialectSQLite } return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { createTableUsersOauth, err := dialect. - Table("users_oauth"). + Table("oauth_users"). SetIfNotExists(true). Column(dialect.Column("user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)). Column(dialect.Column("remote_user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)). UniqueConstraint("user_id"). UniqueConstraint("remote_user_id"). ToSQL() if err != nil { return err } createTableOauthClientState, err := dialect. - Table("oauth_client_state"). + Table("oauth_client_states"). SetIfNotExists(true). Column(dialect.Column("state", wf_db.ColumnTypeVarChar, wf_db.OptionalInt{Set: true, Value: 255})). Column(dialect.Column("used", wf_db.ColumnTypeBool, wf_db.UnsetSize)). Column(dialect.Column("created_at", wf_db.ColumnTypeDateTime, wf_db.UnsetSize).SetDefault("NOW()")). UniqueConstraint("state"). ToSQL() if err != nil { return err } for _, table := range []string{createTableUsersOauth, createTableOauthClientState} { if _, err := tx.ExecContext(ctx, table); err != nil { return err } } return nil }) } diff --git a/migrations/v5.go b/migrations/v5.go new file mode 100644 index 0000000..94e3944 --- /dev/null +++ b/migrations/v5.go @@ -0,0 +1,67 @@ +package migrations + +import ( + "context" + "database/sql" + + wf_db "github.com/writeas/writefreely/db" +) + +func oauthSlack(db *datastore) error { + dialect := wf_db.DialectMySQL + if db.driverName == driverSQLite { + dialect = wf_db.DialectSQLite + } + return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { + builders := []wf_db.SQLBuilder{ + dialect. + AlterTable("oauth_client_states"). + AddColumn(dialect. + Column( + "provider", + wf_db.ColumnTypeVarChar, + wf_db.OptionalInt{Set: true, Value: 24,})). + AddColumn(dialect. + Column( + "client_id", + wf_db.ColumnTypeVarChar, + wf_db.OptionalInt{Set: true, Value: 128,})), + dialect. + AlterTable("oauth_users"). + ChangeColumn("remote_user_id", + dialect. + Column( + "remote_user_id", + wf_db.ColumnTypeVarChar, + wf_db.OptionalInt{Set: true, Value: 128,})). + AddColumn(dialect. + Column( + "provider", + wf_db.ColumnTypeVarChar, + wf_db.OptionalInt{Set: true, Value: 24,})). + AddColumn(dialect. + Column( + "client_id", + wf_db.ColumnTypeVarChar, + wf_db.OptionalInt{Set: true, Value: 128,})). + AddColumn(dialect. + Column( + "access_token", + wf_db.ColumnTypeVarChar, + wf_db.OptionalInt{Set: true, Value: 512,})), + dialect.DropIndex("remote_user_id", "oauth_users"), + dialect.DropIndex("user_id", "oauth_users"), + dialect.CreateUniqueIndex("oauth_users", "oauth_users", "user_id", "provider", "client_id"), + } + for _, builder := range builders { + query, err := builder.ToSQL() + if err != nil { + return err + } + if _, err := tx.ExecContext(ctx, query); err != nil { + return err + } + } + return nil + }) +} diff --git a/oauth.go b/oauth.go index 2a4e3db..7dfc4c7 100644 --- a/oauth.go +++ b/oauth.go @@ -1,270 +1,246 @@ package writefreely import ( "context" "encoding/json" "fmt" + "github.com/gorilla/mux" "github.com/gorilla/sessions" "github.com/guregu/null/zero" "github.com/writeas/impart" "github.com/writeas/nerds/store" "github.com/writeas/web-core/auth" "github.com/writeas/web-core/log" "github.com/writeas/writefreely/config" "io" "io/ioutil" "net/http" - "net/url" - "strings" "time" ) // TokenResponse contains data returned when a token is created either // through a code exchange or using a refresh token. type TokenResponse struct { AccessToken string `json:"access_token"` ExpiresIn int `json:"expires_in"` RefreshToken string `json:"refresh_token"` TokenType string `json:"token_type"` Error string `json:"error"` } // InspectResponse contains data returned when an access token is inspected. type InspectResponse struct { - ClientID string `json:"client_id"` - UserID int64 `json:"user_id"` - ExpiresAt time.Time `json:"expires_at"` - Username string `json:"username"` - Email string `json:"email"` + ClientID string `json:"client_id"` + UserID string `json:"user_id"` + ExpiresAt time.Time `json:"expires_at"` + Username string `json:"username"` + DisplayName string `json:"-"` + Email string `json:"email"` + Error string `json:"error"` } // tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token // endpoint. One megabyte is plenty. const tokenRequestMaxLen = 1000000 // infoRequestMaxLen is the most bytes that we'll read from the // /oauth/inspect endpoint. const infoRequestMaxLen = 1000000 // OAuthDatastoreProvider provides a minimal interface of data store, config, // and session store for use with the oauth handlers. type OAuthDatastoreProvider interface { DB() OAuthDatastore Config() *config.Config SessionStore() sessions.Store } // OAuthDatastore provides a minimal interface of data store methods used in // oauth functionality. type OAuthDatastore interface { - GenerateOAuthState(context.Context) (string, error) - ValidateOAuthState(context.Context, string) error - GetIDForRemoteUser(context.Context, int64) (int64, error) + GetIDForRemoteUser(context.Context, string, string, string) (int64, error) + RecordRemoteUserID(context.Context, int64, string, string, string, string) error + ValidateOAuthState(context.Context, string) (string, string, error) + GenerateOAuthState(context.Context, string, string) (string, error) + CreateUser(*config.Config, *User, string) error - RecordRemoteUserID(context.Context, int64, int64) error GetUserForAuthByID(int64) (*User, error) } type HttpClient interface { Do(req *http.Request) (*http.Response, error) } +type oauthClient interface { + GetProvider() string + GetClientID() string + buildLoginURL(state string) (string, error) + exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) + inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) +} + type oauthHandler struct { - Config *config.Config - DB OAuthDatastore - Store sessions.Store - HttpClient HttpClient + Config *config.Config + DB OAuthDatastore + Store sessions.Store + oauthClient oauthClient } -// buildAuthURL returns a URL used to initiate authentication. -func buildAuthURL(db OAuthDatastore, ctx context.Context, clientID, authLocation, callbackURL string) (string, error) { - state, err := db.GenerateOAuthState(ctx) +func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID()) + if err != nil { + return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} + } + location, err := h.oauthClient.buildLoginURL(state) if err != nil { - return "", err + return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} } + return impart.HTTPError{http.StatusTemporaryRedirect, location} +} - u, err := url.Parse(authLocation) - if err != nil { - return "", err +func configureSlackOauth(parentHandler *Handler, r *mux.Router, app *App) { + if app.Config().SlackOauth.ClientID != "" { + oauthClient := slackOauthClient{ + ClientID: app.Config().SlackOauth.ClientID, + ClientSecret: app.Config().SlackOauth.ClientSecret, + TeamID: app.Config().SlackOauth.TeamID, + CallbackLocation: app.Config().App.Host + "/oauth/callback", + HttpClient: config.DefaultHTTPClient(), + } + configureOauthRoutes(parentHandler, r, app, oauthClient) } - q := u.Query() - q.Set("client_id", clientID) - q.Set("redirect_uri", callbackURL) - q.Set("response_type", "code") - q.Set("state", state) - u.RawQuery = q.Encode() +} + +func configureWriteAsOauth(parentHandler *Handler, r *mux.Router, app *App) { + if app.Config().WriteAsOauth.ClientID != "" { + oauthClient := writeAsOauthClient{ + ClientID: app.Config().WriteAsOauth.ClientID, + ClientSecret: app.Config().WriteAsOauth.ClientSecret, + ExchangeLocation: config.OrDefaultString(app.Config().WriteAsOauth.TokenLocation, writeAsExchangeLocation), + InspectLocation: config.OrDefaultString(app.Config().WriteAsOauth.InspectLocation, writeAsIdentityLocation), + AuthLocation: config.OrDefaultString(app.Config().WriteAsOauth.AuthLocation, writeAsAuthLocation), + HttpClient: config.DefaultHTTPClient(), + CallbackLocation: app.Config().App.Host + "/oauth/callback", + } + if oauthClient.ExchangeLocation == "" { - return u.String(), nil + } + configureOauthRoutes(parentHandler, r, app, oauthClient) + } } -// app *App, w http.ResponseWriter, r *http.Request -func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error { - location, err := buildAuthURL(h.DB, r.Context(), h.Config.App.OAuthClientID, h.Config.App.OAuthProviderAuthLocation, h.Config.App.OAuthClientCallbackLocation) - if err != nil { - return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} +func configureOauthRoutes(parentHandler *Handler, r *mux.Router, app *App, oauthClient oauthClient) { + handler := &oauthHandler{ + Config: app.Config(), + DB: app.DB(), + Store: app.SessionStore(), + oauthClient: oauthClient, } - return impart.HTTPError{http.StatusTemporaryRedirect, location} + r.HandleFunc("/oauth/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthInit)).Methods("GET") + r.HandleFunc("/oauth/callback", parentHandler.OAuth(handler.viewOauthCallback)).Methods("GET") } func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http.Request) error { ctx := r.Context() code := r.FormValue("code") state := r.FormValue("state") - err := h.DB.ValidateOAuthState(ctx, state) + provider, clientID, err := h.DB.ValidateOAuthState(ctx, state) if err != nil { log.Error("Unable to ValidateOAuthState: %s", err) return impart.HTTPError{http.StatusInternalServerError, err.Error()} } - tokenResponse, err := h.exchangeOauthCode(ctx, code) + tokenResponse, err := h.oauthClient.exchangeOauthCode(ctx, code) if err != nil { log.Error("Unable to exchangeOauthCode: %s", err) return impart.HTTPError{http.StatusInternalServerError, err.Error()} } // Now that we have the access token, let's use it real quick to make sur // it really really works. - tokenInfo, err := h.inspectOauthAccessToken(ctx, tokenResponse.AccessToken) + tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken) if err != nil { log.Error("Unable to inspectOauthAccessToken: %s", err) return impart.HTTPError{http.StatusInternalServerError, err.Error()} } - localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID) + localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID, provider, clientID) if err != nil { log.Error("Unable to GetIDForRemoteUser: %s", err) return impart.HTTPError{http.StatusInternalServerError, err.Error()} } - fmt.Println("local user id", localUserID) - if localUserID == -1 { // We don't have, nor do we want, the password from the origin, so we //create a random string. If the user needs to set a password, they //can do so through the settings page or through the password reset //flow. randPass := store.Generate62RandomString(14) hashedPass, err := auth.HashPass([]byte(randPass)) if err != nil { - log.ErrorLog.Println(err) return impart.HTTPError{http.StatusInternalServerError, "unable to create password hash"} } newUser := &User{ Username: tokenInfo.Username, HashedPass: hashedPass, HasPass: true, - Email: zero.NewString("", tokenInfo.Email != ""), + Email: zero.NewString(tokenInfo.Email, tokenInfo.Email != ""), Created: time.Now().Truncate(time.Second).UTC(), } + displayName := tokenInfo.DisplayName + if len(displayName) == 0 { + displayName = tokenInfo.Username + } - err = h.DB.CreateUser(h.Config, newUser, newUser.Username) + err = h.DB.CreateUser(h.Config, newUser, displayName) if err != nil { return impart.HTTPError{http.StatusInternalServerError, err.Error()} } - err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID) + err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID, provider, clientID, tokenResponse.AccessToken) if err != nil { return impart.HTTPError{http.StatusInternalServerError, err.Error()} } if err := loginOrFail(h.Store, w, r, newUser); err != nil { return impart.HTTPError{http.StatusInternalServerError, err.Error()} } return nil } user, err := h.DB.GetUserForAuthByID(localUserID) if err != nil { return impart.HTTPError{http.StatusInternalServerError, err.Error()} } if err = loginOrFail(h.Store, w, r, user); err != nil { return impart.HTTPError{http.StatusInternalServerError, err.Error()} } return nil } -func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { - form := url.Values{} - form.Add("grant_type", "authorization_code") - form.Add("redirect_uri", h.Config.App.OAuthClientCallbackLocation) - form.Add("code", code) - req, err := http.NewRequest("POST", h.Config.App.OAuthProviderTokenLocation, strings.NewReader(form.Encode())) +func limitedJsonUnmarshal(body io.ReadCloser, n int, thing interface{}) error { + lr := io.LimitReader(body, int64(n+1)) + data, err := ioutil.ReadAll(lr) if err != nil { - return nil, err - } - req.WithContext(ctx) - req.Header.Set("User-Agent", "writefreely") - req.Header.Set("Accept", "application/json") - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth(h.Config.App.OAuthClientID, h.Config.App.OAuthClientSecret) - - resp, err := h.HttpClient.Do(req) - if err != nil { - return nil, err - } - - // Nick: I like using limited readers to reduce the risk of an endpoint - // being broken or compromised. - lr := io.LimitReader(resp.Body, tokenRequestMaxLen) - body, err := ioutil.ReadAll(lr) - if err != nil { - return nil, err - } - - var tokenResponse TokenResponse - err = json.Unmarshal(body, &tokenResponse) - if err != nil { - return nil, err - } - - // Check the response for an error message, and return it if there is one. - if tokenResponse.Error != "" { - return nil, fmt.Errorf(tokenResponse.Error) - } - return &tokenResponse, nil -} - -func (h oauthHandler) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) { - req, err := http.NewRequest("GET", h.Config.App.OAuthProviderInspectLocation, nil) - if err != nil { - return nil, err - } - req.WithContext(ctx) - req.Header.Set("User-Agent", "writefreely") - req.Header.Set("Accept", "application/json") - req.Header.Set("Authorization", "Bearer "+accessToken) - - resp, err := h.HttpClient.Do(req) - if err != nil { - return nil, err - } - - // Nick: I like using limited readers to reduce the risk of an endpoint - // being broken or compromised. - lr := io.LimitReader(resp.Body, infoRequestMaxLen) - body, err := ioutil.ReadAll(lr) - if err != nil { - return nil, err + return err } - - var inspectResponse InspectResponse - err = json.Unmarshal(body, &inspectResponse) - if err != nil { - return nil, err + if len(data) == n+1 { + return fmt.Errorf("content larger than max read allowance: %d", n) } - return &inspectResponse, nil + return json.Unmarshal(data, thing) } func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error { // An error may be returned, but a valid session should always be returned. session, _ := store.Get(r, cookieName) session.Values[cookieUserVal] = user.Cookie() if err := session.Save(r, w); err != nil { fmt.Println("error saving session", err) return err } http.Redirect(w, r, "/", http.StatusTemporaryRedirect) return nil } diff --git a/oauth_slack.go b/oauth_slack.go new file mode 100644 index 0000000..066aa18 --- /dev/null +++ b/oauth_slack.go @@ -0,0 +1,164 @@ +package writefreely + +import ( + "context" + "errors" + "github.com/writeas/slug" + "net/http" + "net/url" + "strings" +) + +type slackOauthClient struct { + ClientID string + ClientSecret string + TeamID string + CallbackLocation string + HttpClient HttpClient +} + +type slackExchangeResponse struct { + OK bool `json:"ok"` + AccessToken string `json:"access_token"` + Scope string `json:"scope"` + TeamName string `json:"team_name"` + TeamID string `json:"team_id"` + Error string `json:"error"` +} + +type slackIdentity struct { + Name string `json:"name"` + ID string `json:"id"` + Email string `json:"email"` +} + +type slackTeam struct { + Name string `json:"name"` + ID string `json:"id"` +} + +type slackUserIdentityResponse struct { + OK bool `json:"ok"` + User slackIdentity `json:"user"` + Team slackTeam `json:"team"` + Error string `json:"error"` +} + +const ( + slackAuthLocation = "https://slack.com/oauth/authorize" + slackExchangeLocation = "https://slack.com/api/oauth.access" + slackIdentityLocation = "https://slack.com/api/users.identity" +) + +var _ oauthClient = slackOauthClient{} + +func (c slackOauthClient) GetProvider() string { + return "slack" +} + +func (c slackOauthClient) GetClientID() string { + return c.ClientID +} + +func (c slackOauthClient) buildLoginURL(state string) (string, error) { + u, err := url.Parse(slackAuthLocation) + if err != nil { + return "", err + } + q := u.Query() + q.Set("client_id", c.ClientID) + q.Set("scope", "identity.basic identity.email identity.team") + q.Set("redirect_uri", c.CallbackLocation) + q.Set("state", state) + + // If this param is not set, the user can select which team they + // authenticate through and then we'd have to match the configured team + // against the profile get. That is extra work in the post-auth phase + // that we don't want to do. + q.Set("team", c.TeamID) + + // The Slack OAuth docs don't explicitly list this one, but it is part of + // the spec, so we include it anyway. + q.Set("response_type", "code") + u.RawQuery = q.Encode() + return u.String(), nil +} + +func (c slackOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { + form := url.Values{} + // The oauth.access documentation doesn't explicitly mention this + // parameter, but it is part of the spec, so we include it anyway. + // https://api.slack.com/methods/oauth.access + form.Add("grant_type", "authorization_code") + form.Add("redirect_uri", c.CallbackLocation) + form.Add("code", code) + req, err := http.NewRequest("POST", slackExchangeLocation, strings.NewReader(form.Encode())) + if err != nil { + return nil, err + } + req.WithContext(ctx) + req.Header.Set("User-Agent", "writefreely") + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth(c.ClientID, c.ClientSecret) + + resp, err := c.HttpClient.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, errors.New("unable to exchange code for access token") + } + + var tokenResponse slackExchangeResponse + if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil { + return nil, err + } + if !tokenResponse.OK { + return nil, errors.New(tokenResponse.Error) + } + return tokenResponse.TokenResponse(), nil +} + +func (c slackOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) { + req, err := http.NewRequest("GET", slackIdentityLocation, nil) + if err != nil { + return nil, err + } + req.WithContext(ctx) + req.Header.Set("User-Agent", "writefreely") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+accessToken) + + resp, err := c.HttpClient.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, errors.New("unable to inspect access token") + } + + var inspectResponse slackUserIdentityResponse + if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil { + return nil, err + } + if !inspectResponse.OK { + return nil, errors.New(inspectResponse.Error) + } + return inspectResponse.InspectResponse(), nil +} + +func (resp slackUserIdentityResponse) InspectResponse() *InspectResponse { + return &InspectResponse{ + UserID: resp.User.ID, + Username: slug.Make(resp.User.Name), + DisplayName: resp.User.Name, + Email: resp.User.Email, + } +} + +func (resp slackExchangeResponse) TokenResponse() *TokenResponse { + return &TokenResponse{ + AccessToken: resp.AccessToken, + } +} diff --git a/oauth_test.go b/oauth_test.go index 482a846..838af9a 100644 --- a/oauth_test.go +++ b/oauth_test.go @@ -1,211 +1,243 @@ package writefreely import ( "context" "fmt" "github.com/gorilla/sessions" "github.com/stretchr/testify/assert" "github.com/writeas/nerds/store" "github.com/writeas/writefreely/config" "net/http" "net/http/httptest" "net/url" "strings" "testing" ) type MockOAuthDatastoreProvider struct { DoDB func() OAuthDatastore DoConfig func() *config.Config DoSessionStore func() sessions.Store } type MockOAuthDatastore struct { - DoGenerateOAuthState func(ctx context.Context) (string, error) - DoValidateOAuthState func(context.Context, string) error - DoGetIDForRemoteUser func(context.Context, int64) (int64, error) + DoGenerateOAuthState func(context.Context, string, string) (string, error) + DoValidateOAuthState func(context.Context, string) (string, string, error) + DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error) DoCreateUser func(*config.Config, *User, string) error - DoRecordRemoteUserID func(context.Context, int64, int64) error + DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error DoGetUserForAuthByID func(int64) (*User, error) } +var _ OAuthDatastore = &MockOAuthDatastore{} + type StringReadCloser struct { *strings.Reader } func (src *StringReadCloser) Close() error { return nil } type MockHTTPClient struct { DoDo func(req *http.Request) (*http.Response, error) } func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) { if m.DoDo != nil { return m.DoDo(req) } return &http.Response{}, nil } func (m *MockOAuthDatastoreProvider) SessionStore() sessions.Store { if m.DoSessionStore != nil { return m.DoSessionStore() } return sessions.NewCookieStore([]byte("secret-key")) } func (m *MockOAuthDatastoreProvider) DB() OAuthDatastore { if m.DoDB != nil { return m.DoDB() } return &MockOAuthDatastore{} } func (m *MockOAuthDatastoreProvider) Config() *config.Config { if m.DoConfig != nil { return m.DoConfig() } cfg := config.New() cfg.UseSQLite(true) - cfg.App.EnableOAuth = true - cfg.App.OAuthProviderAuthLocation = "https://write.as/oauth/login" - cfg.App.OAuthProviderTokenLocation = "https://write.as/oauth/token" - cfg.App.OAuthProviderInspectLocation = "https://write.as/oauth/inspect" - cfg.App.OAuthClientCallbackLocation = "http://localhost/oauth/callback" - cfg.App.OAuthClientID = "development" - cfg.App.OAuthClientSecret = "development" + cfg.WriteAsOauth = config.WriteAsOauthCfg{ + ClientID: "development", + ClientSecret: "development", + AuthLocation: "https://write.as/oauth/login", + TokenLocation: "https://write.as/oauth/token", + InspectLocation: "https://write.as/oauth/inspect", + } + cfg.SlackOauth = config.SlackOauthCfg{ + ClientID: "development", + ClientSecret: "development", + TeamID: "development", + } return cfg } -func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) error { +func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) { if m.DoValidateOAuthState != nil { return m.DoValidateOAuthState(ctx, state) } - return nil + return "", "", nil } -func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error) { +func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) { if m.DoGetIDForRemoteUser != nil { - return m.DoGetIDForRemoteUser(ctx, remoteUserID) + return m.DoGetIDForRemoteUser(ctx, remoteUserID, provider, clientID) } return -1, nil } func (m *MockOAuthDatastore) CreateUser(cfg *config.Config, u *User, username string) error { if m.DoCreateUser != nil { return m.DoCreateUser(cfg, u, username) } u.ID = 1 return nil } -func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID int64) error { +func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error { if m.DoRecordRemoteUserID != nil { - return m.DoRecordRemoteUserID(ctx, localUserID, remoteUserID) + return m.DoRecordRemoteUserID(ctx, localUserID, remoteUserID, provider, clientID, accessToken) } return nil } func (m *MockOAuthDatastore) GetUserForAuthByID(userID int64) (*User, error) { if m.DoGetUserForAuthByID != nil { return m.DoGetUserForAuthByID(userID) } user := &User{ } return user, nil } -func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context) (string, error) { +func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string) (string, error) { if m.DoGenerateOAuthState != nil { - return m.DoGenerateOAuthState(ctx) + return m.DoGenerateOAuthState(ctx, provider, clientID) } return store.Generate62RandomString(14), nil } func TestViewOauthInit(t *testing.T) { t.Run("success", func(t *testing.T) { app := &MockOAuthDatastoreProvider{} h := oauthHandler{ Config: app.Config(), DB: app.DB(), Store: app.SessionStore(), + oauthClient:writeAsOauthClient{ + ClientID: app.Config().WriteAsOauth.ClientID, + ClientSecret: app.Config().WriteAsOauth.ClientSecret, + ExchangeLocation: app.Config().WriteAsOauth.TokenLocation, + InspectLocation: app.Config().WriteAsOauth.InspectLocation, + AuthLocation: app.Config().WriteAsOauth.AuthLocation, + CallbackLocation: "http://localhost/oauth/callback", + HttpClient: nil, + }, } req, err := http.NewRequest("GET", "/oauth/client", nil) assert.NoError(t, err) rr := httptest.NewRecorder() - h.viewOauthInit(rr, req) + h.viewOauthInit(nil, rr, req) assert.Equal(t, http.StatusTemporaryRedirect, rr.Code) locURI, err := url.Parse(rr.Header().Get("Location")) assert.NoError(t, err) assert.Equal(t, "/oauth/login", locURI.Path) assert.Equal(t, "development", locURI.Query().Get("client_id")) assert.Equal(t, "http://localhost/oauth/callback", locURI.Query().Get("redirect_uri")) assert.Equal(t, "code", locURI.Query().Get("response_type")) assert.NotEmpty(t, locURI.Query().Get("state")) }) t.Run("state failure", func(t *testing.T) { app := &MockOAuthDatastoreProvider{ DoDB: func() OAuthDatastore { return &MockOAuthDatastore{ - DoGenerateOAuthState: func(ctx context.Context) (string, error) { + DoGenerateOAuthState: func(ctx context.Context, provider, clientID string) (string, error) { return "", fmt.Errorf("pretend unable to write state error") }, } }, } h := oauthHandler{ Config: app.Config(), DB: app.DB(), Store: app.SessionStore(), + oauthClient:writeAsOauthClient{ + ClientID: app.Config().WriteAsOauth.ClientID, + ClientSecret: app.Config().WriteAsOauth.ClientSecret, + ExchangeLocation: app.Config().WriteAsOauth.TokenLocation, + InspectLocation: app.Config().WriteAsOauth.InspectLocation, + AuthLocation: app.Config().WriteAsOauth.AuthLocation, + CallbackLocation: "http://localhost/oauth/callback", + HttpClient: nil, + }, } req, err := http.NewRequest("GET", "/oauth/client", nil) assert.NoError(t, err) rr := httptest.NewRecorder() - h.viewOauthInit(rr, req) + h.viewOauthInit(nil, rr, req) assert.Equal(t, http.StatusInternalServerError, rr.Code) expected := `{"error":"could not prepare oauth redirect url"}` + "\n" assert.Equal(t, expected, rr.Body.String()) }) } func TestViewOauthCallback(t *testing.T) { t.Run("success", func(t *testing.T) { app := &MockOAuthDatastoreProvider{} h := oauthHandler{ Config: app.Config(), DB: app.DB(), Store: app.SessionStore(), - HttpClient: &MockHTTPClient{ - DoDo: func(req *http.Request) (*http.Response, error) { - switch req.URL.String() { - case "https://write.as/oauth/token": - return &http.Response{ - StatusCode: 200, - Body: &StringReadCloser{strings.NewReader(`{"access_token": "access_token", "expires_in": 1000, "refresh_token": "refresh_token", "token_type": "access"}`)}, - }, nil - case "https://write.as/oauth/inspect": + oauthClient: writeAsOauthClient{ + ClientID: app.Config().WriteAsOauth.ClientID, + ClientSecret: app.Config().WriteAsOauth.ClientSecret, + ExchangeLocation: app.Config().WriteAsOauth.TokenLocation, + InspectLocation: app.Config().WriteAsOauth.InspectLocation, + AuthLocation: app.Config().WriteAsOauth.AuthLocation, + CallbackLocation: "http://localhost/oauth/callback", + HttpClient: &MockHTTPClient{ + DoDo: func(req *http.Request) (*http.Response, error) { + switch req.URL.String() { + case "https://write.as/oauth/token": + return &http.Response{ + StatusCode: 200, + Body: &StringReadCloser{strings.NewReader(`{"access_token": "access_token", "expires_in": 1000, "refresh_token": "refresh_token", "token_type": "access"}`)}, + }, nil + case "https://write.as/oauth/inspect": + return &http.Response{ + StatusCode: 200, + Body: &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": "1", "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)}, + }, nil + } + return &http.Response{ - StatusCode: 200, - Body: &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": 1, "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)}, + StatusCode: http.StatusNotFound, }, nil - } - - return &http.Response{ - StatusCode: http.StatusNotFound, - }, nil + }, }, }, } req, err := http.NewRequest("GET", "/oauth/callback", nil) assert.NoError(t, err) rr := httptest.NewRecorder() - h.viewOauthCallback(rr, req) + h.viewOauthCallback(nil, rr, req) assert.NoError(t, err) assert.Equal(t, http.StatusTemporaryRedirect, rr.Code) - }) } diff --git a/oauth_writeas.go b/oauth_writeas.go new file mode 100644 index 0000000..eb12f64 --- /dev/null +++ b/oauth_writeas.go @@ -0,0 +1,110 @@ +package writefreely + +import ( + "context" + "errors" + "net/http" + "net/url" + "strings" +) + +type writeAsOauthClient struct { + ClientID string + ClientSecret string + AuthLocation string + ExchangeLocation string + InspectLocation string + CallbackLocation string + HttpClient HttpClient +} + +var _ oauthClient = writeAsOauthClient{} + +const ( + writeAsAuthLocation = "https://write.as/oauth/login" + writeAsExchangeLocation = "https://write.as/oauth/token" + writeAsIdentityLocation = "https://write.as/oauth/inspect" +) + +func (c writeAsOauthClient) GetProvider() string { + return "write.as" +} + +func (c writeAsOauthClient) GetClientID() string { + return c.ClientID +} + +func (c writeAsOauthClient) buildLoginURL(state string) (string, error) { + u, err := url.Parse(c.AuthLocation) + if err != nil { + return "", err + } + q := u.Query() + q.Set("client_id", c.ClientID) + q.Set("redirect_uri", c.CallbackLocation) + q.Set("response_type", "code") + q.Set("state", state) + u.RawQuery = q.Encode() + return u.String(), nil +} + +func (c writeAsOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { + form := url.Values{} + form.Add("grant_type", "authorization_code") + form.Add("redirect_uri", c.CallbackLocation) + form.Add("code", code) + req, err := http.NewRequest("POST", c.ExchangeLocation, strings.NewReader(form.Encode())) + if err != nil { + return nil, err + } + req.WithContext(ctx) + req.Header.Set("User-Agent", "writefreely") + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth(c.ClientID, c.ClientSecret) + + resp, err := c.HttpClient.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, errors.New("unable to exchange code for access token") + } + + var tokenResponse TokenResponse + if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil { + return nil, err + } + if tokenResponse.Error != "" { + return nil, errors.New(tokenResponse.Error) + } + return &tokenResponse, nil +} + +func (c writeAsOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) { + req, err := http.NewRequest("GET", c.InspectLocation, nil) + if err != nil { + return nil, err + } + req.WithContext(ctx) + req.Header.Set("User-Agent", "writefreely") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+accessToken) + + resp, err := c.HttpClient.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, errors.New("unable to inspect access token") + } + + var inspectResponse InspectResponse + if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil { + return nil, err + } + if inspectResponse.Error != "" { + return nil, errors.New(inspectResponse.Error) + } + return &inspectResponse, nil +} diff --git a/routes.go b/routes.go index 6accc99..a3d4a1c 100644 --- a/routes.go +++ b/routes.go @@ -1,218 +1,212 @@ /* * Copyright © 2018-2019 A Bunch Tell LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package writefreely import ( "net/http" "path/filepath" "strings" "github.com/gorilla/mux" "github.com/writeas/go-webfinger" "github.com/writeas/web-core/log" "github.com/writefreely/go-nodeinfo" ) // InitStaticRoutes adds routes for serving static files. // TODO: this should just be a func, not method func (app *App) InitStaticRoutes(r *mux.Router) { // Handle static files fs := http.FileServer(http.Dir(filepath.Join(app.cfg.Server.StaticParentDir, staticDir))) app.shttp = http.NewServeMux() app.shttp.Handle("/", fs) r.PathPrefix("/").Handler(fs) } // InitRoutes adds dynamic routes for the given mux.Router. func InitRoutes(apper Apper, r *mux.Router) *mux.Router { // Create handler handler := NewWFHandler(apper) // Set up routes hostSubroute := apper.App().cfg.App.Host[strings.Index(apper.App().cfg.App.Host, "://")+3:] if apper.App().cfg.App.SingleUser { hostSubroute = "{domain}" } else { if strings.HasPrefix(hostSubroute, "localhost") { hostSubroute = "localhost" } } if apper.App().cfg.App.SingleUser { log.Info("Adding %s routes (single user)...", hostSubroute) } else { log.Info("Adding %s routes (multi-user)...", hostSubroute) } // Primary app routes write := r.PathPrefix("/").Subrouter() // Federation endpoint configurations wf := webfinger.Default(wfResolver{apper.App().db, apper.App().cfg}) wf.NoTLSHandler = nil // Federation endpoints // host-meta write.HandleFunc("/.well-known/host-meta", handler.Web(handleViewHostMeta, UserLevelReader)) // webfinger write.HandleFunc(webfinger.WebFingerPath, handler.LogHandlerFunc(http.HandlerFunc(wf.Webfinger))) // nodeinfo niCfg := nodeInfoConfig(apper.App().db, apper.App().cfg) ni := nodeinfo.NewService(*niCfg, nodeInfoResolver{apper.App().cfg, apper.App().db}) write.HandleFunc(nodeinfo.NodeInfoPath, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfoDiscover))) write.HandleFunc(niCfg.InfoURL, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfo))) + configureSlackOauth(handler, write, apper.App()) + configureWriteAsOauth(handler, write, apper.App()) + // Set up dyamic page handlers // Handle auth auth := write.PathPrefix("/api/auth/").Subrouter() if apper.App().cfg.App.OpenRegistration { auth.HandleFunc("/signup", handler.All(apiSignup)).Methods("POST") } auth.HandleFunc("/login", handler.All(login)).Methods("POST") auth.HandleFunc("/read", handler.WebErrors(handleWebCollectionUnlock, UserLevelNone)).Methods("POST") auth.HandleFunc("/me", handler.All(handleAPILogout)).Methods("DELETE") - oauthHandler := oauthHandler{ - HttpClient: &http.Client{}, - Config: apper.App().Config(), - DB: apper.App().DB(), - Store: apper.App().SessionStore(), - } - - write.HandleFunc("/oauth/write.as", handler.OAuth(oauthHandler.viewOauthInit)).Methods("GET") - write.HandleFunc("/oauth/callback", handler.OAuth(oauthHandler.viewOauthCallback)).Methods("GET") - // Handle logged in user sections me := write.PathPrefix("/me").Subrouter() me.HandleFunc("/", handler.Redirect("/me", UserLevelUser)) me.HandleFunc("/c", handler.Redirect("/me/c/", UserLevelUser)).Methods("GET") me.HandleFunc("/c/", handler.User(viewCollections)).Methods("GET") me.HandleFunc("/c/{collection}", handler.User(viewEditCollection)).Methods("GET") me.HandleFunc("/c/{collection}/stats", handler.User(viewStats)).Methods("GET") me.HandleFunc("/posts", handler.Redirect("/me/posts/", UserLevelUser)).Methods("GET") me.HandleFunc("/posts/", handler.User(viewArticles)).Methods("GET") me.HandleFunc("/posts/export.csv", handler.Download(viewExportPosts, UserLevelUser)).Methods("GET") me.HandleFunc("/posts/export.zip", handler.Download(viewExportPosts, UserLevelUser)).Methods("GET") me.HandleFunc("/posts/export.json", handler.Download(viewExportPosts, UserLevelUser)).Methods("GET") me.HandleFunc("/export", handler.User(viewExportOptions)).Methods("GET") me.HandleFunc("/export.json", handler.Download(viewExportFull, UserLevelUser)).Methods("GET") me.HandleFunc("/settings", handler.User(viewSettings)).Methods("GET") me.HandleFunc("/invites", handler.User(handleViewUserInvites)).Methods("GET") me.HandleFunc("/logout", handler.Web(viewLogout, UserLevelNone)).Methods("GET") write.HandleFunc("/api/me", handler.All(viewMeAPI)).Methods("GET") apiMe := write.PathPrefix("/api/me/").Subrouter() apiMe.HandleFunc("/", handler.All(viewMeAPI)).Methods("GET") apiMe.HandleFunc("/posts", handler.UserAPI(viewMyPostsAPI)).Methods("GET") apiMe.HandleFunc("/collections", handler.UserAPI(viewMyCollectionsAPI)).Methods("GET") apiMe.HandleFunc("/password", handler.All(updatePassphrase)).Methods("POST") apiMe.HandleFunc("/self", handler.All(updateSettings)).Methods("POST") apiMe.HandleFunc("/invites", handler.User(handleCreateUserInvite)).Methods("POST") // Sign up validation write.HandleFunc("/api/alias", handler.All(handleUsernameCheck)).Methods("POST") // Handle collections write.HandleFunc("/api/collections", handler.All(newCollection)).Methods("POST") apiColls := write.PathPrefix("/api/collections/").Subrouter() apiColls.HandleFunc("/{alias:[0-9a-zA-Z\\-]+}", handler.AllReader(fetchCollection)).Methods("GET") apiColls.HandleFunc("/{alias:[0-9a-zA-Z\\-]+}", handler.All(existingCollection)).Methods("POST", "DELETE") apiColls.HandleFunc("/{alias}/posts", handler.AllReader(fetchCollectionPosts)).Methods("GET") apiColls.HandleFunc("/{alias}/posts", handler.All(newPost)).Methods("POST") apiColls.HandleFunc("/{alias}/posts/{post}", handler.AllReader(fetchPost)).Methods("GET") apiColls.HandleFunc("/{alias}/posts/{post:[a-zA-Z0-9]{10}}", handler.All(existingPost)).Methods("POST") apiColls.HandleFunc("/{alias}/posts/{post}/{property}", handler.AllReader(fetchPostProperty)).Methods("GET") apiColls.HandleFunc("/{alias}/collect", handler.All(addPost)).Methods("POST") apiColls.HandleFunc("/{alias}/pin", handler.All(pinPost)).Methods("POST") apiColls.HandleFunc("/{alias}/unpin", handler.All(pinPost)).Methods("POST") apiColls.HandleFunc("/{alias}/inbox", handler.All(handleFetchCollectionInbox)).Methods("POST") apiColls.HandleFunc("/{alias}/outbox", handler.AllReader(handleFetchCollectionOutbox)).Methods("GET") apiColls.HandleFunc("/{alias}/following", handler.AllReader(handleFetchCollectionFollowing)).Methods("GET") apiColls.HandleFunc("/{alias}/followers", handler.AllReader(handleFetchCollectionFollowers)).Methods("GET") // Handle posts write.HandleFunc("/api/posts", handler.All(newPost)).Methods("POST") posts := write.PathPrefix("/api/posts/").Subrouter() posts.HandleFunc("/{post:[a-zA-Z0-9]{10}}", handler.AllReader(fetchPost)).Methods("GET") posts.HandleFunc("/{post:[a-zA-Z0-9]{10}}", handler.All(existingPost)).Methods("POST", "PUT") posts.HandleFunc("/{post:[a-zA-Z0-9]{10}}", handler.All(deletePost)).Methods("DELETE") posts.HandleFunc("/{post:[a-zA-Z0-9]{10}}/{property}", handler.AllReader(fetchPostProperty)).Methods("GET") posts.HandleFunc("/claim", handler.All(addPost)).Methods("POST") posts.HandleFunc("/disperse", handler.All(dispersePost)).Methods("POST") write.HandleFunc("/auth/signup", handler.Web(handleWebSignup, UserLevelNoneRequired)).Methods("POST") write.HandleFunc("/auth/login", handler.Web(webLogin, UserLevelNoneRequired)).Methods("POST") write.HandleFunc("/admin", handler.Admin(handleViewAdminDash)).Methods("GET") write.HandleFunc("/admin/users", handler.Admin(handleViewAdminUsers)).Methods("GET") write.HandleFunc("/admin/user/{username}", handler.Admin(handleViewAdminUser)).Methods("GET") write.HandleFunc("/admin/user/{username}/status", handler.Admin(handleAdminToggleUserStatus)).Methods("POST") write.HandleFunc("/admin/user/{username}/passphrase", handler.Admin(handleAdminResetUserPass)).Methods("POST") write.HandleFunc("/admin/pages", handler.Admin(handleViewAdminPages)).Methods("GET") write.HandleFunc("/admin/page/{slug}", handler.Admin(handleViewAdminPage)).Methods("GET") write.HandleFunc("/admin/update/config", handler.AdminApper(handleAdminUpdateConfig)).Methods("POST") write.HandleFunc("/admin/update/{page}", handler.Admin(handleAdminUpdateSite)).Methods("POST") // Handle special pages first write.HandleFunc("/login", handler.Web(viewLogin, UserLevelNoneRequired)) write.HandleFunc("/signup", handler.Web(handleViewLanding, UserLevelNoneRequired)) write.HandleFunc("/invite/{code}", handler.Web(handleViewInvite, UserLevelOptional)).Methods("GET") // TODO: show a reader-specific 404 page if the function is disabled write.HandleFunc("/read", handler.Web(viewLocalTimeline, UserLevelReader)) RouteRead(handler, UserLevelReader, write.PathPrefix("/read").Subrouter()) draftEditPrefix := "" if apper.App().cfg.App.SingleUser { draftEditPrefix = "/d" write.HandleFunc("/me/new", handler.Web(handleViewPad, UserLevelOptional)).Methods("GET") } else { write.HandleFunc("/new", handler.Web(handleViewPad, UserLevelOptional)).Methods("GET") } // All the existing stuff write.HandleFunc(draftEditPrefix+"/{action}/edit", handler.Web(handleViewPad, UserLevelOptional)).Methods("GET") write.HandleFunc(draftEditPrefix+"/{action}/meta", handler.Web(handleViewMeta, UserLevelOptional)).Methods("GET") // Collections if apper.App().cfg.App.SingleUser { RouteCollections(handler, write.PathPrefix("/").Subrouter()) } else { write.HandleFunc("/{prefix:[@~$!\\-+]}{collection}", handler.Web(handleViewCollection, UserLevelReader)) write.HandleFunc("/{collection}/", handler.Web(handleViewCollection, UserLevelReader)) RouteCollections(handler, write.PathPrefix("/{prefix:[@~$!\\-+]?}{collection}").Subrouter()) // Posts } write.HandleFunc(draftEditPrefix+"/{post}", handler.Web(handleViewPost, UserLevelOptional)) write.HandleFunc("/", handler.Web(handleViewHome, UserLevelOptional)) + return r } func RouteCollections(handler *Handler, r *mux.Router) { r.HandleFunc("/page/{page:[0-9]+}", handler.Web(handleViewCollection, UserLevelReader)) r.HandleFunc("/tag:{tag}", handler.Web(handleViewCollectionTag, UserLevelReader)) r.HandleFunc("/tag:{tag}/feed/", handler.Web(ViewFeed, UserLevelReader)) r.HandleFunc("/tags/{tag}", handler.Web(handleViewCollectionTag, UserLevelReader)) r.HandleFunc("/sitemap.xml", handler.AllReader(handleViewSitemap)) r.HandleFunc("/feed/", handler.AllReader(ViewFeed)) r.HandleFunc("/{slug}", handler.CollectionPostOrStatic) r.HandleFunc("/{slug}/edit", handler.Web(handleViewPad, UserLevelUser)) r.HandleFunc("/{slug}/edit/meta", handler.Web(handleViewMeta, UserLevelUser)) r.HandleFunc("/{slug}/", handler.Web(handleCollectionPostRedirect, UserLevelReader)).Methods("GET") } func RouteRead(handler *Handler, readPerm UserLevelFunc, r *mux.Router) { r.HandleFunc("/api/posts", handler.Web(viewLocalTimelineAPI, readPerm)) r.HandleFunc("/p/{page}", handler.Web(viewLocalTimeline, readPerm)) r.HandleFunc("/feed/", handler.Web(viewLocalTimelineFeed, readPerm)) r.HandleFunc("/t/{tag}", handler.Web(viewLocalTimeline, readPerm)) r.HandleFunc("/a/{post}", handler.Web(handlePostIDRedirect, readPerm)) r.HandleFunc("/{author}", handler.Web(viewLocalTimeline, readPerm)) r.HandleFunc("/", handler.Web(viewLocalTimeline, readPerm)) }