diff --git a/admin.go b/admin.go index 6e0f3d5..1478dfd 100644 --- a/admin.go +++ b/admin.go @@ -1,693 +1,693 @@ /* * Copyright © 2018-2021 Musing Studio LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package writefreely import ( "database/sql" "fmt" "html/template" "net/http" "runtime" "strconv" "strings" "time" "github.com/gorilla/mux" "github.com/writeas/impart" "github.com/writeas/web-core/auth" "github.com/writeas/web-core/log" "github.com/writeas/web-core/passgen" "github.com/writefreely/writefreely/appstats" "github.com/writefreely/writefreely/config" ) var ( appStartTime = time.Now() sysStatus systemStatus ) const adminUsersPerPage = 30 type systemStatus struct { Uptime string NumGoroutine int // General statistics. MemAllocated string // bytes allocated and still in use MemTotal string // bytes allocated (even if freed) MemSys string // bytes obtained from system (sum of XxxSys below) Lookups uint64 // number of pointer lookups MemMallocs uint64 // number of mallocs MemFrees uint64 // number of frees // Main allocation heap statistics. HeapAlloc string // bytes allocated and still in use HeapSys string // bytes obtained from system HeapIdle string // bytes in idle spans HeapInuse string // bytes in non-idle span HeapReleased string // bytes released to the OS HeapObjects uint64 // total number of allocated objects // Low-level fixed-size structure allocator statistics. // Inuse is bytes used now. // Sys is bytes obtained from system. StackInuse string // bootstrap stacks StackSys string MSpanInuse string // mspan structures MSpanSys string MCacheInuse string // mcache structures MCacheSys string BuckHashSys string // profiling bucket hash table GCSys string // GC metadata OtherSys string // other system allocations // Garbage collector statistics. NextGC string // next run in HeapAlloc time (bytes) LastGC string // last run in absolute time (ns) PauseTotalNs string PauseNs string // circular buffer of recent GC pause times, most recent at [(NumGC+255)%256] NumGC uint32 } type inspectedCollection struct { CollectionObj Followers int LastPost string } type instanceContent struct { ID string Type string Title sql.NullString Content string Updated time.Time } type AdminPage struct { UpdateAvailable bool } func NewAdminPage(app *App) *AdminPage { ap := &AdminPage{} if app.updates != nil { ap.UpdateAvailable = app.updates.AreAvailableNoCheck() } return ap } func (c instanceContent) UpdatedFriendly() template.HTML { /* // TODO: accept a locale in this method and use that for the format var loc monday.Locale = monday.LocaleEnUS return monday.Format(u.Created, monday.DateTimeFormatsByLocale[loc], loc) */ if c.Updated.IsZero() { return "Never" } return template.HTML(c.Updated.Format("January 2, 2006, 3:04 PM")) } func handleViewAdminDash(app *App, u *User, w http.ResponseWriter, r *http.Request) error { p := struct { *UserPage *AdminPage Message string UsersCount, CollectionsCount, PostsCount int64 }{ UserPage: NewUserPage(app, r, u, "Admin", nil), AdminPage: NewAdminPage(app), Message: r.FormValue("m"), } // Get user stats p.UsersCount = app.db.GetAllUsersCount() var err error p.CollectionsCount, err = app.db.GetTotalCollections() if err != nil { return err } p.PostsCount, err = app.db.GetTotalPosts() if err != nil { return err } showUserPage(w, "admin", p) return nil } func handleViewAdminMonitor(app *App, u *User, w http.ResponseWriter, r *http.Request) error { updateAppStats() p := struct { *UserPage *AdminPage SysStatus systemStatus Config config.AppCfg Message, ConfigMessage string }{ UserPage: NewUserPage(app, r, u, "Admin", nil), AdminPage: NewAdminPage(app), SysStatus: sysStatus, Config: app.cfg.App, Message: r.FormValue("m"), ConfigMessage: r.FormValue("cm"), } showUserPage(w, "monitor", p) return nil } func handleViewAdminSettings(app *App, u *User, w http.ResponseWriter, r *http.Request) error { p := struct { *UserPage *AdminPage Config config.AppCfg Message, ConfigMessage string }{ UserPage: NewUserPage(app, r, u, "Admin", nil), AdminPage: NewAdminPage(app), Config: app.cfg.App, Message: r.FormValue("m"), ConfigMessage: r.FormValue("cm"), } showUserPage(w, "app-settings", p) return nil } func handleViewAdminUsers(app *App, u *User, w http.ResponseWriter, r *http.Request) error { p := struct { *UserPage *AdminPage Config config.AppCfg Message string Flashes []string Users *[]User CurPage int TotalUsers int64 TotalPages []int }{ UserPage: NewUserPage(app, r, u, "Users", nil), AdminPage: NewAdminPage(app), Config: app.cfg.App, Message: r.FormValue("m"), } p.Flashes, _ = getSessionFlashes(app, w, r, nil) p.TotalUsers = app.db.GetAllUsersCount() - ttlPages := (p.TotalUsers - 1) / adminUsersPerPage + 1 + ttlPages := (p.TotalUsers-1)/adminUsersPerPage + 1 p.TotalPages = []int{} for i := 1; i <= int(ttlPages); i++ { p.TotalPages = append(p.TotalPages, i) } var err error p.CurPage, err = strconv.Atoi(r.FormValue("p")) if err != nil || p.CurPage < 1 { p.CurPage = 1 } else if p.CurPage > int(ttlPages) { p.CurPage = int(ttlPages) } p.Users, err = app.db.GetAllUsers(uint(p.CurPage)) if err != nil { return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not get users: %v", err)} } showUserPage(w, "users", p) return nil } func handleViewAdminUser(app *App, u *User, w http.ResponseWriter, r *http.Request) error { vars := mux.Vars(r) username := vars["username"] if username == "" { return impart.HTTPError{http.StatusFound, "/admin/users"} } p := struct { *UserPage *AdminPage Config config.AppCfg Message string User *User Colls []inspectedCollection LastPost string NewPassword string TotalPosts int64 ClearEmail string }{ AdminPage: NewAdminPage(app), Config: app.cfg.App, Message: r.FormValue("m"), Colls: []inspectedCollection{}, } var err error p.User, err = app.db.GetUserForAuth(username) if err != nil { if err == ErrUserNotFound { return err } log.Error("Could not get user: %v", err) return impart.HTTPError{http.StatusInternalServerError, err.Error()} } flashes, _ := getSessionFlashes(app, w, r, nil) for _, flash := range flashes { if strings.HasPrefix(flash, "SUCCESS: ") { p.NewPassword = strings.TrimPrefix(flash, "SUCCESS: ") p.ClearEmail = p.User.EmailClear(app.keys) } } p.UserPage = NewUserPage(app, r, u, p.User.Username, nil) p.TotalPosts = app.db.GetUserPostsCount(p.User.ID) lp, err := app.db.GetUserLastPostTime(p.User.ID) if err != nil { return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not get user's last post time: %v", err)} } if lp != nil { p.LastPost = lp.Format("January 2, 2006, 3:04 PM") } colls, err := app.db.GetCollections(p.User, app.cfg.App.Host) if err != nil { return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not get user's collections: %v", err)} } for _, c := range *colls { ic := inspectedCollection{ CollectionObj: CollectionObj{Collection: c}, } if app.cfg.App.Federation { folls, err := app.db.GetAPFollowers(&c) if err == nil { // TODO: handle error here (at least log it) ic.Followers = len(*folls) } } app.db.GetPostsCount(&ic.CollectionObj, true) lp, err := app.db.GetCollectionLastPostTime(c.ID) if err != nil { log.Error("Didn't get last post time for collection %d: %v", c.ID, err) } if lp != nil { ic.LastPost = lp.Format("January 2, 2006, 3:04 PM") } p.Colls = append(p.Colls, ic) } showUserPage(w, "view-user", p) return nil } func handleAdminDeleteUser(app *App, u *User, w http.ResponseWriter, r *http.Request) error { if !u.IsAdmin() { return impart.HTTPError{http.StatusForbidden, "Administrator privileges required for this action"} } vars := mux.Vars(r) username := vars["username"] confirmUsername := r.PostFormValue("confirm-username") if confirmUsername != username { return impart.HTTPError{http.StatusBadRequest, "Username was not confirmed"} } user, err := app.db.GetUserForAuth(username) if err == ErrUserNotFound { return impart.HTTPError{http.StatusNotFound, fmt.Sprintf("User '%s' was not found", username)} } else if err != nil { log.Error("get user for deletion: %v", err) return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not get user with username '%s': %v", username, err)} } err = app.db.DeleteAccount(user.ID) if err != nil { log.Error("delete user %s: %v", user.Username, err) return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not delete user account for '%s': %v", username, err)} } _ = addSessionFlash(app, w, r, fmt.Sprintf("User \"%s\" was deleted successfully.", username), nil) return impart.HTTPError{http.StatusFound, "/admin/users"} } func handleAdminToggleUserStatus(app *App, u *User, w http.ResponseWriter, r *http.Request) error { vars := mux.Vars(r) username := vars["username"] if username == "" { return impart.HTTPError{http.StatusFound, "/admin/users"} } user, err := app.db.GetUserForAuth(username) if err != nil { log.Error("failed to get user: %v", err) return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not get user from username: %v", err)} } if user.IsSilenced() { err = app.db.SetUserStatus(user.ID, UserActive) } else { err = app.db.SetUserStatus(user.ID, UserSilenced) // reset the cache to removed silence user posts updateTimelineCache(app.timeline, true) } if err != nil { log.Error("toggle user silenced: %v", err) return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not toggle user status: %v", err)} } return impart.HTTPError{http.StatusFound, fmt.Sprintf("/admin/user/%s#status", username)} } func handleAdminResetUserPass(app *App, u *User, w http.ResponseWriter, r *http.Request) error { vars := mux.Vars(r) username := vars["username"] if username == "" { return impart.HTTPError{http.StatusFound, "/admin/users"} } // Generate new random password since none supplied pass := passgen.NewWordish() hashedPass, err := auth.HashPass([]byte(pass)) if err != nil { return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not create password hash: %v", err)} } userIDVal := r.FormValue("user") log.Info("ADMIN: Changing user %s password", userIDVal) id, err := strconv.Atoi(userIDVal) if err != nil { return impart.HTTPError{http.StatusBadRequest, fmt.Sprintf("Invalid user ID: %v", err)} } err = app.db.ChangePassphrase(int64(id), true, "", hashedPass) if err != nil { return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not update passphrase: %v", err)} } log.Info("ADMIN: Successfully changed.") addSessionFlash(app, w, r, fmt.Sprintf("SUCCESS: %s", pass), nil) return impart.HTTPError{http.StatusFound, fmt.Sprintf("/admin/user/%s", username)} } func handleViewAdminPages(app *App, u *User, w http.ResponseWriter, r *http.Request) error { p := struct { *UserPage *AdminPage Config config.AppCfg Message string Pages []*instanceContent }{ UserPage: NewUserPage(app, r, u, "Pages", nil), AdminPage: NewAdminPage(app), Config: app.cfg.App, Message: r.FormValue("m"), } var err error p.Pages, err = app.db.GetInstancePages() if err != nil { return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not get pages: %v", err)} } // Add in default pages var hasAbout, hasContact, hasPrivacy bool for i, c := range p.Pages { if hasAbout && hasContact && hasPrivacy { break } if c.ID == "about" { hasAbout = true if !c.Title.Valid { p.Pages[i].Title = defaultAboutTitle(app.cfg) } } else if c.ID == "contact" { hasContact = true if !c.Title.Valid { p.Pages[i].Title = defaultContactTitle() } } else if c.ID == "privacy" { hasPrivacy = true if !c.Title.Valid { p.Pages[i].Title = defaultPrivacyTitle() } } } if !hasAbout { p.Pages = append(p.Pages, &instanceContent{ ID: "about", Title: defaultAboutTitle(app.cfg), Content: defaultAboutPage(app.cfg), Updated: defaultPageUpdatedTime, }) } if !hasContact { p.Pages = append(p.Pages, &instanceContent{ ID: "contact", Title: defaultContactTitle(), Content: defaultContactPage(app), }) } if !hasPrivacy { p.Pages = append(p.Pages, &instanceContent{ ID: "privacy", Title: defaultPrivacyTitle(), Content: defaultPrivacyPolicy(app.cfg), Updated: defaultPageUpdatedTime, }) } showUserPage(w, "pages", p) return nil } func handleViewAdminPage(app *App, u *User, w http.ResponseWriter, r *http.Request) error { vars := mux.Vars(r) slug := vars["slug"] if slug == "" { return impart.HTTPError{http.StatusFound, "/admin/pages"} } p := struct { *UserPage *AdminPage Config config.AppCfg Message string Banner *instanceContent Content *instanceContent }{ AdminPage: NewAdminPage(app), Config: app.cfg.App, Message: r.FormValue("m"), } var err error // Get pre-defined pages, or select slug if slug == "about" { p.Content, err = getAboutPage(app) } else if slug == "contact" { p.Content, err = getContactPage(app) } else if slug == "privacy" { p.Content, err = getPrivacyPage(app) } else if slug == "landing" { p.Banner, err = getLandingBanner(app) if err != nil { return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not get banner: %v", err)} } p.Content, err = getLandingBody(app) p.Content.ID = "landing" } else if slug == "reader" { p.Content, err = getReaderSection(app) } else { p.Content, err = app.db.GetDynamicContent(slug) } if err != nil { return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not get page: %v", err)} } title := "New page" if p.Content != nil { title = "Edit " + p.Content.ID } else { p.Content = &instanceContent{} } p.UserPage = NewUserPage(app, r, u, title, nil) showUserPage(w, "view-page", p) return nil } func handleAdminUpdateSite(app *App, u *User, w http.ResponseWriter, r *http.Request) error { vars := mux.Vars(r) id := vars["page"] // Validate if id != "about" && id != "contact" && id != "privacy" && id != "landing" && id != "reader" { return impart.HTTPError{http.StatusNotFound, "No such page."} } var err error m := "" if id == "landing" { // Handle special landing page err = app.db.UpdateDynamicContent("landing-banner", "", r.FormValue("banner"), "section") if err != nil { m = "?m=" + err.Error() return impart.HTTPError{http.StatusFound, "/admin/page/" + id + m} } err = app.db.UpdateDynamicContent("landing-body", "", r.FormValue("content"), "section") } else if id == "reader" { // Update sections with titles err = app.db.UpdateDynamicContent(id, r.FormValue("title"), r.FormValue("content"), "section") } else { // Update page err = app.db.UpdateDynamicContent(id, r.FormValue("title"), r.FormValue("content"), "page") } if err != nil { m = "?m=" + err.Error() } return impart.HTTPError{http.StatusFound, "/admin/page/" + id + m} } func handleAdminUpdateConfig(apper Apper, u *User, w http.ResponseWriter, r *http.Request) error { apper.App().cfg.App.SiteName = r.FormValue("site_name") apper.App().cfg.App.SiteDesc = r.FormValue("site_desc") apper.App().cfg.App.Landing = r.FormValue("landing") apper.App().cfg.App.OpenRegistration = r.FormValue("open_registration") == "on" apper.App().cfg.App.OpenDeletion = r.FormValue("open_deletion") == "on" mul, err := strconv.Atoi(r.FormValue("min_username_len")) if err == nil { apper.App().cfg.App.MinUsernameLen = mul } mb, err := strconv.Atoi(r.FormValue("max_blogs")) if err == nil { apper.App().cfg.App.MaxBlogs = mb } apper.App().cfg.App.Federation = r.FormValue("federation") == "on" apper.App().cfg.App.PublicStats = r.FormValue("public_stats") == "on" apper.App().cfg.App.Monetization = r.FormValue("monetization") == "on" apper.App().cfg.App.Private = r.FormValue("private") == "on" apper.App().cfg.App.LocalTimeline = r.FormValue("local_timeline") == "on" if apper.App().cfg.App.LocalTimeline && apper.App().timeline == nil { log.Info("Initializing local timeline...") initLocalTimeline(apper.App()) } apper.App().cfg.App.UserInvites = r.FormValue("user_invites") if apper.App().cfg.App.UserInvites == "none" { apper.App().cfg.App.UserInvites = "" } apper.App().cfg.App.DefaultVisibility = r.FormValue("default_visibility") m := "?cm=Configuration+saved." err = apper.SaveConfig(apper.App().cfg) if err != nil { m = "?cm=" + err.Error() } return impart.HTTPError{http.StatusFound, "/admin/settings" + m + "#config"} } func updateAppStats() { sysStatus.Uptime = appstats.TimeSincePro(appStartTime) m := new(runtime.MemStats) runtime.ReadMemStats(m) sysStatus.NumGoroutine = runtime.NumGoroutine() sysStatus.MemAllocated = appstats.FileSize(int64(m.Alloc)) sysStatus.MemTotal = appstats.FileSize(int64(m.TotalAlloc)) sysStatus.MemSys = appstats.FileSize(int64(m.Sys)) sysStatus.Lookups = m.Lookups sysStatus.MemMallocs = m.Mallocs sysStatus.MemFrees = m.Frees sysStatus.HeapAlloc = appstats.FileSize(int64(m.HeapAlloc)) sysStatus.HeapSys = appstats.FileSize(int64(m.HeapSys)) sysStatus.HeapIdle = appstats.FileSize(int64(m.HeapIdle)) sysStatus.HeapInuse = appstats.FileSize(int64(m.HeapInuse)) sysStatus.HeapReleased = appstats.FileSize(int64(m.HeapReleased)) sysStatus.HeapObjects = m.HeapObjects sysStatus.StackInuse = appstats.FileSize(int64(m.StackInuse)) sysStatus.StackSys = appstats.FileSize(int64(m.StackSys)) sysStatus.MSpanInuse = appstats.FileSize(int64(m.MSpanInuse)) sysStatus.MSpanSys = appstats.FileSize(int64(m.MSpanSys)) sysStatus.MCacheInuse = appstats.FileSize(int64(m.MCacheInuse)) sysStatus.MCacheSys = appstats.FileSize(int64(m.MCacheSys)) sysStatus.BuckHashSys = appstats.FileSize(int64(m.BuckHashSys)) sysStatus.GCSys = appstats.FileSize(int64(m.GCSys)) sysStatus.OtherSys = appstats.FileSize(int64(m.OtherSys)) sysStatus.NextGC = appstats.FileSize(int64(m.NextGC)) sysStatus.LastGC = fmt.Sprintf("%.1fs", float64(time.Now().UnixNano()-int64(m.LastGC))/1000/1000/1000) sysStatus.PauseTotalNs = fmt.Sprintf("%.1fs", float64(m.PauseTotalNs)/1000/1000/1000) sysStatus.PauseNs = fmt.Sprintf("%.3fs", float64(m.PauseNs[(m.NumGC+255)%256])/1000/1000/1000) sysStatus.NumGC = m.NumGC } func adminResetPassword(app *App, u *User, newPass string) error { hashedPass, err := auth.HashPass([]byte(newPass)) if err != nil { return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not create password hash: %v", err)} } err = app.db.ChangePassphrase(u.ID, true, "", hashedPass) if err != nil { return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not update passphrase: %v", err)} } return nil } func handleViewAdminUpdates(app *App, u *User, w http.ResponseWriter, r *http.Request) error { check := r.URL.Query().Get("check") if check == "now" && app.cfg.App.UpdateChecks { app.updates.CheckNow() } p := struct { *UserPage *AdminPage CurReleaseNotesURL string LastChecked string LastChecked8601 string LatestVersion string LatestReleaseURL string LatestReleaseNotesURL string CheckFailed bool }{ UserPage: NewUserPage(app, r, u, "Updates", nil), AdminPage: NewAdminPage(app), } p.CurReleaseNotesURL = wfReleaseNotesURL(p.Version) if app.cfg.App.UpdateChecks { p.LastChecked = app.updates.lastCheck.Format("January 2, 2006, 3:04 PM") p.LastChecked8601 = app.updates.lastCheck.Format("2006-01-02T15:04:05Z") p.LatestVersion = app.updates.LatestVersion() p.LatestReleaseURL = app.updates.ReleaseURL() p.LatestReleaseNotesURL = app.updates.ReleaseNotesURL() p.UpdateAvailable = app.updates.AreAvailable() p.CheckFailed = app.updates.checkError != nil } showUserPage(w, "app-updates", p) return nil } diff --git a/app.go b/app.go index 6d71cfc..c08090d 100644 --- a/app.go +++ b/app.go @@ -1,1023 +1,1023 @@ /* * Copyright © 2018-2021 Musing Studio LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package writefreely import ( "crypto/tls" "database/sql" _ "embed" "fmt" "html/template" "net" "net/http" "net/url" "os" "os/signal" "path/filepath" "regexp" "strings" "syscall" "time" "github.com/gorilla/mux" "github.com/gorilla/schema" "github.com/gorilla/sessions" "github.com/manifoldco/promptui" stripmd "github.com/writeas/go-strip-markdown/v2" "github.com/writeas/impart" "github.com/writeas/web-core/auth" "github.com/writeas/web-core/converter" "github.com/writeas/web-core/log" "golang.org/x/crypto/acme/autocert" "github.com/writefreely/writefreely/author" "github.com/writefreely/writefreely/config" "github.com/writefreely/writefreely/key" "github.com/writefreely/writefreely/migrations" "github.com/writefreely/writefreely/page" ) const ( staticDir = "static" assumedTitleLen = 80 postsPerPage = 10 postsPerArchPage = 40 serverSoftware = "WriteFreely" softwareURL = "https://writefreely.org" ) var ( debugging bool // Software version can be set from git env using -ldflags softwareVer = "0.16.0" // DEPRECATED VARS isSingleUser bool // Canonical URL helpers for code paths that still build URLs from shared globals. - canonicalAppHost string - canonicalSubdir string + canonicalAppHost string + canonicalSubdir string ) // App holds data and configuration for an individual WriteFreely instance. type App struct { router *mux.Router shttp *http.ServeMux db *datastore cfg *config.Config cfgFile string keys *key.Keychain sessionStore sessions.Store formDecoder *schema.Decoder updates *updatesCache timeline *localTimeline } // DB returns the App's datastore func (app *App) DB() *datastore { return app.db } // Router returns the App's router func (app *App) Router() *mux.Router { return app.router } // Config returns the App's current configuration. func (app *App) Config() *config.Config { return app.cfg } // SetConfig updates the App's Config to the given value. func (app *App) SetConfig(cfg *config.Config) { app.cfg = cfg } // SetKeys updates the App's Keychain to the given value. func (app *App) SetKeys(k *key.Keychain) { app.keys = k } func (app *App) SessionStore() sessions.Store { return app.sessionStore } func (app *App) SetSessionStore(s sessions.Store) { app.sessionStore = s } // Apper is the interface for getting data into and out of a WriteFreely // instance (or "App"). // // App returns the App for the current instance. // // LoadConfig reads an app configuration into the App, returning any error // encountered. // // SaveConfig persists the current App configuration. // // LoadKeys reads the App's encryption keys and loads them into its // key.Keychain. type Apper interface { App() *App LoadConfig() error SaveConfig(*config.Config) error LoadKeys() error ReqLog(r *http.Request, status int, timeSince time.Duration) string } // App returns the App func (app *App) App() *App { return app } // LoadConfig loads and parses a config file. func (app *App) LoadConfig() error { log.Info("Loading %s configuration...", app.cfgFile) cfg, err := config.Load(app.cfgFile) if err != nil { log.Error("Unable to load configuration: %v", err) os.Exit(1) return err } app.cfg = cfg return nil } // SaveConfig saves the given Config to disk -- namely, to the App's cfgFile. func (app *App) SaveConfig(c *config.Config) error { return config.Save(c, app.cfgFile) } // LoadKeys reads all needed keys from disk into the App. In order to use the // configured `Server.KeysParentDir`, you must call initKeyPaths(App) before // this. func (app *App) LoadKeys() error { var err error app.keys = &key.Keychain{} if debugging { log.Info(" %s", emailKeyPath) } executable, err := os.Executable() if err != nil { executable = "writefreely" } else { executable = filepath.Base(executable) } app.keys.EmailKey, err = os.ReadFile(emailKeyPath) if err != nil { return err } if debugging { log.Info(" %s", cookieAuthKeyPath) } app.keys.CookieAuthKey, err = os.ReadFile(cookieAuthKeyPath) if err != nil { return err } if debugging { log.Info(" %s", cookieKeyPath) } app.keys.CookieKey, err = os.ReadFile(cookieKeyPath) if err != nil { return err } if debugging { log.Info(" %s", csrfKeyPath) } app.keys.CSRFKey, err = os.ReadFile(csrfKeyPath) if err != nil { if os.IsNotExist(err) { log.Error(`Missing key: %s. Run this command to generate missing keys: %s keys generate `, csrfKeyPath, executable) } return err } return nil } func (app *App) ReqLog(r *http.Request, status int, timeSince time.Duration) string { return fmt.Sprintf("\"%s %s\" %d %s \"%s\"", r.Method, r.RequestURI, status, timeSince, r.UserAgent()) } // handleViewHome shows page at root path. It checks the configuration and // authentication state to show the correct page. func handleViewHome(app *App, w http.ResponseWriter, r *http.Request) error { if app.cfg.App.SingleUser { // Render blog index return handleViewCollection(app, w, r) } // Multi-user instance forceLanding := r.FormValue("landing") == "1" if !forceLanding { // Show correct page based on user auth status and configured landing path u := getUserSession(app, r) if app.cfg.App.Chorus { // This instance is focused on reading, so show Reader on home route if not // private or a private-instance user is logged in. if !app.cfg.App.Private || u != nil { return viewLocalTimeline(app, w, r) } } if u != nil { // User is logged in, so show the Pad return handleViewPad(app, w, r) } if app.cfg.App.Private { return viewLogin(app, w, r) } if land := app.cfg.App.LandingPath(); land != "/" { return impart.HTTPError{http.StatusFound, land} } } return handleViewLanding(app, w, r) } func handleViewLanding(app *App, w http.ResponseWriter, r *http.Request) error { forceLanding := r.FormValue("landing") == "1" p := struct { page.StaticPage *OAuthButtons Flashes []template.HTML Banner template.HTML Content template.HTML ForcedLanding bool }{ StaticPage: pageForReq(app, r), OAuthButtons: NewOAuthButtons(app.Config()), ForcedLanding: forceLanding, } banner, err := getLandingBanner(app) if err != nil { log.Error("unable to get landing banner: %v", err) return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not get banner: %v", err)} } p.Banner = template.HTML(applyMarkdown([]byte(banner.Content), "", app.cfg)) content, err := getLandingBody(app) if err != nil { log.Error("unable to get landing content: %v", err) return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not get content: %v", err)} } p.Content = template.HTML(applyMarkdown([]byte(content.Content), "", app.cfg)) // Get error messages session, err := app.sessionStore.Get(r, cookieName) if err != nil { // Ignore this log.Error("Unable to get session in handleViewHome; ignoring: %v", err) } flashes, _ := getSessionFlashes(app, w, r, session) for _, flash := range flashes { p.Flashes = append(p.Flashes, template.HTML(flash)) } // Show landing page return renderPage(w, "landing.tmpl", p) } func handleTemplatedPage(app *App, w http.ResponseWriter, r *http.Request, t *template.Template) error { p := struct { page.StaticPage ContentTitle string Content template.HTML PlainContent string Updated string AboutStats *InstanceStats }{ StaticPage: pageForReq(app, r), } path := app.cfg.App.StripSubdirectory(r.URL.Path) if path == "/about" || path == "/contact" || path == "/privacy" { var c *instanceContent var err error if path == "/about" { c, err = getAboutPage(app) // Fetch stats p.AboutStats = &InstanceStats{} p.AboutStats.NumPosts, _ = app.db.GetTotalPosts() p.AboutStats.NumBlogs, _ = app.db.GetTotalCollections() } else if path == "/contact" { c, err = getContactPage(app) if c.Updated.IsZero() { // Page was never set up, so return 404 return ErrPostNotFound } } else { c, err = getPrivacyPage(app) } if err != nil { return err } p.ContentTitle = c.Title.String p.Content = template.HTML(applyMarkdown([]byte(c.Content), "", app.cfg)) p.PlainContent = shortPostDescription(stripmd.Strip(c.Content)) if !c.Updated.IsZero() { p.Updated = c.Updated.Format("January 2, 2006") } } // Serve templated page err := t.ExecuteTemplate(w, "base", p) if err != nil { log.Error("Unable to render page: %v", err) } return nil } func pageForReq(app *App, r *http.Request) page.StaticPage { p := page.StaticPage{ AppCfg: app.cfg.App, Path: app.cfg.App.StripSubdirectory(r.URL.Path), Version: "v" + softwareVer, } // Use custom style, if file exists if _, err := os.Stat(filepath.Join(app.cfg.Server.StaticParentDir, staticDir, "local", "custom.css")); err == nil { p.CustomCSS = true } // Add user information, if given var u *User accessToken := r.FormValue("t") if accessToken != "" { userID := app.db.GetUserID(accessToken) if userID != -1 { var err error u, err = app.db.GetUserByID(userID) if err == nil { p.Username = u.Username } } } else { u = getUserSession(app, r) if u != nil { p.Username = u.Username p.IsAdmin = u != nil && u.IsAdmin() p.CanInvite = canUserInvite(app.cfg, p.IsAdmin) } } p.CanViewReader = !app.cfg.App.Private || u != nil return p } var fileRegex = regexp.MustCompile("/([^/]*\\.[^/]*)$") // Initialize loads the app configuration and initializes templates, keys, // session, route handlers, and the database connection. func Initialize(apper Apper, debug bool) (*App, error) { debugging = debug apper.LoadConfig() // Load templates err := InitTemplates(apper.App().Config()) if err != nil { return nil, fmt.Errorf("load templates: %s", err) } // Load keys and set up session initKeyPaths(apper.App()) // TODO: find a better way to do this, since it's unneeded in all Apper implementations err = InitKeys(apper) if err != nil { return nil, fmt.Errorf("init keys: %s", err) } apper.App().InitUpdates() apper.App().InitSession() apper.App().InitDecoder() err = ConnectToDatabase(apper.App()) if err != nil { return nil, fmt.Errorf("connect to DB: %s", err) } initActivityPub(apper.App()) if apper.App().cfg.Email.Enabled() { log.Info("Starting publish jobs queue...") go startPublishJobsQueue(apper.App()) } else { log.Info("[jobs] Not starting publish jobs queue: no email provider is configured.") } // Handle local timeline, if enabled if apper.App().cfg.App.LocalTimeline { log.Info("Initializing local timeline...") initLocalTimeline(apper.App()) } return apper.App(), nil } func Serve(app *App, r *mux.Router) { log.Info("Going to serve...") isSingleUser = app.cfg.App.SingleUser canonicalAppHost = strings.TrimSuffix(app.cfg.App.Host, "/") canonicalSubdir = app.cfg.App.SubdirectoryPath() app.cfg.Server.Dev = debugging // Handle shutdown c := make(chan os.Signal, 2) signal.Notify(c, os.Interrupt, syscall.SIGTERM) go func() { <-c log.Info("Shutting down...") shutdown(app) log.Info("Done.") os.Exit(0) }() // Start gopher server if app.cfg.Server.GopherPort > 0 && !app.cfg.App.Private { go initGopher(app) } // Start web application server var bindAddress = app.cfg.Server.Bind if bindAddress == "" { bindAddress = "localhost" } var err error if app.cfg.IsSecureStandalone() { if app.cfg.Server.Autocert { m := &autocert.Manager{ Prompt: autocert.AcceptTOS, Cache: autocert.DirCache(app.cfg.Server.TLSCertPath), } host, err := url.Parse(app.cfg.App.Host) if err != nil { log.Error("[WARNING] Unable to parse configured host! %s", err) log.Error(`[WARNING] ALL hosts are allowed, which can open you to an attack where clients connect to a server by IP address and pretend to be asking for an incorrect host name, and cause you to reach the CA's rate limit for certificate requests. We recommend supplying a valid host name.`) log.Info("Using autocert on ANY host") } else { log.Info("Using autocert on host %s", host.Host) m.HostPolicy = autocert.HostWhitelist(host.Host) } s := &http.Server{ Addr: ":https", Handler: r, TLSConfig: &tls.Config{ GetCertificate: m.GetCertificate, }, } s.SetKeepAlivesEnabled(false) go func() { log.Info("Serving redirects on http://%s:80", bindAddress) err = http.ListenAndServe(":80", m.HTTPHandler(nil)) log.Error("Unable to start redirect server: %v", err) }() log.Info("Serving on https://%s:443", bindAddress) log.Info("---") err = s.ListenAndServeTLS("", "") } else { go func() { log.Info("Serving redirects on http://%s:80", bindAddress) err = http.ListenAndServe(fmt.Sprintf("%s:80", bindAddress), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { loc := app.cfg.App.AbsoluteURL(r.URL.Path) if r.URL.RawQuery != "" { loc += "?" + r.URL.RawQuery } http.Redirect(w, r, loc, http.StatusMovedPermanently) })) log.Error("Unable to start redirect server: %v", err) }() log.Info("Serving on https://%s:443", bindAddress) log.Info("Using manual certificates") log.Info("---") err = http.ListenAndServeTLS(fmt.Sprintf("%s:443", bindAddress), app.cfg.Server.TLSCertPath, app.cfg.Server.TLSKeyPath, r) } } else { network := "tcp" protocol := "http" if strings.HasPrefix(bindAddress, "/") { network = "unix" protocol = "http+unix" // old sockets will remain after server closes; // we need to delete them in order to open new ones err = os.Remove(bindAddress) if err != nil && !os.IsNotExist(err) { log.Error("%s already exists but could not be removed: %v", bindAddress, err) os.Exit(1) } } else { bindAddress = fmt.Sprintf("%s:%d", bindAddress, app.cfg.Server.Port) } log.Info("Serving on %s://%s", protocol, bindAddress) log.Info("---") listener, err := net.Listen(network, bindAddress) if err != nil { log.Error("Could not bind to address: %v", err) os.Exit(1) } if network == "unix" { err = os.Chmod(bindAddress, 0o666) if err != nil { log.Error("Could not update socket permissions: %v", err) os.Exit(1) } } defer listener.Close() err = http.Serve(listener, r) } if err != nil { log.Error("Unable to start: %v", err) os.Exit(1) } } func (app *App) InitDecoder() { // TODO: do this at the package level, instead of the App level // Initialize modules app.formDecoder = schema.NewDecoder() app.formDecoder.RegisterConverter(converter.NullJSONString{}, converter.ConvertJSONNullString) app.formDecoder.RegisterConverter(converter.NullJSONBool{}, converter.ConvertJSONNullBool) app.formDecoder.RegisterConverter(sql.NullString{}, converter.ConvertSQLNullString) app.formDecoder.RegisterConverter(sql.NullBool{}, converter.ConvertSQLNullBool) app.formDecoder.RegisterConverter(sql.NullInt64{}, converter.ConvertSQLNullInt64) app.formDecoder.RegisterConverter(sql.NullFloat64{}, converter.ConvertSQLNullFloat64) } // ConnectToDatabase validates and connects to the configured database, then // tests the connection. func ConnectToDatabase(app *App) error { // Check database configuration if app.cfg.Database.Type == driverMySQL && app.cfg.Database.User == "" { return fmt.Errorf("Database user not set.") } if app.cfg.Database.Host == "" { app.cfg.Database.Host = "localhost" } if app.cfg.Database.Database == "" { app.cfg.Database.Database = "writefreely" } // TODO: check err connectToDatabase(app) // Test database connection err := app.db.Ping() if err != nil { return fmt.Errorf("Database ping failed: %s", err) } log.Info("Connected to database.") ver, err := app.db.version() if err != nil { log.Error("Unable to get DB version: %v", err) } else { log.Info("Database version: %v", ver) if app.cfg.Database.Type == driverMySQL && strings.HasPrefix(ver, "5.") { log.Info("Enabling compatibility for MySQL v5.x") app.db.useSpencerRegex = true } } return nil } // FormatVersion constructs the version string for the application func FormatVersion() string { return serverSoftware + " " + softwareVer } // OutputVersion prints out the version of the application. func OutputVersion() { fmt.Println(FormatVersion()) } // NewApp creates a new app instance. func NewApp(cfgFile string) *App { return &App{ cfgFile: cfgFile, } } // CreateConfig creates a default configuration and saves it to the app's cfgFile. func CreateConfig(app *App) error { log.Info("Creating configuration...") c := config.New() log.Info("Saving configuration %s...", app.cfgFile) err := config.Save(c, app.cfgFile) if err != nil { return fmt.Errorf("Unable to save configuration: %v", err) } return nil } // DoConfig runs the interactive configuration process. func DoConfig(app *App, configSections string) { if configSections == "" { configSections = "server db app" } // let's check there aren't any garbage in the list configSectionsArray := strings.Split(configSections, " ") for _, element := range configSectionsArray { if element != "server" && element != "db" && element != "app" { log.Error("Invalid argument to --sections. Valid arguments are only \"server\", \"db\" and \"app\"") os.Exit(1) } } d, err := config.Configure(app.cfgFile, configSections) if err != nil { log.Error("Unable to configure: %v", err) os.Exit(1) } app.cfg = d.Config connectToDatabase(app) defer shutdown(app) if !app.db.DatabaseInitialized() { err = adminInitDatabase(app) if err != nil { log.Error(err.Error()) os.Exit(1) } } else { log.Info("Database already initialized.") } if d.User != nil { u := &User{ Username: d.User.Username, HashedPass: d.User.HashedPass, Created: time.Now().Truncate(time.Second).UTC(), } // Create blog log.Info("Creating user %s...\n", u.Username) err = app.db.CreateUser(app.cfg, u, app.cfg.App.SiteName, "") if err != nil { log.Error("Unable to create user: %s", err) os.Exit(1) } log.Info("Done!") } os.Exit(0) } // GenerateKeyFiles creates app encryption keys and saves them into the configured KeysParentDir. func GenerateKeyFiles(app *App) error { // Read keys path from config app.LoadConfig() // Create keys dir if it doesn't exist yet fullKeysDir := filepath.Join(app.cfg.Server.KeysParentDir, keysDir) if _, err := os.Stat(fullKeysDir); os.IsNotExist(err) { err = os.Mkdir(fullKeysDir, 0700) if err != nil { return err } } // Generate keys initKeyPaths(app) // TODO: use something like https://github.com/hashicorp/go-multierror to return errors var keyErrs error err := generateKey(emailKeyPath) if err != nil { keyErrs = err } err = generateKey(cookieAuthKeyPath) if err != nil { keyErrs = err } err = generateKey(cookieKeyPath) if err != nil { keyErrs = err } err = generateKey(csrfKeyPath) if err != nil { keyErrs = err } return keyErrs } // CreateSchema creates all database tables needed for the application. func CreateSchema(apper Apper) error { apper.LoadConfig() connectToDatabase(apper.App()) defer shutdown(apper.App()) err := adminInitDatabase(apper.App()) if err != nil { return err } return nil } // Migrate runs all necessary database migrations. func Migrate(apper Apper) error { apper.LoadConfig() connectToDatabase(apper.App()) defer shutdown(apper.App()) err := migrations.Migrate(migrations.NewDatastore(apper.App().db.DB, apper.App().db.driverName)) if err != nil { return fmt.Errorf("migrate: %s", err) } return nil } // ResetPassword runs the interactive password reset process. func ResetPassword(apper Apper, username string) error { // Connect to the database apper.LoadConfig() connectToDatabase(apper.App()) defer shutdown(apper.App()) // Fetch user u, err := apper.App().db.GetUserForAuth(username) if err != nil { log.Error("Get user: %s", err) os.Exit(1) } // Prompt for new password prompt := promptui.Prompt{ Templates: &promptui.PromptTemplates{ Success: "{{ . | bold | faint }}: ", }, Label: "New password", Mask: '*', } newPass, err := prompt.Run() if err != nil { log.Error("%s", err) os.Exit(1) } // Do the update log.Info("Updating...") err = adminResetPassword(apper.App(), u, newPass) if err != nil { log.Error("%s", err) os.Exit(1) } log.Info("Success.") return nil } // DoDeleteAccount runs the confirmation and account delete process. func DoDeleteAccount(apper Apper, username string) error { // Connect to the database apper.LoadConfig() connectToDatabase(apper.App()) defer shutdown(apper.App()) // check user exists u, err := apper.App().db.GetUserForAuth(username) if err != nil { log.Error("%s", err) os.Exit(1) } userID := u.ID // do not delete the admin account // TODO: check for other admins and skip? if u.IsAdmin() { log.Error("Can not delete admin account") os.Exit(1) } // confirm deletion, w/ w/out posts prompt := promptui.Prompt{ Templates: &promptui.PromptTemplates{ Success: "{{ . | bold | faint }}: ", }, Label: fmt.Sprintf("Really delete user : %s", username), IsConfirm: true, } _, err = prompt.Run() if err != nil { log.Info("Aborted...") os.Exit(0) } log.Info("Deleting...") err = apper.App().db.DeleteAccount(userID) if err != nil { log.Error("%s", err) os.Exit(1) } log.Info("Success.") return nil } func connectToDatabase(app *App) { log.Info("Connecting to %s database...", app.cfg.Database.Type) var db *sql.DB var err error if app.cfg.Database.Type == driverMySQL { db, err = sql.Open(app.cfg.Database.Type, fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=%s&tls=%t", app.cfg.Database.User, app.cfg.Database.Password, app.cfg.Database.Host, app.cfg.Database.Port, app.cfg.Database.Database, url.QueryEscape(time.Local.String()), app.cfg.Database.TLS)) db.SetMaxOpenConns(50) } else if app.cfg.Database.Type == driverSQLite { if !SQLiteEnabled { log.Error("Invalid database type '%s'. Binary wasn't compiled with SQLite3 support.", app.cfg.Database.Type) os.Exit(1) } if app.cfg.Database.FileName == "" { log.Error("SQLite database filename value in config.ini is empty.") os.Exit(1) } db, err = sql.Open("sqlite3_with_regex", app.cfg.Database.FileName+"?parseTime=true&cached=shared") db.SetMaxOpenConns(2) } else { log.Error("Invalid database type '%s'. Only 'mysql' and 'sqlite3' are supported right now.", app.cfg.Database.Type) os.Exit(1) } if err != nil { log.Error("%s", err) os.Exit(1) } app.db = &datastore{DB: db, driverName: app.cfg.Database.Type} } func shutdown(app *App) { log.Info("Closing database connection...") app.db.Close() if strings.HasPrefix(app.cfg.Server.Bind, "/") { // Clean up socket log.Info("Removing socket file...") err := os.Remove(app.cfg.Server.Bind) if err != nil { log.Error("Unable to remove socket: %s", err) os.Exit(1) } log.Info("Success.") } } // CreateUser creates a new admin or normal user from the given credentials. func CreateUser(apper Apper, username, password string, isAdmin bool) error { // Create an admin user with --create-admin apper.LoadConfig() connectToDatabase(apper.App()) defer shutdown(apper.App()) // Ensure an admin / first user doesn't already exist firstUser, _ := apper.App().db.GetUserByID(1) if isAdmin { // Abort if trying to create admin user, but one already exists if firstUser != nil { return fmt.Errorf("Admin user already exists (%s). Create a regular user with: writefreely user create [USER]:[PASSWORD]", firstUser.Username) } } else { // Abort if trying to create regular user, but no admin exists yet if firstUser == nil { return fmt.Errorf("No admin user exists yet. Create an admin first with: writefreely user create --admin [USER]:[PASSWORD]") } } // Create the user // Normalize and validate username desiredUsername := username username = getSlug(username, "") usernameDesc := username if username != desiredUsername { usernameDesc += " (originally: " + desiredUsername + ")" } if !author.IsValidUsername(apper.App().cfg, username) { return fmt.Errorf("Username %s is invalid, reserved, or shorter than configured minimum length (%d characters).", usernameDesc, apper.App().cfg.App.MinUsernameLen) } // Hash the password hashedPass, err := auth.HashPass([]byte(password)) if err != nil { return fmt.Errorf("Unable to hash password: %v", err) } u := &User{ Username: username, HashedPass: hashedPass, Created: time.Now().Truncate(time.Second).UTC(), } userType := "user" if isAdmin { userType = "admin" } log.Info("Creating %s %s...", userType, usernameDesc) err = apper.App().db.CreateUser(apper.App().Config(), u, desiredUsername, "") if err != nil { return fmt.Errorf("Unable to create user: %s", err) } log.Info("Done!") return nil } //go:embed schema.sql var schemaSql string //go:embed sqlite.sql var sqliteSql string func adminInitDatabase(app *App) error { var schema string if app.cfg.Database.Type == driverSQLite { schema = sqliteSql } else { schema = schemaSql } tblReg := regexp.MustCompile("CREATE TABLE (IF NOT EXISTS )?`([a-z_]+)`") queries := strings.Split(string(schema), ";\n") for _, q := range queries { if strings.TrimSpace(q) == "" { continue } parts := tblReg.FindStringSubmatch(q) if len(parts) >= 3 { log.Info("Creating table %s...", parts[2]) } else { log.Info("Creating table ??? (Weird query) No match in: %v", parts) } _, err := app.db.Exec(q) if err != nil { log.Error("%s", err) } else { log.Info("Created.") } } // Set up migrations table log.Info("Initializing appmigrations table...") err := migrations.SetInitialMigrations(migrations.NewDatastore(app.db.DB, app.db.driverName)) if err != nil { return fmt.Errorf("Unable to set initial migrations: %v", err) } log.Info("Running migrations...") err = migrations.Migrate(migrations.NewDatastore(app.db.DB, app.db.driverName)) if err != nil { return fmt.Errorf("migrate: %s", err) } log.Info("Done.") return nil } // ServerUserAgent returns a User-Agent string to use in external requests. The // hostName parameter may be left empty. func ServerUserAgent(hostName string) string { hostUAStr := "" if hostName != "" { hostUAStr = "; +" + hostName } return "Go (" + serverSoftware + "/" + softwareVer + hostUAStr + ")" } diff --git a/appstats/appstats_test.go b/appstats/appstats_test.go new file mode 100644 index 0000000..36cc2d8 --- /dev/null +++ b/appstats/appstats_test.go @@ -0,0 +1,131 @@ +package appstats + +import ( + "strings" + "testing" + "time" +) + +func TestTimeSincePro(t *testing.T) { + now := time.Now() + tests := []struct { + name string + then time.Time + contains []string // all substrings expected in result + }{ + { + name: "future time returns future", + then: now.Add(1 * time.Hour), + contains: []string{"future"}, + }, + { + name: "zero diff returns empty string", + then: now, + contains: []string{}, + }, + { + name: "1 second ago", + then: now.Add(-1 * time.Second), + contains: []string{"1 second"}, + }, + { + name: "30 seconds ago", + then: now.Add(-30 * time.Second), + contains: []string{"seconds"}, + }, + { + name: "1 minute ago", + then: now.Add(-1 * time.Minute), + contains: []string{"1 minute"}, + }, + { + name: "45 minutes ago", + then: now.Add(-45 * time.Minute), + contains: []string{"minutes"}, + }, + { + name: "1 hour ago", + then: now.Add(-1 * time.Hour), + contains: []string{"1 hour"}, + }, + { + name: "5 hours ago", + then: now.Add(-5 * time.Hour), + contains: []string{"hours"}, + }, + { + name: "1 day ago", + then: now.Add(-24 * time.Hour), + contains: []string{"1 day"}, + }, + { + name: "3 days ago", + then: now.Add(-72 * time.Hour), + contains: []string{"days"}, + }, + { + name: "1 week ago", + then: now.Add(-7 * 24 * time.Hour), + contains: []string{"1 week"}, + }, + { + name: "2 weeks ago", + then: now.Add(-14 * 24 * time.Hour), + contains: []string{"weeks"}, + }, + { + name: "1 month ago", + then: now.Add(-30 * 24 * time.Hour), + contains: []string{"1 month"}, + }, + { + name: "6 months ago", + then: now.Add(-180 * 24 * time.Hour), + contains: []string{"months"}, + }, + { + name: "1 year ago", + then: now.Add(-365 * 24 * time.Hour), + contains: []string{"1 year"}, + }, + { + name: "3 years ago", + then: now.Add(-3 * 365 * 24 * time.Hour), + contains: []string{"years"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := TimeSincePro(tt.then) + for _, want := range tt.contains { + if !strings.Contains(got, want) { + t.Errorf("TimeSincePro() = %q, want it to contain %q", got, want) + } + } + }) + } +} + +func TestFileSize(t *testing.T) { + tests := []struct { + name string + size int64 + expected string + }{ + {"bytes", 5, "5 B"}, + {"kilobytes", 1024, "1.0 KB"}, + {"megabytes", 1024 * 1024, "1.0 MB"}, + {"gigabytes", 1024 * 1024 * 1024, "1.0 GB"}, + {"large megabytes", 50 * 1024 * 1024, "50 MB"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := FileSize(tt.size) + if got != tt.expected { + t.Errorf("FileSize(%d) = %q, want %q", tt.size, got, tt.expected) + } + }) + } +} diff --git a/author/author_test.go b/author/author_test.go new file mode 100644 index 0000000..9b76daa --- /dev/null +++ b/author/author_test.go @@ -0,0 +1,96 @@ +/* + * Copyright © 2018-2021 Musing Studio LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + +package author + +import ( + "os" + "path/filepath" + "testing" + + "github.com/writefreely/writefreely/config" +) + +// newCfg creates a config pointing at a temp pages dir with the given min username length. +func newCfg(t *testing.T, minLen int) *config.Config { + t.Helper() + dir := t.TempDir() + pagesDir := filepath.Join(dir, "pages") + if err := os.MkdirAll(pagesDir, 0o700); err != nil { + t.Fatalf("mkdir pages: %v", err) + } + cfg := &config.Config{} + cfg.App.MinUsernameLen = minLen + cfg.Server.PagesParentDir = dir + return cfg +} + +func TestIsValidUsername_TooShort(t *testing.T) { + cfg := newCfg(t, 3) + if IsValidUsername(cfg, "ab") { + t.Error("expected 'ab' (len 2) to be invalid when MinUsernameLen=3") + } +} + +func TestIsValidUsername_AtMinLength(t *testing.T) { + cfg := newCfg(t, 3) + if !IsValidUsername(cfg, "abc") { + t.Error("expected 'abc' (len 3) to be valid when MinUsernameLen=3") + } +} + +func TestIsValidUsername_Reserved(t *testing.T) { + cfg := newCfg(t, 1) + reserved := []string{"admin", "about", "login", "logout", "signup", "api", "read"} + for _, name := range reserved { + t.Run(name, func(t *testing.T) { + if IsValidUsername(cfg, name) { + t.Errorf("expected %q to be reserved/invalid", name) + } + }) + } +} + +func TestIsValidUsername_ValidNames(t *testing.T) { + cfg := newCfg(t, 1) + valid := []string{"alice", "bob123", "My-Blog", "A", "user-name"} + for _, name := range valid { + t.Run(name, func(t *testing.T) { + if !IsValidUsername(cfg, name) { + t.Errorf("expected %q to be a valid username", name) + } + }) + } +} + +func TestIsValidUsername_InvalidChars(t *testing.T) { + cfg := newCfg(t, 1) + invalid := []string{"user name", "user@name", "user.name", "user/name", "-leading-dash"} + for _, name := range invalid { + t.Run(name, func(t *testing.T) { + if IsValidUsername(cfg, name) { + t.Errorf("expected %q to be invalid due to bad characters", name) + } + }) + } +} + +func TestIsValidUsername_PageNameIsReserved(t *testing.T) { + cfg := newCfg(t, 1) + // Create a custom page file so that filename becomes reserved + pagesDir := filepath.Join(cfg.Server.PagesParentDir, "pages") + pageFile := filepath.Join(pagesDir, "mypage") + if err := os.WriteFile(pageFile, []byte("content"), 0o600); err != nil { + t.Fatalf("create page file: %v", err) + } + if IsValidUsername(cfg, "mypage") { + t.Error("expected 'mypage' to be reserved because a page with that name exists") + } +} diff --git a/config/config_test.go b/config/config_test.go index fa9270a..fb7f9b2 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1,67 +1,448 @@ package config import ( "os" "path/filepath" "testing" ) func writeTempConfig(t *testing.T, appSection string) string { t.Helper() dir := t.TempDir() p := filepath.Join(dir, "config.ini") contents := "[app]\n" + appSection + "\n" if err := os.WriteFile(p, []byte(contents), 0o600); err != nil { t.Fatalf("write config: %v", err) } return p } func TestLoadMovesHostPathToSubdirectory(t *testing.T) { f := writeTempConfig(t, "host = https://example.com/blog") cfg, err := Load(f) if err != nil { t.Fatalf("load config: %v", err) } if cfg.App.Host != "https://example.com" { t.Fatalf("expected host without path; got %q", cfg.App.Host) } if got := cfg.App.SubdirectoryPath(); got != "/blog" { t.Fatalf("expected subdirectory /blog; got %q", got) } } func TestLoadStripsHostPathWhenMatchingSubdirectory(t *testing.T) { f := writeTempConfig(t, "host = https://example.com/blog\nsubdirectory = /blog") cfg, err := Load(f) if err != nil { t.Fatalf("load config: %v", err) } if cfg.App.Host != "https://example.com" { t.Fatalf("expected host without path; got %q", cfg.App.Host) } if got := cfg.App.SubdirectoryPath(); got != "/blog" { t.Fatalf("expected subdirectory /blog; got %q", got) } } func TestLoadPrefersExplicitSubdirectoryOverHostPath(t *testing.T) { f := writeTempConfig(t, "host = https://example.com/blog\nsubdirectory = /site") cfg, err := Load(f) if err != nil { t.Fatalf("load config: %v", err) } if cfg.App.Host != "https://example.com" { t.Fatalf("expected host without path; got %q", cfg.App.Host) } if got := cfg.App.SubdirectoryPath(); got != "/site" { t.Fatalf("expected explicit subdirectory /site; got %q", got) } } + +func TestNew_Defaults(t *testing.T) { + cfg := New() + if cfg.Server.Port != 8080 { + t.Errorf("default port = %d, want 8080", cfg.Server.Port) + } + if cfg.App.Host != "http://localhost:8080" { + t.Errorf("default host = %q, want http://localhost:8080", cfg.App.Host) + } + if !cfg.App.SingleUser { + t.Error("expected SingleUser to be true by default") + } + if cfg.Database.Type != "mysql" { + t.Errorf("default database type = %q, want mysql", cfg.Database.Type) + } +} + +func TestUseMySQL(t *testing.T) { + cfg := &Config{} + cfg.UseMySQL(true) + if cfg.Database.Type != "mysql" { + t.Errorf("type = %q, want mysql", cfg.Database.Type) + } + if cfg.Database.Host != "localhost" { + t.Errorf("host = %q, want localhost", cfg.Database.Host) + } + if cfg.Database.Port != 3306 { + t.Errorf("port = %d, want 3306", cfg.Database.Port) + } + + // fresh=false should not reset host/port + cfg.Database.Host = "db.example.com" + cfg.UseMySQL(false) + if cfg.Database.Host != "db.example.com" { + t.Errorf("host changed on fresh=false, got %q", cfg.Database.Host) + } +} + +func TestUseSQLite(t *testing.T) { + cfg := &Config{} + cfg.UseSQLite(true) + if cfg.Database.Type != "sqlite3" { + t.Errorf("type = %q, want sqlite3", cfg.Database.Type) + } + if cfg.Database.FileName != "writefreely.db" { + t.Errorf("filename = %q, want writefreely.db", cfg.Database.FileName) + } + + cfg.Database.FileName = "custom.db" + cfg.UseSQLite(false) + if cfg.Database.FileName != "custom.db" { + t.Errorf("filename changed on fresh=false, got %q", cfg.Database.FileName) + } +} + +func TestIsSecureStandalone(t *testing.T) { + tests := []struct { + name string + cfg Config + want bool + }{ + { + name: "port 443 with cert and key", + cfg: Config{Server: ServerCfg{Port: 443, TLSCertPath: "/cert.pem", TLSKeyPath: "/key.pem"}}, + want: true, + }, + { + name: "port 443 missing cert", + cfg: Config{Server: ServerCfg{Port: 443, TLSKeyPath: "/key.pem"}}, + want: false, + }, + { + name: "port 443 missing key", + cfg: Config{Server: ServerCfg{Port: 443, TLSCertPath: "/cert.pem"}}, + want: false, + }, + { + name: "port 8080 with cert and key", + cfg: Config{Server: ServerCfg{Port: 8080, TLSCertPath: "/cert.pem", TLSKeyPath: "/key.pem"}}, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.cfg.IsSecureStandalone(); got != tt.want { + t.Errorf("IsSecureStandalone() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAbsoluteHost(t *testing.T) { + tests := []struct { + name string + host string + subdir string + want string + }{ + { + name: "no subdir", + host: "https://example.com", + want: "https://example.com", + }, + { + name: "with subdir appended", + host: "https://example.com", + subdir: "/blog", + want: "https://example.com/blog", + }, + { + name: "host already includes subdir", + host: "https://example.com/blog", + subdir: "/blog", + want: "https://example.com/blog", + }, + { + name: "empty host", + host: "", + want: "", + }, + { + name: "trailing slash stripped", + host: "https://example.com/", + want: "https://example.com", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := AppCfg{Host: tt.host, Subdirectory: tt.subdir} + if got := cfg.AbsoluteHost(); got != tt.want { + t.Errorf("AbsoluteHost() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestAbsoluteURL(t *testing.T) { + tests := []struct { + name string + host string + subdir string + path string + want string + }{ + { + name: "simple path", + host: "https://example.com", + path: "/about", + want: "https://example.com/about", + }, + { + name: "with subdir prefix", + host: "https://example.com", + subdir: "/blog", + path: "/about", + want: "https://example.com/blog/about", + }, + { + name: "empty host falls back to PrefixPath", + host: "", + path: "/about", + want: "/about", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := AppCfg{Host: tt.host, Subdirectory: tt.subdir} + if got := cfg.AbsoluteURL(tt.path); got != tt.want { + t.Errorf("AbsoluteURL(%q) = %q, want %q", tt.path, got, tt.want) + } + }) + } +} + +func TestStripSubdirectory(t *testing.T) { + tests := []struct { + name string + subdir string + path string + want string + }{ + { + name: "no subdir, non-empty path", + path: "/about", + want: "/about", + }, + { + name: "no subdir, empty path", + path: "", + want: "/", + }, + { + name: "path equals subdir", + subdir: "/blog", + path: "/blog", + want: "/", + }, + { + name: "path under subdir", + subdir: "/blog", + path: "/blog/about", + want: "/about", + }, + { + name: "path unrelated to subdir", + subdir: "/blog", + path: "/other", + want: "/other", + }, + { + name: "empty path with subdir", + subdir: "/blog", + path: "", + want: "/", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := AppCfg{Subdirectory: tt.subdir} + if got := cfg.StripSubdirectory(tt.path); got != tt.want { + t.Errorf("StripSubdirectory(%q) = %q, want %q", tt.path, got, tt.want) + } + }) + } +} + +func TestPrefixPath(t *testing.T) { + tests := []struct { + name string + subdir string + path string + want string + }{ + { + name: "no subdir returns path as-is", + path: "/about", + want: "/about", + }, + { + name: "path gets subdir prepended", + subdir: "/blog", + path: "/about", + want: "/blog/about", + }, + { + name: "root path with subdir", + subdir: "/blog", + path: "/", + want: "/blog/", + }, + { + name: "path already starts with subdir", + subdir: "/blog", + path: "/blog/about", + want: "/blog/about", + }, + { + name: "absolute URL is not prefixed", + subdir: "/blog", + path: "https://example.com/foo", + want: "https://example.com/foo", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := AppCfg{Subdirectory: tt.subdir} + if got := cfg.PrefixPath(tt.path); got != tt.want { + t.Errorf("PrefixPath(%q) = %q, want %q", tt.path, got, tt.want) + } + }) + } +} + +func TestEmailCfg_Enabled(t *testing.T) { + tests := []struct { + name string + cfg EmailCfg + want bool + }{ + { + name: "mailgun domain and key", + cfg: EmailCfg{Domain: "mg.example.com", MailgunPrivate: "key-abc"}, + want: true, + }, + { + name: "SMTP credentials complete", + cfg: EmailCfg{Username: "u", Password: "p", Host: "smtp.example.com", Port: 587}, + want: true, + }, + { + name: "all empty", + cfg: EmailCfg{}, + want: false, + }, + { + name: "SMTP missing port", + cfg: EmailCfg{Username: "u", Password: "p", Host: "smtp.example.com"}, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.cfg.Enabled(); got != tt.want { + t.Errorf("Enabled() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSignupPath(t *testing.T) { + tests := []struct { + name string + cfg AppCfg + want string + }{ + { + name: "closed registration returns empty", + cfg: AppCfg{OpenRegistration: false}, + want: "", + }, + { + name: "open registration no special flags returns /", + cfg: AppCfg{OpenRegistration: true}, + want: "/", + }, + { + name: "chorus mode returns /signup", + cfg: AppCfg{OpenRegistration: true, Chorus: true}, + want: "/signup", + }, + { + name: "private mode returns /signup", + cfg: AppCfg{OpenRegistration: true, Private: true}, + want: "/signup", + }, + { + name: "non-root landing returns /signup", + cfg: AppCfg{OpenRegistration: true, Landing: "start"}, + want: "/signup", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.cfg.SignupPath(); got != tt.want { + t.Errorf("SignupPath() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestLandingPath(t *testing.T) { + tests := []struct { + name string + landing string + want string + }{ + { + name: "empty landing defaults to /", + landing: "", + want: "/", + }, + { + name: "slash prefix preserved", + landing: "/start", + want: "/start", + }, + { + name: "no leading slash gets one added", + landing: "start", + want: "/start", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := AppCfg{Landing: tt.landing} + if got := cfg.LandingPath(); got != tt.want { + t.Errorf("LandingPath() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/config/funcs_test.go b/config/funcs_test.go new file mode 100644 index 0000000..e2991c4 --- /dev/null +++ b/config/funcs_test.go @@ -0,0 +1,140 @@ +package config + +import "testing" + +func TestFriendlyHost(t *testing.T) { + tests := []struct { + name string + host string + want string + }{ + { + name: "plain http host", + host: "http://example.com", + want: "example.com", + }, + { + name: "https host", + host: "https://example.com", + want: "example.com", + }, + { + name: "host with port", + host: "http://example.com:8080", + want: "example.com:8080", + }, + { + name: "https host with port", + host: "https://example.com:443", + want: "example.com:443", + }, + { + name: "punycode host decoded to unicode", + host: "https://xn--nxasmq6b.com", + want: "βόλοσ.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := AppCfg{Host: tt.host} + got := cfg.FriendlyHost() + if got != tt.want { + t.Errorf("FriendlyHost() = %q, want %q (host: %q)", got, tt.want, tt.host) + } + }) + } +} + +func TestCanCreateBlogs(t *testing.T) { + tests := []struct { + name string + maxBlogs int + currentlyUsed uint64 + want bool + }{ + { + name: "unlimited (MaxBlogs=0) always allowed", + maxBlogs: 0, + currentlyUsed: 9999, + want: true, + }, + { + name: "negative MaxBlogs is unlimited", + maxBlogs: -1, + currentlyUsed: 9999, + want: true, + }, + { + name: "under limit", + maxBlogs: 5, + currentlyUsed: 4, + want: true, + }, + { + name: "at limit", + maxBlogs: 5, + currentlyUsed: 5, + want: false, + }, + { + name: "over limit", + maxBlogs: 5, + currentlyUsed: 10, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := AppCfg{MaxBlogs: tt.maxBlogs} + got := cfg.CanCreateBlogs(tt.currentlyUsed) + if got != tt.want { + t.Errorf("CanCreateBlogs(%d) = %v, want %v (MaxBlogs: %d)", tt.currentlyUsed, got, tt.want, tt.maxBlogs) + } + }) + } +} + +func TestOrDefaultString(t *testing.T) { + tests := []struct { + name string + input string + defaultValue string + want string + }{ + { + name: "non-empty input returned as-is", + input: "hello", + defaultValue: "world", + want: "hello", + }, + { + name: "empty input returns default", + input: "", + defaultValue: "world", + want: "world", + }, + { + name: "both empty returns empty default", + input: "", + defaultValue: "", + want: "", + }, + { + name: "whitespace is non-empty", + input: " ", + defaultValue: "fallback", + want: " ", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := OrDefaultString(tt.input, tt.defaultValue) + if got != tt.want { + t.Errorf("OrDefaultString(%q, %q) = %q, want %q", tt.input, tt.defaultValue, got, tt.want) + } + }) + } +} diff --git a/config/validation_test.go b/config/validation_test.go new file mode 100644 index 0000000..597d347 --- /dev/null +++ b/config/validation_test.go @@ -0,0 +1,105 @@ +package config + +import "testing" + +func TestValidateDomain(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + {"http prefix ok", "http://example.com", false}, + {"https prefix ok", "https://example.com", false}, + {"subdomain is ok", "https://subdomain.example.com", false}, + {"https with path ok", "https://example.com/blog", false}, + {"missing scheme", "example.com", true}, + {"ftp scheme rejected", "ftp://example.com", true}, + {"empty string rejected", "", true}, + {"just slashes rejected", "//example.com", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateDomain(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("validateDomain(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +} + +func TestValidatePort(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + {"minimum valid port", "80", false}, + {"common HTTP port", "8080", false}, + {"HTTPS port", "443", false}, + {"max valid port", "65535", false}, + {"port 79 too low", "79", true}, + {"port 0 too low", "0", true}, + {"negative port", "-1", true}, + {"port above max", "65536", true}, + {"non-numeric", "abc", true}, + {"empty string", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validatePort(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("validatePort(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +} + +func TestValidateNonEmpty(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + {"non-empty string ok", "hello", false}, + {"single space ok", " ", false}, + {"empty string errors", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateNonEmpty(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("validateNonEmpty(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +} + +func TestValidateSubdirectory(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + {"empty string allowed", "", false}, + {"root slash allowed", "/", false}, + {"simple path ok", "/blog", false}, + {"nested path ok", "/my/blog", false}, + {"full URL rejected", "https://example.com/blog", true}, + {"with query string rejected", "/blog?page=1", true}, + {"with fragment rejected", "/blog#section", true}, + {"with space rejected", "/my blog", true}, + {"path with tab rejected", "/my\tblog", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateSubdirectory(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("validateSubdirectory(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +} diff --git a/db/dialect_test.go b/db/dialect_test.go new file mode 100644 index 0000000..d92cd00 --- /dev/null +++ b/db/dialect_test.go @@ -0,0 +1,116 @@ +package db + +import "testing" + +func TestDialectType_Table_SQLite(t *testing.T) { + b := DialectSQLite.Table("users") + if b == nil { + t.Fatal("Table() returned nil") + } + if b.Dialect != DialectSQLite { + t.Errorf("Table().Dialect = %v, want DialectSQLite", b.Dialect) + } + if b.Name != "users" { + t.Errorf("Table().Name = %q, want %q", b.Name, "users") + } +} + +func TestDialectType_Table_MySQL(t *testing.T) { + b := DialectMySQL.Table("users") + if b == nil { + t.Fatal("Table() returned nil") + } + if b.Dialect != DialectMySQL { + t.Errorf("Table().Dialect = %v, want DialectMySQL", b.Dialect) + } +} + +func TestDialectType_AlterTable_SQLite(t *testing.T) { + b := DialectSQLite.AlterTable("posts") + if b == nil { + t.Fatal("AlterTable() returned nil") + } + if b.Dialect != DialectSQLite { + t.Errorf("AlterTable().Dialect = %v, want DialectSQLite", b.Dialect) + } + if b.Name != "posts" { + t.Errorf("AlterTable().Name = %q, want %q", b.Name, "posts") + } +} + +func TestDialectType_AlterTable_MySQL(t *testing.T) { + b := DialectMySQL.AlterTable("posts") + if b == nil { + t.Fatal("AlterTable() returned nil") + } + if b.Dialect != DialectMySQL { + t.Errorf("AlterTable().Dialect = %v, want DialectMySQL", b.Dialect) + } +} + +func TestDialect_ColumnDialectField(t *testing.T) { + tests := []struct { + name string + dialect DialectType + colName string + colType ColumnType + wantDialect DialectType + }{ + { + name: "SQLite column preserves dialect", + dialect: DialectSQLite, + colName: "id", + colType: ColumnTypeInteger, + wantDialect: DialectSQLite, + }, + { + name: "MySQL column preserves dialect", + dialect: DialectMySQL, + colName: "username", + colType: ColumnTypeVarChar, + wantDialect: DialectMySQL, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + col := tt.dialect.Column(tt.colName, tt.colType, UnsetSize) + if col == nil { + t.Fatal("Column() returned nil") + } + if col.Dialect != tt.wantDialect { + t.Errorf("Column().Dialect = %v, want %v", col.Dialect, tt.wantDialect) + } + if col.Name != tt.colName { + t.Errorf("Column().Name = %q, want %q", col.Name, tt.colName) + } + if col.Type != tt.colType { + t.Errorf("Column().Type = %v, want %v", col.Type, tt.colType) + } + }) + } +} + +func TestDialect_CreateIndex_PreservesDialect(t *testing.T) { + sqliteIdx := DialectSQLite.CreateIndex("idx_test", "posts", "user_id") + if sqliteIdx.Dialect != DialectSQLite { + t.Errorf("CreateIndex dialect = %v, want DialectSQLite", sqliteIdx.Dialect) + } + + mysqlIdx := DialectMySQL.CreateUniqueIndex("uq_test", "users", "email") + if mysqlIdx.Dialect != DialectMySQL { + t.Errorf("CreateUniqueIndex dialect = %v, want DialectMySQL", mysqlIdx.Dialect) + } +} + +func TestDialect_DropIndex_PreservesDialect(t *testing.T) { + sqliteIdx := DialectSQLite.DropIndex("idx_test", "posts") + if sqliteIdx.Dialect != DialectSQLite { + t.Errorf("DropIndex dialect = %v, want DialectSQLite", sqliteIdx.Dialect) + } + + mysqlIdx := DialectMySQL.DropIndex("idx_test", "posts") + if mysqlIdx.Dialect != DialectMySQL { + t.Errorf("DropIndex dialect = %v, want DialectMySQL", mysqlIdx.Dialect) + } +} diff --git a/db/index_test.go b/db/index_test.go new file mode 100644 index 0000000..49f31ae --- /dev/null +++ b/db/index_test.go @@ -0,0 +1,87 @@ +package db + +import "testing" + +func TestCreateIndexSqlBuilder_ToSQL(t *testing.T) { + tests := []struct { + name string + builder *CreateIndexSqlBuilder + want string + wantErr bool + }{ + { + name: "single column non-unique", + builder: DialectMySQL.CreateIndex("idx_posts_user", "posts", "user_id"), + want: "CREATE INDEX idx_posts_user on posts (user_id)", + wantErr: false, + }, + { + name: "single column unique", + builder: DialectMySQL.CreateUniqueIndex("uq_users_email", "users", "email"), + want: "CREATE UNIQUE INDEX uq_users_email on users (email)", + wantErr: false, + }, + { + name: "multi-column index", + builder: DialectSQLite.CreateIndex("idx_posts_col", "posts", "collection_id", "created"), + want: "CREATE INDEX idx_posts_col on posts (collection_id, created)", + wantErr: false, + }, + { + name: "multi-column unique index", + builder: DialectSQLite.CreateUniqueIndex("uq_multi", "table1", "col_a", "col_b", "col_c"), + want: "CREATE UNIQUE INDEX uq_multi on table1 (col_a, col_b, col_c)", + wantErr: false, + }, + { + name: "no columns returns error", + builder: &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: "idx_empty", Table: "posts"}, + want: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.builder.ToSQL() + if (err != nil) != tt.wantErr { + t.Errorf("ToSQL() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("ToSQL() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestDropIndexSqlBuilder_ToSQL(t *testing.T) { + tests := []struct { + name string + builder *DropIndexSqlBuilder + want string + }{ + { + name: "MySQL drop index", + builder: DialectMySQL.DropIndex("idx_posts_user", "posts"), + want: "DROP INDEX idx_posts_user on posts", + }, + { + name: "SQLite drop index", + builder: DialectSQLite.DropIndex("idx_col", "table1"), + want: "DROP INDEX idx_col on table1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.builder.ToSQL() + if err != nil { + t.Fatalf("ToSQL() unexpected error: %v", err) + } + if got != tt.want { + t.Errorf("ToSQL() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/db/raw_test.go b/db/raw_test.go new file mode 100644 index 0000000..fe4d61d --- /dev/null +++ b/db/raw_test.go @@ -0,0 +1,36 @@ +package db + +import "testing" + +func TestRawSqlBuilder_ToSQL(t *testing.T) { + tests := []struct { + name string + query string + }{ + { + name: "simple query returned unchanged", + query: "SELECT 1", + }, + { + name: "complex query", + query: "ALTER TABLE posts ADD COLUMN views INT NOT NULL DEFAULT 0", + }, + { + name: "empty query", + query: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b := &RawSqlBuilder{Query: tt.query} + got, err := b.ToSQL() + if err != nil { + t.Fatalf("ToSQL() unexpected error: %v", err) + } + if got != tt.query { + t.Errorf("ToSQL() = %q, want %q", got, tt.query) + } + }) + } +} diff --git a/parse/posts_test.go b/parse/posts_test.go index c377f57..561568e 100644 --- a/parse/posts_test.go +++ b/parse/posts_test.go @@ -1,55 +1,104 @@ /* * Copyright © 2018 Musing Studio LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package parse import "testing" func TestPostLede(t *testing.T) { text := map[string]string{ "早安。跨出舒適圈,才能前往": "早安。", "早安。This is my post. It is great.": "早安。", "Hello. 早安。": "Hello.", "Sup? Everyone says punctuation is punctuation.": "Sup?", "Humans are humans, and society is full of good and bad actors. Technology, at the most fundamental level, is a neutral tool that can be used by either to meet any ends. ": "Humans are humans, and society is full of good and bad actors.", `Online Domino Is Must For Everyone Do you want to understand how to play poker online?`: "Online Domino Is Must For Everyone", `おはようございます 私は日本から帰ったばかりです。`: "おはようございます", "Hello, we say, おはよう. We say \"good morning\"": "Hello, we say, おはよう.", } c := 1 for i, o := range text { if s := PostLede(i, true); s != o { t.Errorf("#%d: Got '%s' from '%s'; expected '%s'", c, s, i, o) } c++ } } +func TestTruncate(t *testing.T) { + tests := []struct { + name string + input string + limit int + want string + }{ + { + name: "shorter than limit returns unchanged", + input: "hello", + limit: 10, + want: "hello", + }, + { + name: "exactly at limit returns unchanged", + input: "hello", + limit: 5, + want: "hello", + }, + { + name: "longer than limit is truncated", + input: "hello world", + limit: 5, + want: "hello", + }, + { + name: "multibyte runes counted by rune not byte", + input: "早安。早安。", + limit: 4, + want: "早安。早", + }, + { + name: "empty string returns empty", + input: "", + limit: 5, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := Truncate(tt.input, tt.limit) + if got != tt.want { + t.Errorf("Truncate(%q, %d) = %q, want %q", tt.input, tt.limit, got, tt.want) + } + }) + } +} + func TestTruncToWord(t *testing.T) { text := map[string]string{ "Можливо, ми можемо використовувати інтернет-інструменти, щоб виготовити якийсь текст, який би міг бути і на, і в кінцевому підсумку, буде скорочено, тому що це тривало так довго.": "Можливо, ми можемо використовувати інтернет-інструменти, щоб виготовити якийсь", "早安。This is my post. It is great. It is a long post that is great that is a post that is great.": "早安。This is my post. It is great. It is a long post that is great that is a post", "Sup? Everyone says punctuation is punctuation.": "Sup? Everyone says punctuation is punctuation.", "I arrived in Japan six days ago. Tired from a 10-hour flight after a night-long layover in Calgary, I wandered wide-eyed around Narita airport looking for an ATM.": "I arrived in Japan six days ago. Tired from a 10-hour flight after a night-long", } c := 1 for i, o := range text { if s, _ := TruncToWord(i, 80); s != o { t.Errorf("#%d: Got '%s' from '%s'; expected '%s'", c, s, i, o) } c++ } } diff --git a/postrender_test.go b/postrender_test.go index ec6bbdd..8a1eb81 100644 --- a/postrender_test.go +++ b/postrender_test.go @@ -1,43 +1,91 @@ /* * Copyright © 2021 Musing Studio LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package writefreely import "testing" func TestApplyBasicMarkdown(t *testing.T) { tests := []struct { name string in string result string }{ {"empty", "", ""}, {"empty spaces", " ", ""}, {"empty tabs", "\t", ""}, {"empty newline", "\n", ""}, {"nums", "123", "123"}, {"dot", ".", "."}, {"dash", "-", "-"}, {"plain", "Hello, World!", "Hello, World!"}, {"multibyte", "こんにちは", `こんにちは`}, {"bold", "**안녕하세요**", `안녕하세요`}, {"link", "[WriteFreely](https://writefreely.org)", `WriteFreely`}, {"date", "12. April", `12. April`}, {"table", "| Hi | There |", `| Hi | There |`}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { res := applyBasicMarkdown([]byte(test.in)) if res != test.result { t.Errorf("%s: wanted %s, got %s", test.name, test.result, res) } }) } } + +func TestShortPostDescription(t *testing.T) { + tests := []struct { + name string + content string + want string + }{ + { + name: "short content returned as-is", + content: "Hello, world!", + want: "Hello, world!", + }, + { + name: "content exactly at 140 chars not truncated", + content: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"[:140], + want: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"[:140], + }, + { + name: "content over 140 chars is truncated with ellipsis", + content: "This is a very long post that exceeds the maximum description length of one hundred and forty characters, so it should be truncated with an ellipsis at the end.", + want: "This is a very long post that exceeds the maximum description length of one hundred and forty characters, so it should be truncated with ...", + }, + { + name: "newlines replaced with spaces", + content: "Line one\nLine two\nLine three", + want: "Line one Line two Line three", + }, + { + name: "leading and trailing whitespace trimmed", + content: " trimmed content ", + want: "trimmed content", + }, + { + name: "multibyte runes counted by rune not byte", + content: "日本語のテキストが百四十文字以内に収まっているかどうかをテストします。これは短いテキストです。", + want: "日本語のテキストが百四十文字以内に収まっているかどうかをテストします。これは短いテキストです。", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := shortPostDescription(tt.content) + if got != tt.want { + t.Errorf("shortPostDescription() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/posts_test.go b/posts_test.go index 612c178..803fa64 100644 --- a/posts_test.go +++ b/posts_test.go @@ -1,45 +1,86 @@ /* * Copyright © 2020-2021 Musing Studio LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package writefreely_test import ( "testing" "github.com/guregu/null/zero" "github.com/stretchr/testify/assert" "github.com/writefreely/writefreely" ) func TestPostSummary(t *testing.T) { testCases := map[string]struct { given writefreely.Post expected string }{ "no special chars": {givenPost("Content."), "Content."}, "HTML content": {givenPost("Content
with a
paragraph."), "Content with a paragraph."}, "content with escaped char": {givenPost("Content's all OK."), "Content's all OK."}, "multiline content": {givenPost(`Content in multiple lines.`), "Content in multiple lines."}, } for name, test := range testCases { t.Run(name, func(t *testing.T) { actual := test.given.Summary() assert.Equal(t, test.expected, actual) }) } } func givenPost(content string) writefreely.Post { return writefreely.Post{Title: zero.StringFrom("Title"), Content: content} } + +func givenUntitledPost(id, content string) writefreely.Post { + return writefreely.Post{ID: id, Content: content} +} + +// TestPostSummaryNoTitle covers the untitled post branch of Summary(), where +// the title is derived from the content itself. +func TestPostSummaryNoTitle(t *testing.T) { + testCases := []struct { + name string + given writefreely.Post + expected string + }{ + { + name: "empty content returns empty", + given: givenUntitledPost("abc123", ""), + expected: "", + }, + { + name: "short single-line — title equals description, so returns empty", + // friendlyPostTitle returns the line itself; postDescription with + // title==friendlyId also returns the same content → desc==title → "" + given: givenUntitledPost("abc123", "Short post."), + expected: "", + }, + { + name: "two paragraphs — body after blank line becomes description", + given: givenUntitledPost("abc123", `First paragraph as title. + +Second paragraph as description that is different from the title.`), + expected: "Second paragraph as description that is different from the title.", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actual := tc.given.Summary() + assert.Equal(t, tc.expected, actual) + }) + } +} diff --git a/semver_test.go b/semver_test.go new file mode 100644 index 0000000..2607ea3 --- /dev/null +++ b/semver_test.go @@ -0,0 +1,114 @@ +package writefreely + +import "testing" + +func TestIsValid(t *testing.T) { + tests := []struct { + v string + want bool + }{ + {"v1.0.0", true}, + {"v0.0.1", true}, + {"v1.2.3", true}, + {"v1", true}, + {"v1.2", true}, + {"v1.0.0-alpha", true}, + {"v1.0.0-alpha.1", true}, + {"v1.0.0+build1", true}, // build metadata without dots + {"v1.0.0+build.1", false}, // this impl doesn't allow dots in build metadata + {"v1.0.0-alpha+001", true}, + {"", false}, + {"1.0.0", false}, // missing v prefix + {"v1.0.0.0", false}, // extra patch segment + {"v01.0.0", false}, // leading zero in major + {"vx.y.z", false}, // non-numeric parts + } + + for _, tt := range tests { + t.Run(tt.v, func(t *testing.T) { + got := IsValid(tt.v) + if got != tt.want { + t.Errorf("IsValid(%q) = %v, want %v", tt.v, got, tt.want) + } + }) + } +} + +func TestCompareSemver(t *testing.T) { + tests := []struct { + name string + v, w string + want int + }{ + {"equal versions", "v1.0.0", "v1.0.0", 0}, + {"major bump", "v2.0.0", "v1.0.0", 1}, + {"major behind", "v1.0.0", "v2.0.0", -1}, + {"minor bump", "v1.1.0", "v1.0.0", 1}, + {"minor behind", "v1.0.0", "v1.1.0", -1}, + {"patch bump", "v1.0.1", "v1.0.0", 1}, + {"patch behind", "v1.0.0", "v1.0.1", -1}, + {"release > prerelease", "v1.0.0", "v1.0.0-alpha", 1}, + {"prerelease < release", "v1.0.0-alpha", "v1.0.0", -1}, + {"both invalid", "bad", "alsobad", 0}, + {"first invalid", "bad", "v1.0.0", -1}, + {"second invalid", "v1.0.0", "bad", 1}, + {"short form v1 == v1.0.0", "v1", "v1.0.0", 0}, + {"short form v1.2 == v1.2.0", "v1.2", "v1.2.0", 0}, + {"v1.2.3 > v1.2", "v1.2.3", "v1.2", 1}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := CompareSemver(tt.v, tt.w) + if got != tt.want { + t.Errorf("CompareSemver(%q, %q) = %d, want %d", tt.v, tt.w, got, tt.want) + } + }) + } +} + +// TestIsValid_EdgeCases targets uncovered branches in semParse, parsePrerelease, and parseBuild. +func TestIsValid_EdgeCases(t *testing.T) { + tests := []struct { + name string + v string + want bool + }{ + // semParse: bad minor prefix — non-dot character after parsed major + {"bad minor prefix v1a", "v1a.0.0", false}, + // semParse: bad patch prefix — non-dot after parsed minor + {"bad patch prefix v1.2a", "v1.2a.0", false}, + // semParse: bad minor version — dot followed by non-digit + {"bad minor version v1.x.0", "v1.x.0", false}, + // semParse: bad patch version — dot followed by non-digit + {"bad patch version v1.0.x", "v1.0.x", false}, + + // parsePrerelease: leading zero in a numeric identifier is invalid + {"prerelease with leading zero", "v1.0.0-alpha.01", false}, + // parsePrerelease: trailing dot (empty identifier after dot) + {"prerelease trailing dot", "v1.0.0-alpha.", false}, + // parsePrerelease: double dot (empty identifier between dots) + {"prerelease double dot", "v1.0.0-alpha..beta", false}, + // parsePrerelease: invalid character + {"prerelease invalid char", "v1.0.0-alpha@1", false}, + // parsePrerelease: valid multi-segment prerelease + {"prerelease numeric segment", "v1.0.0-1.2.3", true}, + + // parseBuild: valid build metadata (no dots allowed in this impl) + {"valid build metadata", "v1.0.0+build123", true}, + // parseBuild: build with invalid character + {"build with space", "v1.0.0+build 1", false}, + + // semParse: junk on end after valid prerelease+build + {"junk after build", "v1.0.0+build1junk!", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsValid(tt.v) + if got != tt.want { + t.Errorf("IsValid(%q) = %v, want %v", tt.v, got, tt.want) + } + }) + } +} diff --git a/spam/spam_test.go b/spam/spam_test.go new file mode 100644 index 0000000..3747cee --- /dev/null +++ b/spam/spam_test.go @@ -0,0 +1,134 @@ +/* + * Copyright © 2020-2021 Musing Studio LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + +package spam + +import ( + "net/http/httptest" + "testing" +) + +func TestCleanEmail(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "plain email unchanged (sans dots)", + input: "user@example.com", + expected: "user@example.com", + }, + { + name: "strips plus alias", + input: "user+newsletter@example.com", + expected: "user@example.com", + }, + { + name: "strips dots in local part", + input: "us.er@example.com", + expected: "user@example.com", + }, + { + name: "strips dots and plus alias together", + input: "u.s.e.r+tag@example.com", + expected: "user@example.com", + }, + { + name: "uppercased is lowercased", + input: "User@Example.COM", + expected: "user@example.com", + }, + { + name: "missing @ returns empty", + input: "notanemail", + expected: "", + }, + { + name: "empty string returns empty", + input: "", + expected: "", + }, + { + name: "domain with dots is preserved", + input: "user@mail.example.com", + expected: "user@mail.example.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := CleanEmail(tt.input) + if got != tt.expected { + t.Errorf("CleanEmail(%q) = %q, want %q", tt.input, got, tt.expected) + } + }) + } +} + +func TestHoneypotFieldName(t *testing.T) { + // Reset the package-level field so tests are independent. + honeypotField = "" + + name := HoneypotFieldName() + if name == "" { + t.Fatal("HoneypotFieldName() returned empty string") + } + // Should be a fixed-length base-62 string. + if len(name) != 39 { + t.Errorf("HoneypotFieldName() length = %d, want 39", len(name)) + } + // Subsequent calls must return the same value (singleton behaviour). + if got := HoneypotFieldName(); got != name { + t.Errorf("HoneypotFieldName() returned different value on second call: %q vs %q", got, name) + } +} + +func TestGetIP(t *testing.T) { + tests := []struct { + name string + header string + expected string + }{ + { + name: "single IP", + header: "1.2.3.4", + expected: "1.2.3.4", + }, + { + name: "multiple IPs returns first", + header: "1.2.3.4, 5.6.7.8, 9.10.11.12", + expected: "1.2.3.4", + }, + { + name: "trims whitespace from first IP", + header: " 10.0.0.1 , 192.168.1.1", + expected: "10.0.0.1", + }, + { + name: "missing header returns empty", + header: "", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + if tt.header != "" { + req.Header.Set("X-Forwarded-For", tt.header) + } + got := GetIP(req) + if got != tt.expected { + t.Errorf("GetIP() = %q, want %q (X-Forwarded-For: %q)", got, tt.expected, tt.header) + } + }) + } +}