Page MenuHomeMusing Studio

No OneTemporary

diff --git a/account.go b/account.go
index 2af9fce..d1e6a21 100644
--- a/account.go
+++ b/account.go
@@ -1,1209 +1,1226 @@
/*
* Copyright © 2018-2021 A Bunch Tell LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package writefreely
import (
"encoding/json"
"fmt"
"html/template"
"net/http"
"regexp"
+ "strconv"
"strings"
"sync"
"time"
"github.com/gorilla/csrf"
"github.com/gorilla/mux"
"github.com/gorilla/sessions"
"github.com/guregu/null/zero"
"github.com/writeas/impart"
"github.com/writeas/web-core/auth"
"github.com/writeas/web-core/data"
"github.com/writeas/web-core/log"
"github.com/writefreely/writefreely/author"
"github.com/writefreely/writefreely/config"
"github.com/writefreely/writefreely/page"
)
type (
userSettings struct {
Username string `schema:"username" json:"username"`
Email string `schema:"email" json:"email"`
NewPass string `schema:"new-pass" json:"new_pass"`
OldPass string `schema:"current-pass" json:"current_pass"`
IsLogOut bool `schema:"logout" json:"logout"`
}
UserPage struct {
page.StaticPage
PageTitle string
Separator template.HTML
IsAdmin bool
CanInvite bool
CollAlias string
}
)
func NewUserPage(app *App, r *http.Request, u *User, title string, flashes []string) *UserPage {
up := &UserPage{
StaticPage: pageForReq(app, r),
PageTitle: title,
}
up.Username = u.Username
up.Flashes = flashes
up.Path = r.URL.Path
up.IsAdmin = u.IsAdmin()
up.CanInvite = canUserInvite(app.cfg, up.IsAdmin)
return up
}
func canUserInvite(cfg *config.Config, isAdmin bool) bool {
return cfg.App.UserInvites != "" &&
(isAdmin || cfg.App.UserInvites != "admin")
}
func (up *UserPage) SetMessaging(u *User) {
// up.NeedsAuth = app.db.DoesUserNeedAuth(u.ID)
}
const (
loginAttemptExpiration = 3 * time.Second
)
var actuallyUsernameReg = regexp.MustCompile("username is actually ([a-z0-9\\-]+)\\. Please try that, instead")
func apiSignup(app *App, w http.ResponseWriter, r *http.Request) error {
_, err := signup(app, w, r)
return err
}
func signup(app *App, w http.ResponseWriter, r *http.Request) (*AuthUser, error) {
if app.cfg.App.DisablePasswordAuth {
err := ErrDisabledPasswordAuth
return nil, err
}
reqJSON := IsJSON(r)
// Get params
var ur userRegistration
if reqJSON {
decoder := json.NewDecoder(r.Body)
err := decoder.Decode(&ur)
if err != nil {
log.Error("Couldn't parse signup JSON request: %v\n", err)
return nil, ErrBadJSON
}
} else {
// Check if user is already logged in
u := getUserSession(app, r)
if u != nil {
return &AuthUser{User: u}, nil
}
err := r.ParseForm()
if err != nil {
log.Error("Couldn't parse signup form request: %v\n", err)
return nil, ErrBadFormData
}
err = app.formDecoder.Decode(&ur, r.PostForm)
if err != nil {
log.Error("Couldn't decode signup form request: %v\n", err)
return nil, ErrBadFormData
}
}
return signupWithRegistration(app, ur, w, r)
}
func signupWithRegistration(app *App, signup userRegistration, w http.ResponseWriter, r *http.Request) (*AuthUser, error) {
reqJSON := IsJSON(r)
// Validate required params (alias)
if signup.Alias == "" {
return nil, impart.HTTPError{http.StatusBadRequest, "A username is required."}
}
if signup.Pass == "" {
return nil, impart.HTTPError{http.StatusBadRequest, "A password is required."}
}
var desiredUsername string
if signup.Normalize {
// With this option we simply conform the username to what we expect
// without complaining. Since they might've done something funny, like
// enter: write.as/Way Out There, we'll use their raw input for the new
// collection name and sanitize for the slug / username.
desiredUsername = signup.Alias
signup.Alias = getSlug(signup.Alias, "")
}
if !author.IsValidUsername(app.cfg, signup.Alias) {
// Ensure the username is syntactically correct.
return nil, impart.HTTPError{http.StatusPreconditionFailed, "Username is reserved or isn't valid. It must be at least 3 characters long, and can only include letters, numbers, and hyphens."}
}
// Handle empty optional params
hashedPass, err := auth.HashPass([]byte(signup.Pass))
if err != nil {
return nil, impart.HTTPError{http.StatusInternalServerError, "Could not create password hash."}
}
// Create struct to insert
u := &User{
Username: signup.Alias,
HashedPass: hashedPass,
HasPass: true,
Email: prepareUserEmail(signup.Email, app.keys.EmailKey),
Created: time.Now().Truncate(time.Second).UTC(),
}
// Create actual user
if err := app.db.CreateUser(app.cfg, u, desiredUsername); err != nil {
return nil, err
}
// Log invite if needed
if signup.InviteCode != "" {
err = app.db.CreateInvitedUser(signup.InviteCode, u.ID)
if err != nil {
return nil, err
}
}
// Add back unencrypted data for response
if signup.Email != "" {
u.Email.String = signup.Email
}
resUser := &AuthUser{
User: u,
}
title := signup.Alias
if signup.Normalize {
title = desiredUsername
}
resUser.Collections = &[]Collection{
{
Alias: signup.Alias,
Title: title,
},
}
var token string
if reqJSON && !signup.Web {
token, err = app.db.GetAccessToken(u.ID)
if err != nil {
return nil, impart.HTTPError{http.StatusInternalServerError, "Could not create access token. Try re-authenticating."}
}
resUser.AccessToken = token
} else {
session, err := app.sessionStore.Get(r, cookieName)
if err != nil {
// The cookie should still save, even if there's an error.
// Source: https://github.com/gorilla/sessions/issues/16#issuecomment-143642144
log.Error("Session: %v; ignoring", err)
}
session.Values[cookieUserVal] = resUser.User.Cookie()
err = session.Save(r, w)
if err != nil {
log.Error("Couldn't save session: %v", err)
return nil, err
}
}
if reqJSON {
return resUser, impart.WriteSuccess(w, resUser, http.StatusCreated)
}
return resUser, nil
}
func viewLogout(app *App, w http.ResponseWriter, r *http.Request) error {
session, err := app.sessionStore.Get(r, cookieName)
if err != nil {
return ErrInternalCookieSession
}
// Ensure user has an email or password set before they go, so they don't
// lose access to their account.
val := session.Values[cookieUserVal]
var u = &User{}
var ok bool
if u, ok = val.(*User); !ok {
log.Error("Error casting user object on logout. Vals: %+v Resetting cookie.", session.Values)
err = session.Save(r, w)
if err != nil {
log.Error("Couldn't save session on logout: %v", err)
return impart.HTTPError{http.StatusInternalServerError, "Unable to save cookie session."}
}
return impart.HTTPError{http.StatusFound, "/"}
}
u, err = app.db.GetUserByID(u.ID)
if err != nil && err != ErrUserNotFound {
return impart.HTTPError{http.StatusInternalServerError, "Unable to fetch user information."}
}
session.Options.MaxAge = -1
err = session.Save(r, w)
if err != nil {
log.Error("Couldn't save session on logout: %v", err)
return impart.HTTPError{http.StatusInternalServerError, "Unable to save cookie session."}
}
return impart.HTTPError{http.StatusFound, "/"}
}
func handleAPILogout(app *App, w http.ResponseWriter, r *http.Request) error {
accessToken := r.Header.Get("Authorization")
if accessToken == "" {
return ErrNoAccessToken
}
t := auth.GetToken(accessToken)
if len(t) == 0 {
return ErrNoAccessToken
}
err := app.db.DeleteToken(t)
if err != nil {
return err
}
return impart.HTTPError{Status: http.StatusNoContent}
}
func viewLogin(app *App, w http.ResponseWriter, r *http.Request) error {
var earlyError string
oneTimeToken := r.FormValue("with")
if oneTimeToken != "" {
log.Info("Calling login with one-time token.")
err := login(app, w, r)
if err != nil {
log.Info("Received error: %v", err)
earlyError = fmt.Sprintf("%s", err)
}
}
session, err := app.sessionStore.Get(r, cookieName)
if err != nil {
// Ignore this
log.Error("Unable to get session; ignoring: %v", err)
}
p := &struct {
page.StaticPage
*OAuthButtons
To string
Message template.HTML
Flashes []template.HTML
LoginUsername string
}{
StaticPage: pageForReq(app, r),
OAuthButtons: NewOAuthButtons(app.Config()),
To: r.FormValue("to"),
Message: template.HTML(""),
Flashes: []template.HTML{},
LoginUsername: getTempInfo(app, "login-user", r, w),
}
if earlyError != "" {
p.Flashes = append(p.Flashes, template.HTML(earlyError))
}
// Display any error messages
flashes, _ := getSessionFlashes(app, w, r, session)
for _, flash := range flashes {
p.Flashes = append(p.Flashes, template.HTML(flash))
}
err = pages["login.tmpl"].ExecuteTemplate(w, "base", p)
if err != nil {
log.Error("Unable to render login: %v", err)
return err
}
return nil
}
func webLogin(app *App, w http.ResponseWriter, r *http.Request) error {
err := login(app, w, r)
if err != nil {
username := r.FormValue("alias")
// Login request was unsuccessful; save the error in the session and redirect them
if err, ok := err.(impart.HTTPError); ok {
session, _ := app.sessionStore.Get(r, cookieName)
if session != nil {
session.AddFlash(err.Message)
session.Save(r, w)
}
if m := actuallyUsernameReg.FindStringSubmatch(err.Message); len(m) > 0 {
// Retain fixed username recommendation for the login form
username = m[1]
}
}
// Pass along certain information
saveTempInfo(app, "login-user", username, r, w)
// Retain post-login URL if one was given
redirectTo := "/login"
postLoginRedirect := r.FormValue("to")
if postLoginRedirect != "" {
redirectTo += "?to=" + postLoginRedirect
}
log.Error("Unable to login: %v", err)
return impart.HTTPError{http.StatusTemporaryRedirect, redirectTo}
}
return nil
}
var loginAttemptUsers = sync.Map{}
func login(app *App, w http.ResponseWriter, r *http.Request) error {
reqJSON := IsJSON(r)
oneTimeToken := r.FormValue("with")
verbose := r.FormValue("all") == "true" || r.FormValue("verbose") == "1" || r.FormValue("verbose") == "true" || (reqJSON && oneTimeToken != "")
redirectTo := r.FormValue("to")
if redirectTo == "" {
if app.cfg.App.SingleUser {
redirectTo = "/me/new"
} else {
redirectTo = "/"
}
}
var u *User
var err error
var signin userCredentials
if app.cfg.App.DisablePasswordAuth {
err := ErrDisabledPasswordAuth
return err
}
// Log in with one-time token if one is given
if oneTimeToken != "" {
log.Info("Login: Logging user in via token.")
userID := app.db.GetUserID(oneTimeToken)
if userID == -1 {
log.Error("Login: Got user -1 from token")
err := ErrBadAccessToken
err.Message = "Expired or invalid login code."
return err
}
log.Info("Login: Found user %d.", userID)
u, err = app.db.GetUserByID(userID)
if err != nil {
log.Error("Unable to fetch user on one-time token login: %v", err)
return impart.HTTPError{http.StatusInternalServerError, "There was an error retrieving the user you want."}
}
log.Info("Login: Got user via token")
} else {
// Get params
if reqJSON {
decoder := json.NewDecoder(r.Body)
err := decoder.Decode(&signin)
if err != nil {
log.Error("Couldn't parse signin JSON request: %v\n", err)
return ErrBadJSON
}
} else {
err := r.ParseForm()
if err != nil {
log.Error("Couldn't parse signin form request: %v\n", err)
return ErrBadFormData
}
err = app.formDecoder.Decode(&signin, r.PostForm)
if err != nil {
log.Error("Couldn't decode signin form request: %v\n", err)
return ErrBadFormData
}
}
log.Info("Login: Attempting login for '%s'", signin.Alias)
// Validate required params (all)
if signin.Alias == "" {
msg := "Parameter `alias` required."
if signin.Web {
msg = "A username is required."
}
return impart.HTTPError{http.StatusBadRequest, msg}
}
if !signin.EmailLogin && signin.Pass == "" {
msg := "Parameter `pass` required."
if signin.Web {
msg = "A password is required."
}
return impart.HTTPError{http.StatusBadRequest, msg}
}
// Prevent excessive login attempts on the same account
// Skip this check in dev environment
if !app.cfg.Server.Dev {
now := time.Now()
attemptExp, att := loginAttemptUsers.LoadOrStore(signin.Alias, now.Add(loginAttemptExpiration))
if att {
if attemptExpTime, ok := attemptExp.(time.Time); ok {
if attemptExpTime.After(now) {
// This user attempted previously, and the period hasn't expired yet
return impart.HTTPError{http.StatusTooManyRequests, "You're doing that too much."}
} else {
// This user attempted previously, but the time expired; free up space
loginAttemptUsers.Delete(signin.Alias)
}
} else {
log.Error("Unable to cast expiration to time")
}
}
}
// Retrieve password
u, err = app.db.GetUserForAuth(signin.Alias)
if err != nil {
log.Info("Unable to getUserForAuth on %s: %v", signin.Alias, err)
if strings.IndexAny(signin.Alias, "@") > 0 {
log.Info("Suggesting: %s", ErrUserNotFoundEmail.Message)
return ErrUserNotFoundEmail
}
return err
}
// Authenticate
if u.Email.String == "" {
// User has no email set, so check if they haven't added a password, either,
// so we can return a more helpful error message.
if hasPass, _ := app.db.IsUserPassSet(u.ID); !hasPass {
log.Info("Tried logging in to %s, but no password or email.", signin.Alias)
return impart.HTTPError{http.StatusPreconditionFailed, "This user never added a password or email address. Please contact us for help."}
}
}
if len(u.HashedPass) == 0 {
return impart.HTTPError{http.StatusUnauthorized, "This user never set a password. Perhaps try logging in via OAuth?"}
}
if !auth.Authenticated(u.HashedPass, []byte(signin.Pass)) {
return impart.HTTPError{http.StatusUnauthorized, "Incorrect password."}
}
}
if reqJSON && !signin.Web {
var token string
if r.Header.Get("User-Agent") == "" {
// Get last created token when User-Agent is empty
token = app.db.FetchLastAccessToken(u.ID)
if token == "" {
token, err = app.db.GetAccessToken(u.ID)
}
} else {
token, err = app.db.GetAccessToken(u.ID)
}
if err != nil {
log.Error("Login: Unable to create access token: %v", err)
return impart.HTTPError{http.StatusInternalServerError, "Could not create access token. Try re-authenticating."}
}
resUser := getVerboseAuthUser(app, token, u, verbose)
return impart.WriteSuccess(w, resUser, http.StatusOK)
}
session, err := app.sessionStore.Get(r, cookieName)
if err != nil {
// The cookie should still save, even if there's an error.
log.Error("Login: Session: %v; ignoring", err)
}
// Remove unwanted data
session.Values[cookieUserVal] = u.Cookie()
err = session.Save(r, w)
if err != nil {
log.Error("Login: Couldn't save session: %v", err)
// TODO: return error
}
// Send success
if reqJSON {
return impart.WriteSuccess(w, &AuthUser{User: u}, http.StatusOK)
}
log.Info("Login: Redirecting to %s", redirectTo)
w.Header().Set("Location", redirectTo)
w.WriteHeader(http.StatusFound)
return nil
}
func getVerboseAuthUser(app *App, token string, u *User, verbose bool) *AuthUser {
resUser := &AuthUser{
AccessToken: token,
User: u,
}
// Fetch verbose user data if requested
if verbose {
posts, err := app.db.GetUserPosts(u)
if err != nil {
log.Error("Login: Unable to get user posts: %v", err)
}
colls, err := app.db.GetCollections(u, app.cfg.App.Host)
if err != nil {
log.Error("Login: Unable to get user collections: %v", err)
}
passIsSet, err := app.db.IsUserPassSet(u.ID)
if err != nil {
// TODO: correct error meesage
log.Error("Login: Unable to get user collections: %v", err)
}
resUser.Posts = posts
resUser.Collections = colls
resUser.User.HasPass = passIsSet
}
return resUser
}
func viewExportOptions(app *App, u *User, w http.ResponseWriter, r *http.Request) error {
// Fetch extra user data
p := NewUserPage(app, r, u, "Export", nil)
showUserPage(w, "export", p)
return nil
}
func viewExportPosts(app *App, w http.ResponseWriter, r *http.Request) ([]byte, string, error) {
var filename string
var u = &User{}
reqJSON := IsJSON(r)
if reqJSON {
// Use given Authorization header
accessToken := r.Header.Get("Authorization")
if accessToken == "" {
return nil, filename, ErrNoAccessToken
}
userID := app.db.GetUserID(accessToken)
if userID == -1 {
return nil, filename, ErrBadAccessToken
}
var err error
u, err = app.db.GetUserByID(userID)
if err != nil {
return nil, filename, impart.HTTPError{http.StatusInternalServerError, "Unable to retrieve requested user."}
}
} else {
// Use user cookie
session, err := app.sessionStore.Get(r, cookieName)
if err != nil {
// The cookie should still save, even if there's an error.
log.Error("Session: %v; ignoring", err)
}
val := session.Values[cookieUserVal]
var ok bool
if u, ok = val.(*User); !ok {
return nil, filename, ErrNotLoggedIn
}
}
filename = u.Username + "-posts-" + time.Now().Truncate(time.Second).UTC().Format("200601021504")
// Fetch data we're exporting
var err error
var data []byte
posts, err := app.db.GetUserPosts(u)
if err != nil {
return data, filename, err
}
// Export as CSV
if strings.HasSuffix(r.URL.Path, ".csv") {
data = exportPostsCSV(app.cfg.App.Host, u, posts)
return data, filename, err
}
if strings.HasSuffix(r.URL.Path, ".zip") {
data = exportPostsZip(u, posts)
return data, filename, err
}
if r.FormValue("pretty") == "1" {
data, err = json.MarshalIndent(posts, "", "\t")
} else {
data, err = json.Marshal(posts)
}
return data, filename, err
}
func viewExportFull(app *App, w http.ResponseWriter, r *http.Request) ([]byte, string, error) {
var err error
filename := ""
u := getUserSession(app, r)
if u == nil {
return nil, filename, ErrNotLoggedIn
}
filename = u.Username + "-" + time.Now().Truncate(time.Second).UTC().Format("200601021504")
exportUser := compileFullExport(app, u)
var data []byte
if r.FormValue("pretty") == "1" {
data, err = json.MarshalIndent(exportUser, "", "\t")
} else {
data, err = json.Marshal(exportUser)
}
return data, filename, err
}
func viewMeAPI(app *App, w http.ResponseWriter, r *http.Request) error {
reqJSON := IsJSON(r)
uObj := struct {
ID int64 `json:"id,omitempty"`
Username string `json:"username,omitempty"`
}{}
var err error
if reqJSON {
_, uObj.Username, err = app.db.GetUserDataFromToken(r.Header.Get("Authorization"))
if err != nil {
return err
}
} else {
u := getUserSession(app, r)
if u == nil {
return impart.WriteSuccess(w, uObj, http.StatusOK)
}
uObj.Username = u.Username
}
return impart.WriteSuccess(w, uObj, http.StatusOK)
}
func viewMyPostsAPI(app *App, u *User, w http.ResponseWriter, r *http.Request) error {
reqJSON := IsJSON(r)
if !reqJSON {
return ErrBadRequestedType
}
+ isAnonPosts := r.FormValue("anonymous") == "1"
+ if isAnonPosts {
+ pageStr := r.FormValue("page")
+ pg, err := strconv.Atoi(pageStr)
+ if err != nil {
+ log.Error("Error parsing page parameter '%s': %s", pageStr, err)
+ pg = 1
+ }
+
+ p, err := app.db.GetAnonymousPosts(u, pg)
+ if err != nil {
+ return err
+ }
+ return impart.WriteSuccess(w, p, http.StatusOK)
+ }
+
var err error
p := GetPostsCache(u.ID)
if p == nil {
userPostsCache.Lock()
if userPostsCache.users[u.ID].ready == nil {
userPostsCache.users[u.ID] = postsCacheItem{ready: make(chan struct{})}
userPostsCache.Unlock()
p, err = app.db.GetUserPosts(u)
if err != nil {
return err
}
CachePosts(u.ID, p)
} else {
userPostsCache.Unlock()
<-userPostsCache.users[u.ID].ready
p = GetPostsCache(u.ID)
}
}
return impart.WriteSuccess(w, p, http.StatusOK)
}
func viewMyCollectionsAPI(app *App, u *User, w http.ResponseWriter, r *http.Request) error {
reqJSON := IsJSON(r)
if !reqJSON {
return ErrBadRequestedType
}
p, err := app.db.GetCollections(u, app.cfg.App.Host)
if err != nil {
return err
}
return impart.WriteSuccess(w, p, http.StatusOK)
}
func viewArticles(app *App, u *User, w http.ResponseWriter, r *http.Request) error {
- p, err := app.db.GetAnonymousPosts(u)
+ p, err := app.db.GetAnonymousPosts(u, 1)
if err != nil {
log.Error("unable to fetch anon posts: %v", err)
}
// nil-out AnonymousPosts slice for easy detection in the template
if p != nil && len(*p) == 0 {
p = nil
}
f, err := getSessionFlashes(app, w, r, nil)
if err != nil {
log.Error("unable to fetch flashes: %v", err)
}
c, err := app.db.GetPublishableCollections(u, app.cfg.App.Host)
if err != nil {
log.Error("unable to fetch collections: %v", err)
}
silenced, err := app.db.IsUserSilenced(u.ID)
if err != nil {
log.Error("view articles: %v", err)
}
d := struct {
*UserPage
AnonymousPosts *[]PublicPost
Collections *[]Collection
Silenced bool
}{
UserPage: NewUserPage(app, r, u, u.Username+"'s Posts", f),
AnonymousPosts: p,
Collections: c,
Silenced: silenced,
}
d.UserPage.SetMessaging(u)
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
w.Header().Set("Expires", "Thu, 04 Oct 1990 20:00:00 GMT")
showUserPage(w, "articles", d)
return nil
}
func viewCollections(app *App, u *User, w http.ResponseWriter, r *http.Request) error {
c, err := app.db.GetCollections(u, app.cfg.App.Host)
if err != nil {
log.Error("unable to fetch collections: %v", err)
return fmt.Errorf("No collections")
}
f, _ := getSessionFlashes(app, w, r, nil)
uc, _ := app.db.GetUserCollectionCount(u.ID)
// TODO: handle any errors
silenced, err := app.db.IsUserSilenced(u.ID)
if err != nil {
log.Error("view collections %v", err)
return fmt.Errorf("view collections: %v", err)
}
d := struct {
*UserPage
Collections *[]Collection
UsedCollections, TotalCollections int
NewBlogsDisabled bool
Silenced bool
}{
UserPage: NewUserPage(app, r, u, u.Username+"'s Blogs", f),
Collections: c,
UsedCollections: int(uc),
NewBlogsDisabled: !app.cfg.App.CanCreateBlogs(uc),
Silenced: silenced,
}
d.UserPage.SetMessaging(u)
showUserPage(w, "collections", d)
return nil
}
func viewEditCollection(app *App, u *User, w http.ResponseWriter, r *http.Request) error {
vars := mux.Vars(r)
c, err := app.db.GetCollection(vars["collection"])
if err != nil {
return err
}
if c.OwnerID != u.ID {
return ErrCollectionNotFound
}
// Add collection properties
c.Monetization = app.db.GetCollectionAttribute(c.ID, "monetization_pointer")
silenced, err := app.db.IsUserSilenced(u.ID)
if err != nil {
log.Error("view edit collection %v", err)
return fmt.Errorf("view edit collection: %v", err)
}
flashes, _ := getSessionFlashes(app, w, r, nil)
obj := struct {
*UserPage
*Collection
Silenced bool
}{
UserPage: NewUserPage(app, r, u, "Edit "+c.DisplayTitle(), flashes),
Collection: c,
Silenced: silenced,
}
obj.UserPage.CollAlias = c.Alias
showUserPage(w, "collection", obj)
return nil
}
func updateSettings(app *App, w http.ResponseWriter, r *http.Request) error {
reqJSON := IsJSON(r)
var s userSettings
var u *User
var sess *sessions.Session
var err error
if reqJSON {
accessToken := r.Header.Get("Authorization")
if accessToken == "" {
return ErrNoAccessToken
}
u, err = app.db.GetAPIUser(accessToken)
if err != nil {
return ErrBadAccessToken
}
decoder := json.NewDecoder(r.Body)
err := decoder.Decode(&s)
if err != nil {
log.Error("Couldn't parse settings JSON request: %v\n", err)
return ErrBadJSON
}
// Prevent all username updates
// TODO: support changing username via JSON API request
s.Username = ""
} else {
u, sess = getUserAndSession(app, r)
if u == nil {
return ErrNotLoggedIn
}
err := r.ParseForm()
if err != nil {
log.Error("Couldn't parse settings form request: %v\n", err)
return ErrBadFormData
}
err = app.formDecoder.Decode(&s, r.PostForm)
if err != nil {
log.Error("Couldn't decode settings form request: %v\n", err)
return ErrBadFormData
}
}
// Do update
postUpdateReturn := r.FormValue("return")
redirectTo := "/me/settings"
if s.IsLogOut {
redirectTo += "?logout=1"
} else if postUpdateReturn != "" {
redirectTo = postUpdateReturn
}
// Only do updates on values we need
if s.Username != "" && s.Username == u.Username {
// Username hasn't actually changed; blank it out
s.Username = ""
}
err = app.db.ChangeSettings(app, u, &s)
if err != nil {
if reqJSON {
return err
}
if err, ok := err.(impart.HTTPError); ok {
addSessionFlash(app, w, r, err.Message, nil)
}
} else {
// Successful update.
if reqJSON {
return impart.WriteSuccess(w, u, http.StatusOK)
}
if s.IsLogOut {
redirectTo = "/me/logout"
} else {
sess.Values[cookieUserVal] = u.Cookie()
addSessionFlash(app, w, r, "Account updated.", nil)
}
}
w.Header().Set("Location", redirectTo)
w.WriteHeader(http.StatusFound)
return nil
}
func updatePassphrase(app *App, w http.ResponseWriter, r *http.Request) error {
accessToken := r.Header.Get("Authorization")
if accessToken == "" {
return ErrNoAccessToken
}
curPass := r.FormValue("current")
newPass := r.FormValue("new")
// Ensure a new password is given (always required)
if newPass == "" {
return impart.HTTPError{http.StatusBadRequest, "Provide a new password."}
}
userID, sudo := app.db.GetUserIDPrivilege(accessToken)
if userID == -1 {
return ErrBadAccessToken
}
// Ensure a current password is given if the access token doesn't have sudo
// privileges.
if !sudo && curPass == "" {
return impart.HTTPError{http.StatusBadRequest, "Provide current password."}
}
// Hash the new password
hashedPass, err := auth.HashPass([]byte(newPass))
if err != nil {
return impart.HTTPError{http.StatusInternalServerError, "Could not create password hash."}
}
// Do update
err = app.db.ChangePassphrase(userID, sudo, curPass, hashedPass)
if err != nil {
return err
}
return impart.WriteSuccess(w, struct{}{}, http.StatusOK)
}
func viewStats(app *App, u *User, w http.ResponseWriter, r *http.Request) error {
var c *Collection
var err error
vars := mux.Vars(r)
alias := vars["collection"]
if alias != "" {
c, err = app.db.GetCollection(alias)
if err != nil {
return err
}
if c.OwnerID != u.ID {
return ErrCollectionNotFound
}
}
topPosts, err := app.db.GetTopPosts(u, alias)
if err != nil {
log.Error("Unable to get top posts: %v", err)
return err
}
flashes, _ := getSessionFlashes(app, w, r, nil)
titleStats := ""
if c != nil {
titleStats = c.DisplayTitle() + " "
}
silenced, err := app.db.IsUserSilenced(u.ID)
if err != nil {
log.Error("view stats: %v", err)
return err
}
obj := struct {
*UserPage
VisitsBlog string
Collection *Collection
TopPosts *[]PublicPost
APFollowers int
Silenced bool
}{
UserPage: NewUserPage(app, r, u, titleStats+"Stats", flashes),
VisitsBlog: alias,
Collection: c,
TopPosts: topPosts,
Silenced: silenced,
}
obj.UserPage.CollAlias = c.Alias
if app.cfg.App.Federation {
folls, err := app.db.GetAPFollowers(c)
if err != nil {
return err
}
obj.APFollowers = len(*folls)
}
showUserPage(w, "stats", obj)
return nil
}
func viewSettings(app *App, u *User, w http.ResponseWriter, r *http.Request) error {
fullUser, err := app.db.GetUserByID(u.ID)
if err != nil {
log.Error("Unable to get user for settings: %s", err)
return impart.HTTPError{http.StatusInternalServerError, "Unable to retrieve user data. The humans have been alerted."}
}
passIsSet, err := app.db.IsUserPassSet(u.ID)
if err != nil {
log.Error("Unable to get isUserPassSet for settings: %s", err)
return impart.HTTPError{http.StatusInternalServerError, "Unable to retrieve user data. The humans have been alerted."}
}
flashes, _ := getSessionFlashes(app, w, r, nil)
enableOauthSlack := app.Config().SlackOauth.ClientID != ""
enableOauthWriteAs := app.Config().WriteAsOauth.ClientID != ""
enableOauthGitLab := app.Config().GitlabOauth.ClientID != ""
enableOauthGeneric := app.Config().GenericOauth.ClientID != ""
enableOauthGitea := app.Config().GiteaOauth.ClientID != ""
oauthAccounts, err := app.db.GetOauthAccounts(r.Context(), u.ID)
if err != nil {
log.Error("Unable to get oauth accounts for settings: %s", err)
return impart.HTTPError{http.StatusInternalServerError, "Unable to retrieve user data. The humans have been alerted."}
}
for idx, oauthAccount := range oauthAccounts {
switch oauthAccount.Provider {
case "slack":
enableOauthSlack = false
case "write.as":
enableOauthWriteAs = false
case "gitlab":
enableOauthGitLab = false
case "generic":
oauthAccounts[idx].DisplayName = app.Config().GenericOauth.DisplayName
oauthAccounts[idx].AllowDisconnect = app.Config().GenericOauth.AllowDisconnect
enableOauthGeneric = false
case "gitea":
enableOauthGitea = false
}
}
displayOauthSection := enableOauthSlack || enableOauthWriteAs || enableOauthGitLab || enableOauthGeneric || enableOauthGitea || len(oauthAccounts) > 0
obj := struct {
*UserPage
Email string
HasPass bool
IsLogOut bool
Silenced bool
CSRFField template.HTML
OauthSection bool
OauthAccounts []oauthAccountInfo
OauthSlack bool
OauthWriteAs bool
OauthGitLab bool
GitLabDisplayName string
OauthGeneric bool
OauthGenericDisplayName string
OauthGitea bool
GiteaDisplayName string
}{
UserPage: NewUserPage(app, r, u, "Account Settings", flashes),
Email: fullUser.EmailClear(app.keys),
HasPass: passIsSet,
IsLogOut: r.FormValue("logout") == "1",
Silenced: fullUser.IsSilenced(),
CSRFField: csrf.TemplateField(r),
OauthSection: displayOauthSection,
OauthAccounts: oauthAccounts,
OauthSlack: enableOauthSlack,
OauthWriteAs: enableOauthWriteAs,
OauthGitLab: enableOauthGitLab,
GitLabDisplayName: config.OrDefaultString(app.Config().GitlabOauth.DisplayName, gitlabDisplayName),
OauthGeneric: enableOauthGeneric,
OauthGenericDisplayName: config.OrDefaultString(app.Config().GenericOauth.DisplayName, genericOauthDisplayName),
OauthGitea: enableOauthGitea,
GiteaDisplayName: config.OrDefaultString(app.Config().GiteaOauth.DisplayName, giteaDisplayName),
}
showUserPage(w, "settings", obj)
return nil
}
func saveTempInfo(app *App, key, val string, r *http.Request, w http.ResponseWriter) error {
session, err := app.sessionStore.Get(r, "t")
if err != nil {
return ErrInternalCookieSession
}
session.Values[key] = val
err = session.Save(r, w)
if err != nil {
log.Error("Couldn't saveTempInfo for key-val (%s:%s): %v", key, val, err)
}
return err
}
func getTempInfo(app *App, key string, r *http.Request, w http.ResponseWriter) string {
session, err := app.sessionStore.Get(r, "t")
if err != nil {
return ""
}
// Get the information
var s = ""
var ok bool
if s, ok = session.Values[key].(string); !ok {
return ""
}
// Delete cookie
session.Options.MaxAge = -1
err = session.Save(r, w)
if err != nil {
log.Error("Couldn't erase temp data for key %s: %v", key, err)
}
// Return value
return s
}
func handleUserDelete(app *App, u *User, w http.ResponseWriter, r *http.Request) error {
if !app.cfg.App.OpenDeletion {
return impart.HTTPError{http.StatusForbidden, "Open account deletion is disabled on this instance."}
}
confirmUsername := r.PostFormValue("confirm-username")
if u.Username != confirmUsername {
return impart.HTTPError{http.StatusBadRequest, "Confirmation username must match your username exactly."}
}
// Check for account deletion safeguards in place
if u.IsAdmin() {
return impart.HTTPError{http.StatusForbidden, "Cannot delete admin."}
}
err := app.db.DeleteAccount(u.ID)
if err != nil {
log.Error("user delete account: %v", err)
return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not delete account: %v", err)}
}
// FIXME: This doesn't ever appear to the user, as (I believe) the value is erased when the session cookie is reset
_ = addSessionFlash(app, w, r, "Thanks for writing with us! You account was deleted successfully.", nil)
return impart.HTTPError{http.StatusFound, "/me/logout"}
}
func removeOauth(app *App, u *User, w http.ResponseWriter, r *http.Request) error {
provider := r.FormValue("provider")
clientID := r.FormValue("client_id")
remoteUserID := r.FormValue("remote_user_id")
err := app.db.RemoveOauth(r.Context(), u.ID, provider, clientID, remoteUserID)
if err != nil {
return impart.HTTPError{Status: http.StatusInternalServerError, Message: err.Error()}
}
return impart.HTTPError{Status: http.StatusFound, Message: "/me/settings"}
}
func prepareUserEmail(input string, emailKey []byte) zero.String {
email := zero.NewString("", input != "")
if len(input) > 0 {
encEmail, err := data.Encrypt(emailKey, input)
if err != nil {
log.Error("Unable to encrypt email: %s\n", err)
} else {
email.String = string(encEmail)
}
}
return email
}
diff --git a/database.go b/database.go
index 88b46e5..df300ce 100644
--- a/database.go
+++ b/database.go
@@ -1,2789 +1,2800 @@
/*
* Copyright © 2018-2021 A Bunch Tell LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package writefreely
import (
"context"
"database/sql"
"fmt"
"github.com/writeas/web-core/silobridge"
wf_db "github.com/writefreely/writefreely/db"
"net/http"
"strings"
"time"
"github.com/guregu/null"
"github.com/guregu/null/zero"
uuid "github.com/nu7hatch/gouuid"
"github.com/writeas/activityserve"
"github.com/writeas/impart"
"github.com/writeas/web-core/activitypub"
"github.com/writeas/web-core/auth"
"github.com/writeas/web-core/data"
"github.com/writeas/web-core/id"
"github.com/writeas/web-core/log"
"github.com/writeas/web-core/query"
"github.com/writefreely/writefreely/author"
"github.com/writefreely/writefreely/config"
"github.com/writefreely/writefreely/key"
)
const (
mySQLErrDuplicateKey = 1062
mySQLErrCollationMix = 1267
mySQLErrTooManyConns = 1040
mySQLErrMaxUserConns = 1203
driverMySQL = "mysql"
driverSQLite = "sqlite3"
)
var (
SQLiteEnabled bool
)
type writestore interface {
CreateUser(*config.Config, *User, string) error
UpdateUserEmail(keys *key.Keychain, userID int64, email string) error
UpdateEncryptedUserEmail(int64, []byte) error
GetUserByID(int64) (*User, error)
GetUserForAuth(string) (*User, error)
GetUserForAuthByID(int64) (*User, error)
GetUserNameFromToken(string) (string, error)
GetUserDataFromToken(string) (int64, string, error)
GetAPIUser(header string) (*User, error)
GetUserID(accessToken string) int64
GetUserIDPrivilege(accessToken string) (userID int64, sudo bool)
DeleteToken(accessToken []byte) error
FetchLastAccessToken(userID int64) string
GetAccessToken(userID int64) (string, error)
GetTemporaryAccessToken(userID int64, validSecs int) (string, error)
GetTemporaryOneTimeAccessToken(userID int64, validSecs int, oneTime bool) (string, error)
DeleteAccount(userID int64) error
ChangeSettings(app *App, u *User, s *userSettings) error
ChangePassphrase(userID int64, sudo bool, curPass string, hashedPass []byte) error
GetCollections(u *User, hostName string) (*[]Collection, error)
GetPublishableCollections(u *User, hostName string) (*[]Collection, error)
GetMeStats(u *User) userMeStats
GetTotalCollections() (int64, error)
GetTotalPosts() (int64, error)
GetTopPosts(u *User, alias string) (*[]PublicPost, error)
- GetAnonymousPosts(u *User) (*[]PublicPost, error)
+ GetAnonymousPosts(u *User, page int) (*[]PublicPost, error)
GetUserPosts(u *User) (*[]PublicPost, error)
CreateOwnedPost(post *SubmittedPost, accessToken, collAlias, hostName string) (*PublicPost, error)
CreatePost(userID, collID int64, post *SubmittedPost) (*Post, error)
UpdateOwnedPost(post *AuthenticatedPost, userID int64) error
GetEditablePost(id, editToken string) (*PublicPost, error)
PostIDExists(id string) bool
GetPost(id string, collectionID int64) (*PublicPost, error)
GetOwnedPost(id string, ownerID int64) (*PublicPost, error)
GetPostProperty(id string, collectionID int64, property string) (interface{}, error)
CreateCollectionFromToken(*config.Config, string, string, string) (*Collection, error)
CreateCollection(*config.Config, string, string, int64) (*Collection, error)
GetCollectionBy(condition string, value interface{}) (*Collection, error)
GetCollection(alias string) (*Collection, error)
GetCollectionForPad(alias string) (*Collection, error)
GetCollectionByID(id int64) (*Collection, error)
UpdateCollection(c *SubmittedCollection, alias string) error
DeleteCollection(alias string, userID int64) error
UpdatePostPinState(pinned bool, postID string, collID, ownerID, pos int64) error
GetLastPinnedPostPos(collID int64) int64
GetPinnedPosts(coll *CollectionObj, includeFuture bool) (*[]PublicPost, error)
RemoveCollectionRedirect(t *sql.Tx, alias string) error
GetCollectionRedirect(alias string) (new string)
IsCollectionAttributeOn(id int64, attr string) bool
CollectionHasAttribute(id int64, attr string) bool
CanCollect(cpr *ClaimPostRequest, userID int64) bool
AttemptClaim(p *ClaimPostRequest, query string, params []interface{}, slugIdx int) (sql.Result, error)
DispersePosts(userID int64, postIDs []string) (*[]ClaimPostResult, error)
ClaimPosts(cfg *config.Config, userID int64, collAlias string, posts *[]ClaimPostRequest) (*[]ClaimPostResult, error)
GetPostsCount(c *CollectionObj, includeFuture bool)
GetPosts(cfg *config.Config, c *Collection, page int, includeFuture, forceRecentFirst, includePinned bool) (*[]PublicPost, error)
GetPostsTagged(cfg *config.Config, c *Collection, tag string, page int, includeFuture bool) (*[]PublicPost, error)
GetAPFollowers(c *Collection) (*[]RemoteUser, error)
GetAPActorKeys(collectionID int64) ([]byte, []byte)
CreateUserInvite(id string, userID int64, maxUses int, expires *time.Time) error
GetUserInvites(userID int64) (*[]Invite, error)
GetUserInvite(id string) (*Invite, error)
GetUsersInvitedCount(id string) int64
CreateInvitedUser(inviteID string, userID int64) error
GetDynamicContent(id string) (*instanceContent, error)
UpdateDynamicContent(id, title, content, contentType string) error
GetAllUsers(page uint) (*[]User, error)
GetAllUsersCount() int64
GetUserLastPostTime(id int64) (*time.Time, error)
GetCollectionLastPostTime(id int64) (*time.Time, error)
GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
RecordRemoteUserID(context.Context, int64, string, string, string, string) error
ValidateOAuthState(context.Context, string) (string, string, int64, string, error)
GenerateOAuthState(context.Context, string, string, int64, string) (string, error)
GetOauthAccounts(ctx context.Context, userID int64) ([]oauthAccountInfo, error)
RemoveOauth(ctx context.Context, userID int64, provider string, clientID string, remoteUserID string) error
DatabaseInitialized() bool
}
type datastore struct {
*sql.DB
driverName string
}
var _ writestore = &datastore{}
func (db *datastore) now() string {
if db.driverName == driverSQLite {
return "strftime('%Y-%m-%d %H:%M:%S','now')"
}
return "NOW()"
}
func (db *datastore) clip(field string, l int) string {
if db.driverName == driverSQLite {
return fmt.Sprintf("SUBSTR(%s, 0, %d)", field, l)
}
return fmt.Sprintf("LEFT(%s, %d)", field, l)
}
func (db *datastore) upsert(indexedCols ...string) string {
if db.driverName == driverSQLite {
// NOTE: SQLite UPSERT syntax only works in v3.24.0 (2018-06-04) or later
// Leaving this for whenever we can upgrade and include it in our binary
cc := strings.Join(indexedCols, ", ")
return "ON CONFLICT(" + cc + ") DO UPDATE SET"
}
return "ON DUPLICATE KEY UPDATE"
}
func (db *datastore) dateSub(l int, unit string) string {
if db.driverName == driverSQLite {
return fmt.Sprintf("DATETIME('now', '-%d %s')", l, unit)
}
return fmt.Sprintf("DATE_SUB(NOW(), INTERVAL %d %s)", l, unit)
}
// CreateUser creates a new user in the database from the given User, UPDATING it in the process with the user's ID.
func (db *datastore) CreateUser(cfg *config.Config, u *User, collectionTitle string) error {
if db.PostIDExists(u.Username) {
return impart.HTTPError{http.StatusConflict, "Invalid collection name."}
}
// New users get a `users` and `collections` row.
t, err := db.Begin()
if err != nil {
return err
}
// 1. Add to `users` table
// NOTE: Assumes User's Password is already hashed!
res, err := t.Exec("INSERT INTO users (username, password, email) VALUES (?, ?, ?)", u.Username, u.HashedPass, u.Email)
if err != nil {
t.Rollback()
if db.isDuplicateKeyErr(err) {
return impart.HTTPError{http.StatusConflict, "Username is already taken."}
}
log.Error("Rolling back users INSERT: %v\n", err)
return err
}
u.ID, err = res.LastInsertId()
if err != nil {
t.Rollback()
log.Error("Rolling back after LastInsertId: %v\n", err)
return err
}
// 2. Create user's Collection
if collectionTitle == "" {
collectionTitle = u.Username
}
res, err = t.Exec("INSERT INTO collections (alias, title, description, privacy, owner_id, view_count) VALUES (?, ?, ?, ?, ?, ?)", u.Username, collectionTitle, "", defaultVisibility(cfg), u.ID, 0)
if err != nil {
t.Rollback()
if db.isDuplicateKeyErr(err) {
return impart.HTTPError{http.StatusConflict, "Username is already taken."}
}
log.Error("Rolling back collections INSERT: %v\n", err)
return err
}
db.RemoveCollectionRedirect(t, u.Username)
err = t.Commit()
if err != nil {
t.Rollback()
log.Error("Rolling back after Commit(): %v\n", err)
return err
}
return nil
}
// FIXME: We're returning errors inconsistently in this file. Do we use Errorf
// for returned value, or impart?
func (db *datastore) UpdateUserEmail(keys *key.Keychain, userID int64, email string) error {
encEmail, err := data.Encrypt(keys.EmailKey, email)
if err != nil {
return fmt.Errorf("Couldn't encrypt email %s: %s\n", email, err)
}
return db.UpdateEncryptedUserEmail(userID, encEmail)
}
func (db *datastore) UpdateEncryptedUserEmail(userID int64, encEmail []byte) error {
_, err := db.Exec("UPDATE users SET email = ? WHERE id = ?", encEmail, userID)
if err != nil {
return fmt.Errorf("Unable to update user email: %s", err)
}
return nil
}
func (db *datastore) CreateCollectionFromToken(cfg *config.Config, alias, title, accessToken string) (*Collection, error) {
userID := db.GetUserID(accessToken)
if userID == -1 {
return nil, ErrBadAccessToken
}
return db.CreateCollection(cfg, alias, title, userID)
}
func (db *datastore) GetUserCollectionCount(userID int64) (uint64, error) {
var collCount uint64
err := db.QueryRow("SELECT COUNT(*) FROM collections WHERE owner_id = ?", userID).Scan(&collCount)
switch {
case err == sql.ErrNoRows:
return 0, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user from database."}
case err != nil:
log.Error("Couldn't get collections count for user %d: %v", userID, err)
return 0, err
}
return collCount, nil
}
func (db *datastore) CreateCollection(cfg *config.Config, alias, title string, userID int64) (*Collection, error) {
if db.PostIDExists(alias) {
return nil, impart.HTTPError{http.StatusConflict, "Invalid collection name."}
}
// All good, so create new collection
res, err := db.Exec("INSERT INTO collections (alias, title, description, privacy, owner_id, view_count) VALUES (?, ?, ?, ?, ?, ?)", alias, title, "", defaultVisibility(cfg), userID, 0)
if err != nil {
if db.isDuplicateKeyErr(err) {
return nil, impart.HTTPError{http.StatusConflict, "Collection already exists."}
}
log.Error("Couldn't add to collections: %v\n", err)
return nil, err
}
c := &Collection{
Alias: alias,
Title: title,
OwnerID: userID,
PublicOwner: false,
Public: defaultVisibility(cfg) == CollPublic,
}
c.ID, err = res.LastInsertId()
if err != nil {
log.Error("Couldn't get collection LastInsertId: %v\n", err)
}
return c, nil
}
func (db *datastore) GetUserByID(id int64) (*User, error) {
u := &User{ID: id}
err := db.QueryRow("SELECT username, password, email, created, status FROM users WHERE id = ?", id).Scan(&u.Username, &u.HashedPass, &u.Email, &u.Created, &u.Status)
switch {
case err == sql.ErrNoRows:
return nil, ErrUserNotFound
case err != nil:
log.Error("Couldn't SELECT user password: %v", err)
return nil, err
}
return u, nil
}
// IsUserSilenced returns true if the user account associated with id is
// currently silenced.
func (db *datastore) IsUserSilenced(id int64) (bool, error) {
u := &User{ID: id}
err := db.QueryRow("SELECT status FROM users WHERE id = ?", id).Scan(&u.Status)
switch {
case err == sql.ErrNoRows:
return false, fmt.Errorf("is user silenced: %v", ErrUserNotFound)
case err != nil:
log.Error("Couldn't SELECT user status: %v", err)
return false, fmt.Errorf("is user silenced: %v", err)
}
return u.IsSilenced(), nil
}
// DoesUserNeedAuth returns true if the user hasn't provided any methods for
// authenticating with the account, such a passphrase or email address.
// Any errors are reported to admin and silently quashed, returning false as the
// result.
func (db *datastore) DoesUserNeedAuth(id int64) bool {
var pass, email []byte
// Find out if user has an email set first
err := db.QueryRow("SELECT password, email FROM users WHERE id = ?", id).Scan(&pass, &email)
switch {
case err == sql.ErrNoRows:
// ERROR. Don't give false positives on needing auth methods
return false
case err != nil:
// ERROR. Don't give false positives on needing auth methods
log.Error("Couldn't SELECT user %d from users: %v", id, err)
return false
}
// User doesn't need auth if there's an email
return len(email) == 0 && len(pass) == 0
}
func (db *datastore) IsUserPassSet(id int64) (bool, error) {
var pass []byte
err := db.QueryRow("SELECT password FROM users WHERE id = ?", id).Scan(&pass)
switch {
case err == sql.ErrNoRows:
return false, nil
case err != nil:
log.Error("Couldn't SELECT user %d from users: %v", id, err)
return false, err
}
return len(pass) > 0, nil
}
func (db *datastore) GetUserForAuth(username string) (*User, error) {
u := &User{Username: username}
err := db.QueryRow("SELECT id, password, email, created, status FROM users WHERE username = ?", username).Scan(&u.ID, &u.HashedPass, &u.Email, &u.Created, &u.Status)
switch {
case err == sql.ErrNoRows:
// Check if they've entered the wrong, unnormalized username
username = getSlug(username, "")
if username != u.Username {
err = db.QueryRow("SELECT id FROM users WHERE username = ? LIMIT 1", username).Scan(&u.ID)
if err == nil {
return db.GetUserForAuth(username)
}
}
return nil, ErrUserNotFound
case err != nil:
log.Error("Couldn't SELECT user password: %v", err)
return nil, err
}
return u, nil
}
func (db *datastore) GetUserForAuthByID(userID int64) (*User, error) {
u := &User{ID: userID}
err := db.QueryRow("SELECT id, password, email, created, status FROM users WHERE id = ?", u.ID).Scan(&u.ID, &u.HashedPass, &u.Email, &u.Created, &u.Status)
switch {
case err == sql.ErrNoRows:
return nil, ErrUserNotFound
case err != nil:
log.Error("Couldn't SELECT userForAuthByID: %v", err)
return nil, err
}
return u, nil
}
func (db *datastore) GetUserNameFromToken(accessToken string) (string, error) {
t := auth.GetToken(accessToken)
if len(t) == 0 {
return "", ErrNoAccessToken
}
var oneTime bool
var username string
err := db.QueryRow("SELECT username, one_time FROM accesstokens LEFT JOIN users ON user_id = id WHERE token LIKE ? AND (expires IS NULL OR expires > "+db.now()+")", t).Scan(&username, &oneTime)
switch {
case err == sql.ErrNoRows:
return "", ErrBadAccessToken
case err != nil:
return "", ErrInternalGeneral
}
// Delete token if it was one-time
if oneTime {
db.DeleteToken(t[:])
}
return username, nil
}
func (db *datastore) GetUserDataFromToken(accessToken string) (int64, string, error) {
t := auth.GetToken(accessToken)
if len(t) == 0 {
return 0, "", ErrNoAccessToken
}
var userID int64
var oneTime bool
var username string
err := db.QueryRow("SELECT user_id, username, one_time FROM accesstokens LEFT JOIN users ON user_id = id WHERE token LIKE ? AND (expires IS NULL OR expires > "+db.now()+")", t).Scan(&userID, &username, &oneTime)
switch {
case err == sql.ErrNoRows:
return 0, "", ErrBadAccessToken
case err != nil:
return 0, "", ErrInternalGeneral
}
// Delete token if it was one-time
if oneTime {
db.DeleteToken(t[:])
}
return userID, username, nil
}
func (db *datastore) GetAPIUser(header string) (*User, error) {
uID := db.GetUserID(header)
if uID == -1 {
return nil, fmt.Errorf(ErrUserNotFound.Error())
}
return db.GetUserByID(uID)
}
// GetUserID takes a hexadecimal accessToken, parses it into its binary
// representation, and gets any user ID associated with the token. If no user
// is associated, -1 is returned.
func (db *datastore) GetUserID(accessToken string) int64 {
i, _ := db.GetUserIDPrivilege(accessToken)
return i
}
func (db *datastore) GetUserIDPrivilege(accessToken string) (userID int64, sudo bool) {
t := auth.GetToken(accessToken)
if len(t) == 0 {
return -1, false
}
var oneTime bool
err := db.QueryRow("SELECT user_id, sudo, one_time FROM accesstokens WHERE token LIKE ? AND (expires IS NULL OR expires > "+db.now()+")", t).Scan(&userID, &sudo, &oneTime)
switch {
case err == sql.ErrNoRows:
return -1, false
case err != nil:
return -1, false
}
// Delete token if it was one-time
if oneTime {
db.DeleteToken(t[:])
}
return
}
func (db *datastore) DeleteToken(accessToken []byte) error {
res, err := db.Exec("DELETE FROM accesstokens WHERE token LIKE ?", accessToken)
if err != nil {
return err
}
rowsAffected, _ := res.RowsAffected()
if rowsAffected == 0 {
return impart.HTTPError{http.StatusNotFound, "Token is invalid or doesn't exist"}
}
return nil
}
// FetchLastAccessToken creates a new non-expiring, valid access token for the given
// userID.
func (db *datastore) FetchLastAccessToken(userID int64) string {
var t []byte
err := db.QueryRow("SELECT token FROM accesstokens WHERE user_id = ? AND (expires IS NULL OR expires > "+db.now()+") ORDER BY created DESC LIMIT 1", userID).Scan(&t)
switch {
case err == sql.ErrNoRows:
return ""
case err != nil:
log.Error("Failed selecting from accesstoken: %v", err)
return ""
}
u, err := uuid.Parse(t)
if err != nil {
return ""
}
return u.String()
}
// GetAccessToken creates a new non-expiring, valid access token for the given
// userID.
func (db *datastore) GetAccessToken(userID int64) (string, error) {
return db.GetTemporaryOneTimeAccessToken(userID, 0, false)
}
// GetTemporaryAccessToken creates a new valid access token for the given
// userID that remains valid for the given time in seconds. If validSecs is 0,
// the access token doesn't automatically expire.
func (db *datastore) GetTemporaryAccessToken(userID int64, validSecs int) (string, error) {
return db.GetTemporaryOneTimeAccessToken(userID, validSecs, false)
}
// GetTemporaryOneTimeAccessToken creates a new valid access token for the given
// userID that remains valid for the given time in seconds and can only be used
// once if oneTime is true. If validSecs is 0, the access token doesn't
// automatically expire.
func (db *datastore) GetTemporaryOneTimeAccessToken(userID int64, validSecs int, oneTime bool) (string, error) {
u, err := uuid.NewV4()
if err != nil {
log.Error("Unable to generate token: %v", err)
return "", err
}
// Insert UUID to `accesstokens`
binTok := u[:]
expirationVal := "NULL"
if validSecs > 0 {
expirationVal = fmt.Sprintf("DATE_ADD("+db.now()+", INTERVAL %d SECOND)", validSecs)
}
_, err = db.Exec("INSERT INTO accesstokens (token, user_id, one_time, expires) VALUES (?, ?, ?, "+expirationVal+")", string(binTok), userID, oneTime)
if err != nil {
log.Error("Couldn't INSERT accesstoken: %v", err)
return "", err
}
return u.String(), nil
}
func (db *datastore) CreateOwnedPost(post *SubmittedPost, accessToken, collAlias, hostName string) (*PublicPost, error) {
var userID, collID int64 = -1, -1
var coll *Collection
var err error
if accessToken != "" {
userID = db.GetUserID(accessToken)
if userID == -1 {
return nil, ErrBadAccessToken
}
if collAlias != "" {
coll, err = db.GetCollection(collAlias)
if err != nil {
return nil, err
}
coll.hostName = hostName
if coll.OwnerID != userID {
return nil, ErrForbiddenCollection
}
collID = coll.ID
}
}
rp := &PublicPost{}
rp.Post, err = db.CreatePost(userID, collID, post)
if err != nil {
return rp, err
}
if coll != nil {
coll.ForPublic()
rp.Collection = &CollectionObj{Collection: *coll}
}
return rp, nil
}
func (db *datastore) CreatePost(userID, collID int64, post *SubmittedPost) (*Post, error) {
idLen := postIDLen
friendlyID := id.GenerateFriendlyRandomString(idLen)
// Handle appearance / font face
appearance := post.Font
if !post.isFontValid() {
appearance = "norm"
}
var err error
ownerID := sql.NullInt64{
Valid: false,
}
ownerCollID := sql.NullInt64{
Valid: false,
}
slug := sql.NullString{"", false}
// If an alias was supplied, we'll add this to the collection as well.
if userID > 0 {
ownerID.Int64 = userID
ownerID.Valid = true
if collID > 0 {
ownerCollID.Int64 = collID
ownerCollID.Valid = true
var slugVal string
if post.Slug != nil && *post.Slug != "" {
slugVal = *post.Slug
} else {
if post.Title != nil && *post.Title != "" {
slugVal = getSlug(*post.Title, post.Language.String)
if slugVal == "" {
slugVal = getSlug(*post.Content, post.Language.String)
}
} else {
slugVal = getSlug(*post.Content, post.Language.String)
}
}
if slugVal == "" {
slugVal = friendlyID
}
slug = sql.NullString{slugVal, true}
}
}
created := time.Now()
if db.driverName == driverSQLite {
// SQLite stores datetimes in UTC, so convert time.Now() to it here
created = created.UTC()
}
if post.Created != nil {
created, err = time.Parse("2006-01-02T15:04:05Z", *post.Created)
if err != nil {
log.Error("Unable to parse Created time '%s': %v", *post.Created, err)
created = time.Now()
if db.driverName == driverSQLite {
// SQLite stores datetimes in UTC, so convert time.Now() to it here
created = created.UTC()
}
}
}
stmt, err := db.Prepare("INSERT INTO posts (id, slug, title, content, text_appearance, language, rtl, privacy, owner_id, collection_id, created, updated, view_count) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, " + db.now() + ", ?)")
if err != nil {
return nil, err
}
defer stmt.Close()
_, err = stmt.Exec(friendlyID, slug, post.Title, post.Content, appearance, post.Language, post.IsRTL, 0, ownerID, ownerCollID, created, 0)
if err != nil {
if db.isDuplicateKeyErr(err) {
// Duplicate entry error; try a new slug
// TODO: make this a little more robust
slug = sql.NullString{id.GenSafeUniqueSlug(slug.String), true}
_, err = stmt.Exec(friendlyID, slug, post.Title, post.Content, appearance, post.Language, post.IsRTL, 0, ownerID, ownerCollID, created, 0)
if err != nil {
return nil, handleFailedPostInsert(fmt.Errorf("Retried slug generation, still failed: %v", err))
}
} else {
return nil, handleFailedPostInsert(err)
}
}
// TODO: return Created field in proper format
return &Post{
ID: friendlyID,
Slug: null.NewString(slug.String, slug.Valid),
Font: appearance,
Language: zero.NewString(post.Language.String, post.Language.Valid),
RTL: zero.NewBool(post.IsRTL.Bool, post.IsRTL.Valid),
OwnerID: null.NewInt(userID, true),
CollectionID: null.NewInt(userID, true),
Created: created.Truncate(time.Second).UTC(),
Updated: time.Now().Truncate(time.Second).UTC(),
Title: zero.NewString(*(post.Title), true),
Content: *(post.Content),
}, nil
}
// UpdateOwnedPost updates an existing post with only the given fields in the
// supplied AuthenticatedPost.
func (db *datastore) UpdateOwnedPost(post *AuthenticatedPost, userID int64) error {
params := []interface{}{}
var queryUpdates, sep, authCondition string
if post.Slug != nil && *post.Slug != "" {
queryUpdates += sep + "slug = ?"
sep = ", "
params = append(params, getSlug(*post.Slug, ""))
}
if post.Content != nil {
queryUpdates += sep + "content = ?"
sep = ", "
params = append(params, post.Content)
}
if post.Title != nil {
queryUpdates += sep + "title = ?"
sep = ", "
params = append(params, post.Title)
}
if post.Language.Valid {
queryUpdates += sep + "language = ?"
sep = ", "
params = append(params, post.Language.String)
}
if post.IsRTL.Valid {
queryUpdates += sep + "rtl = ?"
sep = ", "
params = append(params, post.IsRTL.Bool)
}
if post.Font != "" {
queryUpdates += sep + "text_appearance = ?"
sep = ", "
params = append(params, post.Font)
}
if post.Created != nil {
createTime, err := time.Parse(postMetaDateFormat, *post.Created)
if err != nil {
log.Error("Unable to parse Created date: %v", err)
return fmt.Errorf("That's the incorrect format for Created date.")
}
queryUpdates += sep + "created = ?"
sep = ", "
params = append(params, createTime)
}
// WHERE parameters...
// id = ?
params = append(params, post.ID)
// AND owner_id = ?
authCondition = "(owner_id = ?)"
params = append(params, userID)
if queryUpdates == "" {
return ErrPostNoUpdatableVals
}
queryUpdates += sep + "updated = " + db.now()
res, err := db.Exec("UPDATE posts SET "+queryUpdates+" WHERE id = ? AND "+authCondition, params...)
if err != nil {
log.Error("Unable to update owned post: %v", err)
return err
}
rowsAffected, _ := res.RowsAffected()
if rowsAffected == 0 {
// Show the correct error message if nothing was updated
var dummy int
err := db.QueryRow("SELECT 1 FROM posts WHERE id = ? AND "+authCondition, post.ID, params[len(params)-1]).Scan(&dummy)
switch {
case err == sql.ErrNoRows:
return ErrUnauthorizedEditPost
case err != nil:
log.Error("Failed selecting from posts: %v", err)
}
return nil
}
return nil
}
func (db *datastore) GetCollectionBy(condition string, value interface{}) (*Collection, error) {
c := &Collection{}
// FIXME: change Collection to reflect database values. Add helper functions to get actual values
var styleSheet, script, signature, format zero.String
row := db.QueryRow("SELECT id, alias, title, description, style_sheet, script, post_signature, format, owner_id, privacy, view_count FROM collections WHERE "+condition, value)
err := row.Scan(&c.ID, &c.Alias, &c.Title, &c.Description, &styleSheet, &script, &signature, &format, &c.OwnerID, &c.Visibility, &c.Views)
switch {
case err == sql.ErrNoRows:
return nil, impart.HTTPError{http.StatusNotFound, "Collection doesn't exist."}
case db.isHighLoadError(err):
return nil, ErrUnavailable
case err != nil:
log.Error("Failed selecting from collections: %v", err)
return nil, err
}
c.StyleSheet = styleSheet.String
c.Script = script.String
c.Signature = signature.String
c.Format = format.String
c.Public = c.IsPublic()
c.db = db
return c, nil
}
func (db *datastore) GetCollection(alias string) (*Collection, error) {
return db.GetCollectionBy("alias = ?", alias)
}
func (db *datastore) GetCollectionForPad(alias string) (*Collection, error) {
c := &Collection{Alias: alias}
row := db.QueryRow("SELECT id, alias, title, description, privacy FROM collections WHERE alias = ?", alias)
err := row.Scan(&c.ID, &c.Alias, &c.Title, &c.Description, &c.Visibility)
switch {
case err == sql.ErrNoRows:
return c, impart.HTTPError{http.StatusNotFound, "Collection doesn't exist."}
case err != nil:
log.Error("Failed selecting from collections: %v", err)
return c, ErrInternalGeneral
}
c.Public = c.IsPublic()
return c, nil
}
func (db *datastore) GetCollectionByID(id int64) (*Collection, error) {
return db.GetCollectionBy("id = ?", id)
}
func (db *datastore) GetCollectionFromDomain(host string) (*Collection, error) {
return db.GetCollectionBy("host = ?", host)
}
func (db *datastore) UpdateCollection(c *SubmittedCollection, alias string) error {
q := query.NewUpdate().
SetStringPtr(c.Title, "title").
SetStringPtr(c.Description, "description").
SetNullString(c.StyleSheet, "style_sheet").
SetNullString(c.Script, "script").
SetNullString(c.Signature, "post_signature")
if c.Format != nil {
cf := &CollectionFormat{Format: c.Format.String}
if cf.Valid() {
q.SetNullString(c.Format, "format")
}
}
var updatePass bool
if c.Visibility != nil && (collVisibility(*c.Visibility)&CollProtected == 0 || c.Pass != "") {
q.SetIntPtr(c.Visibility, "privacy")
if c.Pass != "" {
updatePass = true
}
}
// WHERE values
q.Where("alias = ? AND owner_id = ?", alias, c.OwnerID)
if q.Updates == "" {
return ErrPostNoUpdatableVals
}
// Find any current domain
var collID int64
var rowsAffected int64
var changed bool
var res sql.Result
err := db.QueryRow("SELECT id FROM collections WHERE alias = ?", alias).Scan(&collID)
if err != nil {
log.Error("Failed selecting from collections: %v. Some things won't work.", err)
}
// Update MathJax value
if c.MathJax {
if db.driverName == driverSQLite {
_, err = db.Exec("INSERT OR REPLACE INTO collectionattributes (collection_id, attribute, value) VALUES (?, ?, ?)", collID, "render_mathjax", "1")
} else {
_, err = db.Exec("INSERT INTO collectionattributes (collection_id, attribute, value) VALUES (?, ?, ?) "+db.upsert("collection_id", "attribute")+" value = ?", collID, "render_mathjax", "1", "1")
}
if err != nil {
log.Error("Unable to insert render_mathjax value: %v", err)
return err
}
} else {
_, err = db.Exec("DELETE FROM collectionattributes WHERE collection_id = ? AND attribute = ?", collID, "render_mathjax")
if err != nil {
log.Error("Unable to delete render_mathjax value: %v", err)
return err
}
}
// Update Monetization value
if c.Monetization != nil {
skipUpdate := false
if *c.Monetization != "" {
// Strip away any excess spaces
trimmed := strings.TrimSpace(*c.Monetization)
// Only update value when it starts with "$", per spec: https://paymentpointers.org
if strings.HasPrefix(trimmed, "$") {
c.Monetization = &trimmed
} else {
// Value appears invalid, so don't update
skipUpdate = true
}
}
if !skipUpdate {
_, err = db.Exec("INSERT INTO collectionattributes (collection_id, attribute, value) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE value = ?", collID, "monetization_pointer", *c.Monetization, *c.Monetization)
if err != nil {
log.Error("Unable to insert monetization_pointer value: %v", err)
return err
}
}
}
// Update rest of the collection data
res, err = db.Exec("UPDATE collections SET "+q.Updates+" WHERE "+q.Conditions, q.Params...)
if err != nil {
log.Error("Unable to update collection: %v", err)
return err
}
rowsAffected, _ = res.RowsAffected()
if !changed || rowsAffected == 0 {
// Show the correct error message if nothing was updated
var dummy int
err := db.QueryRow("SELECT 1 FROM collections WHERE alias = ? AND owner_id = ?", alias, c.OwnerID).Scan(&dummy)
switch {
case err == sql.ErrNoRows:
return ErrUnauthorizedEditPost
case err != nil:
log.Error("Failed selecting from collections: %v", err)
}
if !updatePass {
return nil
}
}
if updatePass {
hashedPass, err := auth.HashPass([]byte(c.Pass))
if err != nil {
log.Error("Unable to create hash: %s", err)
return impart.HTTPError{http.StatusInternalServerError, "Could not create password hash."}
}
if db.driverName == driverSQLite {
_, err = db.Exec("INSERT OR REPLACE INTO collectionpasswords (collection_id, password) VALUES ((SELECT id FROM collections WHERE alias = ?), ?)", alias, hashedPass)
} else {
_, err = db.Exec("INSERT INTO collectionpasswords (collection_id, password) VALUES ((SELECT id FROM collections WHERE alias = ?), ?) "+db.upsert("collection_id")+" password = ?", alias, hashedPass, hashedPass)
}
if err != nil {
return err
}
}
return nil
}
const postCols = "id, slug, text_appearance, language, rtl, privacy, owner_id, collection_id, pinned_position, created, updated, view_count, title, content"
// getEditablePost returns a PublicPost with the given ID only if the given
// edit token is valid for the post.
func (db *datastore) GetEditablePost(id, editToken string) (*PublicPost, error) {
// FIXME: code duplicated from getPost()
// TODO: add slight logic difference to getPost / one func
var ownerName sql.NullString
p := &Post{}
row := db.QueryRow("SELECT "+postCols+", (SELECT username FROM users WHERE users.id = posts.owner_id) AS username FROM posts WHERE id = ? LIMIT 1", id)
err := row.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content, &ownerName)
switch {
case err == sql.ErrNoRows:
return nil, ErrPostNotFound
case err != nil:
log.Error("Failed selecting from collections: %v", err)
return nil, err
}
if p.Content == "" && p.Title.String == "" {
return nil, ErrPostUnpublished
}
res := p.processPost()
if ownerName.Valid {
res.Owner = &PublicUser{Username: ownerName.String}
}
return &res, nil
}
func (db *datastore) PostIDExists(id string) bool {
var dummy bool
err := db.QueryRow("SELECT 1 FROM posts WHERE id = ?", id).Scan(&dummy)
return err == nil && dummy
}
// GetPost gets a public-facing post object from the database. If collectionID
// is > 0, the post will be retrieved by slug and collection ID, rather than
// post ID.
// TODO: break this into two functions:
// - GetPost(id string)
// - GetCollectionPost(slug string, collectionID int64)
func (db *datastore) GetPost(id string, collectionID int64) (*PublicPost, error) {
var ownerName sql.NullString
p := &Post{}
var row *sql.Row
var where string
params := []interface{}{id}
if collectionID > 0 {
where = "slug = ? AND collection_id = ?"
params = append(params, collectionID)
} else {
where = "id = ?"
}
row = db.QueryRow("SELECT "+postCols+", (SELECT username FROM users WHERE users.id = posts.owner_id) AS username FROM posts WHERE "+where+" LIMIT 1", params...)
err := row.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content, &ownerName)
switch {
case err == sql.ErrNoRows:
if collectionID > 0 {
return nil, ErrCollectionPageNotFound
}
return nil, ErrPostNotFound
case err != nil:
log.Error("Failed selecting from collections: %v", err)
return nil, err
}
if p.Content == "" && p.Title.String == "" {
return nil, ErrPostUnpublished
}
res := p.processPost()
if ownerName.Valid {
res.Owner = &PublicUser{Username: ownerName.String}
}
return &res, nil
}
// TODO: don't duplicate getPost() functionality
func (db *datastore) GetOwnedPost(id string, ownerID int64) (*PublicPost, error) {
p := &Post{}
var row *sql.Row
where := "id = ? AND owner_id = ?"
params := []interface{}{id, ownerID}
row = db.QueryRow("SELECT "+postCols+" FROM posts WHERE "+where+" LIMIT 1", params...)
err := row.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content)
switch {
case err == sql.ErrNoRows:
return nil, ErrPostNotFound
case err != nil:
log.Error("Failed selecting from collections: %v", err)
return nil, err
}
if p.Content == "" && p.Title.String == "" {
return nil, ErrPostUnpublished
}
res := p.processPost()
return &res, nil
}
func (db *datastore) GetPostProperty(id string, collectionID int64, property string) (interface{}, error) {
propSelects := map[string]string{
"views": "view_count AS views",
}
selectQuery, ok := propSelects[property]
if !ok {
return nil, impart.HTTPError{http.StatusBadRequest, fmt.Sprintf("Invalid property: %s.", property)}
}
var res interface{}
var row *sql.Row
if collectionID != 0 {
row = db.QueryRow("SELECT "+selectQuery+" FROM posts WHERE slug = ? AND collection_id = ? LIMIT 1", id, collectionID)
} else {
row = db.QueryRow("SELECT "+selectQuery+" FROM posts WHERE id = ? LIMIT 1", id)
}
err := row.Scan(&res)
switch {
case err == sql.ErrNoRows:
return nil, impart.HTTPError{http.StatusNotFound, "Post not found."}
case err != nil:
log.Error("Failed selecting post: %v", err)
return nil, err
}
return res, nil
}
// GetPostsCount modifies the CollectionObj to include the correct number of
// standard (non-pinned) posts. It will return future posts if `includeFuture`
// is true.
func (db *datastore) GetPostsCount(c *CollectionObj, includeFuture bool) {
var count int64
timeCondition := ""
if !includeFuture {
timeCondition = "AND created <= " + db.now()
}
err := db.QueryRow("SELECT COUNT(*) FROM posts WHERE collection_id = ? AND pinned_position IS NULL "+timeCondition, c.ID).Scan(&count)
switch {
case err == sql.ErrNoRows:
c.TotalPosts = 0
case err != nil:
log.Error("Failed selecting from collections: %v", err)
c.TotalPosts = 0
}
c.TotalPosts = int(count)
}
// GetPosts retrieves all posts for the given Collection.
// It will return future posts if `includeFuture` is true.
// It will include only standard (non-pinned) posts unless `includePinned` is true.
// TODO: change includeFuture to isOwner, since that's how it's used
func (db *datastore) GetPosts(cfg *config.Config, c *Collection, page int, includeFuture, forceRecentFirst, includePinned bool) (*[]PublicPost, error) {
collID := c.ID
cf := c.NewFormat()
order := "DESC"
if cf.Ascending() && !forceRecentFirst {
order = "ASC"
}
pagePosts := cf.PostsPerPage()
start := page*pagePosts - pagePosts
if page == 0 {
start = 0
pagePosts = 1000
}
limitStr := ""
if page > 0 {
limitStr = fmt.Sprintf(" LIMIT %d, %d", start, pagePosts)
}
timeCondition := ""
if !includeFuture {
timeCondition = "AND created <= " + db.now()
}
pinnedCondition := ""
if !includePinned {
pinnedCondition = "AND pinned_position IS NULL"
}
rows, err := db.Query("SELECT "+postCols+" FROM posts WHERE collection_id = ? "+pinnedCondition+" "+timeCondition+" ORDER BY created "+order+limitStr, collID)
if err != nil {
log.Error("Failed selecting from posts: %v", err)
return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve collection posts."}
}
defer rows.Close()
// TODO: extract this common row scanning logic for queries using `postCols`
posts := []PublicPost{}
for rows.Next() {
p := &Post{}
err = rows.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content)
if err != nil {
log.Error("Failed scanning row: %v", err)
break
}
p.extractData()
p.augmentContent(c)
p.formatContent(cfg, c, includeFuture)
posts = append(posts, p.processPost())
}
err = rows.Err()
if err != nil {
log.Error("Error after Next() on rows: %v", err)
}
return &posts, nil
}
// GetPostsTagged retrieves all posts on the given Collection that contain the
// given tag.
// It will return future posts if `includeFuture` is true.
// TODO: change includeFuture to isOwner, since that's how it's used
func (db *datastore) GetPostsTagged(cfg *config.Config, c *Collection, tag string, page int, includeFuture bool) (*[]PublicPost, error) {
collID := c.ID
cf := c.NewFormat()
order := "DESC"
if cf.Ascending() {
order = "ASC"
}
pagePosts := cf.PostsPerPage()
start := page*pagePosts - pagePosts
if page == 0 {
start = 0
pagePosts = 1000
}
limitStr := ""
if page > 0 {
limitStr = fmt.Sprintf(" LIMIT %d, %d", start, pagePosts)
}
timeCondition := ""
if !includeFuture {
timeCondition = "AND created <= " + db.now()
}
var rows *sql.Rows
var err error
if db.driverName == driverSQLite {
rows, err = db.Query("SELECT "+postCols+" FROM posts WHERE collection_id = ? AND LOWER(content) regexp ? "+timeCondition+" ORDER BY created "+order+limitStr, collID, `.*#`+strings.ToLower(tag)+`\b.*`)
} else {
rows, err = db.Query("SELECT "+postCols+" FROM posts WHERE collection_id = ? AND LOWER(content) RLIKE ? "+timeCondition+" ORDER BY created "+order+limitStr, collID, "#"+strings.ToLower(tag)+"[[:>:]]")
}
if err != nil {
log.Error("Failed selecting from posts: %v", err)
return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve collection posts."}
}
defer rows.Close()
// TODO: extract this common row scanning logic for queries using `postCols`
posts := []PublicPost{}
for rows.Next() {
p := &Post{}
err = rows.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content)
if err != nil {
log.Error("Failed scanning row: %v", err)
break
}
p.extractData()
p.augmentContent(c)
p.formatContent(cfg, c, includeFuture)
posts = append(posts, p.processPost())
}
err = rows.Err()
if err != nil {
log.Error("Error after Next() on rows: %v", err)
}
return &posts, nil
}
func (db *datastore) GetAPFollowers(c *Collection) (*[]RemoteUser, error) {
rows, err := db.Query("SELECT actor_id, inbox, shared_inbox FROM remotefollows f INNER JOIN remoteusers u ON f.remote_user_id = u.id WHERE collection_id = ?", c.ID)
if err != nil {
log.Error("Failed selecting from followers: %v", err)
return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve followers."}
}
defer rows.Close()
followers := []RemoteUser{}
for rows.Next() {
f := RemoteUser{}
err = rows.Scan(&f.ActorID, &f.Inbox, &f.SharedInbox)
followers = append(followers, f)
}
return &followers, nil
}
// CanCollect returns whether or not the given user can add the given post to a
// collection. This is true when a post is already owned by the user.
// NOTE: this is currently only used to potentially add owned posts to a
// collection. This has the SIDE EFFECT of also generating a slug for the post.
// FIXME: make this side effect more explicit (or extract it)
func (db *datastore) CanCollect(cpr *ClaimPostRequest, userID int64) bool {
var title, content string
var lang sql.NullString
err := db.QueryRow("SELECT title, content, language FROM posts WHERE id = ? AND owner_id = ?", cpr.ID, userID).Scan(&title, &content, &lang)
switch {
case err == sql.ErrNoRows:
return false
case err != nil:
log.Error("Failed on post CanCollect(%s, %d): %v", cpr.ID, userID, err)
return false
}
// Since we have the post content and the post is collectable, generate the
// post's slug now.
cpr.Slug = getSlugFromPost(title, content, lang.String)
return true
}
func (db *datastore) AttemptClaim(p *ClaimPostRequest, query string, params []interface{}, slugIdx int) (sql.Result, error) {
qRes, err := db.Exec(query, params...)
if err != nil {
if db.isDuplicateKeyErr(err) && slugIdx > -1 {
s := id.GenSafeUniqueSlug(p.Slug)
if s == p.Slug {
// Sanity check to prevent infinite recursion
return qRes, fmt.Errorf("GenSafeUniqueSlug generated nothing unique: %s", s)
}
p.Slug = s
params[slugIdx] = p.Slug
return db.AttemptClaim(p, query, params, slugIdx)
}
return qRes, fmt.Errorf("attemptClaim: %s", err)
}
return qRes, nil
}
func (db *datastore) DispersePosts(userID int64, postIDs []string) (*[]ClaimPostResult, error) {
postClaimReqs := map[string]bool{}
res := []ClaimPostResult{}
for i := range postIDs {
postID := postIDs[i]
r := ClaimPostResult{Code: 0, ErrorMessage: ""}
// Perform post validation
if postID == "" {
r.ErrorMessage = "Missing post ID. "
}
if _, ok := postClaimReqs[postID]; ok {
r.Code = 429
r.ErrorMessage = "You've already tried anonymizing this post."
r.ID = postID
res = append(res, r)
continue
}
postClaimReqs[postID] = true
var err error
// Get full post information to return
var fullPost *PublicPost
fullPost, err = db.GetPost(postID, 0)
if err != nil {
if err, ok := err.(impart.HTTPError); ok {
r.Code = err.Status
r.ErrorMessage = err.Message
r.ID = postID
res = append(res, r)
continue
} else {
log.Error("Error getting post in dispersePosts: %v", err)
}
}
if fullPost.OwnerID.Int64 != userID {
r.Code = http.StatusConflict
r.ErrorMessage = "Post is already owned by someone else."
r.ID = postID
res = append(res, r)
continue
}
var qRes sql.Result
var query string
var params []interface{}
// Do AND owner_id = ? for sanity.
// This should've been caught and returned with a good error message
// just above.
query = "UPDATE posts SET collection_id = NULL WHERE id = ? AND owner_id = ?"
params = []interface{}{postID, userID}
qRes, err = db.Exec(query, params...)
if err != nil {
r.Code = http.StatusInternalServerError
r.ErrorMessage = "A glitch happened on our end."
r.ID = postID
res = append(res, r)
log.Error("dispersePosts (post %s): %v", postID, err)
continue
}
// Post was successfully dispersed
r.Code = http.StatusOK
r.Post = fullPost
rowsAffected, _ := qRes.RowsAffected()
if rowsAffected == 0 {
// This was already claimed, but return 200
r.Code = http.StatusOK
}
res = append(res, r)
}
return &res, nil
}
func (db *datastore) ClaimPosts(cfg *config.Config, userID int64, collAlias string, posts *[]ClaimPostRequest) (*[]ClaimPostResult, error) {
postClaimReqs := map[string]bool{}
res := []ClaimPostResult{}
postCollAlias := collAlias
for i := range *posts {
p := (*posts)[i]
if &p == nil {
continue
}
r := ClaimPostResult{Code: 0, ErrorMessage: ""}
// Perform post validation
if p.ID == "" {
r.ErrorMessage = "Missing post ID `id`. "
}
if _, ok := postClaimReqs[p.ID]; ok {
r.Code = 429
r.ErrorMessage = "You've already tried claiming this post."
r.ID = p.ID
res = append(res, r)
continue
}
postClaimReqs[p.ID] = true
canCollect := db.CanCollect(&p, userID)
if !canCollect && p.Token == "" {
// TODO: ensure post isn't owned by anyone else when a valid modify
// token is given.
r.ErrorMessage += "Missing post Edit Token `token`."
}
if r.ErrorMessage != "" {
// Post validate failed
r.Code = http.StatusBadRequest
r.ID = p.ID
res = append(res, r)
continue
}
var err error
var qRes sql.Result
var query string
var params []interface{}
var slugIdx int = -1
var coll *Collection
if collAlias == "" {
// Posts are being claimed at /posts/claim, not
// /collections/{alias}/collect, so use given individual collection
// to associate post with.
postCollAlias = p.CollectionAlias
}
if postCollAlias != "" {
// Associate this post with a collection
if p.CreateCollection {
// This is a new collection
// TODO: consider removing this. This seriously complicates this
// method and adds another (unnecessary?) logic path.
coll, err = db.CreateCollection(cfg, postCollAlias, "", userID)
if err != nil {
if err, ok := err.(impart.HTTPError); ok {
r.Code = err.Status
r.ErrorMessage = err.Message
} else {
r.Code = http.StatusInternalServerError
r.ErrorMessage = "Unknown error occurred creating collection"
}
r.ID = p.ID
res = append(res, r)
continue
}
} else {
// Attempt to add to existing collection
coll, err = db.GetCollection(postCollAlias)
if err != nil {
if err, ok := err.(impart.HTTPError); ok {
if err.Status == http.StatusNotFound {
// Show obfuscated "forbidden" response, as if attempting to add to an
// unowned blog.
r.Code = ErrForbiddenCollection.Status
r.ErrorMessage = ErrForbiddenCollection.Message
} else {
r.Code = err.Status
r.ErrorMessage = err.Message
}
} else {
r.Code = http.StatusInternalServerError
r.ErrorMessage = "Unknown error occurred claiming post with collection"
}
r.ID = p.ID
res = append(res, r)
continue
}
if coll.OwnerID != userID {
r.Code = ErrForbiddenCollection.Status
r.ErrorMessage = ErrForbiddenCollection.Message
r.ID = p.ID
res = append(res, r)
continue
}
}
if p.Slug == "" {
p.Slug = p.ID
}
if canCollect {
// User already owns this post, so just add it to the given
// collection.
query = "UPDATE posts SET collection_id = ?, slug = ? WHERE id = ? AND owner_id = ?"
params = []interface{}{coll.ID, p.Slug, p.ID, userID}
slugIdx = 1
} else {
query = "UPDATE posts SET owner_id = ?, collection_id = ?, slug = ? WHERE id = ? AND modify_token = ? AND owner_id IS NULL"
params = []interface{}{userID, coll.ID, p.Slug, p.ID, p.Token}
slugIdx = 2
}
} else {
query = "UPDATE posts SET owner_id = ? WHERE id = ? AND modify_token = ? AND owner_id IS NULL"
params = []interface{}{userID, p.ID, p.Token}
}
qRes, err = db.AttemptClaim(&p, query, params, slugIdx)
if err != nil {
r.Code = http.StatusInternalServerError
r.ErrorMessage = "An unknown error occurred."
r.ID = p.ID
res = append(res, r)
log.Error("claimPosts (post %s): %v", p.ID, err)
continue
}
// Get full post information to return
var fullPost *PublicPost
if p.Token != "" {
fullPost, err = db.GetEditablePost(p.ID, p.Token)
} else {
fullPost, err = db.GetPost(p.ID, 0)
}
if err != nil {
if err, ok := err.(impart.HTTPError); ok {
r.Code = err.Status
r.ErrorMessage = err.Message
r.ID = p.ID
res = append(res, r)
continue
}
}
if fullPost.OwnerID.Int64 != userID {
r.Code = http.StatusConflict
r.ErrorMessage = "Post is already owned by someone else."
r.ID = p.ID
res = append(res, r)
continue
}
// Post was successfully claimed
r.Code = http.StatusOK
r.Post = fullPost
if coll != nil {
r.Post.Collection = &CollectionObj{Collection: *coll}
}
rowsAffected, _ := qRes.RowsAffected()
if rowsAffected == 0 {
// This was already claimed, but return 200
r.Code = http.StatusOK
}
res = append(res, r)
}
return &res, nil
}
func (db *datastore) UpdatePostPinState(pinned bool, postID string, collID, ownerID, pos int64) error {
if pos <= 0 || pos > 20 {
pos = db.GetLastPinnedPostPos(collID) + 1
if pos == -1 {
pos = 1
}
}
var err error
if pinned {
_, err = db.Exec("UPDATE posts SET pinned_position = ? WHERE id = ?", pos, postID)
} else {
_, err = db.Exec("UPDATE posts SET pinned_position = NULL WHERE id = ?", postID)
}
if err != nil {
log.Error("Unable to update pinned post: %v", err)
return err
}
return nil
}
func (db *datastore) GetLastPinnedPostPos(collID int64) int64 {
var lastPos sql.NullInt64
err := db.QueryRow("SELECT MAX(pinned_position) FROM posts WHERE collection_id = ? AND pinned_position IS NOT NULL", collID).Scan(&lastPos)
switch {
case err == sql.ErrNoRows:
return -1
case err != nil:
log.Error("Failed selecting from posts: %v", err)
return -1
}
if !lastPos.Valid {
return -1
}
return lastPos.Int64
}
func (db *datastore) GetPinnedPosts(coll *CollectionObj, includeFuture bool) (*[]PublicPost, error) {
// FIXME: sqlite-backed instances don't include ellipsis on truncated titles
timeCondition := ""
if !includeFuture {
timeCondition = "AND created <= " + db.now()
}
rows, err := db.Query("SELECT id, slug, title, "+db.clip("content", 80)+", pinned_position FROM posts WHERE collection_id = ? AND pinned_position IS NOT NULL "+timeCondition+" ORDER BY pinned_position ASC", coll.ID)
if err != nil {
log.Error("Failed selecting pinned posts: %v", err)
return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve pinned posts."}
}
defer rows.Close()
posts := []PublicPost{}
for rows.Next() {
p := &Post{}
err = rows.Scan(&p.ID, &p.Slug, &p.Title, &p.Content, &p.PinnedPosition)
if err != nil {
log.Error("Failed scanning row: %v", err)
break
}
p.extractData()
p.augmentContent(&coll.Collection)
pp := p.processPost()
pp.Collection = coll
posts = append(posts, pp)
}
return &posts, nil
}
func (db *datastore) GetCollections(u *User, hostName string) (*[]Collection, error) {
rows, err := db.Query("SELECT id, alias, title, description, privacy, view_count FROM collections WHERE owner_id = ? ORDER BY id ASC", u.ID)
if err != nil {
log.Error("Failed selecting from collections: %v", err)
return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user collections."}
}
defer rows.Close()
colls := []Collection{}
for rows.Next() {
c := Collection{}
err = rows.Scan(&c.ID, &c.Alias, &c.Title, &c.Description, &c.Visibility, &c.Views)
if err != nil {
log.Error("Failed scanning row: %v", err)
break
}
c.hostName = hostName
c.URL = c.CanonicalURL()
c.Public = c.IsPublic()
colls = append(colls, c)
}
err = rows.Err()
if err != nil {
log.Error("Error after Next() on rows: %v", err)
}
return &colls, nil
}
func (db *datastore) GetPublishableCollections(u *User, hostName string) (*[]Collection, error) {
c, err := db.GetCollections(u, hostName)
if err != nil {
return nil, err
}
if len(*c) == 0 {
return nil, impart.HTTPError{http.StatusInternalServerError, "You don't seem to have any blogs; they might've moved to another account. Try logging out and logging into your other account."}
}
return c, nil
}
func (db *datastore) GetPublicCollections(hostName string) (*[]Collection, error) {
rows, err := db.Query(`SELECT c.id, alias, title, description, privacy, view_count
FROM collections c
LEFT JOIN users u ON u.id = c.owner_id
WHERE c.privacy = 1 AND u.status = 0
ORDER BY id ASC`)
if err != nil {
log.Error("Failed selecting public collections: %v", err)
return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve public collections."}
}
defer rows.Close()
colls := []Collection{}
for rows.Next() {
c := Collection{}
err = rows.Scan(&c.ID, &c.Alias, &c.Title, &c.Description, &c.Visibility, &c.Views)
if err != nil {
log.Error("Failed scanning row: %v", err)
break
}
c.hostName = hostName
c.URL = c.CanonicalURL()
c.Public = c.IsPublic()
colls = append(colls, c)
}
err = rows.Err()
if err != nil {
log.Error("Error after Next() on rows: %v", err)
}
return &colls, nil
}
func (db *datastore) GetMeStats(u *User) userMeStats {
s := userMeStats{}
// User counts
colls, _ := db.GetUserCollectionCount(u.ID)
s.TotalCollections = colls
var articles, collPosts uint64
err := db.QueryRow("SELECT COUNT(*) FROM posts WHERE owner_id = ? AND collection_id IS NULL", u.ID).Scan(&articles)
if err != nil && err != sql.ErrNoRows {
log.Error("Couldn't get articles count for user %d: %v", u.ID, err)
}
s.TotalArticles = articles
err = db.QueryRow("SELECT COUNT(*) FROM posts WHERE owner_id = ? AND collection_id IS NOT NULL", u.ID).Scan(&collPosts)
if err != nil && err != sql.ErrNoRows {
log.Error("Couldn't get coll posts count for user %d: %v", u.ID, err)
}
s.CollectionPosts = collPosts
return s
}
func (db *datastore) GetTotalCollections() (collCount int64, err error) {
err = db.QueryRow(`
SELECT COUNT(*)
FROM collections c
LEFT JOIN users u ON u.id = c.owner_id
WHERE u.status = 0`).Scan(&collCount)
if err != nil {
log.Error("Unable to fetch collections count: %v", err)
}
return
}
func (db *datastore) GetTotalPosts() (postCount int64, err error) {
err = db.QueryRow(`
SELECT COUNT(*)
FROM posts p
LEFT JOIN users u ON u.id = p.owner_id
WHERE u.status = 0`).Scan(&postCount)
if err != nil {
log.Error("Unable to fetch posts count: %v", err)
}
return
}
func (db *datastore) GetTopPosts(u *User, alias string) (*[]PublicPost, error) {
params := []interface{}{u.ID}
where := ""
if alias != "" {
where = " AND alias = ?"
params = append(params, alias)
}
rows, err := db.Query("SELECT p.id, p.slug, p.view_count, p.title, c.alias, c.title, c.description, c.view_count FROM posts p LEFT JOIN collections c ON p.collection_id = c.id WHERE p.owner_id = ?"+where+" ORDER BY p.view_count DESC, created DESC LIMIT 25", params...)
if err != nil {
log.Error("Failed selecting from posts: %v", err)
return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user top posts."}
}
defer rows.Close()
posts := []PublicPost{}
var gotErr bool
for rows.Next() {
p := Post{}
c := Collection{}
var alias, title, description sql.NullString
var views sql.NullInt64
err = rows.Scan(&p.ID, &p.Slug, &p.ViewCount, &p.Title, &alias, &title, &description, &views)
if err != nil {
log.Error("Failed scanning User.getPosts() row: %v", err)
gotErr = true
break
}
p.extractData()
pubPost := p.processPost()
if alias.Valid && alias.String != "" {
c.Alias = alias.String
c.Title = title.String
c.Description = description.String
c.Views = views.Int64
pubPost.Collection = &CollectionObj{Collection: c}
}
posts = append(posts, pubPost)
}
err = rows.Err()
if err != nil {
log.Error("Error after Next() on rows: %v", err)
}
if gotErr && len(posts) == 0 {
// There were a lot of errors
return nil, impart.HTTPError{http.StatusInternalServerError, "Unable to get data."}
}
return &posts, nil
}
-func (db *datastore) GetAnonymousPosts(u *User) (*[]PublicPost, error) {
- rows, err := db.Query("SELECT id, view_count, title, created, updated, content FROM posts WHERE owner_id = ? AND collection_id IS NULL ORDER BY created DESC", u.ID)
+func (db *datastore) GetAnonymousPosts(u *User, page int) (*[]PublicPost, error) {
+ pagePosts := 10
+ start := page*pagePosts - pagePosts
+ if page == 0 {
+ start = 0
+ pagePosts = 1000
+ }
+
+ limitStr := ""
+ if page > 0 {
+ limitStr = fmt.Sprintf(" LIMIT %d, %d", start, pagePosts)
+ }
+ rows, err := db.Query("SELECT id, view_count, title, created, updated, content FROM posts WHERE owner_id = ? AND collection_id IS NULL ORDER BY created DESC"+limitStr, u.ID)
if err != nil {
log.Error("Failed selecting from posts: %v", err)
return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user anonymous posts."}
}
defer rows.Close()
posts := []PublicPost{}
for rows.Next() {
p := Post{}
err = rows.Scan(&p.ID, &p.ViewCount, &p.Title, &p.Created, &p.Updated, &p.Content)
if err != nil {
log.Error("Failed scanning row: %v", err)
break
}
p.extractData()
posts = append(posts, p.processPost())
}
err = rows.Err()
if err != nil {
log.Error("Error after Next() on rows: %v", err)
}
return &posts, nil
}
func (db *datastore) GetUserPosts(u *User) (*[]PublicPost, error) {
rows, err := db.Query("SELECT p.id, p.slug, p.view_count, p.title, p.created, p.updated, p.content, p.text_appearance, p.language, p.rtl, c.alias, c.title, c.description, c.view_count FROM posts p LEFT JOIN collections c ON collection_id = c.id WHERE p.owner_id = ? ORDER BY created ASC", u.ID)
if err != nil {
log.Error("Failed selecting from posts: %v", err)
return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user posts."}
}
defer rows.Close()
posts := []PublicPost{}
var gotErr bool
for rows.Next() {
p := Post{}
c := Collection{}
var alias, title, description sql.NullString
var views sql.NullInt64
err = rows.Scan(&p.ID, &p.Slug, &p.ViewCount, &p.Title, &p.Created, &p.Updated, &p.Content, &p.Font, &p.Language, &p.RTL, &alias, &title, &description, &views)
if err != nil {
log.Error("Failed scanning User.getPosts() row: %v", err)
gotErr = true
break
}
p.extractData()
pubPost := p.processPost()
if alias.Valid && alias.String != "" {
c.Alias = alias.String
c.Title = title.String
c.Description = description.String
c.Views = views.Int64
pubPost.Collection = &CollectionObj{Collection: c}
}
posts = append(posts, pubPost)
}
err = rows.Err()
if err != nil {
log.Error("Error after Next() on rows: %v", err)
}
if gotErr && len(posts) == 0 {
// There were a lot of errors
return nil, impart.HTTPError{http.StatusInternalServerError, "Unable to get data."}
}
return &posts, nil
}
func (db *datastore) GetUserPostsCount(userID int64) int64 {
var count int64
err := db.QueryRow("SELECT COUNT(*) FROM posts WHERE owner_id = ?", userID).Scan(&count)
switch {
case err == sql.ErrNoRows:
return 0
case err != nil:
log.Error("Failed selecting posts count for user %d: %v", userID, err)
return 0
}
return count
}
// ChangeSettings takes a User and applies the changes in the given
// userSettings, MODIFYING THE USER with successful changes.
func (db *datastore) ChangeSettings(app *App, u *User, s *userSettings) error {
var errPass error
q := query.NewUpdate()
// Update email if given
if s.Email != "" {
encEmail, err := data.Encrypt(app.keys.EmailKey, s.Email)
if err != nil {
log.Error("Couldn't encrypt email %s: %s\n", s.Email, err)
return impart.HTTPError{http.StatusInternalServerError, "Unable to encrypt email address."}
}
q.SetBytes(encEmail, "email")
// Update the email if something goes awry updating the password
defer func() {
if errPass != nil {
db.UpdateEncryptedUserEmail(u.ID, encEmail)
}
}()
u.Email = zero.StringFrom(s.Email)
}
// Update username if given
var newUsername string
if s.Username != "" {
var ie *impart.HTTPError
newUsername, ie = getValidUsername(app, s.Username, u.Username)
if ie != nil {
// Username is invalid
return *ie
}
if !author.IsValidUsername(app.cfg, newUsername) {
// Ensure the username is syntactically correct.
return impart.HTTPError{http.StatusPreconditionFailed, "Username isn't valid."}
}
t, err := db.Begin()
if err != nil {
log.Error("Couldn't start username change transaction: %v", err)
return err
}
_, err = t.Exec("UPDATE users SET username = ? WHERE id = ?", newUsername, u.ID)
if err != nil {
t.Rollback()
if db.isDuplicateKeyErr(err) {
return impart.HTTPError{http.StatusConflict, "Username is already taken."}
}
log.Error("Unable to update users table: %v", err)
return ErrInternalGeneral
}
_, err = t.Exec("UPDATE collections SET alias = ? WHERE alias = ? AND owner_id = ?", newUsername, u.Username, u.ID)
if err != nil {
t.Rollback()
if db.isDuplicateKeyErr(err) {
return impart.HTTPError{http.StatusConflict, "Username is already taken."}
}
log.Error("Unable to update collection: %v", err)
return ErrInternalGeneral
}
// Keep track of name changes for redirection
db.RemoveCollectionRedirect(t, newUsername)
_, err = t.Exec("UPDATE collectionredirects SET new_alias = ? WHERE new_alias = ?", newUsername, u.Username)
if err != nil {
log.Error("Unable to update collectionredirects: %v", err)
}
_, err = t.Exec("INSERT INTO collectionredirects (prev_alias, new_alias) VALUES (?, ?)", u.Username, newUsername)
if err != nil {
log.Error("Unable to add new collectionredirect: %v", err)
}
err = t.Commit()
if err != nil {
t.Rollback()
log.Error("Rolling back after Commit(): %v\n", err)
return err
}
u.Username = newUsername
}
// Update passphrase if given
if s.NewPass != "" {
// Check if user has already set a password
var err error
u.HasPass, err = db.IsUserPassSet(u.ID)
if err != nil {
errPass = impart.HTTPError{http.StatusInternalServerError, "Unable to retrieve user data."}
return errPass
}
if u.HasPass {
// Check if currently-set password is correct
hashedPass := u.HashedPass
if len(hashedPass) == 0 {
authUser, err := db.GetUserForAuthByID(u.ID)
if err != nil {
errPass = err
return errPass
}
hashedPass = authUser.HashedPass
}
if !auth.Authenticated(hashedPass, []byte(s.OldPass)) {
errPass = impart.HTTPError{http.StatusUnauthorized, "Incorrect password."}
return errPass
}
}
hashedPass, err := auth.HashPass([]byte(s.NewPass))
if err != nil {
errPass = impart.HTTPError{http.StatusInternalServerError, "Could not create password hash."}
return errPass
}
q.SetBytes(hashedPass, "password")
}
// WHERE values
q.Append(u.ID)
if q.Updates == "" {
if s.Username == "" {
return ErrPostNoUpdatableVals
}
// Nothing to update except username. That was successful, so return now.
return nil
}
res, err := db.Exec("UPDATE users SET "+q.Updates+" WHERE id = ?", q.Params...)
if err != nil {
log.Error("Unable to update collection: %v", err)
return err
}
rowsAffected, _ := res.RowsAffected()
if rowsAffected == 0 {
// Show the correct error message if nothing was updated
var dummy int
err := db.QueryRow("SELECT 1 FROM users WHERE id = ?", u.ID).Scan(&dummy)
switch {
case err == sql.ErrNoRows:
return ErrUnauthorizedGeneral
case err != nil:
log.Error("Failed selecting from users: %v", err)
}
return nil
}
if s.NewPass != "" && !u.HasPass {
u.HasPass = true
}
return nil
}
func (db *datastore) ChangePassphrase(userID int64, sudo bool, curPass string, hashedPass []byte) error {
var dbPass []byte
err := db.QueryRow("SELECT password FROM users WHERE id = ?", userID).Scan(&dbPass)
switch {
case err == sql.ErrNoRows:
return ErrUserNotFound
case err != nil:
log.Error("Couldn't SELECT user password for change: %v", err)
return err
}
if !sudo && !auth.Authenticated(dbPass, []byte(curPass)) {
return impart.HTTPError{http.StatusUnauthorized, "Incorrect password."}
}
_, err = db.Exec("UPDATE users SET password = ? WHERE id = ?", hashedPass, userID)
if err != nil {
log.Error("Could not update passphrase: %v", err)
return err
}
return nil
}
func (db *datastore) RemoveCollectionRedirect(t *sql.Tx, alias string) error {
_, err := t.Exec("DELETE FROM collectionredirects WHERE prev_alias = ?", alias)
if err != nil {
log.Error("Unable to delete from collectionredirects: %v", err)
return err
}
return nil
}
func (db *datastore) GetCollectionRedirect(alias string) (new string) {
row := db.QueryRow("SELECT new_alias FROM collectionredirects WHERE prev_alias = ?", alias)
err := row.Scan(&new)
if err != nil && err != sql.ErrNoRows && !db.isIgnorableError(err) {
log.Error("Failed selecting from collectionredirects: %v", err)
}
return
}
func (db *datastore) DeleteCollection(alias string, userID int64) error {
c := &Collection{Alias: alias}
var username string
row := db.QueryRow("SELECT username FROM users WHERE id = ?", userID)
err := row.Scan(&username)
if err != nil {
return err
}
// Ensure user isn't deleting their main blog
if alias == username {
return impart.HTTPError{http.StatusForbidden, "You cannot currently delete your primary blog."}
}
row = db.QueryRow("SELECT id FROM collections WHERE alias = ? AND owner_id = ?", alias, userID)
err = row.Scan(&c.ID)
switch {
case err == sql.ErrNoRows:
return impart.HTTPError{http.StatusNotFound, "Collection doesn't exist or you're not allowed to delete it."}
case err != nil:
log.Error("Failed selecting from collections: %v", err)
return ErrInternalGeneral
}
t, err := db.Begin()
if err != nil {
return err
}
// Float all collection's posts
_, err = t.Exec("UPDATE posts SET collection_id = NULL WHERE collection_id = ? AND owner_id = ?", c.ID, userID)
if err != nil {
t.Rollback()
return err
}
// Remove redirects to or from this collection
_, err = t.Exec("DELETE FROM collectionredirects WHERE prev_alias = ? OR new_alias = ?", alias, alias)
if err != nil {
t.Rollback()
return err
}
// Remove any optional collection password
_, err = t.Exec("DELETE FROM collectionpasswords WHERE collection_id = ?", c.ID)
if err != nil {
t.Rollback()
return err
}
// Finally, delete collection itself
_, err = t.Exec("DELETE FROM collections WHERE id = ?", c.ID)
if err != nil {
t.Rollback()
return err
}
err = t.Commit()
if err != nil {
t.Rollback()
return err
}
return nil
}
func (db *datastore) IsCollectionAttributeOn(id int64, attr string) bool {
var v string
err := db.QueryRow("SELECT value FROM collectionattributes WHERE collection_id = ? AND attribute = ?", id, attr).Scan(&v)
switch {
case err == sql.ErrNoRows:
return false
case err != nil:
log.Error("Couldn't SELECT value in isCollectionAttributeOn for attribute '%s': %v", attr, err)
return false
}
return v == "1"
}
func (db *datastore) CollectionHasAttribute(id int64, attr string) bool {
var dummy string
err := db.QueryRow("SELECT value FROM collectionattributes WHERE collection_id = ? AND attribute = ?", id, attr).Scan(&dummy)
switch {
case err == sql.ErrNoRows:
return false
case err != nil:
log.Error("Couldn't SELECT value in collectionHasAttribute for attribute '%s': %v", attr, err)
return false
}
return true
}
func (db *datastore) GetCollectionAttribute(id int64, attr string) string {
var v string
err := db.QueryRow("SELECT value FROM collectionattributes WHERE collection_id = ? AND attribute = ?", id, attr).Scan(&v)
switch {
case err == sql.ErrNoRows:
return ""
case err != nil:
log.Error("Couldn't SELECT value in getCollectionAttribute for attribute '%s': %v", attr, err)
return ""
}
return v
}
func (db *datastore) SetCollectionAttribute(id int64, attr, v string) error {
_, err := db.Exec("INSERT INTO collectionattributes (collection_id, attribute, value) VALUES (?, ?, ?)", id, attr, v)
if err != nil {
log.Error("Unable to INSERT into collectionattributes: %v", err)
return err
}
return nil
}
// DeleteAccount will delete the entire account for userID
func (db *datastore) DeleteAccount(userID int64) error {
// Get all collections
rows, err := db.Query("SELECT id, alias FROM collections WHERE owner_id = ?", userID)
if err != nil {
log.Error("Unable to get collections: %v", err)
return err
}
defer rows.Close()
colls := []Collection{}
var c Collection
for rows.Next() {
err = rows.Scan(&c.ID, &c.Alias)
if err != nil {
log.Error("Unable to scan collection cols: %v", err)
return err
}
colls = append(colls, c)
}
// Start transaction
t, err := db.Begin()
if err != nil {
log.Error("Unable to begin: %v", err)
return err
}
// Clean up all collection related information
var res sql.Result
for _, c := range colls {
// Delete tokens
res, err = t.Exec("DELETE FROM collectionattributes WHERE collection_id = ?", c.ID)
if err != nil {
t.Rollback()
log.Error("Unable to delete attributes on %s: %v", c.Alias, err)
return err
}
rs, _ := res.RowsAffected()
log.Info("Deleted %d for %s from collectionattributes", rs, c.Alias)
// Remove any optional collection password
res, err = t.Exec("DELETE FROM collectionpasswords WHERE collection_id = ?", c.ID)
if err != nil {
t.Rollback()
log.Error("Unable to delete passwords on %s: %v", c.Alias, err)
return err
}
rs, _ = res.RowsAffected()
log.Info("Deleted %d for %s from collectionpasswords", rs, c.Alias)
// Remove redirects to this collection
res, err = t.Exec("DELETE FROM collectionredirects WHERE new_alias = ?", c.Alias)
if err != nil {
t.Rollback()
log.Error("Unable to delete redirects on %s: %v", c.Alias, err)
return err
}
rs, _ = res.RowsAffected()
log.Info("Deleted %d for %s from collectionredirects", rs, c.Alias)
// Remove any collection keys
res, err = t.Exec("DELETE FROM collectionkeys WHERE collection_id = ?", c.ID)
if err != nil {
t.Rollback()
log.Error("Unable to delete keys on %s: %v", c.Alias, err)
return err
}
rs, _ = res.RowsAffected()
log.Info("Deleted %d for %s from collectionkeys", rs, c.Alias)
// TODO: federate delete collection
// Remove remote follows
res, err = t.Exec("DELETE FROM remotefollows WHERE collection_id = ?", c.ID)
if err != nil {
t.Rollback()
log.Error("Unable to delete remote follows on %s: %v", c.Alias, err)
return err
}
rs, _ = res.RowsAffected()
log.Info("Deleted %d for %s from remotefollows", rs, c.Alias)
}
// Delete collections
res, err = t.Exec("DELETE FROM collections WHERE owner_id = ?", userID)
if err != nil {
t.Rollback()
log.Error("Unable to delete collections: %v", err)
return err
}
rs, _ := res.RowsAffected()
log.Info("Deleted %d from collections", rs)
// Delete tokens
res, err = t.Exec("DELETE FROM accesstokens WHERE user_id = ?", userID)
if err != nil {
t.Rollback()
log.Error("Unable to delete access tokens: %v", err)
return err
}
rs, _ = res.RowsAffected()
log.Info("Deleted %d from accesstokens", rs)
// Delete user attributes
res, err = t.Exec("DELETE FROM oauth_users WHERE user_id = ?", userID)
if err != nil {
t.Rollback()
log.Error("Unable to delete oauth_users: %v", err)
return err
}
rs, _ = res.RowsAffected()
log.Info("Deleted %d from oauth_users", rs)
// Delete posts
// TODO: should maybe get each row so we can federate a delete
// if so needs to be outside of transaction like collections
res, err = t.Exec("DELETE FROM posts WHERE owner_id = ?", userID)
if err != nil {
t.Rollback()
log.Error("Unable to delete posts: %v", err)
return err
}
rs, _ = res.RowsAffected()
log.Info("Deleted %d from posts", rs)
// Delete user attributes
res, err = t.Exec("DELETE FROM userattributes WHERE user_id = ?", userID)
if err != nil {
t.Rollback()
log.Error("Unable to delete attributes: %v", err)
return err
}
rs, _ = res.RowsAffected()
log.Info("Deleted %d from userattributes", rs)
// Delete user invites
res, err = t.Exec("DELETE FROM userinvites WHERE owner_id = ?", userID)
if err != nil {
t.Rollback()
log.Error("Unable to delete invites: %v", err)
return err
}
rs, _ = res.RowsAffected()
log.Info("Deleted %d from userinvites", rs)
// Delete the user
res, err = t.Exec("DELETE FROM users WHERE id = ?", userID)
if err != nil {
t.Rollback()
log.Error("Unable to delete user: %v", err)
return err
}
rs, _ = res.RowsAffected()
log.Info("Deleted %d from users", rs)
// Commit all changes to the database
err = t.Commit()
if err != nil {
t.Rollback()
log.Error("Unable to commit: %v", err)
return err
}
// TODO: federate delete actor
return nil
}
func (db *datastore) GetAPActorKeys(collectionID int64) ([]byte, []byte) {
var pub, priv []byte
err := db.QueryRow("SELECT public_key, private_key FROM collectionkeys WHERE collection_id = ?", collectionID).Scan(&pub, &priv)
switch {
case err == sql.ErrNoRows:
// Generate keys
pub, priv = activitypub.GenerateKeys()
_, err = db.Exec("INSERT INTO collectionkeys (collection_id, public_key, private_key) VALUES (?, ?, ?)", collectionID, pub, priv)
if err != nil {
log.Error("Unable to INSERT new activitypub keypair: %v", err)
return nil, nil
}
case err != nil:
log.Error("Couldn't SELECT collectionkeys: %v", err)
return nil, nil
}
return pub, priv
}
func (db *datastore) CreateUserInvite(id string, userID int64, maxUses int, expires *time.Time) error {
_, err := db.Exec("INSERT INTO userinvites (id, owner_id, max_uses, created, expires, inactive) VALUES (?, ?, ?, "+db.now()+", ?, 0)", id, userID, maxUses, expires)
return err
}
func (db *datastore) GetUserInvites(userID int64) (*[]Invite, error) {
rows, err := db.Query("SELECT id, max_uses, created, expires, inactive FROM userinvites WHERE owner_id = ? ORDER BY created DESC", userID)
if err != nil {
log.Error("Failed selecting from userinvites: %v", err)
return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user invites."}
}
defer rows.Close()
is := []Invite{}
for rows.Next() {
i := Invite{}
err = rows.Scan(&i.ID, &i.MaxUses, &i.Created, &i.Expires, &i.Inactive)
is = append(is, i)
}
return &is, nil
}
func (db *datastore) GetUserInvite(id string) (*Invite, error) {
var i Invite
err := db.QueryRow("SELECT id, max_uses, created, expires, inactive FROM userinvites WHERE id = ?", id).Scan(&i.ID, &i.MaxUses, &i.Created, &i.Expires, &i.Inactive)
switch {
case err == sql.ErrNoRows, db.isIgnorableError(err):
return nil, impart.HTTPError{http.StatusNotFound, "Invite doesn't exist."}
case err != nil:
log.Error("Failed selecting invite: %v", err)
return nil, err
}
return &i, nil
}
// IsUsersInvite returns true if the user with ID created the invite with code
// and an error other than sql no rows, if any. Will return false in the event
// of an error.
func (db *datastore) IsUsersInvite(code string, userID int64) (bool, error) {
var id string
err := db.QueryRow("SELECT id FROM userinvites WHERE id = ? AND owner_id = ?", code, userID).Scan(&id)
if err != nil && err != sql.ErrNoRows {
log.Error("Failed selecting invite: %v", err)
return false, err
}
return id != "", nil
}
func (db *datastore) GetUsersInvitedCount(id string) int64 {
var count int64
err := db.QueryRow("SELECT COUNT(*) FROM usersinvited WHERE invite_id = ?", id).Scan(&count)
switch {
case err == sql.ErrNoRows:
return 0
case err != nil:
log.Error("Failed selecting users invited count: %v", err)
return 0
}
return count
}
func (db *datastore) CreateInvitedUser(inviteID string, userID int64) error {
_, err := db.Exec("INSERT INTO usersinvited (invite_id, user_id) VALUES (?, ?)", inviteID, userID)
return err
}
func (db *datastore) GetInstancePages() ([]*instanceContent, error) {
return db.GetAllDynamicContent("page")
}
func (db *datastore) GetAllDynamicContent(t string) ([]*instanceContent, error) {
where := ""
params := []interface{}{}
if t != "" {
where = " WHERE content_type = ?"
params = append(params, t)
}
rows, err := db.Query("SELECT id, title, content, updated, content_type FROM appcontent"+where, params...)
if err != nil {
log.Error("Failed selecting from appcontent: %v", err)
return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve instance pages."}
}
defer rows.Close()
pages := []*instanceContent{}
for rows.Next() {
c := &instanceContent{}
err = rows.Scan(&c.ID, &c.Title, &c.Content, &c.Updated, &c.Type)
if err != nil {
log.Error("Failed scanning row: %v", err)
break
}
pages = append(pages, c)
}
err = rows.Err()
if err != nil {
log.Error("Error after Next() on rows: %v", err)
}
return pages, nil
}
func (db *datastore) GetDynamicContent(id string) (*instanceContent, error) {
c := &instanceContent{
ID: id,
}
err := db.QueryRow("SELECT title, content, updated, content_type FROM appcontent WHERE id = ?", id).Scan(&c.Title, &c.Content, &c.Updated, &c.Type)
switch {
case err == sql.ErrNoRows:
return nil, nil
case err != nil:
log.Error("Couldn't SELECT FROM appcontent for id '%s': %v", id, err)
return nil, err
}
return c, nil
}
func (db *datastore) UpdateDynamicContent(id, title, content, contentType string) error {
var err error
if db.driverName == driverSQLite {
_, err = db.Exec("INSERT OR REPLACE INTO appcontent (id, title, content, updated, content_type) VALUES (?, ?, ?, "+db.now()+", ?)", id, title, content, contentType)
} else {
_, err = db.Exec("INSERT INTO appcontent (id, title, content, updated, content_type) VALUES (?, ?, ?, "+db.now()+", ?) "+db.upsert("id")+" title = ?, content = ?, updated = "+db.now(), id, title, content, contentType, title, content)
}
if err != nil {
log.Error("Unable to INSERT appcontent for '%s': %v", id, err)
}
return err
}
func (db *datastore) GetAllUsers(page uint) (*[]User, error) {
limitStr := fmt.Sprintf("0, %d", adminUsersPerPage)
if page > 1 {
limitStr = fmt.Sprintf("%d, %d", (page-1)*adminUsersPerPage, adminUsersPerPage)
}
rows, err := db.Query("SELECT id, username, created, status FROM users ORDER BY created DESC LIMIT " + limitStr)
if err != nil {
log.Error("Failed selecting from users: %v", err)
return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve all users."}
}
defer rows.Close()
users := []User{}
for rows.Next() {
u := User{}
err = rows.Scan(&u.ID, &u.Username, &u.Created, &u.Status)
if err != nil {
log.Error("Failed scanning GetAllUsers() row: %v", err)
break
}
users = append(users, u)
}
return &users, nil
}
func (db *datastore) GetAllUsersCount() int64 {
var count int64
err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
switch {
case err == sql.ErrNoRows:
return 0
case err != nil:
log.Error("Failed selecting all users count: %v", err)
return 0
}
return count
}
func (db *datastore) GetUserLastPostTime(id int64) (*time.Time, error) {
var t time.Time
err := db.QueryRow("SELECT created FROM posts WHERE owner_id = ? ORDER BY created DESC LIMIT 1", id).Scan(&t)
switch {
case err == sql.ErrNoRows:
return nil, nil
case err != nil:
log.Error("Failed selecting last post time from posts: %v", err)
return nil, err
}
return &t, nil
}
// SetUserStatus changes a user's status in the database. see Users.UserStatus
func (db *datastore) SetUserStatus(id int64, status UserStatus) error {
_, err := db.Exec("UPDATE users SET status = ? WHERE id = ?", status, id)
if err != nil {
return fmt.Errorf("failed to update user status: %v", err)
}
return nil
}
func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) {
var t time.Time
err := db.QueryRow("SELECT created FROM posts WHERE collection_id = ? ORDER BY created DESC LIMIT 1", id).Scan(&t)
switch {
case err == sql.ErrNoRows:
return nil, nil
case err != nil:
log.Error("Failed selecting last post time from posts: %v", err)
return nil, err
}
return &t, nil
}
func (db *datastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUser int64, inviteCode string) (string, error) {
state := id.Generate62RandomString(24)
attachUserVal := sql.NullInt64{Valid: attachUser > 0, Int64: attachUser}
inviteCodeVal := sql.NullString{Valid: inviteCode != "", String: inviteCode}
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at, attach_user_id, invite_code) VALUES (?, ?, ?, FALSE, "+db.now()+", ?, ?)", state, provider, clientID, attachUserVal, inviteCodeVal)
if err != nil {
return "", fmt.Errorf("unable to record oauth client state: %w", err)
}
return state, nil
}
func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, string, error) {
var provider string
var clientID string
var attachUserID sql.NullInt64
var inviteCode sql.NullString
err := wf_db.RunTransactionWithOptions(ctx, db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
err := tx.
QueryRowContext(ctx, "SELECT provider, client_id, attach_user_id, invite_code FROM oauth_client_states WHERE state = ? AND used = FALSE", state).
Scan(&provider, &clientID, &attachUserID, &inviteCode)
if err != nil {
return err
}
res, err := tx.ExecContext(ctx, "UPDATE oauth_client_states SET used = TRUE WHERE state = ?", state)
if err != nil {
return err
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return err
}
if rowsAffected != 1 {
return fmt.Errorf("state not found")
}
return nil
})
if err != nil {
return "", "", 0, "", nil
}
return provider, clientID, attachUserID.Int64, inviteCode.String, nil
}
func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error {
var err error
if db.driverName == driverSQLite {
_, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO oauth_users (user_id, remote_user_id, provider, client_id, access_token) VALUES (?, ?, ?, ?, ?)", localUserID, remoteUserID, provider, clientID, accessToken)
} else {
_, err = db.ExecContext(ctx, "INSERT INTO oauth_users (user_id, remote_user_id, provider, client_id, access_token) VALUES (?, ?, ?, ?, ?) "+db.upsert("user")+" access_token = ?", localUserID, remoteUserID, provider, clientID, accessToken, accessToken)
}
if err != nil {
log.Error("Unable to INSERT oauth_users for '%d': %v", localUserID, err)
}
return err
}
// GetIDForRemoteUser returns a user ID associated with a remote user ID.
func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) {
var userID int64 = -1
err := db.
QueryRowContext(ctx, "SELECT user_id FROM oauth_users WHERE remote_user_id = ? AND provider = ? AND client_id = ?", remoteUserID, provider, clientID).
Scan(&userID)
// Not finding a record is OK.
if err != nil && err != sql.ErrNoRows {
return -1, err
}
return userID, nil
}
type oauthAccountInfo struct {
Provider string
ClientID string
RemoteUserID string
DisplayName string
AllowDisconnect bool
}
func (db *datastore) GetOauthAccounts(ctx context.Context, userID int64) ([]oauthAccountInfo, error) {
rows, err := db.QueryContext(ctx, "SELECT provider, client_id, remote_user_id FROM oauth_users WHERE user_id = ? ", userID)
if err != nil {
log.Error("Failed selecting from oauth_users: %v", err)
return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user oauth accounts."}
}
defer rows.Close()
var records []oauthAccountInfo
for rows.Next() {
info := oauthAccountInfo{}
err = rows.Scan(&info.Provider, &info.ClientID, &info.RemoteUserID)
if err != nil {
log.Error("Failed scanning GetAllUsers() row: %v", err)
break
}
records = append(records, info)
}
return records, nil
}
// DatabaseInitialized returns whether or not the current datastore has been
// initialized with the correct schema.
// Currently, it checks to see if the `users` table exists.
func (db *datastore) DatabaseInitialized() bool {
var dummy string
var err error
if db.driverName == driverSQLite {
err = db.QueryRow("SELECT name FROM sqlite_master WHERE type = 'table' AND name = 'users'").Scan(&dummy)
} else {
err = db.QueryRow("SHOW TABLES LIKE 'users'").Scan(&dummy)
}
switch {
case err == sql.ErrNoRows:
return false
case err != nil:
log.Error("Couldn't SHOW TABLES: %v", err)
return false
}
return true
}
func (db *datastore) RemoveOauth(ctx context.Context, userID int64, provider string, clientID string, remoteUserID string) error {
_, err := db.ExecContext(ctx, `DELETE FROM oauth_users WHERE user_id = ? AND provider = ? AND client_id = ? AND remote_user_id = ?`, userID, provider, clientID, remoteUserID)
return err
}
func stringLogln(log *string, s string, v ...interface{}) {
*log += fmt.Sprintf(s+"\n", v...)
}
func handleFailedPostInsert(err error) error {
log.Error("Couldn't insert into posts: %v", err)
return err
}
func (db *datastore) GetProfilePageFromHandle(app *App, handle string) (string, error) {
handle = strings.TrimLeft(handle, "@")
actorIRI := ""
parts := strings.Split(handle, "@")
if len(parts) != 2 {
return "", fmt.Errorf("invalid handle format")
}
domain := parts[1]
// Check non-AP instances
if siloProfileURL := silobridge.Profile(parts[0], domain); siloProfileURL != "" {
return siloProfileURL, nil
}
remoteUser, err := getRemoteUserFromHandle(app, handle)
if err != nil {
// can't find using handle in the table but the table may already have this user without
// handle from a previous version
// TODO: Make this determination. We should know whether a user exists without a handle, or doesn't exist at all
actorIRI = RemoteLookup(handle)
_, errRemoteUser := getRemoteUser(app, actorIRI)
// if it exists then we need to update the handle
if errRemoteUser == nil {
_, err := app.db.Exec("UPDATE remoteusers SET handle = ? WHERE actor_id = ?", handle, actorIRI)
if err != nil {
log.Error("Couldn't update handle '%s' for user %s", handle, actorIRI)
}
} else {
// this probably means we don't have the user in the table so let's try to insert it
// here we need to ask the server for the inboxes
remoteActor, err := activityserve.NewRemoteActor(actorIRI)
if err != nil {
log.Error("Couldn't fetch remote actor: %v", err)
}
if debugging {
log.Info("%s %s %s %s", actorIRI, remoteActor.GetInbox(), remoteActor.GetSharedInbox(), handle)
}
_, err = app.db.Exec("INSERT INTO remoteusers (actor_id, inbox, shared_inbox, handle) VALUES(?, ?, ?, ?)", actorIRI, remoteActor.GetInbox(), remoteActor.GetSharedInbox(), handle)
if err != nil {
log.Error("Couldn't insert remote user: %v", err)
return "", err
}
}
} else {
actorIRI = remoteUser.ActorID
}
return actorIRI, nil
}
diff --git a/export.go b/export.go
index 592bc0c..bdfd7c4 100644
--- a/export.go
+++ b/export.go
@@ -1,132 +1,132 @@
/*
* Copyright © 2018-2019 A Bunch Tell LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package writefreely
import (
"archive/zip"
"bytes"
"encoding/csv"
"strings"
"time"
"github.com/writeas/web-core/log"
)
func exportPostsCSV(hostName string, u *User, posts *[]PublicPost) []byte {
var b bytes.Buffer
r := [][]string{
{"id", "slug", "blog", "url", "created", "title", "body"},
}
for _, p := range *posts {
var blog string
if p.Collection != nil {
blog = p.Collection.Alias
p.Collection.hostName = hostName
}
f := []string{p.ID, p.Slug.String, blog, p.CanonicalURL(hostName), p.Created8601(), p.Title.String, strings.Replace(p.Content, "\n", "\\n", -1)}
r = append(r, f)
}
w := csv.NewWriter(&b)
w.WriteAll(r) // calls Flush internally
if err := w.Error(); err != nil {
log.Info("error writing csv: %v", err)
}
return b.Bytes()
}
type exportedTxt struct {
Name, Title, Body string
Mod time.Time
}
func exportPostsZip(u *User, posts *[]PublicPost) []byte {
// Create a buffer to write our archive to.
b := new(bytes.Buffer)
// Create a new zip archive.
w := zip.NewWriter(b)
// Add some files to the archive.
var filename string
files := []exportedTxt{}
for _, p := range *posts {
filename = ""
if p.Collection != nil {
filename += p.Collection.Alias + "/"
}
if p.Slug.String != "" {
filename += p.Slug.String + "_"
}
filename += p.ID + ".txt"
files = append(files, exportedTxt{filename, p.Title.String, p.Content, p.Created})
}
for _, file := range files {
head := &zip.FileHeader{Name: file.Name}
head.SetModTime(file.Mod)
f, err := w.CreateHeader(head)
if err != nil {
log.Error("export zip header: %v", err)
}
var fullPost string
if file.Title != "" {
fullPost = "# " + file.Title + "\n\n"
}
fullPost += file.Body
_, err = f.Write([]byte(fullPost))
if err != nil {
log.Error("export zip write: %v", err)
}
}
// Make sure to check the error on Close.
err := w.Close()
if err != nil {
log.Error("export zip close: %v", err)
}
return b.Bytes()
}
func compileFullExport(app *App, u *User) *ExportUser {
exportUser := &ExportUser{
User: u,
}
colls, err := app.db.GetCollections(u, app.cfg.App.Host)
if err != nil {
log.Error("unable to fetch collections: %v", err)
}
- posts, err := app.db.GetAnonymousPosts(u)
+ posts, err := app.db.GetAnonymousPosts(u, 0)
if err != nil {
log.Error("unable to fetch anon posts: %v", err)
}
exportUser.AnonymousPosts = *posts
var collObjs []CollectionObj
for _, c := range *colls {
co := &CollectionObj{Collection: c}
co.Posts, err = app.db.GetPosts(app.cfg, &c, 0, true, false, true)
if err != nil {
log.Error("unable to get collection posts: %v", err)
}
app.db.GetPostsCount(co, true)
collObjs = append(collObjs, *co)
}
exportUser.Collections = &collObjs
return exportUser
}
diff --git a/handle.go b/handle.go
index 01d5728..4c454ec 100644
--- a/handle.go
+++ b/handle.go
@@ -1,935 +1,955 @@
/*
* Copyright © 2018-2021 A Bunch Tell LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package writefreely
import (
"fmt"
"html/template"
"net/http"
"net/url"
"runtime/debug"
"strconv"
"strings"
"time"
"github.com/gorilla/sessions"
"github.com/prologic/go-gopher"
"github.com/writeas/impart"
"github.com/writeas/web-core/log"
"github.com/writefreely/writefreely/config"
"github.com/writefreely/writefreely/page"
)
// UserLevel represents the required user level for accessing an endpoint
type UserLevel int
const (
UserLevelNoneType UserLevel = iota // user or not -- ignored
UserLevelOptionalType // user or not -- object fetched if user
UserLevelNoneRequiredType // non-user (required)
UserLevelUserType // user (required)
)
func UserLevelNone(cfg *config.Config) UserLevel {
return UserLevelNoneType
}
func UserLevelOptional(cfg *config.Config) UserLevel {
return UserLevelOptionalType
}
func UserLevelNoneRequired(cfg *config.Config) UserLevel {
return UserLevelNoneRequiredType
}
func UserLevelUser(cfg *config.Config) UserLevel {
return UserLevelUserType
}
// UserLevelReader returns the permission level required for any route where
// users can read published content.
func UserLevelReader(cfg *config.Config) UserLevel {
if cfg.App.Private {
return UserLevelUserType
}
return UserLevelOptionalType
}
type (
handlerFunc func(app *App, w http.ResponseWriter, r *http.Request) error
gopherFunc func(app *App, w gopher.ResponseWriter, r *gopher.Request) error
userHandlerFunc func(app *App, u *User, w http.ResponseWriter, r *http.Request) error
userApperHandlerFunc func(apper Apper, u *User, w http.ResponseWriter, r *http.Request) error
dataHandlerFunc func(app *App, w http.ResponseWriter, r *http.Request) ([]byte, string, error)
authFunc func(app *App, r *http.Request) (*User, error)
UserLevelFunc func(cfg *config.Config) UserLevel
)
type Handler struct {
errors *ErrorPages
sessionStore sessions.Store
app Apper
}
// ErrorPages hold template HTML error pages for displaying errors to the user.
// In each, there should be a defined template named "base".
type ErrorPages struct {
NotFound *template.Template
Gone *template.Template
InternalServerError *template.Template
UnavailableError *template.Template
Blank *template.Template
}
// NewHandler returns a new Handler instance, using the given StaticPage data,
// and saving alias to the application's CookieStore.
func NewHandler(apper Apper) *Handler {
h := &Handler{
errors: &ErrorPages{
NotFound: template.Must(template.New("").Parse("{{define \"base\"}}<html><head><title>404</title></head><body><p>Not found.</p></body></html>{{end}}")),
Gone: template.Must(template.New("").Parse("{{define \"base\"}}<html><head><title>410</title></head><body><p>Gone.</p></body></html>{{end}}")),
InternalServerError: template.Must(template.New("").Parse("{{define \"base\"}}<html><head><title>500</title></head><body><p>Internal server error.</p></body></html>{{end}}")),
UnavailableError: template.Must(template.New("").Parse("{{define \"base\"}}<html><head><title>503</title></head><body><p>Service is temporarily unavailable.</p></body></html>{{end}}")),
Blank: template.Must(template.New("").Parse("{{define \"base\"}}<html><head><title>{{.Title}}</title></head><body><p>{{.Content}}</p></body></html>{{end}}")),
},
sessionStore: apper.App().SessionStore(),
app: apper,
}
return h
}
// NewWFHandler returns a new Handler instance, using WriteFreely template files.
// You MUST call writefreely.InitTemplates() before this.
func NewWFHandler(apper Apper) *Handler {
h := NewHandler(apper)
h.SetErrorPages(&ErrorPages{
NotFound: pages["404-general.tmpl"],
Gone: pages["410.tmpl"],
InternalServerError: pages["500.tmpl"],
UnavailableError: pages["503.tmpl"],
Blank: pages["blank.tmpl"],
})
return h
}
// SetErrorPages sets the given set of ErrorPages as templates for any errors
// that come up.
func (h *Handler) SetErrorPages(e *ErrorPages) {
h.errors = e
}
// User handles requests made in the web application by the authenticated user.
// This provides user-friendly HTML pages and actions that work in the browser.
func (h *Handler) User(f userHandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
h.handleHTTPError(w, r, func() error {
var status int
start := time.Now()
defer func() {
if e := recover(); e != nil {
log.Error("%s: %s", e, debug.Stack())
h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r))
status = http.StatusInternalServerError
}
log.Info(h.app.ReqLog(r, status, time.Since(start)))
}()
u := getUserSession(h.app.App(), r)
if u == nil {
err := ErrNotLoggedIn
status = err.Status
return err
}
err := f(h.app.App(), u, w, r)
if err == nil {
status = http.StatusOK
} else if err, ok := err.(impart.HTTPError); ok {
status = err.Status
} else {
status = http.StatusInternalServerError
}
return err
}())
}
}
// Admin handles requests on /admin routes
func (h *Handler) Admin(f userHandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
h.handleHTTPError(w, r, func() error {
var status int
start := time.Now()
defer func() {
if e := recover(); e != nil {
log.Error("%s: %s", e, debug.Stack())
h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r))
status = http.StatusInternalServerError
}
log.Info(h.app.ReqLog(r, status, time.Since(start)))
}()
u := getUserSession(h.app.App(), r)
if u == nil || !u.IsAdmin() {
err := impart.HTTPError{http.StatusNotFound, ""}
status = err.Status
return err
}
err := f(h.app.App(), u, w, r)
if err == nil {
status = http.StatusOK
} else if err, ok := err.(impart.HTTPError); ok {
status = err.Status
} else {
status = http.StatusInternalServerError
}
return err
}())
}
}
// AdminApper handles requests on /admin routes that require an Apper.
func (h *Handler) AdminApper(f userApperHandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
h.handleHTTPError(w, r, func() error {
var status int
start := time.Now()
defer func() {
if e := recover(); e != nil {
log.Error("%s: %s", e, debug.Stack())
h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r))
status = http.StatusInternalServerError
}
log.Info(h.app.ReqLog(r, status, time.Since(start)))
}()
u := getUserSession(h.app.App(), r)
if u == nil || !u.IsAdmin() {
err := impart.HTTPError{http.StatusNotFound, ""}
status = err.Status
return err
}
err := f(h.app, u, w, r)
if err == nil {
status = http.StatusOK
} else if err, ok := err.(impart.HTTPError); ok {
status = err.Status
} else {
status = http.StatusInternalServerError
}
return err
}())
}
}
func apiAuth(app *App, r *http.Request) (*User, error) {
// Authorize user from Authorization header
t := r.Header.Get("Authorization")
if t == "" {
return nil, ErrNoAccessToken
}
u := &User{ID: app.db.GetUserID(t)}
if u.ID == -1 {
return nil, ErrBadAccessToken
}
return u, nil
}
// optionaAPIAuth is used for endpoints that accept authenticated requests via
// Authorization header or cookie, unlike apiAuth. It returns a different err
// in the case where no Authorization header is present.
func optionalAPIAuth(app *App, r *http.Request) (*User, error) {
// Authorize user from Authorization header
t := r.Header.Get("Authorization")
if t == "" {
return nil, ErrNotLoggedIn
}
u := &User{ID: app.db.GetUserID(t)}
if u.ID == -1 {
return nil, ErrBadAccessToken
}
return u, nil
}
func webAuth(app *App, r *http.Request) (*User, error) {
u := getUserSession(app, r)
if u == nil {
return nil, ErrNotLoggedIn
}
return u, nil
}
// UserAPI handles requests made in the API by the authenticated user.
// This provides user-friendly HTML pages and actions that work in the browser.
func (h *Handler) UserAPI(f userHandlerFunc) http.HandlerFunc {
return h.UserAll(false, f, apiAuth)
}
+// UserWebAPI handles endpoints that accept a user authorized either via the web (cookies) or an Authorization header.
+func (h *Handler) UserWebAPI(f userHandlerFunc) http.HandlerFunc {
+ return h.UserAll(false, f, func(app *App, r *http.Request) (*User, error) {
+ // Authorize user via cookies
+ u := getUserSession(app, r)
+ if u != nil {
+ return u, nil
+ }
+
+ // Fall back to access token, since user isn't logged in via web
+ var err error
+ u, err = apiAuth(app, r)
+ if err != nil {
+ return nil, err
+ }
+
+ return u, nil
+ })
+}
+
func (h *Handler) UserAll(web bool, f userHandlerFunc, a authFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
handleFunc := func() error {
var status int
start := time.Now()
defer func() {
if e := recover(); e != nil {
log.Error("%s: %s", e, debug.Stack())
impart.WriteError(w, impart.HTTPError{http.StatusInternalServerError, "Something didn't work quite right."})
status = 500
}
log.Info(h.app.ReqLog(r, status, time.Since(start)))
}()
u, err := a(h.app.App(), r)
if err != nil {
if err, ok := err.(impart.HTTPError); ok {
status = err.Status
} else {
status = 500
}
return err
}
err = f(h.app.App(), u, w, r)
if err == nil {
status = 200
} else if err, ok := err.(impart.HTTPError); ok {
status = err.Status
} else {
status = 500
}
return err
}
if web {
h.handleHTTPError(w, r, handleFunc())
} else {
h.handleError(w, r, handleFunc())
}
}
}
func (h *Handler) RedirectOnErr(f handlerFunc, loc string) handlerFunc {
return func(app *App, w http.ResponseWriter, r *http.Request) error {
err := f(app, w, r)
if err != nil {
if ie, ok := err.(impart.HTTPError); ok {
// Override default redirect with returned error's, if it's a
// redirect error.
if ie.Status == http.StatusFound {
return ie
}
}
return impart.HTTPError{http.StatusFound, loc}
}
return nil
}
}
func (h *Handler) Page(n string) http.HandlerFunc {
return h.Web(func(app *App, w http.ResponseWriter, r *http.Request) error {
t, ok := pages[n]
if !ok {
return impart.HTTPError{http.StatusNotFound, "Page not found."}
}
sp := pageForReq(app, r)
err := t.ExecuteTemplate(w, "base", sp)
if err != nil {
log.Error("Unable to render page: %v", err)
}
return err
}, UserLevelOptional)
}
func (h *Handler) WebErrors(f handlerFunc, ul UserLevelFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// TODO: factor out this logic shared with Web()
h.handleHTTPError(w, r, func() error {
var status int
start := time.Now()
defer func() {
if e := recover(); e != nil {
u := getUserSession(h.app.App(), r)
username := "None"
if u != nil {
username = u.Username
}
log.Error("User: %s\n\n%s: %s", username, e, debug.Stack())
h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r))
status = 500
}
log.Info(h.app.ReqLog(r, status, time.Since(start)))
}()
var session *sessions.Session
var err error
if ul(h.app.App().cfg) != UserLevelNoneType {
session, err = h.sessionStore.Get(r, cookieName)
if err != nil && (ul(h.app.App().cfg) == UserLevelNoneRequiredType || ul(h.app.App().cfg) == UserLevelUserType) {
// Cookie is required, but we can ignore this error
log.Error("Handler: Unable to get session (for user permission %d); ignoring: %v", ul(h.app.App().cfg), err)
}
_, gotUser := session.Values[cookieUserVal].(*User)
if ul(h.app.App().cfg) == UserLevelNoneRequiredType && gotUser {
to := correctPageFromLoginAttempt(r)
log.Info("Handler: Required NO user, but got one. Redirecting to %s", to)
err := impart.HTTPError{http.StatusFound, to}
status = err.Status
return err
} else if ul(h.app.App().cfg) == UserLevelUserType && !gotUser {
log.Info("Handler: Required a user, but DIDN'T get one. Sending not logged in.")
err := ErrNotLoggedIn
status = err.Status
return err
}
}
// TODO: pass User object to function
err = f(h.app.App(), w, r)
if err == nil {
status = 200
} else if httpErr, ok := err.(impart.HTTPError); ok {
status = httpErr.Status
if status < 300 || status > 399 {
addSessionFlash(h.app.App(), w, r, httpErr.Message, session)
return impart.HTTPError{http.StatusFound, r.Referer()}
}
} else {
e := fmt.Sprintf("[Web handler] 500: %v", err)
if !strings.HasSuffix(e, "write: broken pipe") {
log.Error(e)
} else {
log.Error(e)
}
log.Info("Web handler internal error render")
h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r))
status = 500
}
return err
}())
}
}
func (h *Handler) CollectionPostOrStatic(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, ".") && !isRaw(r) {
start := time.Now()
status := 200
defer func() {
log.Info(h.app.ReqLog(r, status, time.Since(start)))
}()
// Serve static file
h.app.App().shttp.ServeHTTP(w, r)
return
}
h.Web(viewCollectionPost, UserLevelReader)(w, r)
}
// Web handles requests made in the web application. This provides user-
// friendly HTML pages and actions that work in the browser.
func (h *Handler) Web(f handlerFunc, ul UserLevelFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
h.handleHTTPError(w, r, func() error {
var status int
start := time.Now()
defer func() {
if e := recover(); e != nil {
u := getUserSession(h.app.App(), r)
username := "None"
if u != nil {
username = u.Username
}
log.Error("User: %s\n\n%s: %s", username, e, debug.Stack())
log.Info("Web deferred internal error render")
h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r))
status = 500
}
log.Info(h.app.ReqLog(r, status, time.Since(start)))
}()
if ul(h.app.App().cfg) != UserLevelNoneType {
session, err := h.sessionStore.Get(r, cookieName)
if err != nil && (ul(h.app.App().cfg) == UserLevelNoneRequiredType || ul(h.app.App().cfg) == UserLevelUserType) {
// Cookie is required, but we can ignore this error
log.Error("Handler: Unable to get session (for user permission %d); ignoring: %v", ul(h.app.App().cfg), err)
}
_, gotUser := session.Values[cookieUserVal].(*User)
if ul(h.app.App().cfg) == UserLevelNoneRequiredType && gotUser {
to := correctPageFromLoginAttempt(r)
log.Info("Handler: Required NO user, but got one. Redirecting to %s", to)
err := impart.HTTPError{http.StatusFound, to}
status = err.Status
return err
} else if ul(h.app.App().cfg) == UserLevelUserType && !gotUser {
log.Info("Handler: Required a user, but DIDN'T get one. Sending not logged in.")
err := ErrNotLoggedIn
status = err.Status
return err
}
}
// TODO: pass User object to function
err := f(h.app.App(), w, r)
if err == nil {
status = 200
} else if httpErr, ok := err.(impart.HTTPError); ok {
status = httpErr.Status
} else {
e := fmt.Sprintf("[Web handler] 500: %v", err)
log.Error(e)
log.Info("Web internal error render")
h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r))
status = 500
}
return err
}())
}
}
func (h *Handler) All(f handlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
h.handleError(w, r, func() error {
// TODO: return correct "success" status
status := 200
start := time.Now()
defer func() {
if e := recover(); e != nil {
log.Error("%s:\n%s", e, debug.Stack())
impart.WriteError(w, impart.HTTPError{http.StatusInternalServerError, "Something didn't work quite right."})
status = 500
}
log.Info(h.app.ReqLog(r, status, time.Since(start)))
}()
// TODO: do any needed authentication
err := f(h.app.App(), w, r)
if err != nil {
if err, ok := err.(impart.HTTPError); ok {
status = err.Status
} else {
status = 500
}
}
return err
}())
}
}
func (h *Handler) OAuth(f handlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
h.handleOAuthError(w, r, func() error {
// TODO: return correct "success" status
status := 200
start := time.Now()
defer func() {
if e := recover(); e != nil {
log.Error("%s:\n%s", e, debug.Stack())
impart.WriteError(w, impart.HTTPError{http.StatusInternalServerError, "Something didn't work quite right."})
status = 500
}
log.Info(h.app.ReqLog(r, status, time.Since(start)))
}()
err := f(h.app.App(), w, r)
if err != nil {
if err, ok := err.(impart.HTTPError); ok {
status = err.Status
} else {
status = 500
}
}
return err
}())
}
}
func (h *Handler) AllReader(f handlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
h.handleError(w, r, func() error {
status := 200
start := time.Now()
defer func() {
if e := recover(); e != nil {
log.Error("%s:\n%s", e, debug.Stack())
impart.WriteError(w, impart.HTTPError{http.StatusInternalServerError, "Something didn't work quite right."})
status = 500
}
log.Info(h.app.ReqLog(r, status, time.Since(start)))
}()
// Allow any origin, as public endpoints are handled in here
w.Header().Set("Access-Control-Allow-Origin", "*")
if h.app.App().cfg.App.Private {
// This instance is private, so ensure it's being accessed by a valid user
// Check if authenticated with an access token
_, apiErr := optionalAPIAuth(h.app.App(), r)
if apiErr != nil {
if err, ok := apiErr.(impart.HTTPError); ok {
status = err.Status
} else {
status = 500
}
if apiErr == ErrNotLoggedIn {
// Fall back to web auth since there was no access token given
_, err := webAuth(h.app.App(), r)
if err != nil {
if err, ok := apiErr.(impart.HTTPError); ok {
status = err.Status
} else {
status = 500
}
return err
}
} else {
return apiErr
}
}
}
err := f(h.app.App(), w, r)
if err != nil {
if err, ok := err.(impart.HTTPError); ok {
status = err.Status
} else {
status = 500
}
}
return err
}())
}
}
func (h *Handler) Download(f dataHandlerFunc, ul UserLevelFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
h.handleHTTPError(w, r, func() error {
var status int
start := time.Now()
defer func() {
if e := recover(); e != nil {
log.Error("%s: %s", e, debug.Stack())
h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r))
status = 500
}
log.Info(h.app.ReqLog(r, status, time.Since(start)))
}()
data, filename, err := f(h.app.App(), w, r)
if err != nil {
if err, ok := err.(impart.HTTPError); ok {
status = err.Status
} else {
status = 500
}
return err
}
ext := ".json"
ct := "application/json"
if strings.HasSuffix(r.URL.Path, ".csv") {
ext = ".csv"
ct = "text/csv"
} else if strings.HasSuffix(r.URL.Path, ".zip") {
ext = ".zip"
ct = "application/zip"
}
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s%s", filename, ext))
w.Header().Set("Content-Type", ct)
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
fmt.Fprint(w, string(data))
status = 200
return nil
}())
}
}
func (h *Handler) Redirect(url string, ul UserLevelFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
h.handleHTTPError(w, r, func() error {
start := time.Now()
var status int
if ul(h.app.App().cfg) != UserLevelNoneType {
session, err := h.sessionStore.Get(r, cookieName)
if err != nil && (ul(h.app.App().cfg) == UserLevelNoneRequiredType || ul(h.app.App().cfg) == UserLevelUserType) {
// Cookie is required, but we can ignore this error
log.Error("Handler: Unable to get session (for user permission %d); ignoring: %v", ul(h.app.App().cfg), err)
}
_, gotUser := session.Values[cookieUserVal].(*User)
if ul(h.app.App().cfg) == UserLevelNoneRequiredType && gotUser {
to := correctPageFromLoginAttempt(r)
log.Info("Handler: Required NO user, but got one. Redirecting to %s", to)
err := impart.HTTPError{http.StatusFound, to}
status = err.Status
return err
} else if ul(h.app.App().cfg) == UserLevelUserType && !gotUser {
log.Info("Handler: Required a user, but DIDN'T get one. Sending not logged in.")
err := ErrNotLoggedIn
status = err.Status
return err
}
}
status = sendRedirect(w, http.StatusFound, url)
log.Info(h.app.ReqLog(r, status, time.Since(start)))
return nil
}())
}
}
func (h *Handler) handleHTTPError(w http.ResponseWriter, r *http.Request, err error) {
if err == nil {
return
}
if err, ok := err.(impart.HTTPError); ok {
if err.Status >= 300 && err.Status < 400 {
sendRedirect(w, err.Status, err.Message)
return
} else if err.Status == http.StatusUnauthorized {
q := ""
if r.URL.RawQuery != "" {
q = url.QueryEscape("?" + r.URL.RawQuery)
}
sendRedirect(w, http.StatusFound, "/login?to="+r.URL.Path+q)
return
} else if err.Status == http.StatusGone {
w.WriteHeader(err.Status)
p := &struct {
page.StaticPage
Content *template.HTML
}{
StaticPage: pageForReq(h.app.App(), r),
}
if err.Message != "" {
co := template.HTML(err.Message)
p.Content = &co
}
h.errors.Gone.ExecuteTemplate(w, "base", p)
return
} else if err.Status == http.StatusNotFound {
w.WriteHeader(err.Status)
if strings.Contains(r.Header.Get("Accept"), "application/activity+json") {
// This is a fediverse request; simply return the header
return
}
h.errors.NotFound.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r))
return
} else if err.Status == http.StatusInternalServerError {
w.WriteHeader(err.Status)
log.Info("handleHTTPErorr internal error render")
h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r))
return
} else if err.Status == http.StatusServiceUnavailable {
w.WriteHeader(err.Status)
h.errors.UnavailableError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r))
return
} else if err.Status == http.StatusAccepted {
impart.WriteSuccess(w, "", err.Status)
return
} else {
p := &struct {
page.StaticPage
Title string
Content template.HTML
}{
pageForReq(h.app.App(), r),
fmt.Sprintf("Uh oh (%d)", err.Status),
template.HTML(fmt.Sprintf("<p style=\"text-align: center\" class=\"introduction\">%s</p>", err.Message)),
}
h.errors.Blank.ExecuteTemplate(w, "base", p)
return
}
impart.WriteError(w, err)
return
}
impart.WriteError(w, impart.HTTPError{http.StatusInternalServerError, "This is an unhelpful error message for a miscellaneous internal error."})
}
func (h *Handler) handleError(w http.ResponseWriter, r *http.Request, err error) {
if err == nil {
return
}
if err, ok := err.(impart.HTTPError); ok {
if err.Status >= 300 && err.Status < 400 {
sendRedirect(w, err.Status, err.Message)
return
}
// if strings.Contains(r.Header.Get("Accept"), "text/html") {
impart.WriteError(w, err)
// }
return
}
if IsJSON(r) {
impart.WriteError(w, impart.HTTPError{http.StatusInternalServerError, "This is an unhelpful error message for a miscellaneous internal error."})
return
}
h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r))
}
func (h *Handler) handleOAuthError(w http.ResponseWriter, r *http.Request, err error) {
if err == nil {
return
}
if err, ok := err.(impart.HTTPError); ok {
if err.Status >= 300 && err.Status < 400 {
sendRedirect(w, err.Status, err.Message)
return
}
impart.WriteOAuthError(w, err)
return
}
impart.WriteOAuthError(w, impart.HTTPError{http.StatusInternalServerError, "This is an unhelpful error message for a miscellaneous internal error."})
return
}
func correctPageFromLoginAttempt(r *http.Request) string {
to := r.FormValue("to")
if to == "" {
to = "/"
} else if !strings.HasPrefix(to, "/") {
to = "/" + to
}
return to
}
func (h *Handler) LogHandlerFunc(f http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
h.handleHTTPError(w, r, func() error {
status := 200
start := time.Now()
defer func() {
if e := recover(); e != nil {
log.Error("Handler.LogHandlerFunc\n\n%s: %s", e, debug.Stack())
h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r))
status = 500
}
// TODO: log actual status code returned
log.Info(h.app.ReqLog(r, status, time.Since(start)))
}()
if h.app.App().cfg.App.Private {
// This instance is private, so ensure it's being accessed by a valid user
// Check if authenticated with an access token
_, apiErr := optionalAPIAuth(h.app.App(), r)
if apiErr != nil {
if err, ok := apiErr.(impart.HTTPError); ok {
status = err.Status
} else {
status = 500
}
if apiErr == ErrNotLoggedIn {
// Fall back to web auth since there was no access token given
_, err := webAuth(h.app.App(), r)
if err != nil {
if err, ok := apiErr.(impart.HTTPError); ok {
status = err.Status
} else {
status = 500
}
return err
}
} else {
return apiErr
}
}
}
f(w, r)
return nil
}())
}
}
func (h *Handler) Gopher(f gopherFunc) gopher.HandlerFunc {
return func(w gopher.ResponseWriter, r *gopher.Request) {
defer func() {
if e := recover(); e != nil {
log.Error("%s: %s", e, debug.Stack())
w.WriteError("An internal error occurred")
}
log.Info("gopher: %s", r.Selector)
}()
err := f(h.app.App(), w, r)
if err != nil {
log.Error("failed: %s", err)
w.WriteError("the page failed for some reason (see logs)")
}
}
}
func sendRedirect(w http.ResponseWriter, code int, location string) int {
w.Header().Set("Location", location)
w.WriteHeader(code)
return code
}
func cacheControl(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Cache-Control", "public, max-age=604800, immutable")
next.ServeHTTP(w, r)
})
}
diff --git a/routes.go b/routes.go
index 2b23bd1..1244e97 100644
--- a/routes.go
+++ b/routes.go
@@ -1,235 +1,235 @@
/*
* Copyright © 2018-2021 A Bunch Tell LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package writefreely
import (
"net/http"
"net/url"
"path/filepath"
"strings"
"github.com/gorilla/csrf"
"github.com/gorilla/mux"
"github.com/writeas/go-webfinger"
"github.com/writeas/web-core/log"
"github.com/writefreely/go-nodeinfo"
)
// InitStaticRoutes adds routes for serving static files.
// TODO: this should just be a func, not method
func (app *App) InitStaticRoutes(r *mux.Router) {
// Handle static files
fs := http.FileServer(http.Dir(filepath.Join(app.cfg.Server.StaticParentDir, staticDir)))
fs = cacheControl(fs)
app.shttp = http.NewServeMux()
app.shttp.Handle("/", fs)
r.PathPrefix("/").Handler(fs)
}
// InitRoutes adds dynamic routes for the given mux.Router.
func InitRoutes(apper Apper, r *mux.Router) *mux.Router {
// Create handler
handler := NewWFHandler(apper)
// Set up routes
hostSubroute := apper.App().cfg.App.Host[strings.Index(apper.App().cfg.App.Host, "://")+3:]
if apper.App().cfg.App.SingleUser {
hostSubroute = "{domain}"
} else {
if strings.HasPrefix(hostSubroute, "localhost") {
hostSubroute = "localhost"
}
}
if apper.App().cfg.App.SingleUser {
log.Info("Adding %s routes (single user)...", hostSubroute)
} else {
log.Info("Adding %s routes (multi-user)...", hostSubroute)
}
// Primary app routes
write := r.PathPrefix("/").Subrouter()
// Federation endpoint configurations
wf := webfinger.Default(wfResolver{apper.App().db, apper.App().cfg})
wf.NoTLSHandler = nil
// Federation endpoints
// host-meta
write.HandleFunc("/.well-known/host-meta", handler.Web(handleViewHostMeta, UserLevelReader))
// webfinger
write.HandleFunc(webfinger.WebFingerPath, handler.LogHandlerFunc(http.HandlerFunc(wf.Webfinger)))
// nodeinfo
niCfg := nodeInfoConfig(apper.App().db, apper.App().cfg)
ni := nodeinfo.NewService(*niCfg, nodeInfoResolver{apper.App().cfg, apper.App().db})
write.HandleFunc(nodeinfo.NodeInfoPath, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfoDiscover)))
write.HandleFunc(niCfg.InfoURL, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfo)))
// handle mentions
write.HandleFunc("/@/{handle}", handler.Web(handleViewMention, UserLevelReader))
configureSlackOauth(handler, write, apper.App())
configureWriteAsOauth(handler, write, apper.App())
configureGitlabOauth(handler, write, apper.App())
configureGenericOauth(handler, write, apper.App())
configureGiteaOauth(handler, write, apper.App())
// Set up dyamic page handlers
// Handle auth
auth := write.PathPrefix("/api/auth/").Subrouter()
if apper.App().cfg.App.OpenRegistration {
auth.HandleFunc("/signup", handler.All(apiSignup)).Methods("POST")
}
auth.HandleFunc("/login", handler.All(login)).Methods("POST")
auth.HandleFunc("/read", handler.WebErrors(handleWebCollectionUnlock, UserLevelNone)).Methods("POST")
auth.HandleFunc("/me", handler.All(handleAPILogout)).Methods("DELETE")
// Handle logged in user sections
me := write.PathPrefix("/me").Subrouter()
me.HandleFunc("/", handler.Redirect("/me", UserLevelUser))
me.HandleFunc("/c", handler.Redirect("/me/c/", UserLevelUser)).Methods("GET")
me.HandleFunc("/c/", handler.User(viewCollections)).Methods("GET")
me.HandleFunc("/c/{collection}", handler.User(viewEditCollection)).Methods("GET")
me.HandleFunc("/c/{collection}/stats", handler.User(viewStats)).Methods("GET")
me.Path("/delete").Handler(csrf.Protect(apper.App().keys.CSRFKey)(handler.User(handleUserDelete))).Methods("POST")
me.HandleFunc("/posts", handler.Redirect("/me/posts/", UserLevelUser)).Methods("GET")
me.HandleFunc("/posts/", handler.User(viewArticles)).Methods("GET")
me.HandleFunc("/posts/export.csv", handler.Download(viewExportPosts, UserLevelUser)).Methods("GET")
me.HandleFunc("/posts/export.zip", handler.Download(viewExportPosts, UserLevelUser)).Methods("GET")
me.HandleFunc("/posts/export.json", handler.Download(viewExportPosts, UserLevelUser)).Methods("GET")
me.HandleFunc("/export", handler.User(viewExportOptions)).Methods("GET")
me.HandleFunc("/export.json", handler.Download(viewExportFull, UserLevelUser)).Methods("GET")
me.HandleFunc("/import", handler.User(viewImport)).Methods("GET")
me.Path("/settings").Handler(csrf.Protect(apper.App().keys.CSRFKey)(handler.User(viewSettings))).Methods("GET")
me.HandleFunc("/invites", handler.User(handleViewUserInvites)).Methods("GET")
me.HandleFunc("/logout", handler.Web(viewLogout, UserLevelNone)).Methods("GET")
write.HandleFunc("/api/me", handler.All(viewMeAPI)).Methods("GET")
apiMe := write.PathPrefix("/api/me/").Subrouter()
apiMe.HandleFunc("/", handler.All(viewMeAPI)).Methods("GET")
- apiMe.HandleFunc("/posts", handler.UserAPI(viewMyPostsAPI)).Methods("GET")
+ apiMe.HandleFunc("/posts", handler.UserWebAPI(viewMyPostsAPI)).Methods("GET")
apiMe.HandleFunc("/collections", handler.UserAPI(viewMyCollectionsAPI)).Methods("GET")
apiMe.HandleFunc("/password", handler.All(updatePassphrase)).Methods("POST")
apiMe.HandleFunc("/self", handler.All(updateSettings)).Methods("POST")
apiMe.HandleFunc("/invites", handler.User(handleCreateUserInvite)).Methods("POST")
apiMe.HandleFunc("/import", handler.User(handleImport)).Methods("POST")
apiMe.HandleFunc("/oauth/remove", handler.User(removeOauth)).Methods("POST")
// Sign up validation
write.HandleFunc("/api/alias", handler.All(handleUsernameCheck)).Methods("POST")
write.HandleFunc("/api/markdown", handler.All(handleRenderMarkdown)).Methods("POST")
instanceURL, _ := url.Parse(apper.App().Config().App.Host)
host := instanceURL.Host
// Handle collections
write.HandleFunc("/api/collections", handler.All(newCollection)).Methods("POST")
apiColls := write.PathPrefix("/api/collections/").Subrouter()
apiColls.HandleFunc("/"+host, handler.AllReader(fetchCollection)).Methods("GET")
apiColls.HandleFunc("/{alias:[0-9a-zA-Z\\-]+}", handler.AllReader(fetchCollection)).Methods("GET")
apiColls.HandleFunc("/{alias:[0-9a-zA-Z\\-]+}", handler.All(existingCollection)).Methods("POST", "DELETE")
apiColls.HandleFunc("/{alias}/posts", handler.AllReader(fetchCollectionPosts)).Methods("GET")
apiColls.HandleFunc("/{alias}/posts", handler.All(newPost)).Methods("POST")
apiColls.HandleFunc("/{alias}/posts/{post}", handler.AllReader(fetchPost)).Methods("GET")
apiColls.HandleFunc("/{alias}/posts/{post:[a-zA-Z0-9]{10}}", handler.All(existingPost)).Methods("POST")
apiColls.HandleFunc("/{alias}/posts/{post}/{property}", handler.AllReader(fetchPostProperty)).Methods("GET")
apiColls.HandleFunc("/{alias}/collect", handler.All(addPost)).Methods("POST")
apiColls.HandleFunc("/{alias}/pin", handler.All(pinPost)).Methods("POST")
apiColls.HandleFunc("/{alias}/unpin", handler.All(pinPost)).Methods("POST")
apiColls.HandleFunc("/{alias}/inbox", handler.All(handleFetchCollectionInbox)).Methods("POST")
apiColls.HandleFunc("/{alias}/outbox", handler.AllReader(handleFetchCollectionOutbox)).Methods("GET")
apiColls.HandleFunc("/{alias}/following", handler.AllReader(handleFetchCollectionFollowing)).Methods("GET")
apiColls.HandleFunc("/{alias}/followers", handler.AllReader(handleFetchCollectionFollowers)).Methods("GET")
// Handle posts
write.HandleFunc("/api/posts", handler.All(newPost)).Methods("POST")
posts := write.PathPrefix("/api/posts/").Subrouter()
posts.HandleFunc("/{post:[a-zA-Z0-9]{10}}", handler.AllReader(fetchPost)).Methods("GET")
posts.HandleFunc("/{post:[a-zA-Z0-9]{10}}", handler.All(existingPost)).Methods("POST", "PUT")
posts.HandleFunc("/{post:[a-zA-Z0-9]{10}}", handler.All(deletePost)).Methods("DELETE")
posts.HandleFunc("/{post:[a-zA-Z0-9]{10}}/{property}", handler.AllReader(fetchPostProperty)).Methods("GET")
posts.HandleFunc("/claim", handler.All(addPost)).Methods("POST")
posts.HandleFunc("/disperse", handler.All(dispersePost)).Methods("POST")
write.HandleFunc("/auth/signup", handler.Web(handleWebSignup, UserLevelNoneRequired)).Methods("POST")
write.HandleFunc("/auth/login", handler.Web(webLogin, UserLevelNoneRequired)).Methods("POST")
write.HandleFunc("/admin", handler.Admin(handleViewAdminDash)).Methods("GET")
write.HandleFunc("/admin/monitor", handler.Admin(handleViewAdminMonitor)).Methods("GET")
write.HandleFunc("/admin/settings", handler.Admin(handleViewAdminSettings)).Methods("GET")
write.HandleFunc("/admin/users", handler.Admin(handleViewAdminUsers)).Methods("GET")
write.HandleFunc("/admin/user/{username}", handler.Admin(handleViewAdminUser)).Methods("GET")
write.HandleFunc("/admin/user/{username}/delete", handler.Admin(handleAdminDeleteUser)).Methods("POST")
write.HandleFunc("/admin/user/{username}/status", handler.Admin(handleAdminToggleUserStatus)).Methods("POST")
write.HandleFunc("/admin/user/{username}/passphrase", handler.Admin(handleAdminResetUserPass)).Methods("POST")
write.HandleFunc("/admin/pages", handler.Admin(handleViewAdminPages)).Methods("GET")
write.HandleFunc("/admin/page/{slug}", handler.Admin(handleViewAdminPage)).Methods("GET")
write.HandleFunc("/admin/update/config", handler.AdminApper(handleAdminUpdateConfig)).Methods("POST")
write.HandleFunc("/admin/update/{page}", handler.Admin(handleAdminUpdateSite)).Methods("POST")
write.HandleFunc("/admin/updates", handler.Admin(handleViewAdminUpdates)).Methods("GET")
// Handle special pages first
write.HandleFunc("/login", handler.Web(viewLogin, UserLevelNoneRequired))
write.HandleFunc("/signup", handler.Web(handleViewLanding, UserLevelNoneRequired))
write.HandleFunc("/invite/{code:[a-zA-Z0-9]+}", handler.Web(handleViewInvite, UserLevelOptional)).Methods("GET")
// TODO: show a reader-specific 404 page if the function is disabled
write.HandleFunc("/read", handler.Web(viewLocalTimeline, UserLevelReader))
RouteRead(handler, UserLevelReader, write.PathPrefix("/read").Subrouter())
draftEditPrefix := ""
if apper.App().cfg.App.SingleUser {
draftEditPrefix = "/d"
write.HandleFunc("/me/new", handler.Web(handleViewPad, UserLevelUser)).Methods("GET")
} else {
write.HandleFunc("/new", handler.Web(handleViewPad, UserLevelUser)).Methods("GET")
}
// All the existing stuff
write.HandleFunc(draftEditPrefix+"/{action}/edit", handler.Web(handleViewPad, UserLevelUser)).Methods("GET")
write.HandleFunc(draftEditPrefix+"/{action}/meta", handler.Web(handleViewMeta, UserLevelUser)).Methods("GET")
// Collections
if apper.App().cfg.App.SingleUser {
RouteCollections(handler, write.PathPrefix("/").Subrouter())
} else {
write.HandleFunc("/{prefix:[@~$!\\-+]}{collection}", handler.Web(handleViewCollection, UserLevelReader))
write.HandleFunc("/{collection}/", handler.Web(handleViewCollection, UserLevelReader))
RouteCollections(handler, write.PathPrefix("/{prefix:[@~$!\\-+]?}{collection}").Subrouter())
// Posts
}
write.HandleFunc(draftEditPrefix+"/{post}", handler.Web(handleViewPost, UserLevelOptional))
write.HandleFunc("/", handler.Web(handleViewHome, UserLevelOptional))
return r
}
func RouteCollections(handler *Handler, r *mux.Router) {
r.HandleFunc("/logout", handler.Web(handleLogOutCollection, UserLevelOptional))
r.HandleFunc("/page/{page:[0-9]+}", handler.Web(handleViewCollection, UserLevelReader))
r.HandleFunc("/tag:{tag}", handler.Web(handleViewCollectionTag, UserLevelReader))
r.HandleFunc("/tag:{tag}/feed/", handler.Web(ViewFeed, UserLevelReader))
r.HandleFunc("/sitemap.xml", handler.AllReader(handleViewSitemap))
r.HandleFunc("/feed/", handler.AllReader(ViewFeed))
r.HandleFunc("/{slug}", handler.CollectionPostOrStatic)
r.HandleFunc("/{slug}/edit", handler.Web(handleViewPad, UserLevelUser))
r.HandleFunc("/{slug}/edit/meta", handler.Web(handleViewMeta, UserLevelUser))
r.HandleFunc("/{slug}/", handler.Web(handleCollectionPostRedirect, UserLevelReader)).Methods("GET")
}
func RouteRead(handler *Handler, readPerm UserLevelFunc, r *mux.Router) {
r.HandleFunc("/api/posts", handler.Web(viewLocalTimelineAPI, readPerm))
r.HandleFunc("/p/{page}", handler.Web(viewLocalTimeline, readPerm))
r.HandleFunc("/feed/", handler.Web(viewLocalTimelineFeed, readPerm))
r.HandleFunc("/t/{tag}", handler.Web(viewLocalTimeline, readPerm))
r.HandleFunc("/a/{post}", handler.Web(handlePostIDRedirect, readPerm))
r.HandleFunc("/{author}", handler.Web(viewLocalTimeline, readPerm))
r.HandleFunc("/", handler.Web(viewLocalTimeline, readPerm))
}
diff --git a/static/js/posts.js b/static/js/posts.js
index 58b55a2..dfc30b7 100644
--- a/static/js/posts.js
+++ b/static/js/posts.js
@@ -1,315 +1,332 @@
/**
* Functionality for managing local Write.as posts.
*
* Dependencies:
* h.js
*/
function toggleTheme() {
var btns;
try {
btns = Array.prototype.slice.call(document.getElementById('belt').querySelectorAll('.tool img'));
} catch (e) {}
if (document.body.className == 'light') {
document.body.className = 'dark';
try {
for (var i=0; i<btns.length; i++) {
btns[i].src = btns[i].src.replace('_dark@2x.png', '@2x.png');
}
} catch (e) {}
} else if (document.body.className == 'dark') {
document.body.className = 'light';
try {
for (var i=0; i<btns.length; i++) {
btns[i].src = btns[i].src.replace('@2x.png', '_dark@2x.png');
}
} catch (e) {}
} else {
// Don't alter the theme
return;
}
H.set('padTheme', document.body.className);
}
if (H.get('padTheme', 'light') != 'light') {
toggleTheme();
}
var deleting = false;
function delPost(e, id, owned) {
e.preventDefault();
if (deleting) {
return;
}
// TODO: UNDO!
if (window.confirm('Are you sure you want to delete this post?')) {
var token;
for (var i=0; i<posts.length; i++) {
if (posts[i].id == id) {
token = posts[i].token;
break;
}
}
if (owned || token) {
// AJAX
deletePost(id, token, function() {
// Remove post from list
var $postEl = document.getElementById('post-' + id);
$postEl.parentNode.removeChild($postEl);
if (posts.length == 0) {
displayNoPosts();
return;
}
// Fill in full page of posts
var $postsChildren = $posts.el.getElementsByClassName('post');
if ($postsChildren.length < postsPerPage && $postsChildren.length < posts.length) {
var lastVisiblePostID = $postsChildren[$postsChildren.length-1].id;
lastVisiblePostID = lastVisiblePostID.substr(lastVisiblePostID.indexOf('-')+1);
for (var i=0; i<posts.length-1; i++) {
if (posts[i].id == lastVisiblePostID) {
var $moreBtn = document.getElementById('more-posts');
if ($moreBtn) {
// Should always land here (?)
$posts.el.insertBefore(createPostEl(posts[i-1]), $moreBtn);
} else {
$posts.el.appendChild(createPostEl(posts[i-1]));
}
}
}
}
});
} else {
alert('Something went seriously wrong. Try refreshing.');
}
}
}
var getFormattedDate = function(d) {
var mos = [
"January", "February", "March",
"April", "May", "June", "July",
"August", "September", "October",
"November", "December"
];
var day = d.getDate();
var mo = d.getMonth();
var yr = d.getFullYear();
return mos[mo] + ' ' + day + ', ' + yr;
};
var posts = JSON.parse(H.get('posts', '[]'));
var initialListPop = function() {
pages = Math.ceil(posts.length / postsPerPage);
loadPage(page, true);
};
var $posts = H.getEl("posts");
if ($posts.el == null) {
$posts = H.getEl("unsynced-posts");
}
$posts.el.innerHTML = '<p class="status">Reading...</p>';
var createMorePostsEl = function() {
var $more = document.createElement('div');
var nextPage = page+1;
$more.id = 'more-posts';
$more.innerHTML = '<p><a href="#' + nextPage + '">More...</a></p>';
return $more;
};
var localPosts = function() {
var $delPost, lastDelPost, lastInfoHTML;
var $info = He.get('unsynced-posts-info');
var findPostIdx = function(id) {
for (var i=0; i<posts.length; i++) {
if (posts[i].id == id) {
return i;
}
}
return -1;
};
var DismissError = function(e, el) {
e.preventDefault();
var $errorMsg = el.parentNode.previousElementSibling;
$errorMsg.parentNode.removeChild($errorMsg);
var $errorMsgNav = el.parentNode;
$errorMsgNav.parentNode.removeChild($errorMsgNav);
};
var DeletePostLocal = function(e, el, id) {
e.preventDefault();
if (!window.confirm('Are you sure you want to delete this post?')) {
return;
}
var i = findPostIdx(id);
if (i > -1) {
lastDelPost = posts.splice(i, 1)[0];
$delPost = H.getEl('post-'+id);
$delPost.setClass('del-undo');
var $unsyncPosts = document.getElementById('unsynced-posts');
var visible = $unsyncPosts.children.length;
for (var i=0; i < $unsyncPosts.children.length; i++) { // NOTE: *.children support in IE9+
if ($unsyncPosts.children[i].className.indexOf('del-undo') !== -1) {
visible--;
}
}
if (visible == 0) {
H.getEl('unsynced-posts-header').hide();
// TODO: fix undo functionality and don't do the following:
H.getEl('unsynced-posts-info').hide();
}
H.set('posts', JSON.stringify(posts));
// TODO: fix undo functionality and re-add
//lastInfoHTML = $info.innerHTML;
//$info.innerHTML = 'Unsynced entry deleted. <a href="#" onclick="localPosts.undoDelete()">Undo</a>.';
}
};
var UndoDelete = function() {
// TODO: fix this header reappearing
H.getEl('unsynced-posts-header').show();
$delPost.removeClass('del-undo');
$info.innerHTML = lastInfoHTML;
};
return {
dismissError: DismissError,
deletePost: DeletePostLocal,
undoDelete: UndoDelete,
};
}();
-var createPostEl = function(post) {
+var movePostHTML = function(postID) {
+ let $tmpl = document.getElementById('move-tmpl');
+ if ($tmpl === null) {
+ return "";
+ }
+ return $tmpl.innerHTML.replace(/POST_ID/g, postID);
+}
+var createPostEl = function(post, owned) {
var $post = document.createElement('div');
- var title = (post.title || post.id);
+ let p = H.createPost(post.id, "", post.body)
+ var title = (post.title || p.title || post.id);
title = title.replace(/</g, "&lt;");
$post.id = 'post-' + post.id;
$post.className = 'post';
$post.innerHTML = '<h3><a href="/' + post.id + '">' + title + '</a></h3>';
var posted = "";
if (post.created) {
posted = getFormattedDate(new Date(post.created))
}
var hasDraft = H.exists('draft' + post.id);
- $post.innerHTML += '<h4><date>' + posted + '</date> <a class="action" href="/pad/' + post.id + '">edit' + (hasDraft ? 'ed' : '') + '</a> <a class="delete action" href="/' + post.id + '" onclick="delPost(event, \'' + post.id + '\')">delete</a></h4>';
+ $post.innerHTML += '<h4><date>' + posted + '</date> <a class="action" href="/pad/' + post.id + '">edit' + (hasDraft ? 'ed' : '') + '</a> <a class="delete action" href="/' + post.id + '" onclick="delPost(event, \'' + post.id + '\'' + (owned === true ? ', true' : '') + ')">delete</a> '+movePostHTML(post.id)+'</h4>';
if (post.error) {
$post.innerHTML += '<p class="error"><strong>Sync error:</strong> ' + post.error + ' <nav><a href="#" onclick="localPosts.dismissError(event, this)">dismiss</a> <a href="#" onclick="localPosts.deletePost(event, this, \''+post.id+'\')">remove post</a></nav></p>';
}
if (post.summary) {
+ // TODO: switch to using p.summary, after ensuring it matches summary generated on the backend.
$post.innerHTML += '<p>' + post.summary.replace(/</g, "&lt;") + '</p>';
+ } else if (post.body) {
+ var preview;
+ if (post.body.length > 140) {
+ preview = post.body.substr(0, 140) + '...';
+ } else {
+ preview = post.body;
+ }
+ $post.innerHTML += '<p>' + preview.replace(/</g, "&lt;") + '</p>';
}
return $post;
};
var loadPage = function(p, loadAll) {
if (loadAll) {
$posts.el.innerHTML = '';
}
var startPost = posts.length - 1 - (loadAll ? 0 : ((p-1)*postsPerPage));
var endPost = posts.length - 1 - (p*postsPerPage);
for (var i=startPost; i>=0 && i>endPost; i--) {
$posts.el.appendChild(createPostEl(posts[i]));
}
if (loadAll) {
if (p < pages) {
$posts.el.appendChild(createMorePostsEl());
}
} else {
var $moreEl = document.getElementById('more-posts');
$moreEl.parentNode.removeChild($moreEl);
}
try {
postsLoaded(posts.length);
} catch (e) {}
};
var getPageNum = function(url) {
var hash;
if (url) {
hash = url.substr(url.indexOf('#')+1);
} else {
hash = window.location.hash.substr(1);
}
var page = hash || 1;
page = parseInt(page);
if (isNaN(page)) {
page = 1;
}
return page;
};
var postsPerPage = 10;
var pages = 0;
var page = getPageNum();
window.addEventListener('hashchange', function(e) {
var newPage = getPageNum();
var didPageIncrement = newPage == getPageNum(e.oldURL) + 1;
loadPage(newPage, !didPageIncrement);
});
var deletePost = function(postID, token, callback) {
deleting = true;
var $delBtn = document.getElementById('post-' + postID).getElementsByClassName('delete action')[0];
$delBtn.innerHTML = '...';
var http = new XMLHttpRequest();
var url = "/api/posts/" + postID + (typeof token !== 'undefined' ? "?token=" + encodeURIComponent(token) : '');
http.open("DELETE", url, true);
http.onreadystatechange = function() {
if (http.readyState == 4) {
deleting = false;
if (http.status == 204 || http.status == 404) {
for (var i=0; i<posts.length; i++) {
if (posts[i].id == postID) {
// TODO: use this return value, along will full content, for restoring post
posts.splice(i, 1);
break;
}
}
H.set('posts', JSON.stringify(posts));
callback();
} else if (http.status == 409) {
$delBtn.innerHTML = 'delete';
alert("Post is synced to another account. Delete the post from that account instead.");
// TODO: show "remove" button instead of "delete" now
// Persist that state.
// Have it remove the post locally only.
} else {
$delBtn.innerHTML = 'delete';
alert("Failed to delete. Please try again.");
}
}
}
http.send();
};
var hasWritten = H.get('lastDoc', '') !== '';
var displayNoPosts = function() {
if (auth) {
$posts.el.innerHTML = '';
return;
}
var cta = '<a href="/pad">Create a post</a> and it\'ll appear here.';
if (hasWritten) {
cta = '<a href="/pad">Finish your post</a> and it\'ll appear here.';
}
H.getEl("posts").el.innerHTML = '<p class="status">No posts created yet.</p><p class="status">' + cta + '</p>';
};
if (posts.length == 0) {
displayNoPosts();
} else {
initialListPop();
}
diff --git a/templates/user/articles.tmpl b/templates/user/articles.tmpl
index e96d51e..92f9c40 100644
--- a/templates/user/articles.tmpl
+++ b/templates/user/articles.tmpl
@@ -1,152 +1,227 @@
{{define "articles"}}
{{template "header" .}}
+<style type="text/css">
+ a.loading {
+ font-style: italic;
+ color: #666;
+ }
+ #move-tmpl {
+ display: none;
+ }
+</style>
<div class="snug content-container">
{{if .Flashes}}<ul class="errors">
{{range .Flashes}}<li class="urgent">{{.}}</li>{{end}}
</ul>{{end}}
{{if .Silenced}}
{{template "user-silenced"}}
{{end}}
<h1 id="posts-header">Drafts</h1>
{{ if .AnonymousPosts }}
<p>These are your draft posts. You can share them individually (without a blog) or move them to your blog when you're ready.</p>
- <div class="atoms posts">
+ <div id="anon-posts" class="atoms posts">
{{ range $el := .AnonymousPosts }}<div id="post-{{.ID}}" class="post">
<h3><a href="/{{if $.SingleUser}}d/{{end}}{{.ID}}" itemprop="url">{{.DisplayTitle}}</a></h3>
<h4>
<date datetime="{{.Created}}" pubdate itemprop="datePublished" content="{{.Created}}">{{.DisplayDate}}</date>
<a class="action" href="/{{if $.SingleUser}}d/{{end}}{{.ID}}/edit">edit</a>
<a class="delete action" href="/{{.ID}}" onclick="delPost(event, '{{.ID}}', true)">delete</a>
{{ if $.Collections }}
{{if gt (len $.Collections) 1}}<div class="action flat-select">
<select id="move-{{.ID}}" onchange="postActions.multiMove(this, '{{.ID}}', {{if $.SingleUser}}true{{else}}false{{end}})" title="Move this post to one of your blogs">
<option style="display:none"></option>
{{range $.Collections}}<option value="{{.Alias}}">{{.DisplayTitle}}</option>{{end}}
</select>
<label for="move-{{.ID}}">move to...</label>
<img class="ic-18dp" src="/img/ic_down_arrow_dark@2x.png" />
</div>{{else}}
{{range $.Collections}}
<a class="action" href="/{{$el.ID}}" title="Publish this post to your blog '{{.DisplayTitle}}'" onclick="postActions.move(this, '{{$el.ID}}', '{{.Alias}}', {{if $.SingleUser}}true{{else}}false{{end}});return false">move to {{.DisplayTitle}}</a>
{{end}}
{{end}}
{{ end }}
</h4>
{{if .Summary}}<p>{{.SummaryHTML}}</p>{{end}}
</div>{{end}}
-</div>{{ else }}<div id="no-posts-published">
+</div>
+{{if eq (len .AnonymousPosts) 10}}<p id="load-more-p"><a href="#load">Load more...</a></p>{{end}}
+{{ else }}<div id="no-posts-published">
<p>Your anonymous and draft posts will show up here once you've published some. You'll be able to share them individually (without a blog) or move them to a blog when you're ready.</p>
{{if not .SingleUser}}<p>Alternatively, see your blogs and their posts on your <a href="/me/c/">Blogs</a> page.</p>{{end}}
+
<p class="text-cta"><a href="{{if .SingleUser}}/me/new{{else}}/{{end}}">Start writing</a></p></div>{{ end }}
<div id="moving"></div>
<h2 id="unsynced-posts-header" style="display: none">unsynced posts</h2>
<div id="unsynced-posts-info" style="margin-top: 1em"></div>
<div id="unsynced-posts" class="atoms"></div>
</div>
+{{ if .Collections }}
+ <div id="move-tmpl">
+ {{if gt (len .Collections) 1}}
+ <div class="action flat-select">
+ <select id="move-POST_ID" onchange="postActions.multiMove(this, 'POST_ID', {{if .SingleUser}}true{{else}}false{{end}})" title="Move this post to one of your blogs">
+ <option style="display:none"></option>
+ {{range .Collections}}<option value="{{.Alias}}">{{.DisplayTitle}}</option>{{end}}
+ </select>
+ <label for="move-POST_ID">move to...</label>
+ <img class="ic-18dp" src="/img/ic_down_arrow_dark@2x.png" />
+ </div>
+ {{else}}
+ {{range .Collections}}
+ <a class="action" href="/POST_ID" title="Publish this post to your blog '{{.DisplayTitle}}'" onclick="postActions.move(this, 'POST_ID', '{{.Alias}}', {{if $.SingleUser}}true{{else}}false{{end}});return false">move to {{.DisplayTitle}}</a>
+ {{end}}
+ {{end}}
+ </div>
+{{ end }}
+
<script src="/js/h.js"></script>
<script src="/js/postactions.js"></script>
<script>
var auth = true;
function postsLoaded(n) {
if (n == 0) {
return;
}
document.getElementById('unsynced-posts-header').style.display = 'block';
var syncing = false;
var $pInfo = document.getElementById('unsynced-posts-info');
$pInfo.className = 'alert info';
var plural = n != 1;
$pInfo.innerHTML = '<p>You have <strong>'+n+'</strong> post'+(plural?'s that aren\'t':' that isn\'t')+' synced to your account yet. <a href="#" id="btn-sync">Sync '+(plural?'them':'it')+' now</a>.</p>';
var $noPosts = document.getElementById('no-posts-published');
if ($noPosts != null) {
$noPosts.style.display = 'none';
document.getElementById('posts-header').style.display = 'none';
}
H.getEl('btn-sync').on('click', function(e) {
e.preventDefault();
if (syncing) {
return;
}
var http = new XMLHttpRequest();
var params = [];
var posts = JSON.parse(H.get('posts', '[]'));
if (posts.length > 0) {
for (var i=0; i<posts.length; i++) {
params.push({id: posts[i].id, token: posts[i].token});
}
}
this.style.fontWeight = 'bold';
this.innerText = 'Syncing '+(plural?'them':'it')+' now...';
http.open("POST", "/api/posts/claim", true);
// Send the proper header information along with the request
http.setRequestHeader("Content-type", "application/json");
http.onreadystatechange = function() {
if (http.readyState == 4) {
syncing = false;
this.innerText = 'Importing '+(plural?'them':'it')+' now...';
if (http.status == 200) {
var res = JSON.parse(http.responseText);
if (res.data.length > 0) {
if (res.data.length != posts.length) {
// TODO: handle something that royally fucked up
console.error("Request and result array length didn't match!");
return;
}
for (var i=0; i<res.data.length; i++) {
if (res.data[i].code == 200) {
// Post successfully claimed.
for (var j=0; j<posts.length; j++) {
// Find post in local store
if (posts[j].id == res.data[i].post.id) {
// Remove this post
posts.splice(j, 1);
break;
}
}
} else {
for (var j=0; j<posts.length; j++) {
// Find post in local store
if (posts[j].id == res.data[i].id) {
// Note the error in the local post
posts[j].error = res.data[i].error_msg;
break;
}
}
}
}
H.set('posts', JSON.stringify(posts));
location.reload();
}
} else {
// TODO: handle error visually (option to retry)
console.error("Didn't work at all, man.");
this.style.fontWeight = 'normal';
this.innerText = 'Sync '+(plural?'them':'it')+' now';
}
}
}
http.send(JSON.stringify(params));
syncing = true;
});
}
+
+var $loadMore = H.getEl("load-more-p");
+var curPage = 1;
+var isLoadingMore = false;
+function loadMorePosts() {
+ if (isLoadingMore === true) {
+ return;
+ }
+ var $link = this;
+ isLoadingMore = true;
+
+ $link.className = 'loading';
+ $link.textContent = 'Loading posts...';
+
+ var $posts = H.getEl("anon-posts");
+
+ curPage++;
+
+ var http = new XMLHttpRequest();
+ var url = "/api/me/posts?anonymous=1&page=" + curPage;
+ http.open("GET", url, true);
+ http.setRequestHeader("Content-type", "application/json");
+ http.onreadystatechange = function() {
+ if (http.readyState == 4) {
+ if (http.status == 200) {
+ var data = JSON.parse(http.responseText);
+ for (var i=0; i<data.data.length; i++) {
+ $posts.el.appendChild(createPostEl(data.data[i], true));
+ }
+ if (data.data.length < 10) {
+ $loadMore.el.parentNode.removeChild($loadMore.el);
+ }
+ } else {
+ alert("Failed to load more posts. Please try again.");
+ curPage--;
+ }
+ isLoadingMore = false;
+ $link.className = '';
+ $link.textContent = 'Load more...';
+ }
+ }
+ http.send();
+}
+$loadMore.el.querySelector('a').addEventListener('click', loadMorePosts);
</script>
<script src="/js/posts.js"></script>
{{template "footer" .}}
{{end}}

File Metadata

Mime Type
text/x-diff
Expires
Mon, Nov 25, 5:58 AM (1 d, 19 h)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3104737

Event Timeline