Page MenuHomeMusing Studio

No OneTemporary

diff --git a/author/author.go b/author/author.go
index a95eb61..0c2ea57 100644
--- a/author/author.go
+++ b/author/author.go
@@ -1,136 +1,137 @@
/*
* Copyright © 2018-2021 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package author
import (
- "github.com/writeas/web-core/log"
- "github.com/writefreely/writefreely/config"
"os"
"path/filepath"
"regexp"
+
+ "github.com/writeas/web-core/log"
+ "github.com/writefreely/writefreely/config"
)
// Regex pattern for valid usernames
var validUsernameReg = regexp.MustCompile("^[a-zA-Z0-9][a-zA-Z0-9-]*$")
// List of reserved usernames
var reservedUsernames = map[string]bool{
"a": true,
"about": true,
"add": true,
"admin": true,
"administrator": true,
"adminzone": true,
"api": true,
"article": true,
"articles": true,
"auth": true,
"authenticate": true,
"browse": true,
"c": true,
"categories": true,
"category": true,
"changes": true,
"community": true,
"create": true,
"css": true,
"data": true,
"dev": true,
"developers": true,
"draft": true,
"drafts": true,
"edit": true,
"edits": true,
"faq": true,
"feed": true,
"feedback": true,
"guide": true,
"guides": true,
"help": true,
"index": true,
"invite": true,
"js": true,
"login": true,
"logout": true,
"me": true,
"media": true,
"meta": true,
"metadata": true,
"new": true,
"news": true,
"oauth": true,
"post": true,
"posts": true,
"privacy": true,
"publication": true,
"publications": true,
"publish": true,
"random": true,
"read": true,
"reader": true,
"register": true,
"remove": true,
"signin": true,
"signout": true,
"signup": true,
"start": true,
"status": true,
"summary": true,
"support": true,
"tag": true,
"tags": true,
"team": true,
"template": true,
"templates": true,
"terms": true,
"terms-of-service": true,
"termsofservice": true,
"theme": true,
"themes": true,
"tips": true,
"tos": true,
"update": true,
"updates": true,
"user": true,
"users": true,
"yourname": true,
}
// IsValidUsername returns true if a given username is neither reserved nor
// of the correct format.
func IsValidUsername(cfg *config.Config, username string) bool {
// Username has to be above a character limit
if len(username) < cfg.App.MinUsernameLen {
return false
}
// Username is invalid if page with the same name exists. So traverse
// available pages, adding them to reservedUsernames map that'll be checked
// later.
err := filepath.Walk(filepath.Join(cfg.Server.PagesParentDir, "pages"), func(path string, i os.FileInfo, err error) error {
if err != nil {
return err
}
reservedUsernames[i.Name()] = true
return nil
})
if err != nil {
log.Error("[IMPORTANT WARNING]: Could not determine IsValidUsername! %s", err)
return false
}
// Username is invalid if it is reserved!
if _, reserved := reservedUsernames[username]; reserved {
return false
}
// TODO: use correct regexp function here
return len(validUsernameReg.FindStringSubmatch(username)) > 0
}
diff --git a/config/funcs.go b/config/funcs.go
index 7c6fd4f..3bbcfa7 100644
--- a/config/funcs.go
+++ b/config/funcs.go
@@ -1,62 +1,63 @@
/*
* Copyright © 2018, 2020-2021 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package config
import (
- "github.com/writeas/web-core/log"
- "golang.org/x/net/idna"
"net/http"
"net/url"
"strings"
"time"
+
+ "github.com/writeas/web-core/log"
+ "golang.org/x/net/idna"
)
// FriendlyHost returns the app's Host sans any schema
func (ac AppCfg) FriendlyHost() string {
rawHost := ac.Host[strings.Index(ac.Host, "://")+len("://"):]
u, err := url.Parse(ac.Host)
if err != nil {
log.Error("url.Parse failed on %s: %s", ac.Host, err)
return rawHost
}
d, err := idna.ToUnicode(u.Hostname())
if err != nil {
log.Error("idna.ToUnicode failed on %s: %s", ac.Host, err)
return rawHost
}
res := d
if u.Port() != "" {
res += ":" + u.Port()
}
return res
}
func (ac AppCfg) CanCreateBlogs(currentlyUsed uint64) bool {
if ac.MaxBlogs <= 0 {
return true
}
return int(currentlyUsed) < ac.MaxBlogs
}
// OrDefaultString returns input or a default value if input is empty.
func OrDefaultString(input, defaultValue string) string {
if len(input) == 0 {
return defaultValue
}
return input
}
// DefaultHTTPClient returns a sane default HTTP client.
func DefaultHTTPClient() *http.Client {
return &http.Client{Timeout: 10 * time.Second}
}
diff --git a/database-sqlite.go b/database-sqlite.go
index 7d33315..c34bca2 100644
--- a/database-sqlite.go
+++ b/database-sqlite.go
@@ -1,73 +1,74 @@
//go:build sqlite && !wflib
// +build sqlite,!wflib
/*
* Copyright © 2019-2020 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package writefreely
import (
"database/sql"
+ "regexp"
+
"github.com/go-sql-driver/mysql"
"github.com/mattn/go-sqlite3"
"github.com/writeas/web-core/log"
- "regexp"
)
func init() {
SQLiteEnabled = true
regex := func(re, s string) (bool, error) {
return regexp.MatchString(re, s)
}
sql.Register("sqlite3_with_regex", &sqlite3.SQLiteDriver{
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
return conn.RegisterFunc("regexp", regex, true)
},
})
}
func (db *datastore) isDuplicateKeyErr(err error) bool {
if db.driverName == driverSQLite {
if err, ok := err.(sqlite3.Error); ok {
return err.Code == sqlite3.ErrConstraint
}
} else if db.driverName == driverMySQL {
if mysqlErr, ok := err.(*mysql.MySQLError); ok {
return mysqlErr.Number == mySQLErrDuplicateKey
}
} else {
log.Error("isDuplicateKeyErr: failed check for unrecognized driver '%s'", db.driverName)
}
return false
}
func (db *datastore) isIgnorableError(err error) bool {
if db.driverName == driverMySQL {
if mysqlErr, ok := err.(*mysql.MySQLError); ok {
return mysqlErr.Number == mySQLErrCollationMix
}
} else {
log.Error("isIgnorableError: failed check for unrecognized driver '%s'", db.driverName)
}
return false
}
func (db *datastore) isHighLoadError(err error) bool {
if db.driverName == driverMySQL {
if mysqlErr, ok := err.(*mysql.MySQLError); ok {
return mysqlErr.Number == mySQLErrMaxUserConns || mysqlErr.Number == mySQLErrTooManyConns
}
}
return false
}
diff --git a/database_activitypub.go b/database_activitypub.go
index 9df3724..b832d37 100644
--- a/database_activitypub.go
+++ b/database_activitypub.go
@@ -1,49 +1,50 @@
/*
* Copyright © 2024 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package writefreely
import (
"database/sql"
"fmt"
+
"github.com/writeas/web-core/activitystreams"
"github.com/writeas/web-core/log"
)
func apAddRemoteUser(app *App, t *sql.Tx, fullActor *activitystreams.Person) (int64, error) {
// Add remote user locally, since it wasn't found before
res, err := t.Exec("INSERT INTO remoteusers (actor_id, inbox, shared_inbox, url) VALUES (?, ?, ?, ?)", fullActor.ID, fullActor.Inbox, fullActor.Endpoints.SharedInbox, fullActor.URL)
if err != nil {
t.Rollback()
return -1, fmt.Errorf("couldn't add new remoteuser in DB: %v", err)
}
remoteUserID, err := res.LastInsertId()
if err != nil {
t.Rollback()
return -1, fmt.Errorf("no lastinsertid for followers, rolling back: %v", err)
}
// Add in key
_, err = t.Exec("INSERT INTO remoteuserkeys (id, remote_user_id, public_key) VALUES (?, ?, ?)", fullActor.PublicKey.ID, remoteUserID, fullActor.PublicKey.PublicKeyPEM)
if err != nil {
if !app.db.isDuplicateKeyErr(err) {
t.Rollback()
log.Error("Couldn't add follower keys in DB: %v\n", err)
return -1, fmt.Errorf("couldn't add follower keys in DB: %v", err)
} else {
t.Rollback()
log.Error("Couldn't add follower keys in DB: %v\n", err)
return -1, fmt.Errorf("couldn't add follower keys in DB: %v", err)
}
}
return remoteUserID, nil
}
diff --git a/database_test.go b/database_test.go
index c114077..fc0523c 100644
--- a/database_test.go
+++ b/database_test.go
@@ -1,50 +1,51 @@
package writefreely
import (
"context"
"database/sql"
- "github.com/stretchr/testify/assert"
"testing"
+
+ "github.com/stretchr/testify/assert"
)
func TestOAuthDatastore(t *testing.T) {
if !runMySQLTests() {
t.Skip("skipping mysql tests")
}
withTestDB(t, func(db *sql.DB) {
ctx := context.Background()
ds := &datastore{
DB: db,
driverName: "",
}
state, err := ds.GenerateOAuthState(ctx, "test", "development", 0, "")
assert.NoError(t, err)
assert.Len(t, state, 24)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false", state)
_, _, _, _, err = ds.ValidateOAuthState(ctx, state)
assert.NoError(t, err)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true", state)
var localUserID int64 = 99
var remoteUserID = "100"
err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID, "test", "test", "access_token_a")
assert.NoError(t, err)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users` WHERE `user_id` = ? AND `remote_user_id` = ? AND access_token = 'access_token_a'", localUserID, remoteUserID)
err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID, "test", "test", "access_token_b")
assert.NoError(t, err)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users` WHERE `user_id` = ? AND `remote_user_id` = ? AND access_token = 'access_token_b'", localUserID, remoteUserID)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users`")
foundUserID, err := ds.GetIDForRemoteUser(ctx, remoteUserID, "test", "test")
assert.NoError(t, err)
assert.Equal(t, localUserID, foundUserID)
})
}
diff --git a/db/create_test.go b/db/create_test.go
index 369d5c1..b971bdf 100644
--- a/db/create_test.go
+++ b/db/create_test.go
@@ -1,146 +1,147 @@
package db
import (
- "github.com/stretchr/testify/assert"
"testing"
+
+ "github.com/stretchr/testify/assert"
)
func TestDialect_Column(t *testing.T) {
c1 := DialectSQLite.Column("foo", ColumnTypeBool, UnsetSize)
assert.Equal(t, DialectSQLite, c1.Dialect)
c2 := DialectMySQL.Column("foo", ColumnTypeBool, UnsetSize)
assert.Equal(t, DialectMySQL, c2.Dialect)
}
func TestColumnType_Format(t *testing.T) {
type args struct {
dialect DialectType
size OptionalInt
}
tests := []struct {
name string
d ColumnType
args args
want string
wantErr bool
}{
{"Sqlite bool", ColumnTypeBool, args{dialect: DialectSQLite}, "INTEGER", false},
{"Sqlite small int", ColumnTypeSmallInt, args{dialect: DialectSQLite}, "INTEGER", false},
{"Sqlite int", ColumnTypeInteger, args{dialect: DialectSQLite}, "INTEGER", false},
{"Sqlite char", ColumnTypeChar, args{dialect: DialectSQLite}, "TEXT", false},
{"Sqlite varchar", ColumnTypeVarChar, args{dialect: DialectSQLite}, "TEXT", false},
{"Sqlite text", ColumnTypeText, args{dialect: DialectSQLite}, "TEXT", false},
{"Sqlite datetime", ColumnTypeDateTime, args{dialect: DialectSQLite}, "DATETIME", false},
{"MySQL bool", ColumnTypeBool, args{dialect: DialectMySQL}, "TINYINT(1)", false},
{"MySQL small int", ColumnTypeSmallInt, args{dialect: DialectMySQL}, "SMALLINT", false},
{"MySQL small int with param", ColumnTypeSmallInt, args{dialect: DialectMySQL, size: OptionalInt{true, 3}}, "SMALLINT(3)", false},
{"MySQL int", ColumnTypeInteger, args{dialect: DialectMySQL}, "INT", false},
{"MySQL int with param", ColumnTypeInteger, args{dialect: DialectMySQL, size: OptionalInt{true, 11}}, "INT(11)", false},
{"MySQL char", ColumnTypeChar, args{dialect: DialectMySQL}, "CHAR", false},
{"MySQL char with param", ColumnTypeChar, args{dialect: DialectMySQL, size: OptionalInt{true, 4}}, "CHAR(4)", false},
{"MySQL varchar", ColumnTypeVarChar, args{dialect: DialectMySQL}, "VARCHAR", false},
{"MySQL varchar with param", ColumnTypeVarChar, args{dialect: DialectMySQL, size: OptionalInt{true, 25}}, "VARCHAR(25)", false},
{"MySQL text", ColumnTypeText, args{dialect: DialectMySQL}, "TEXT", false},
{"MySQL datetime", ColumnTypeDateTime, args{dialect: DialectMySQL}, "DATETIME", false},
{"invalid column type", 10000, args{dialect: DialectMySQL}, "", true},
{"invalid dialect", ColumnTypeBool, args{dialect: 10000}, "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.d.Format(tt.args.dialect, tt.args.size)
if (err != nil) != tt.wantErr {
t.Errorf("Format() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("Format() got = %v, want %v", got, tt.want)
}
})
}
}
func TestColumn_Build(t *testing.T) {
type fields struct {
Dialect DialectType
Name string
Nullable bool
Default OptionalString
Type ColumnType
Size OptionalInt
PrimaryKey bool
}
tests := []struct {
name string
fields fields
want string
wantErr bool
}{
{"Sqlite bool", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo INTEGER NOT NULL", false},
{"Sqlite bool nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo INTEGER", false},
{"Sqlite small int", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeSmallInt, UnsetSize, true}, "foo INTEGER NOT NULL PRIMARY KEY", false},
{"Sqlite small int nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeSmallInt, UnsetSize, false}, "foo INTEGER", false},
{"Sqlite int", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INTEGER NOT NULL", false},
{"Sqlite int nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INTEGER", false},
{"Sqlite char", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo TEXT NOT NULL", false},
{"Sqlite char nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo TEXT", false},
{"Sqlite varchar", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo TEXT NOT NULL", false},
{"Sqlite varchar nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo TEXT", false},
{"Sqlite text", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT NOT NULL", false},
{"Sqlite text nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT", false},
{"Sqlite datetime", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME NOT NULL", false},
{"Sqlite datetime nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME", false},
{"MySQL bool", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo TINYINT(1) NOT NULL", false},
{"MySQL bool nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo TINYINT(1)", false},
{"MySQL small int", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeSmallInt, UnsetSize, true}, "foo SMALLINT NOT NULL PRIMARY KEY", false},
{"MySQL small int nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeSmallInt, UnsetSize, false}, "foo SMALLINT", false},
{"MySQL int", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INT NOT NULL", false},
{"MySQL int nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INT", false},
{"MySQL char", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo CHAR NOT NULL", false},
{"MySQL char nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo CHAR", false},
{"MySQL varchar", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo VARCHAR NOT NULL", false},
{"MySQL varchar nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo VARCHAR", false},
{"MySQL text", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT NOT NULL", false},
{"MySQL text nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT", false},
{"MySQL datetime", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME NOT NULL", false},
{"MySQL datetime nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Column{
Dialect: tt.fields.Dialect,
Name: tt.fields.Name,
Nullable: tt.fields.Nullable,
Default: tt.fields.Default,
Type: tt.fields.Type,
Size: tt.fields.Size,
PrimaryKey: tt.fields.PrimaryKey,
}
if got, err := c.String(); got != tt.want {
if (err != nil) != tt.wantErr {
t.Errorf("String() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("String() got = %v, want %v", got, tt.want)
}
}
})
}
}
func TestCreateTableSqlBuilder_ToSQL(t *testing.T) {
sql, err := DialectMySQL.
Table("foo").
SetIfNotExists(true).
Column(DialectMySQL.Column("bar", ColumnTypeInteger, UnsetSize).SetPrimaryKey(true)).
Column(DialectMySQL.Column("baz", ColumnTypeText, UnsetSize)).
Column(DialectMySQL.Column("qux", ColumnTypeDateTime, UnsetSize).SetDefault("NOW()")).
UniqueConstraint("bar").
UniqueConstraint("bar", "baz").
ToSQL()
assert.NoError(t, err)
assert.Equal(t, "CREATE TABLE IF NOT EXISTS foo ( bar INT NOT NULL PRIMARY KEY, baz TEXT NOT NULL, qux DATETIME NOT NULL DEFAULT NOW(), UNIQUE(bar), UNIQUE(bar,baz) )", sql)
}
diff --git a/email.go b/email.go
index 5651bee..e5283a9 100644
--- a/email.go
+++ b/email.go
@@ -1,512 +1,512 @@
/*
* Copyright © 2019-2021 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package writefreely
import (
"bytes"
"database/sql"
"encoding/json"
"fmt"
- "github.com/writefreely/writefreely/mailer"
"html/template"
"net/http"
"strings"
"time"
"github.com/aymerick/douceur/inliner"
"github.com/gorilla/mux"
stripmd "github.com/writeas/go-strip-markdown/v2"
"github.com/writeas/impart"
"github.com/writeas/web-core/data"
"github.com/writeas/web-core/log"
"github.com/writefreely/writefreely/key"
+ "github.com/writefreely/writefreely/mailer"
"github.com/writefreely/writefreely/spam"
)
const (
emailSendDelay = 15
)
type (
SubmittedSubscription struct {
CollAlias string
UserID int64
Email string `schema:"email" json:"email"`
Web bool `schema:"web" json:"web"`
Slug string `schema:"slug" json:"slug"`
From string `schema:"from" json:"from"`
}
EmailSubscriber struct {
ID string
CollID int64
UserID sql.NullInt64
Email sql.NullString
Subscribed time.Time
Token string
Confirmed bool
AllowExport bool
acctEmail sql.NullString
}
)
func (es *EmailSubscriber) FinalEmail(keys *key.Keychain) string {
if !es.UserID.Valid || es.Email.Valid {
return es.Email.String
}
decEmail, err := data.Decrypt(keys.EmailKey, []byte(es.acctEmail.String))
if err != nil {
log.Error("Error decrypting user email: %v", err)
return ""
}
return string(decEmail)
}
func (es *EmailSubscriber) SubscribedFriendly() string {
return es.Subscribed.Format("January 2, 2006")
}
func handleCreateEmailSubscription(app *App, w http.ResponseWriter, r *http.Request) error {
reqJSON := IsJSON(r)
vars := mux.Vars(r)
var err error
ss := SubmittedSubscription{
CollAlias: vars["alias"],
}
u := getUserSession(app, r)
if u != nil {
ss.UserID = u.ID
}
if reqJSON {
// Decode JSON request
decoder := json.NewDecoder(r.Body)
err = decoder.Decode(&ss)
if err != nil {
log.Error("Couldn't parse new subscription JSON request: %v\n", err)
return ErrBadJSON
}
} else {
err = r.ParseForm()
if err != nil {
log.Error("Couldn't parse new subscription form request: %v\n", err)
return ErrBadFormData
}
err = app.formDecoder.Decode(&ss, r.PostForm)
if err != nil {
log.Error("Continuing, but error decoding new subscription form request: %v\n", err)
//return ErrBadFormData
}
}
c, err := app.db.GetCollection(ss.CollAlias)
if err != nil {
log.Error("getCollection: %s", err)
return err
}
c.hostName = app.cfg.App.Host
from := c.CanonicalURL()
isAuthorBanned, err := app.db.IsUserSilenced(c.OwnerID)
if isAuthorBanned {
log.Info("Author is silenced, so subscription is blocked.")
return impart.HTTPError{http.StatusFound, from}
}
if ss.Web {
if u != nil && u.ID == c.OwnerID {
from = "/" + c.Alias + "/"
}
from += ss.Slug
}
if r.FormValue(spam.HoneypotFieldName()) != "" || r.FormValue("fake_password") != "" {
log.Info("Honeypot field was filled out! Not subscribing.")
return impart.HTTPError{http.StatusFound, from}
}
if ss.Email == "" && ss.UserID < 1 {
log.Info("No subscriber data. Not subscribing.")
return impart.HTTPError{http.StatusFound, from}
}
confirmed := app.db.IsSubscriberConfirmed(ss.Email)
es, err := app.db.AddEmailSubscription(c.ID, ss.UserID, ss.Email, confirmed)
if err != nil {
log.Error("addEmailSubscription: %s", err)
return err
}
// Send confirmation email if needed
if !confirmed {
err = sendSubConfirmEmail(app, c, ss.Email, es.ID, es.Token)
if err != nil {
log.Error("Failed to send subscription confirmation email: %s", err)
return err
}
}
if ss.Web {
session, err := app.sessionStore.Get(r, userEmailCookieName)
if err != nil {
// The cookie should still save, even if there's an error.
// Source: https://github.com/gorilla/sessions/issues/16#issuecomment-143642144
log.Error("Getting user email cookie: %v; ignoring", err)
}
if confirmed {
addSessionFlash(app, w, r, "<strong>Subscribed</strong>. You'll now receive future blog posts via email.", nil)
} else {
addSessionFlash(app, w, r, "Please check your email and <strong>click the confirmation link</strong> to subscribe.", nil)
}
session.Values[userEmailCookieVal] = ss.Email
err = session.Save(r, w)
if err != nil {
log.Error("save email cookie: %s", err)
return err
}
return impart.HTTPError{http.StatusFound, from}
}
return impart.WriteSuccess(w, "", http.StatusAccepted)
}
func handleExportEmailSubscriptions(app *App, w http.ResponseWriter, r *http.Request) ([]byte, string, error) {
vars := mux.Vars(r)
var err error
alias := vars["alias"]
filename := ""
u := getUserSession(app, r)
if u == nil {
return nil, filename, ErrNotLoggedIn
}
c, err := app.db.GetCollection(alias)
if err != nil {
return nil, filename, err
}
// Verify permissions / ownership
if u.ID != c.OwnerID {
return nil, filename, ErrForbiddenCollectionAccess
}
filename = "subscribers-" + alias + "-" + time.Now().Truncate(time.Second).UTC().Format("200601021504")
subs, err := app.db.GetEmailSubscribers(c.ID, true)
if err != nil {
return nil, filename, err
}
var data []byte
for _, sub := range subs {
data = append(data, []byte(sub.Email.String+"\n")...)
}
data = bytes.TrimRight(data, "\n")
return data, filename, err
}
func handleDeleteEmailSubscription(app *App, w http.ResponseWriter, r *http.Request) error {
alias := collectionAliasFromReq(r)
vars := mux.Vars(r)
subID := vars["subscriber"]
email := r.FormValue("email")
token := r.FormValue("t")
slug := r.FormValue("slug")
isWeb := r.Method == "GET"
// Display collection if this is a collection
var c *Collection
var err error
if app.cfg.App.SingleUser {
c, err = app.db.GetCollectionByID(1)
} else {
c, err = app.db.GetCollection(alias)
}
if err != nil {
log.Error("Get collection: %s", err)
return err
}
from := c.CanonicalURL()
if subID != "" {
// User unsubscribing via email, so assume action is taken by either current
// user or not current user, and only use the request's information to
// satisfy this unsubscribe, i.e. subscriberID and token.
err = app.db.DeleteEmailSubscriber(subID, token)
} else {
// User unsubscribing through the web app, so assume action is taken by
// currently-auth'd user.
var userID int64
u := getUserSession(app, r)
if u != nil {
// User is logged in
userID = u.ID
if userID == c.OwnerID {
from = "/" + c.Alias + "/"
}
}
if email == "" && userID <= 0 {
// Get email address from saved cookie
session, err := app.sessionStore.Get(r, userEmailCookieName)
if err != nil {
log.Error("Unable to get email cookie: %s", err)
} else {
email = session.Values[userEmailCookieVal].(string)
}
}
if email == "" && userID <= 0 {
err = fmt.Errorf("No subscriber given.")
log.Error("Not deleting subscription: %s", err)
return err
}
err = app.db.DeleteEmailSubscriberByUser(email, userID, c.ID)
}
if err != nil {
log.Error("Unable to delete subscriber: %v", err)
return err
}
if isWeb {
from += slug
addSessionFlash(app, w, r, "<strong>Unsubscribed</strong>. You will no longer receive these blog posts via email.", nil)
return impart.HTTPError{http.StatusFound, from}
}
return impart.WriteSuccess(w, "", http.StatusAccepted)
}
func handleConfirmEmailSubscription(app *App, w http.ResponseWriter, r *http.Request) error {
alias := collectionAliasFromReq(r)
subID := mux.Vars(r)["subscriber"]
token := r.FormValue("t")
var c *Collection
var err error
if app.cfg.App.SingleUser {
c, err = app.db.GetCollectionByID(1)
} else {
c, err = app.db.GetCollection(alias)
}
if err != nil {
log.Error("Get collection: %s", err)
return err
}
from := c.CanonicalURL()
err = app.db.UpdateSubscriberConfirmed(subID, token)
if err != nil {
addSessionFlash(app, w, r, err.Error(), nil)
return impart.HTTPError{http.StatusFound, from}
}
addSessionFlash(app, w, r, "<strong>Confirmed</strong>! Thanks. Now you'll receive future blog posts via email.", nil)
return impart.HTTPError{http.StatusFound, from}
}
func emailPost(app *App, p *PublicPost, collID int64) error {
p.augmentContent()
// Do some shortcode replacement.
// Since the user is receiving this email, we can assume they're subscribed via email.
p.Content = strings.Replace(p.Content, "<!--emailsub-->", `<p id="emailsub">You're subscribed to email updates.</p>`, -1)
if p.HTMLContent == template.HTML("") {
p.formatContent(app.cfg, false, false)
}
p.augmentReadingDestination()
title := p.Title.String
if title != "" {
title = p.Title.String + "\n\n"
}
plainMsg := title + "A new post from " + p.CanonicalURL(app.cfg.App.Host) + "\n\n" + stripmd.Strip(p.Content)
plainMsg += `
---------------------------------------------------------------------------------
Originally published on ` + p.Collection.DisplayTitle() + ` (` + p.Collection.CanonicalURL() + `), a blog you subscribe to.
Sent to %recipient.to%. Unsubscribe: ` + p.Collection.CanonicalURL() + `email/unsubscribe/%recipient.id%?t=%recipient.token%`
mlr, err := mailer.New(app.cfg.Email)
if err != nil {
return err
}
m, err := mlr.NewMessage(p.Collection.DisplayTitle()+" <"+p.Collection.Alias+"@"+app.cfg.Email.Domain+">", stripmd.Strip(p.DisplayTitle()), plainMsg)
if err != nil {
return err
}
replyTo := app.db.GetCollectionAttribute(collID, collAttrLetterReplyTo)
if replyTo != "" {
m.SetReplyTo(replyTo)
}
subs, err := app.db.GetEmailSubscribers(collID, true)
if err != nil {
log.Error("Unable to get email subscribers: %v", err)
return err
}
if len(subs) == 0 {
return nil
}
if title != "" {
title = string(`<h2 id="title">` + p.FormattedDisplayTitle() + `</h2>`)
}
m.AddTag("New post")
fontFam := "Lora, Palatino, Baskerville, serif"
if p.IsSans() {
fontFam = `"Open Sans", Tahoma, Arial, sans-serif`
} else if p.IsMonospace() {
fontFam = `Hack, consolas, Menlo-Regular, Menlo, Monaco, monospace, monospace`
}
// TODO: move this to a templated file and LESS-generated stylesheet
fullHTML := `<html>
<head>
<style>
body {
font-size: 120%;
font-family: ` + fontFam + `;
margin: 1em 2em;
}
#article {
line-height: 1.5;
margin: 1.5em 0;
white-space: pre-wrap;
word-wrap: break-word;
}
h1, h2, h3, h4, h5, h6, p, code {
display: inline
}
img, iframe, video {
max-width: 100%
}
#title {
margin-bottom: 1em;
display: block;
}
.intro {
font-style: italic;
font-size: 0.95em;
}
div#footer {
text-align: center;
max-width: 35em;
margin: 2em auto;
}
div#footer p {
display: block;
font-size: 0.86em;
color: #666;
}
hr {
border: 1px solid #ccc;
margin: 2em 1em;
}
p#emailsub {
text-align: center;
display: inline-block !important;
width: 100%;
font-style: italic;
}
</style>
</head>
<body>
<div id="article">` + title + `<p class="intro">From <a href="` + p.CanonicalURL(app.cfg.App.Host) + `">` + p.DisplayCanonicalURL() + `</a></p>
` + string(p.HTMLContent) + `</div>
<hr />
<div id="footer">
<p>Originally published on <a href="` + p.Collection.CanonicalURL() + `">` + p.Collection.DisplayTitle() + `</a>, a blog you subscribe to.</p>
<p>Sent to %recipient.to%. <a href="` + p.Collection.CanonicalURL() + `email/unsubscribe/%recipient.id%?t=%recipient.token%">Unsubscribe</a>.</p>
</div>
</body>
</html>`
// inline CSS
html, err := inliner.Inline(fullHTML)
if err != nil {
log.Error("Unable to inline email HTML: %v", err)
return err
}
m.SetHTML(html)
log.Info("[email] Adding %d recipient(s)", len(subs))
for _, s := range subs {
e := s.FinalEmail(app.keys)
log.Info("[email] Adding %s", e)
err = m.AddRecipientAndVariables(e, map[string]string{
"id": s.ID,
"to": e,
"token": s.Token,
})
if err != nil {
log.Error("Unable to add receipient %s: %s", e, err)
}
}
err = mlr.Send(m)
log.Info("[email] Email sent")
if err != nil {
log.Error("Unable to send post email: %v", err)
return err
}
return nil
}
func sendSubConfirmEmail(app *App, c *Collection, email, subID, token string) error {
if email == "" {
return fmt.Errorf("You must supply an email to verify.")
}
// Send email
mlr, err := mailer.New(app.cfg.Email)
if err != nil {
return err
}
plainMsg := "Confirm your subscription to " + c.DisplayTitle() + ` (` + c.CanonicalURL() + `) to start receiving future posts. Simply click the following link (or copy and paste it into your browser):
` + c.CanonicalURL() + "email/confirm/" + subID + "?t=" + token + `
If you didn't subscribe to this site or you're not sure why you're getting this email, you can delete it. You won't be subscribed or receive any future emails.`
m, err := mlr.NewMessage(c.DisplayTitle()+" <"+c.Alias+"@"+app.cfg.Email.Domain+">", "Confirm your subscription to "+c.DisplayTitle(), plainMsg, fmt.Sprintf("<%s>", email))
if err != nil {
return err
}
m.AddTag("Email Verification")
m.SetHTML(`<html>
<body style="font-family:Lora, 'Palatino Linotype', Palatino, Baskerville, 'Book Antiqua', 'New York', 'DejaVu serif', serif; font-size: 100%%; margin:1em 2em;">
<div style="font-size: 1.2em;">
<p>Confirm your subscription to <a href="` + c.CanonicalURL() + `">` + c.DisplayTitle() + `</a> to start receiving future posts:</p>
<p><a href="` + c.CanonicalURL() + `email/confirm/` + subID + `?t=` + token + `">Subscribe to ` + c.DisplayTitle() + `</a></p>
<p>If you didn't subscribe to this site or you're not sure why you're getting this email, you can delete it. You won't be subscribed or receive any future emails.</p>
</div>
</body>
</html>`)
err = mlr.Send(m)
if err != nil {
return err
}
return nil
}
diff --git a/jobs.go b/jobs.go
index 251b82d..02fcfeb 100644
--- a/jobs.go
+++ b/jobs.go
@@ -1,72 +1,73 @@
package writefreely
import (
- "github.com/writeas/web-core/log"
"time"
+
+ "github.com/writeas/web-core/log"
)
type PostJob struct {
ID int64
PostID string
Action string
Delay int64
}
func addJob(app *App, p *PublicPost, action string, delay int64) error {
j := &PostJob{
PostID: p.ID,
Action: action,
Delay: delay,
}
return app.db.InsertJob(j)
}
func startPublishJobsQueue(app *App) {
t := time.NewTicker(62 * time.Second)
for {
log.Info("[jobs] Done.")
<-t.C
log.Info("[jobs] Fetching email publish jobs...")
jobs, err := app.db.GetJobsToRun("email")
if err != nil {
log.Error("[jobs] %s - Skipping.", err)
continue
}
log.Info("[jobs] Running %d email publish jobs...", len(jobs))
err = runJobs(app, jobs, true)
if err != nil {
log.Error("[jobs] Failed: %s", err)
}
}
}
func runJobs(app *App, jobs []*PostJob, reqColl bool) error {
for _, j := range jobs {
p, err := app.db.GetPost(j.PostID, 0)
if err != nil {
log.Info("[job #%d] Unable to get post: %s", j.ID, err)
continue
}
if !p.CollectionID.Valid && reqColl {
log.Info("[job #%d] Post %s not part of a collection", j.ID, p.ID)
app.db.DeleteJob(j.ID)
continue
}
coll, err := app.db.GetCollectionByID(p.CollectionID.Int64)
if err != nil {
log.Info("[job #%d] Unable to get collection: %s", j.ID, err)
continue
}
coll.hostName = app.cfg.App.Host
coll.ForPublic()
p.Collection = &CollectionObj{Collection: *coll}
err = emailPost(app, p, p.Collection.ID)
if err != nil {
log.Error("[job #%d] Failed to email post %s", j.ID, p.ID)
continue
}
log.Info("[job #%d] Success for post %s.", j.ID, p.ID)
app.db.DeleteJob(j.ID)
}
return nil
}
diff --git a/keys.go b/keys.go
index b5896f7..fa2adb5 100644
--- a/keys.go
+++ b/keys.go
@@ -1,74 +1,75 @@
/*
* Copyright © 2018-2019, 2021 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package writefreely
import (
- "github.com/writeas/web-core/log"
- "github.com/writefreely/writefreely/key"
"os"
"path/filepath"
+
+ "github.com/writeas/web-core/log"
+ "github.com/writefreely/writefreely/key"
)
const (
keysDir = "keys"
)
var (
emailKeyPath = filepath.Join(keysDir, "email.aes256")
cookieAuthKeyPath = filepath.Join(keysDir, "cookies_auth.aes256")
cookieKeyPath = filepath.Join(keysDir, "cookies_enc.aes256")
csrfKeyPath = filepath.Join(keysDir, "csrf.aes256")
)
// InitKeys loads encryption keys into memory via the given Apper interface
func InitKeys(apper Apper) error {
log.Info("Loading encryption keys...")
err := apper.LoadKeys()
if err != nil {
return err
}
return nil
}
func initKeyPaths(app *App) {
emailKeyPath = filepath.Join(app.cfg.Server.KeysParentDir, emailKeyPath)
cookieAuthKeyPath = filepath.Join(app.cfg.Server.KeysParentDir, cookieAuthKeyPath)
cookieKeyPath = filepath.Join(app.cfg.Server.KeysParentDir, cookieKeyPath)
csrfKeyPath = filepath.Join(app.cfg.Server.KeysParentDir, csrfKeyPath)
}
// generateKey generates a key at the given path used for the encryption of
// certain user data. Because user data becomes unrecoverable without these
// keys, this won't overwrite any existing key, and instead outputs a message.
func generateKey(path string) error {
// Check if key file exists
if _, err := os.Stat(path); err == nil {
log.Info("%s already exists. rm the file if you understand the consequences.", path)
return nil
} else if !os.IsNotExist(err) {
log.Error("%s", err)
return err
}
log.Info("Generating %s.", path)
b, err := key.GenerateBytes(key.EncKeysBytes)
if err != nil {
log.Error("FAILED. %s. Run writefreely --gen-keys again.", err)
return err
}
err = os.WriteFile(path, b, 0600)
if err != nil {
log.Error("FAILED writing file: %s", err)
return err
}
log.Info("Success.")
return nil
}
diff --git a/mailer/mailer.go b/mailer/mailer.go
index 30892e6..a08643c 100644
--- a/mailer/mailer.go
+++ b/mailer/mailer.go
@@ -1,181 +1,182 @@
/*
* Copyright © 2024 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package mailer
import (
"fmt"
+ "strings"
+
"github.com/mailgun/mailgun-go"
"github.com/writeas/web-core/log"
"github.com/writefreely/writefreely/config"
mail "github.com/xhit/go-simple-mail/v2"
- "strings"
)
type (
// Mailer holds configurations for the preferred mailing provider.
Mailer struct {
smtp *mail.SMTPServer
mailGun *mailgun.MailgunImpl
}
// Message holds the email contents and metadata for the preferred mailing provider.
Message struct {
mgMsg *mailgun.Message
smtpMsg *SmtpMessage
}
SmtpMessage struct {
from string
replyTo string
subject string
recipients []Recipient
html string
text string
}
Recipient struct {
email string
vars map[string]string
}
)
// New creates a new Mailer from the instance's config.EmailCfg, returning an error if not properly configured.
func New(eCfg config.EmailCfg) (*Mailer, error) {
m := &Mailer{}
if eCfg.Domain != "" && eCfg.MailgunPrivate != "" {
m.mailGun = mailgun.NewMailgun(eCfg.Domain, eCfg.MailgunPrivate)
if eCfg.MailgunEurope {
m.mailGun.SetAPIBase("https://api.eu.mailgun.net/v3")
}
} else if eCfg.Username != "" && eCfg.Password != "" && eCfg.Host != "" && eCfg.Port > 0 {
m.smtp = mail.NewSMTPClient()
m.smtp.Host = eCfg.Host
m.smtp.Port = eCfg.Port
m.smtp.Username = eCfg.Username
m.smtp.Password = eCfg.Password
if eCfg.EnableStartTLS {
m.smtp.Encryption = mail.EncryptionSTARTTLS
}
// To allow sending multiple email
m.smtp.KeepAlive = true
} else {
return nil, fmt.Errorf("no email provider is configured")
}
return m, nil
}
// NewMessage creates a new Message from the given parameters.
func (m *Mailer) NewMessage(from, subject, text string, to ...string) (*Message, error) {
msg := &Message{}
if m.mailGun != nil {
msg.mgMsg = m.mailGun.NewMessage(from, subject, text, to...)
} else if m.smtp != nil {
msg.smtpMsg = &SmtpMessage{
from: from,
replyTo: "",
subject: subject,
recipients: make([]Recipient, len(to)),
html: "",
text: text,
}
for _, r := range to {
msg.smtpMsg.recipients = append(msg.smtpMsg.recipients, Recipient{r, make(map[string]string)})
}
}
return msg, nil
}
// SetHTML sets the body of the message.
func (m *Message) SetHTML(html string) {
if m.smtpMsg != nil {
m.smtpMsg.html = html
} else if m.mgMsg != nil {
m.mgMsg.SetHtml(html)
}
}
func (m *Message) SetReplyTo(replyTo string) {
if m.smtpMsg != nil {
m.smtpMsg.replyTo = replyTo
} else {
m.mgMsg.SetReplyTo(replyTo)
}
}
// AddTag attaches a tag to the Message for providers that support it.
func (m *Message) AddTag(tag string) {
if m.mgMsg != nil {
m.mgMsg.AddTag(tag)
}
}
func (m *Message) AddRecipientAndVariables(r string, vars map[string]string) error {
if m.smtpMsg != nil {
m.smtpMsg.recipients = append(m.smtpMsg.recipients, Recipient{r, vars})
return nil
} else {
varsInterfaces := make(map[string]interface{}, len(vars))
for k, v := range vars {
varsInterfaces[k] = v
}
return m.mgMsg.AddRecipientAndVariables(r, varsInterfaces)
}
}
// Send sends the given message via the preferred provider.
func (m *Mailer) Send(msg *Message) error {
if m.smtp != nil {
client, err := m.smtp.Connect()
if err != nil {
return err
}
emailSent := false
for _, r := range msg.smtpMsg.recipients {
customMsg := mail.NewMSG()
customMsg.SetFrom(msg.smtpMsg.from)
if msg.smtpMsg.replyTo != "" {
customMsg.SetReplyTo(msg.smtpMsg.replyTo)
}
customMsg.SetSubject(msg.smtpMsg.subject)
customMsg.AddTo(r.email)
cText := msg.smtpMsg.text
cHtml := msg.smtpMsg.html
for v, value := range r.vars {
placeHolder := fmt.Sprintf("%%recipient.%s%%", v)
cText = strings.ReplaceAll(cText, placeHolder, value)
cHtml = strings.ReplaceAll(cHtml, placeHolder, value)
}
customMsg.SetBody(mail.TextHTML, cHtml)
customMsg.AddAlternative(mail.TextPlain, cText)
e := customMsg.Error
if e == nil {
e = customMsg.Send(client)
}
if e == nil {
emailSent = true
} else {
log.Error("Unable to send email to %s: %v", r.email, e)
err = e
}
}
if !emailSent {
// only send an error if no email could be sent (to avoid retry of successfully sent emails)
return err
}
} else if m.mailGun != nil {
_, _, err := m.mailGun.Send(msg.mgMsg)
if err != nil {
return err
}
}
return nil
}
diff --git a/main_test.go b/main_test.go
index 9db7a7e..4d28143 100644
--- a/main_test.go
+++ b/main_test.go
@@ -1,153 +1,154 @@
package writefreely
import (
"context"
"database/sql"
"encoding/gob"
"errors"
"fmt"
- uuid "github.com/nu7hatch/gouuid"
- "github.com/stretchr/testify/assert"
"math/rand"
"os"
"strings"
"testing"
"time"
+
+ uuid "github.com/nu7hatch/gouuid"
+ "github.com/stretchr/testify/assert"
)
var testDB *sql.DB
type ScopedTestBody func(*sql.DB)
// TestMain provides testing infrastructure within this package.
func TestMain(m *testing.M) {
rand.Seed(time.Now().UTC().UnixNano())
gob.Register(&User{})
if runMySQLTests() {
var err error
testDB, err = initMySQL(os.Getenv("WF_USER"), os.Getenv("WF_PASSWORD"), os.Getenv("WF_DB"), os.Getenv("WF_HOST"))
if err != nil {
fmt.Println(err)
return
}
}
code := m.Run()
if runMySQLTests() {
if closeErr := testDB.Close(); closeErr != nil {
fmt.Println(closeErr)
}
}
os.Exit(code)
}
func runMySQLTests() bool {
return len(os.Getenv("TEST_MYSQL")) > 0
}
func initMySQL(dbUser, dbPassword, dbName, dbHost string) (*sql.DB, error) {
if dbUser == "" || dbPassword == "" {
return nil, errors.New("database user or password not set")
}
if dbHost == "" {
dbHost = "localhost"
}
if dbName == "" {
dbName = "writefreely"
}
dsn := fmt.Sprintf("%s:%s@tcp(%s:3306)/%s?charset=utf8mb4&parseTime=true", dbUser, dbPassword, dbHost, dbName)
db, err := sql.Open("mysql", dsn)
if err != nil {
return nil, err
}
if err := ensureMySQL(db); err != nil {
return nil, err
}
return db, nil
}
func ensureMySQL(db *sql.DB) error {
if err := db.Ping(); err != nil {
return err
}
db.SetMaxOpenConns(250)
return nil
}
// withTestDB provides a scoped database connection.
func withTestDB(t *testing.T, testBody ScopedTestBody) {
db, cleanup, err := newTestDatabase(testDB,
os.Getenv("WF_USER"),
os.Getenv("WF_PASSWORD"),
os.Getenv("WF_DB"),
os.Getenv("WF_HOST"),
)
assert.NoError(t, err)
defer func() {
assert.NoError(t, cleanup())
}()
testBody(db)
}
// newTestDatabase creates a new temporary test database. When a test
// database connection is returned, it will have created a new database and
// initialized it with tables from a reference database.
func newTestDatabase(base *sql.DB, dbUser, dbPassword, dbName, dbHost string) (*sql.DB, func() error, error) {
var err error
var baseName = dbName
if baseName == "" {
row := base.QueryRow("SELECT DATABASE()")
err := row.Scan(&baseName)
if err != nil {
return nil, nil, err
}
}
tUUID, _ := uuid.NewV4()
suffix := strings.Replace(tUUID.String(), "-", "_", -1)
newDBName := baseName + suffix
_, err = base.Exec("CREATE DATABASE " + newDBName)
if err != nil {
return nil, nil, err
}
newDB, err := initMySQL(dbUser, dbPassword, newDBName, dbHost)
if err != nil {
return nil, nil, err
}
rows, err := base.Query("SHOW TABLES IN " + baseName)
if err != nil {
return nil, nil, err
}
for rows.Next() {
var tableName string
if err := rows.Scan(&tableName); err != nil {
return nil, nil, err
}
query := fmt.Sprintf("CREATE TABLE %s LIKE %s.%s", tableName, baseName, tableName)
if _, err := newDB.Exec(query); err != nil {
return nil, nil, err
}
}
cleanup := func() error {
if closeErr := newDB.Close(); closeErr != nil {
fmt.Println(closeErr)
}
_, err = base.Exec("DROP DATABASE " + newDBName)
return err
}
return newDB, cleanup, nil
}
func countRows(t *testing.T, ctx context.Context, db *sql.DB, count int, query string, args ...interface{}) {
var returned int
err := db.QueryRowContext(ctx, query, args...).Scan(&returned)
assert.NoError(t, err, "error executing query %s and args %s", query, args)
assert.Equal(t, count, returned, "unexpected return count %d, expected %d from %s and args %s", returned, count, query, args)
}
diff --git a/monetization.go b/monetization.go
index 4d6b42b..05e81db 100644
--- a/monetization.go
+++ b/monetization.go
@@ -1,160 +1,161 @@
/*
* Copyright © 2020-2021 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package writefreely
import (
"bytes"
"fmt"
- "github.com/gorilla/mux"
- "github.com/writeas/impart"
- "github.com/writeas/web-core/log"
"io"
"net/http"
"net/url"
"os"
"strings"
+
+ "github.com/gorilla/mux"
+ "github.com/writeas/impart"
+ "github.com/writeas/web-core/log"
)
func displayMonetization(monetization, alias string) string {
if monetization == "" {
return ""
}
ptrURL, err := url.Parse(strings.Replace(monetization, "$", "https://", 1))
if err == nil {
if strings.HasSuffix(ptrURL.Host, ".xrptipbot.com") {
// xrp tip bot doesn't support stream receipts, so return plain pointer
return monetization
}
}
u := os.Getenv("PAYMENT_HOST")
if u == "" {
return "$webmonetization.org/api/receipts/" + url.PathEscape(monetization)
}
u += "/" + alias
return u
}
func handleSPSPEndpoint(app *App, w http.ResponseWriter, r *http.Request) error {
idStr := r.FormValue("id")
id, err := url.QueryUnescape(idStr)
if err != nil {
log.Error("Unable to unescape: %s", err)
return err
}
var c *Collection
if strings.IndexRune(id, '.') > 0 && app.cfg.App.SingleUser {
c, err = app.db.GetCollectionByID(1)
} else {
c, err = app.db.GetCollection(id)
}
if err != nil {
return err
}
pointer := c.Monetization
if pointer == "" {
err := impart.HTTPError{http.StatusNotFound, "No monetization pointer."}
return err
}
fmt.Fprintf(w, pointer)
return nil
}
func handleGetSplitContent(app *App, w http.ResponseWriter, r *http.Request) error {
var collID int64
var collLookupID string
var coll *Collection
var err error
vars := mux.Vars(r)
if collAlias := vars["alias"]; collAlias != "" {
// Fetch collection information, since an alias is provided
coll, err = app.db.GetCollection(collAlias)
if err != nil {
return err
}
collID = coll.ID
collLookupID = coll.Alias
}
p, err := app.db.GetPost(vars["post"], collID)
if err != nil {
return err
}
receipt := r.FormValue("receipt")
if receipt == "" {
return impart.HTTPError{http.StatusBadRequest, "No `receipt` given."}
}
err = verifyReceipt(receipt, collLookupID)
if err != nil {
return err
}
d := struct {
Content string `json:"body"`
HTMLContent string `json:"html_body"`
}{}
if exc := strings.Index(p.Content, shortCodePaid); exc > -1 {
baseURL := ""
if coll != nil {
baseURL = coll.CanonicalURL()
}
d.Content = p.Content[exc+len(shortCodePaid):]
d.HTMLContent = applyMarkdown([]byte(d.Content), baseURL, app.cfg)
}
return impart.WriteSuccess(w, d, http.StatusOK)
}
func verifyReceipt(receipt, id string) error {
receiptsHost := os.Getenv("RECEIPTS_HOST")
if receiptsHost == "" {
receiptsHost = "https://webmonetization.org/api/receipts/verify?id=" + id
} else {
receiptsHost = fmt.Sprintf("%s/receipts?id=%s", receiptsHost, id)
}
log.Info("Verifying receipt %s at %s", receipt, receiptsHost)
r, err := http.NewRequest("POST", receiptsHost, bytes.NewBufferString(receipt))
if err != nil {
log.Error("Unable to create new request to %s: %s", receiptsHost, err)
return err
}
resp, err := http.DefaultClient.Do(r)
if err != nil {
log.Error("Unable to Do() request to %s: %s", receiptsHost, err)
return err
}
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Error("Unable to read %s response body: %s", receiptsHost, err)
return err
}
log.Info("Status : %s", resp.Status)
log.Info("Response: %s", body)
if resp.StatusCode != http.StatusOK {
log.Error("Bad response from %s:\nStatus: %d\n%s", receiptsHost, resp.StatusCode, string(body))
return impart.HTTPError{resp.StatusCode, string(body)}
}
return nil
}
diff --git a/nodeinfo.go b/nodeinfo.go
index 5a38d96..608aa74 100644
--- a/nodeinfo.go
+++ b/nodeinfo.go
@@ -1,121 +1,122 @@
/*
* Copyright © 2018-2019, 2021 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package writefreely
import (
+ "strings"
+
"github.com/writeas/web-core/log"
"github.com/writefreely/go-nodeinfo"
"github.com/writefreely/writefreely/config"
- "strings"
)
type nodeInfoResolver struct {
cfg *config.Config
db *datastore
}
func nodeInfoConfig(db *datastore, cfg *config.Config) *nodeinfo.Config {
name := cfg.App.SiteName
desc := cfg.App.SiteDesc
if desc == "" {
desc = "Minimal, federated blogging platform."
}
if cfg.App.SingleUser {
// Fetch blog information, instead
coll, err := db.GetCollectionByID(1)
if err == nil {
desc = coll.Description
}
}
return &nodeinfo.Config{
BaseURL: cfg.App.Host,
InfoURL: "/api/nodeinfo",
Metadata: nodeinfo.Metadata{
NodeName: name,
NodeDescription: desc,
Private: cfg.App.Private,
Software: nodeinfo.SoftwareMeta{
HomePage: softwareURL,
GitHub: "https://github.com/writefreely/writefreely",
Follow: "https://writing.exchange/@writefreely",
},
MaxBlogs: cfg.App.MaxBlogs,
PublicReader: cfg.App.LocalTimeline,
Invites: cfg.App.UserInvites != "",
},
Protocols: []nodeinfo.NodeProtocol{
nodeinfo.ProtocolActivityPub,
},
Services: nodeinfo.Services{
Inbound: []nodeinfo.NodeService{},
Outbound: []nodeinfo.NodeService{
nodeinfo.ServiceRSS,
},
},
Software: nodeinfo.SoftwareInfo{
Name: strings.ToLower(serverSoftware),
Version: softwareVer,
},
}
}
func (r nodeInfoResolver) IsOpenRegistration() (bool, error) {
return r.cfg.App.OpenRegistration, nil
}
func (r nodeInfoResolver) Usage() (nodeinfo.Usage, error) {
var collCount, postCount int64
var activeHalfYear, activeMonth int
var err error
collCount, err = r.db.GetTotalCollections()
if err != nil {
collCount = 0
}
postCount, err = r.db.GetTotalPosts()
if err != nil {
log.Error("Unable to fetch post counts: %v", err)
}
if r.cfg.App.PublicStats {
// Display bi-yearly / monthly stats
err = r.db.QueryRow(`SELECT COUNT(*) FROM (
SELECT DISTINCT collection_id
FROM posts
INNER JOIN collections c
ON collection_id = c.id
WHERE collection_id IS NOT NULL
AND updated > DATE_SUB(NOW(), INTERVAL 6 MONTH)) co`).Scan(&activeHalfYear)
if err != nil {
log.Error("Failed getting 6-month active user stats: %s", err)
}
err = r.db.QueryRow(`SELECT COUNT(*) FROM (
SELECT DISTINCT collection_id
FROM posts
INNER JOIN collections c
ON collection_id = c.id
WHERE collection_id IS NOT NULL
AND updated > DATE_SUB(NOW(), INTERVAL 1 MONTH)) co`).Scan(&activeMonth)
if err != nil {
log.Error("Failed getting 1-month active user stats: %s", err)
}
}
return nodeinfo.Usage{
Users: nodeinfo.UsageUsers{
Total: int(collCount),
ActiveHalfYear: activeHalfYear,
ActiveMonth: activeMonth,
},
LocalPosts: int(postCount),
}, nil
}
diff --git a/oauth_generic.go b/oauth_generic.go
index 36f166c..6fa09d7 100644
--- a/oauth_generic.go
+++ b/oauth_generic.go
@@ -1,144 +1,145 @@
/*
* Copyright © 2020-2021 Musing Studio LLC and respective authors.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package writefreely
import (
"context"
"errors"
"fmt"
- "github.com/writeas/web-core/log"
"net/http"
"net/url"
"strings"
+
+ "github.com/writeas/web-core/log"
)
type genericOauthClient struct {
ClientID string
ClientSecret string
AuthLocation string
ExchangeLocation string
InspectLocation string
CallbackLocation string
Scope string
MapUserID string
MapUsername string
MapDisplayName string
MapEmail string
HttpClient HttpClient
}
var _ oauthClient = genericOauthClient{}
const (
genericOauthDisplayName = "OAuth"
)
func (c genericOauthClient) GetProvider() string {
return "generic"
}
func (c genericOauthClient) GetClientID() string {
return c.ClientID
}
func (c genericOauthClient) GetCallbackLocation() string {
return c.CallbackLocation
}
func (c genericOauthClient) buildLoginURL(state string) (string, error) {
u, err := url.Parse(c.AuthLocation)
if err != nil {
return "", err
}
q := u.Query()
q.Set("client_id", c.ClientID)
q.Set("redirect_uri", c.CallbackLocation)
q.Set("response_type", "code")
q.Set("state", state)
q.Set("scope", c.Scope)
u.RawQuery = q.Encode()
return u.String(), nil
}
func (c genericOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) {
form := url.Values{}
form.Add("client_id", c.ClientID)
form.Add("client_secret", c.ClientSecret)
form.Add("grant_type", "authorization_code")
form.Add("redirect_uri", c.CallbackLocation)
form.Add("scope", c.Scope)
form.Add("code", code)
req, err := http.NewRequest("POST", c.ExchangeLocation, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", ServerUserAgent(""))
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(c.ClientID, c.ClientSecret)
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, errors.New("unable to exchange code for access token")
}
var tokenResponse TokenResponse
if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil {
return nil, err
}
if tokenResponse.Error != "" {
return nil, errors.New(tokenResponse.Error)
}
return &tokenResponse, nil
}
func (c genericOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) {
req, err := http.NewRequest("GET", c.InspectLocation, nil)
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", ServerUserAgent(""))
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, errors.New("unable to inspect access token")
}
// since we don't know what the JSON from the server will look like, we create a
// generic interface and then map manually to values set in the config
var genericInterface map[string]interface{}
if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &genericInterface); err != nil {
return nil, err
}
// map each relevant field in inspectResponse to the mapped field from the config
var inspectResponse InspectResponse
inspectResponse.UserID, _ = genericInterface[c.MapUserID].(string)
if inspectResponse.UserID == "" {
log.Error("[CONFIGURATION ERROR] Generic OAuth provider returned empty UserID value (`%s`).\n Do you need to configure a different `map_user_id` value for this provider?", c.MapUserID)
return nil, fmt.Errorf("no UserID (`%s`) value returned", c.MapUserID)
}
inspectResponse.Username, _ = genericInterface[c.MapUsername].(string)
inspectResponse.DisplayName, _ = genericInterface[c.MapDisplayName].(string)
inspectResponse.Email, _ = genericInterface[c.MapEmail].(string)
return &inspectResponse, nil
}
diff --git a/oauth_gitea.go b/oauth_gitea.go
index 3ecd4a7..249007f 100644
--- a/oauth_gitea.go
+++ b/oauth_gitea.go
@@ -1,133 +1,134 @@
package writefreely
import (
"context"
"errors"
"fmt"
- "github.com/writeas/web-core/log"
"net/http"
"net/url"
"strings"
+
+ "github.com/writeas/web-core/log"
)
type giteaOauthClient struct {
ClientID string
ClientSecret string
AuthLocation string
ExchangeLocation string
InspectLocation string
CallbackLocation string
Scope string
MapUserID string
MapUsername string
MapDisplayName string
MapEmail string
HttpClient HttpClient
}
var _ oauthClient = giteaOauthClient{}
const (
giteaDisplayName = "Gitea"
)
func (c giteaOauthClient) GetProvider() string {
return "gitea"
}
func (c giteaOauthClient) GetClientID() string {
return c.ClientID
}
func (c giteaOauthClient) GetCallbackLocation() string {
return c.CallbackLocation
}
func (c giteaOauthClient) buildLoginURL(state string) (string, error) {
u, err := url.Parse(c.AuthLocation)
if err != nil {
return "", err
}
q := u.Query()
q.Set("client_id", c.ClientID)
q.Set("redirect_uri", c.CallbackLocation)
q.Set("response_type", "code")
q.Set("state", state)
q.Set("scope", c.Scope)
u.RawQuery = q.Encode()
return u.String(), nil
}
func (c giteaOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) {
form := url.Values{}
form.Add("grant_type", "authorization_code")
form.Add("redirect_uri", c.CallbackLocation)
form.Add("scope", c.Scope)
form.Add("code", code)
req, err := http.NewRequest("POST", c.ExchangeLocation, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", ServerUserAgent(""))
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(c.ClientID, c.ClientSecret)
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, errors.New("unable to exchange code for access token")
}
var tokenResponse TokenResponse
if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil {
return nil, err
}
if tokenResponse.Error != "" {
return nil, errors.New(tokenResponse.Error)
}
return &tokenResponse, nil
}
func (c giteaOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) {
req, err := http.NewRequest("GET", c.InspectLocation, nil)
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", ServerUserAgent(""))
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, errors.New("unable to inspect access token")
}
// since we don't know what the JSON from the server will look like, we create a
// generic interface and then map manually to values set in the config
var genericInterface map[string]interface{}
if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &genericInterface); err != nil {
return nil, err
}
// map each relevant field in inspectResponse to the mapped field from the config
var inspectResponse InspectResponse
inspectResponse.UserID, _ = genericInterface[c.MapUserID].(string)
// log.Info("Userid from Gitea: %s", inspectResponse.UserID)
if inspectResponse.UserID == "" {
log.Error("[CONFIGURATION ERROR] Gitea OAuth provider returned empty UserID value (`%s`).\n Do you need to configure a different `map_user_id` value for this provider?", c.MapUserID)
return nil, fmt.Errorf("no UserID (`%s`) value returned", c.MapUserID)
}
inspectResponse.Username, _ = genericInterface[c.MapUsername].(string)
inspectResponse.DisplayName, _ = genericInterface[c.MapDisplayName].(string)
inspectResponse.Email, _ = genericInterface[c.MapEmail].(string)
return &inspectResponse, nil
}
diff --git a/oauth_signup.go b/oauth_signup.go
index b9b52c5..b5f881b 100644
--- a/oauth_signup.go
+++ b/oauth_signup.go
@@ -1,231 +1,232 @@
/*
* Copyright © 2020-2021 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package writefreely
import (
"crypto/sha256"
"encoding/hex"
"fmt"
- "github.com/writeas/impart"
- "github.com/writeas/web-core/auth"
- "github.com/writeas/web-core/log"
- "github.com/writefreely/writefreely/page"
"html/template"
"net/http"
"strings"
"time"
+
+ "github.com/writeas/impart"
+ "github.com/writeas/web-core/auth"
+ "github.com/writeas/web-core/log"
+ "github.com/writefreely/writefreely/page"
)
type viewOauthSignupVars struct {
page.StaticPage
To string
Message template.HTML
Flashes []template.HTML
AccessToken string
TokenUsername string
TokenAlias string // TODO: rename this to match the data it represents: the collection title
TokenEmail string
TokenRemoteUser string
Provider string
ClientID string
TokenHash string
InviteCode string
LoginUsername string
Alias string // TODO: rename this to match the data it represents: the collection title
Email string
}
const (
oauthParamAccessToken = "access_token"
oauthParamTokenUsername = "token_username"
oauthParamTokenAlias = "token_alias"
oauthParamTokenEmail = "token_email"
oauthParamTokenRemoteUserID = "token_remote_user"
oauthParamClientID = "client_id"
oauthParamProvider = "provider"
oauthParamHash = "signature"
oauthParamUsername = "username"
oauthParamAlias = "alias"
oauthParamEmail = "email"
oauthParamPassword = "password"
oauthParamInviteCode = "invite_code"
)
type oauthSignupPageParams struct {
AccessToken string
TokenUsername string
TokenAlias string // TODO: rename this to match the data it represents: the collection title
TokenEmail string
TokenRemoteUser string
ClientID string
Provider string
TokenHash string
InviteCode string
}
func (p oauthSignupPageParams) HashTokenParams(key string) string {
hasher := sha256.New()
hasher.Write([]byte(key))
hasher.Write([]byte(p.AccessToken))
hasher.Write([]byte(p.TokenUsername))
hasher.Write([]byte(p.TokenAlias))
hasher.Write([]byte(p.TokenEmail))
hasher.Write([]byte(p.TokenRemoteUser))
hasher.Write([]byte(p.ClientID))
hasher.Write([]byte(p.Provider))
return hex.EncodeToString(hasher.Sum(nil))
}
func (h oauthHandler) viewOauthSignup(app *App, w http.ResponseWriter, r *http.Request) error {
tp := &oauthSignupPageParams{
AccessToken: r.FormValue(oauthParamAccessToken),
TokenUsername: r.FormValue(oauthParamTokenUsername),
TokenAlias: r.FormValue(oauthParamTokenAlias),
TokenEmail: r.FormValue(oauthParamTokenEmail),
TokenRemoteUser: r.FormValue(oauthParamTokenRemoteUserID),
ClientID: r.FormValue(oauthParamClientID),
Provider: r.FormValue(oauthParamProvider),
InviteCode: r.FormValue(oauthParamInviteCode),
}
if tp.HashTokenParams(h.Config.Server.HashSeed) != r.FormValue(oauthParamHash) {
return impart.HTTPError{Status: http.StatusBadRequest, Message: "Request has been tampered with."}
}
tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed)
if err := h.validateOauthSignup(r); err != nil {
return h.showOauthSignupPage(app, w, r, tp, err)
}
var err error
hashedPass := []byte{}
clearPass := r.FormValue(oauthParamPassword)
hasPass := clearPass != ""
if hasPass {
hashedPass, err = auth.HashPass([]byte(clearPass))
if err != nil {
return h.showOauthSignupPage(app, w, r, tp, fmt.Errorf("unable to hash password"))
}
}
newUser := &User{
Username: r.FormValue(oauthParamUsername),
HashedPass: hashedPass,
HasPass: hasPass,
Email: prepareUserEmail(r.FormValue(oauthParamEmail), h.EmailKey),
Created: time.Now().Truncate(time.Second).UTC(),
}
displayName := r.FormValue(oauthParamAlias)
if len(displayName) == 0 {
displayName = r.FormValue(oauthParamUsername)
}
err = h.DB.CreateUser(h.Config, newUser, displayName, "")
if err != nil {
return h.showOauthSignupPage(app, w, r, tp, err)
}
// Log invite if needed
if tp.InviteCode != "" {
err = app.db.CreateInvitedUser(tp.InviteCode, newUser.ID)
if err != nil {
return err
}
}
err = h.DB.RecordRemoteUserID(r.Context(), newUser.ID, r.FormValue(oauthParamTokenRemoteUserID), r.FormValue(oauthParamProvider), r.FormValue(oauthParamClientID), r.FormValue(oauthParamAccessToken))
if err != nil {
return h.showOauthSignupPage(app, w, r, tp, err)
}
if err := loginOrFail(h.Store, w, r, newUser); err != nil {
return h.showOauthSignupPage(app, w, r, tp, err)
}
return nil
}
func (h oauthHandler) validateOauthSignup(r *http.Request) error {
username := r.FormValue(oauthParamUsername)
if len(username) < h.Config.App.MinUsernameLen {
return impart.HTTPError{Status: http.StatusBadRequest, Message: "Username is too short."}
}
if len(username) > 100 {
return impart.HTTPError{Status: http.StatusBadRequest, Message: "Username is too long."}
}
collTitle := r.FormValue(oauthParamAlias)
if len(collTitle) == 0 {
collTitle = username
}
email := r.FormValue(oauthParamEmail)
if len(email) > 0 {
parts := strings.Split(email, "@")
if len(parts) != 2 || (len(parts[0]) < 1 || len(parts[1]) < 1) {
return impart.HTTPError{Status: http.StatusBadRequest, Message: "Invalid email address"}
}
}
return nil
}
func (h oauthHandler) showOauthSignupPage(app *App, w http.ResponseWriter, r *http.Request, tp *oauthSignupPageParams, errMsg error) error {
username := tp.TokenUsername
collTitle := tp.TokenAlias
email := tp.TokenEmail
session, err := app.sessionStore.Get(r, cookieName)
if err != nil {
// Ignore this
log.Error("Unable to get session; ignoring: %v", err)
}
if tmpValue := r.FormValue(oauthParamUsername); len(tmpValue) > 0 {
username = tmpValue
}
if tmpValue := r.FormValue(oauthParamAlias); len(tmpValue) > 0 {
collTitle = tmpValue
}
if tmpValue := r.FormValue(oauthParamEmail); len(tmpValue) > 0 {
email = tmpValue
}
p := &viewOauthSignupVars{
StaticPage: pageForReq(app, r),
To: r.FormValue("to"),
Flashes: []template.HTML{},
AccessToken: tp.AccessToken,
TokenUsername: tp.TokenUsername,
TokenAlias: tp.TokenAlias,
TokenEmail: tp.TokenEmail,
TokenRemoteUser: tp.TokenRemoteUser,
Provider: tp.Provider,
ClientID: tp.ClientID,
TokenHash: tp.TokenHash,
InviteCode: tp.InviteCode,
LoginUsername: username,
Alias: collTitle,
Email: email,
}
// Display any error messages
flashes, _ := getSessionFlashes(app, w, r, session)
for _, flash := range flashes {
p.Flashes = append(p.Flashes, template.HTML(flash))
}
if errMsg != nil {
p.Flashes = append(p.Flashes, template.HTML(errMsg.Error()))
}
err = pages["signup-oauth.tmpl"].ExecuteTemplate(w, "base", p)
if err != nil {
log.Error("Unable to render signup-oauth: %v", err)
return err
}
return nil
}
diff --git a/oauth_slack.go b/oauth_slack.go
index 40f50e4..d18564a 100644
--- a/oauth_slack.go
+++ b/oauth_slack.go
@@ -1,178 +1,179 @@
/*
* Copyright © 2019-2020 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package writefreely
import (
"context"
"errors"
- "github.com/gosimple/slug"
"net/http"
"net/url"
"strings"
+
+ "github.com/gosimple/slug"
)
type slackOauthClient struct {
ClientID string
ClientSecret string
TeamID string
CallbackLocation string
HttpClient HttpClient
}
type slackExchangeResponse struct {
OK bool `json:"ok"`
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
TeamName string `json:"team_name"`
TeamID string `json:"team_id"`
Error string `json:"error"`
}
type slackIdentity struct {
Name string `json:"name"`
ID string `json:"id"`
Email string `json:"email"`
}
type slackTeam struct {
Name string `json:"name"`
ID string `json:"id"`
}
type slackUserIdentityResponse struct {
OK bool `json:"ok"`
User slackIdentity `json:"user"`
Team slackTeam `json:"team"`
Error string `json:"error"`
}
const (
slackAuthLocation = "https://slack.com/oauth/authorize"
slackExchangeLocation = "https://slack.com/api/oauth.access"
slackIdentityLocation = "https://slack.com/api/users.identity"
)
var _ oauthClient = slackOauthClient{}
func (c slackOauthClient) GetProvider() string {
return "slack"
}
func (c slackOauthClient) GetClientID() string {
return c.ClientID
}
func (c slackOauthClient) GetCallbackLocation() string {
return c.CallbackLocation
}
func (c slackOauthClient) buildLoginURL(state string) (string, error) {
u, err := url.Parse(slackAuthLocation)
if err != nil {
return "", err
}
q := u.Query()
q.Set("client_id", c.ClientID)
q.Set("scope", "identity.basic identity.email identity.team")
q.Set("redirect_uri", c.CallbackLocation)
q.Set("state", state)
// If this param is not set, the user can select which team they
// authenticate through and then we'd have to match the configured team
// against the profile get. That is extra work in the post-auth phase
// that we don't want to do.
q.Set("team", c.TeamID)
// The Slack OAuth docs don't explicitly list this one, but it is part of
// the spec, so we include it anyway.
q.Set("response_type", "code")
u.RawQuery = q.Encode()
return u.String(), nil
}
func (c slackOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) {
form := url.Values{}
// The oauth.access documentation doesn't explicitly mention this
// parameter, but it is part of the spec, so we include it anyway.
// https://api.slack.com/methods/oauth.access
form.Add("grant_type", "authorization_code")
form.Add("redirect_uri", c.CallbackLocation)
form.Add("code", code)
req, err := http.NewRequest("POST", slackExchangeLocation, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", ServerUserAgent(""))
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(c.ClientID, c.ClientSecret)
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, errors.New("unable to exchange code for access token")
}
var tokenResponse slackExchangeResponse
if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil {
return nil, err
}
if !tokenResponse.OK {
return nil, errors.New(tokenResponse.Error)
}
return tokenResponse.TokenResponse(), nil
}
func (c slackOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) {
req, err := http.NewRequest("GET", slackIdentityLocation, nil)
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", ServerUserAgent(""))
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, errors.New("unable to inspect access token")
}
var inspectResponse slackUserIdentityResponse
if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil {
return nil, err
}
if !inspectResponse.OK {
return nil, errors.New(inspectResponse.Error)
}
return inspectResponse.InspectResponse(), nil
}
func (resp slackUserIdentityResponse) InspectResponse() *InspectResponse {
return &InspectResponse{
UserID: resp.User.ID,
Username: slug.Make(resp.User.Name),
DisplayName: resp.User.Name,
Email: resp.User.Email,
}
}
func (resp slackExchangeResponse) TokenResponse() *TokenResponse {
return &TokenResponse{
AccessToken: resp.AccessToken,
}
}
diff --git a/oauth_test.go b/oauth_test.go
index 1bd0bfc..326e7ed 100644
--- a/oauth_test.go
+++ b/oauth_test.go
@@ -1,261 +1,262 @@
/*
* Copyright © 2019-2021 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package writefreely
import (
"context"
"fmt"
- "github.com/gorilla/sessions"
- "github.com/stretchr/testify/assert"
- "github.com/writeas/impart"
- "github.com/writeas/web-core/id"
- "github.com/writefreely/writefreely/config"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
+
+ "github.com/gorilla/sessions"
+ "github.com/stretchr/testify/assert"
+ "github.com/writeas/impart"
+ "github.com/writeas/web-core/id"
+ "github.com/writefreely/writefreely/config"
)
type MockOAuthDatastoreProvider struct {
DoDB func() OAuthDatastore
DoConfig func() *config.Config
DoSessionStore func() sessions.Store
}
type MockOAuthDatastore struct {
DoGenerateOAuthState func(context.Context, string, string, int64, string) (string, error)
DoValidateOAuthState func(context.Context, string) (string, string, int64, string, error)
DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error)
DoCreateUser func(*config.Config, *User, string) error
DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error
DoGetUserByID func(int64) (*User, error)
}
var _ OAuthDatastore = &MockOAuthDatastore{}
type StringReadCloser struct {
*strings.Reader
}
func (src *StringReadCloser) Close() error {
return nil
}
type MockHTTPClient struct {
DoDo func(req *http.Request) (*http.Response, error)
}
func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) {
if m.DoDo != nil {
return m.DoDo(req)
}
return &http.Response{}, nil
}
func (m *MockOAuthDatastoreProvider) SessionStore() sessions.Store {
if m.DoSessionStore != nil {
return m.DoSessionStore()
}
return sessions.NewCookieStore([]byte("secret-key"))
}
func (m *MockOAuthDatastoreProvider) DB() OAuthDatastore {
if m.DoDB != nil {
return m.DoDB()
}
return &MockOAuthDatastore{}
}
func (m *MockOAuthDatastoreProvider) Config() *config.Config {
if m.DoConfig != nil {
return m.DoConfig()
}
cfg := config.New()
cfg.UseSQLite(true)
cfg.WriteAsOauth = config.WriteAsOauthCfg{
ClientID: "development",
ClientSecret: "development",
AuthLocation: "https://write.as/oauth/login",
TokenLocation: "https://write.as/oauth/token",
InspectLocation: "https://write.as/oauth/inspect",
}
cfg.SlackOauth = config.SlackOauthCfg{
ClientID: "development",
ClientSecret: "development",
TeamID: "development",
}
return cfg
}
func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, string, error) {
if m.DoValidateOAuthState != nil {
return m.DoValidateOAuthState(ctx, state)
}
return "", "", 0, "", nil
}
func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) {
if m.DoGetIDForRemoteUser != nil {
return m.DoGetIDForRemoteUser(ctx, remoteUserID, provider, clientID)
}
return -1, nil
}
func (m *MockOAuthDatastore) CreateUser(cfg *config.Config, u *User, username, description string) error {
if m.DoCreateUser != nil {
return m.DoCreateUser(cfg, u, username)
}
u.ID = 1
return nil
}
func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error {
if m.DoRecordRemoteUserID != nil {
return m.DoRecordRemoteUserID(ctx, localUserID, remoteUserID, provider, clientID, accessToken)
}
return nil
}
func (m *MockOAuthDatastore) GetUserByID(userID int64) (*User, error) {
if m.DoGetUserByID != nil {
return m.DoGetUserByID(userID)
}
user := &User{}
return user, nil
}
func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUserID int64, inviteCode string) (string, error) {
if m.DoGenerateOAuthState != nil {
return m.DoGenerateOAuthState(ctx, provider, clientID, attachUserID, inviteCode)
}
return id.Generate62RandomString(14), nil
}
func TestViewOauthInit(t *testing.T) {
t.Run("success", func(t *testing.T) {
app := &MockOAuthDatastoreProvider{}
h := oauthHandler{
Config: app.Config(),
DB: app.DB(),
Store: app.SessionStore(),
EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd},
oauthClient: writeAsOauthClient{
ClientID: app.Config().WriteAsOauth.ClientID,
ClientSecret: app.Config().WriteAsOauth.ClientSecret,
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
InspectLocation: app.Config().WriteAsOauth.InspectLocation,
AuthLocation: app.Config().WriteAsOauth.AuthLocation,
CallbackLocation: "http://localhost/oauth/callback",
HttpClient: nil,
},
}
req, err := http.NewRequest("GET", "/oauth/client", nil)
assert.NoError(t, err)
rr := httptest.NewRecorder()
err = h.viewOauthInit(nil, rr, req)
assert.NotNil(t, err)
httpErr, ok := err.(impart.HTTPError)
assert.True(t, ok)
assert.Equal(t, http.StatusTemporaryRedirect, httpErr.Status)
assert.NotEmpty(t, httpErr.Message)
locURI, err := url.Parse(httpErr.Message)
assert.NoError(t, err)
assert.Equal(t, "/oauth/login", locURI.Path)
assert.Equal(t, "development", locURI.Query().Get("client_id"))
assert.Equal(t, "http://localhost/oauth/callback", locURI.Query().Get("redirect_uri"))
assert.Equal(t, "code", locURI.Query().Get("response_type"))
assert.NotEmpty(t, locURI.Query().Get("state"))
})
t.Run("state failure", func(t *testing.T) {
app := &MockOAuthDatastoreProvider{
DoDB: func() OAuthDatastore {
return &MockOAuthDatastore{
DoGenerateOAuthState: func(ctx context.Context, provider, clientID string, attachUserID int64, inviteCode string) (string, error) {
return "", fmt.Errorf("pretend unable to write state error")
},
}
},
}
h := oauthHandler{
Config: app.Config(),
DB: app.DB(),
Store: app.SessionStore(),
EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd},
oauthClient: writeAsOauthClient{
ClientID: app.Config().WriteAsOauth.ClientID,
ClientSecret: app.Config().WriteAsOauth.ClientSecret,
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
InspectLocation: app.Config().WriteAsOauth.InspectLocation,
AuthLocation: app.Config().WriteAsOauth.AuthLocation,
CallbackLocation: "http://localhost/oauth/callback",
HttpClient: nil,
},
}
req, err := http.NewRequest("GET", "/oauth/client", nil)
assert.NoError(t, err)
rr := httptest.NewRecorder()
err = h.viewOauthInit(nil, rr, req)
httpErr, ok := err.(impart.HTTPError)
assert.True(t, ok)
assert.NotEmpty(t, httpErr.Message)
assert.Equal(t, http.StatusInternalServerError, httpErr.Status)
assert.Equal(t, "could not prepare oauth redirect url", httpErr.Message)
})
}
func TestViewOauthCallback(t *testing.T) {
t.Run("success", func(t *testing.T) {
app := &MockOAuthDatastoreProvider{}
h := oauthHandler{
Config: app.Config(),
DB: app.DB(),
Store: app.SessionStore(),
EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd},
oauthClient: writeAsOauthClient{
ClientID: app.Config().WriteAsOauth.ClientID,
ClientSecret: app.Config().WriteAsOauth.ClientSecret,
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
InspectLocation: app.Config().WriteAsOauth.InspectLocation,
AuthLocation: app.Config().WriteAsOauth.AuthLocation,
CallbackLocation: "http://localhost/oauth/callback",
HttpClient: &MockHTTPClient{
DoDo: func(req *http.Request) (*http.Response, error) {
switch req.URL.String() {
case "https://write.as/oauth/token":
return &http.Response{
StatusCode: 200,
Body: &StringReadCloser{strings.NewReader(`{"access_token": "access_token", "expires_in": 1000, "refresh_token": "refresh_token", "token_type": "access"}`)},
}, nil
case "https://write.as/oauth/inspect":
return &http.Response{
StatusCode: 200,
Body: &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": "1", "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)},
}, nil
}
return &http.Response{
StatusCode: http.StatusNotFound,
}, nil
},
},
},
}
req, err := http.NewRequest("GET", "/oauth/callback", nil)
assert.NoError(t, err)
rr := httptest.NewRecorder()
err = h.viewOauthCallback(&App{cfg: app.Config(), sessionStore: app.SessionStore()}, rr, req)
assert.NoError(t, err)
assert.Equal(t, http.StatusTemporaryRedirect, rr.Code)
})
}
diff --git a/page/page.go b/page/page.go
index 5ab7750..057f034 100644
--- a/page/page.go
+++ b/page/page.go
@@ -1,48 +1,49 @@
/*
* Copyright © 2018-2019, 2021 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
// package page provides mechanisms and data for generating a WriteFreely page.
package page
import (
- "github.com/writefreely/writefreely/config"
"strings"
+
+ "github.com/writefreely/writefreely/config"
)
type StaticPage struct {
// App configuration
config.AppCfg
Version string
HeaderNav bool
CustomCSS bool
// Request values
Path string
Username string
Values map[string]string
Flashes []string
CanViewReader bool
IsAdmin bool
CanInvite bool
}
// SanitizeHost alters the StaticPage to contain a real hostname. This is
// especially important for the Tor hidden service, as it can be served over
// proxies, messing up the apparent hostname.
func (sp *StaticPage) SanitizeHost(cfg *config.Config) {
if cfg.Server.HiddenHost != "" && strings.HasPrefix(sp.Host, cfg.Server.HiddenHost) {
sp.Host = cfg.Server.HiddenHost
}
}
func (sp StaticPage) OfficialVersion() string {
p := strings.Split(sp.Version, "-")
return p[0]
}
diff --git a/pages.go b/pages.go
index 6ebc3dc..d835545 100644
--- a/pages.go
+++ b/pages.go
@@ -1,198 +1,199 @@
/*
* Copyright © 2018-2019, 2021 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package writefreely
import (
"database/sql"
- "github.com/writefreely/writefreely/config"
"time"
+
+ "github.com/writefreely/writefreely/config"
)
var defaultPageUpdatedTime = time.Date(2018, 11, 8, 12, 0, 0, 0, time.Local)
func getAboutPage(app *App) (*instanceContent, error) {
c, err := app.db.GetDynamicContent("about")
if err != nil {
return nil, err
}
if c == nil {
c = &instanceContent{
ID: "about",
Type: "page",
Content: defaultAboutPage(app.cfg),
}
}
if !c.Title.Valid {
c.Title = defaultAboutTitle(app.cfg)
}
return c, nil
}
func defaultAboutTitle(cfg *config.Config) sql.NullString {
return sql.NullString{String: "About " + cfg.App.SiteName, Valid: true}
}
func getContactPage(app *App) (*instanceContent, error) {
c, err := app.db.GetDynamicContent("contact")
if err != nil {
return nil, err
}
if c == nil {
c = &instanceContent{
ID: "contact",
Type: "page",
Content: defaultContactPage(app),
}
}
if !c.Title.Valid {
c.Title = defaultContactTitle()
}
return c, nil
}
func defaultContactTitle() sql.NullString {
return sql.NullString{String: "Contact Us", Valid: true}
}
func getPrivacyPage(app *App) (*instanceContent, error) {
c, err := app.db.GetDynamicContent("privacy")
if err != nil {
return nil, err
}
if c == nil {
c = &instanceContent{
ID: "privacy",
Type: "page",
Content: defaultPrivacyPolicy(app.cfg),
Updated: defaultPageUpdatedTime,
}
}
if !c.Title.Valid {
c.Title = defaultPrivacyTitle()
}
return c, nil
}
func defaultPrivacyTitle() sql.NullString {
return sql.NullString{String: "Privacy Policy", Valid: true}
}
func defaultAboutPage(cfg *config.Config) string {
if cfg.App.Federation {
return `_` + cfg.App.SiteName + `_ is an interconnected place for you to write and publish, powered by [WriteFreely](https://writefreely.org) and ActivityPub.`
}
return `_` + cfg.App.SiteName + `_ is a place for you to write and publish, powered by [WriteFreely](https://writefreely.org).`
}
func defaultContactPage(app *App) string {
c, err := app.db.GetCollectionByID(1)
if err != nil {
return ""
}
return `_` + app.cfg.App.SiteName + `_ is administered by: [**` + c.Alias + `**](/` + c.Alias + `/).
Contact them at this email address: _EMAIL GOES HERE_.
You can also reach them here...`
}
func defaultPrivacyPolicy(cfg *config.Config) string {
return `[WriteFreely](https://writefreely.org), the software that powers this site, is built to enforce your right to privacy by default.
It retains as little data about you as possible, not even requiring an email address to sign up. However, if you _do_ give us your email address, it is stored encrypted in our database. We salt and hash your account's password.
We store log files, or data about what happens on our servers. We also use cookies to keep you logged into your account.
Beyond this, it's important that you trust whoever runs **` + cfg.App.SiteName + `**. Software can only do so much to protect you -- your level of privacy protections will ultimately fall on the humans that run this particular service.`
}
func getLandingBanner(app *App) (*instanceContent, error) {
c, err := app.db.GetDynamicContent("landing-banner")
if err != nil {
return nil, err
}
if c == nil {
c = &instanceContent{
ID: "landing-banner",
Type: "section",
Content: defaultLandingBanner(app.cfg),
Updated: defaultPageUpdatedTime,
}
}
return c, nil
}
func getLandingBody(app *App) (*instanceContent, error) {
c, err := app.db.GetDynamicContent("landing-body")
if err != nil {
return nil, err
}
if c == nil {
c = &instanceContent{
ID: "landing-body",
Type: "section",
Content: defaultLandingBody(app.cfg),
Updated: defaultPageUpdatedTime,
}
}
return c, nil
}
func defaultLandingBanner(cfg *config.Config) string {
if cfg.App.Federation {
return "# Start your blog in the fediverse"
}
return "# Start your blog"
}
func defaultLandingBody(cfg *config.Config) string {
if cfg.App.Federation {
return `## Join the Fediverse
The fediverse is a large network of platforms that all speak a common language. Imagine if you could reply to Instagram posts from Twitter, or interact with your favorite Medium blogs from Facebook -- federated alternatives like [PixelFed](https://pixelfed.org), [Mastodon](https://joinmastodon.org), and WriteFreely enable you to do these types of things.
<div style="text-align:center">
<iframe style="width: 560px; height: 315px; max-width: 100%;" sandbox="allow-same-origin allow-scripts" src="https://video.writeas.org/videos/embed/cc55e615-d204-417c-9575-7b57674cc6f3" frameborder="0" allowfullscreen></iframe>
</div>
## Write More Socially
WriteFreely can communicate with other federated platforms like Mastodon, so people can follow your blogs, bookmark their favorite posts, and boost them to their followers. Sign up above to create a blog and join the fediverse.`
}
return ""
}
func getReaderSection(app *App) (*instanceContent, error) {
c, err := app.db.GetDynamicContent("reader")
if err != nil {
return nil, err
}
if c == nil {
c = &instanceContent{
ID: "reader",
Type: "section",
Content: defaultReaderBanner(app.cfg),
Updated: defaultPageUpdatedTime,
}
}
if !c.Title.Valid {
c.Title = defaultReaderTitle(app.cfg)
}
return c, nil
}
func defaultReaderTitle(cfg *config.Config) sql.NullString {
return sql.NullString{String: "Reader", Valid: true}
}
func defaultReaderBanner(cfg *config.Config) string {
return "Read the latest posts from " + cfg.App.SiteName + "."
}
diff --git a/parse/posts.go b/parse/posts.go
index 5a39a8d..d28c5fd 100644
--- a/parse/posts.go
+++ b/parse/posts.go
@@ -1,91 +1,92 @@
/*
* Copyright © 2018-2020 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
// Package parse assists in the parsing of plain text posts
package parse
import (
- "github.com/writeas/web-core/stringmanip"
"regexp"
"strings"
+
+ "github.com/writeas/web-core/stringmanip"
)
var (
titleElementReg = regexp.MustCompile("</?p>")
urlReg = regexp.MustCompile("https?://")
imgReg = regexp.MustCompile(`!\[([^]]+)\]\([^)]+\)`)
)
// PostLede attempts to extract the first thought of the given post, generally
// contained within the first line or sentence of text.
func PostLede(t string, includePunc bool) string {
// Adjust where we truncate if we want to include punctuation
iAdj := 0
if includePunc {
iAdj = 1
}
// Find lede within first line of text
nl := strings.IndexRune(t, '\n')
if nl > -1 {
t = t[:nl]
}
// Strip certain HTML tags
t = titleElementReg.ReplaceAllString(t, "")
// Strip URL protocols
t = urlReg.ReplaceAllString(t, "")
// Strip image URL, leaving only alt text
t = imgReg.ReplaceAllString(t, " $1 ")
// Find lede within first sentence
punc := strings.Index(t, ". ")
if punc > -1 {
t = t[:punc+iAdj]
}
punc = stringmanip.IndexRune(t, '。')
if punc > -1 {
c := []rune(t)
t = string(c[:punc+iAdj])
}
punc = stringmanip.IndexRune(t, '?')
if punc > -1 {
c := []rune(t)
t = string(c[:punc+iAdj])
}
return t
}
// TruncToWord truncates the given text to the provided limit.
func TruncToWord(s string, l int) (string, bool) {
truncated := false
c := []rune(s)
if len(c) > l {
truncated = true
s = string(c[:l])
spaceIdx := strings.LastIndexByte(s, ' ')
if spaceIdx > -1 {
s = s[:spaceIdx]
}
}
return s, truncated
}
// Truncate truncates the given text to the provided limit, returning the original string if it's shorter than the limit.
func Truncate(s string, l int) string {
c := []rune(s)
if len(c) > l {
s = string(c[:l])
}
return s
}
diff --git a/session.go b/session.go
index 03d3963..9f8bbba 100644
--- a/session.go
+++ b/session.go
@@ -1,146 +1,147 @@
/*
* Copyright © 2018-2019 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package writefreely
import (
"encoding/gob"
- "github.com/gorilla/sessions"
- "github.com/writeas/web-core/log"
"net/http"
"strings"
+
+ "github.com/gorilla/sessions"
+ "github.com/writeas/web-core/log"
)
const (
day = 86400
sessionLength = 180 * day
userEmailCookieName = "ue"
userEmailCookieVal = "email"
cookieName = "wfu"
cookieUserVal = "u"
blogPassCookieName = "ub"
)
// InitSession creates the cookie store. It depends on the keychain already
// being loaded.
func (app *App) InitSession() {
// Register complex data types we'll be storing in cookies
gob.Register(&User{})
// Create the cookie store
store := sessions.NewCookieStore(app.keys.CookieAuthKey, app.keys.CookieKey)
store.Options = &sessions.Options{
Path: "/",
MaxAge: sessionLength,
HttpOnly: true,
Secure: strings.HasPrefix(app.cfg.App.Host, "https://"),
}
if store.Options.Secure {
store.Options.SameSite = http.SameSiteNoneMode
}
app.sessionStore = store
}
func getSessionFlashes(app *App, w http.ResponseWriter, r *http.Request, session *sessions.Session) ([]string, error) {
var err error
if session == nil {
session, err = app.sessionStore.Get(r, cookieName)
if err != nil {
return nil, err
}
}
f := []string{}
if flashes := session.Flashes(); len(flashes) > 0 {
for _, flash := range flashes {
if str, ok := flash.(string); ok {
f = append(f, str)
}
}
}
saveUserSession(app, r, w)
return f, nil
}
func addSessionFlash(app *App, w http.ResponseWriter, r *http.Request, m string, session *sessions.Session) error {
var err error
if session == nil {
session, err = app.sessionStore.Get(r, cookieName)
}
if err != nil {
log.Error("Unable to add flash '%s': %v", m, err)
return err
}
session.AddFlash(m)
saveUserSession(app, r, w)
return nil
}
func getUserAndSession(app *App, r *http.Request) (*User, *sessions.Session) {
session, err := app.sessionStore.Get(r, cookieName)
if err == nil {
// Got the currently logged-in user
val := session.Values[cookieUserVal]
var u = &User{}
var ok bool
if u, ok = val.(*User); ok {
return u, session
}
}
return nil, nil
}
func getUserSession(app *App, r *http.Request) *User {
u, _ := getUserAndSession(app, r)
return u
}
func saveUserSession(app *App, r *http.Request, w http.ResponseWriter) error {
session, err := app.sessionStore.Get(r, cookieName)
if err != nil {
return ErrInternalCookieSession
}
// Extend the session
session.Options.MaxAge = int(sessionLength)
// Remove any information that accidentally got added
// FIXME: find where Plan information is getting saved to cookie.
val := session.Values[cookieUserVal]
var u = &User{}
var ok bool
if u, ok = val.(*User); ok {
session.Values[cookieUserVal] = u.Cookie()
}
err = session.Save(r, w)
if err != nil {
log.Error("Couldn't saveUserSession: %v", err)
}
return err
}
func getFullUserSession(app *App, r *http.Request) (*User, error) {
u := getUserSession(app, r)
if u == nil {
return nil, nil
}
var err error
u, err = app.db.GetUserByID(u.ID)
return u, err
}
diff --git a/spam/email.go b/spam/email.go
index de017ab..ca8ac94 100644
--- a/spam/email.go
+++ b/spam/email.go
@@ -1,43 +1,44 @@
/*
* Copyright © 2020-2021 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package spam
import (
- "github.com/writeas/web-core/id"
"strings"
+
+ "github.com/writeas/web-core/id"
)
var honeypotField string
func HoneypotFieldName() string {
if honeypotField == "" {
honeypotField = id.Generate62RandomString(39)
}
return honeypotField
}
// CleanEmail takes an email address and strips it down to a unique address that can be blocked.
func CleanEmail(email string) string {
emailParts := strings.Split(strings.ToLower(email), "@")
if len(emailParts) < 2 {
return ""
}
u := emailParts[0]
d := emailParts[1]
// Ignore anything after '+'
plusIdx := strings.IndexRune(u, '+')
if plusIdx > -1 {
u = u[:plusIdx]
}
// Strip dots in email address
u = strings.ReplaceAll(u, ".", "")
return u + "@" + d
}
diff --git a/updates.go b/updates.go
index e29e13b..791f245 100644
--- a/updates.go
+++ b/updates.go
@@ -1,131 +1,132 @@
/*
* Copyright © 2019-2020 Musing Studio LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package writefreely
import (
- "github.com/writeas/web-core/log"
"io"
"net/http"
"strings"
"sync"
"time"
+
+ "github.com/writeas/web-core/log"
)
// updatesCacheTime is the default interval between cache updates for new
// software versions
const defaultUpdatesCacheTime = 12 * time.Hour
// updatesCache holds data about current and new releases of the writefreely
// software
type updatesCache struct {
mu sync.Mutex
frequency time.Duration
lastCheck time.Time
latestVersion string
currentVersion string
checkError error
}
// CheckNow asks for the latest released version of writefreely and updates
// the cache last checked time. If the version postdates the current 'latest'
// the version value is replaced.
func (uc *updatesCache) CheckNow() error {
if debugging {
log.Info("[update check] Checking for update now.")
}
uc.mu.Lock()
defer uc.mu.Unlock()
uc.lastCheck = time.Now()
latestRemote, err := newVersionCheck()
if err != nil {
log.Error("[update check] Failed: %v", err)
uc.checkError = err
return err
}
if CompareSemver(latestRemote, uc.latestVersion) == 1 {
uc.latestVersion = latestRemote
}
return nil
}
// AreAvailable updates the cache if the frequency duration has passed
// then returns if the latest release is newer than the current running version.
func (uc updatesCache) AreAvailable() bool {
if time.Since(uc.lastCheck) > uc.frequency {
uc.CheckNow()
}
return CompareSemver(uc.latestVersion, uc.currentVersion) == 1
}
// AreAvailableNoCheck returns if the latest release is newer than the current
// running version.
func (uc updatesCache) AreAvailableNoCheck() bool {
return CompareSemver(uc.latestVersion, uc.currentVersion) == 1
}
// LatestVersion returns the latest stored version available.
func (uc updatesCache) LatestVersion() string {
return uc.latestVersion
}
func (uc updatesCache) ReleaseURL() string {
return "https://writefreely.org/releases/" + uc.latestVersion
}
// ReleaseNotesURL returns the full URL to the blog.writefreely.org release notes
// for the latest version as stored in the cache.
func (uc updatesCache) ReleaseNotesURL() string {
return wfReleaseNotesURL(uc.latestVersion)
}
func wfReleaseNotesURL(v string) string {
ver := strings.TrimPrefix(v, "v")
ver = strings.TrimSuffix(ver, ".0")
// hack until go 1.12 in build/travis
seg := strings.Split(ver, ".")
return "https://blog.writefreely.org/version-" + strings.Join(seg, "-")
}
// newUpdatesCache returns an initialized updates cache
func newUpdatesCache(expiry time.Duration) *updatesCache {
cache := updatesCache{
frequency: expiry,
currentVersion: "v" + softwareVer,
}
go cache.CheckNow()
return &cache
}
// InitUpdates initializes the updates cache, if the config value is set
// It uses the defaultUpdatesCacheTime for the cache expiry
func (app *App) InitUpdates() {
if app.cfg.App.UpdateChecks {
app.updates = newUpdatesCache(defaultUpdatesCacheTime)
}
}
func newVersionCheck() (string, error) {
res, err := http.Get("https://version.writefreely.org")
if debugging {
log.Info("[update check] GET https://version.writefreely.org")
}
// TODO: return error if statusCode != OK
if err == nil && res.StatusCode == http.StatusOK {
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return "", err
}
return string(body), nil
}
return "", err
}

File Metadata

Mime Type
text/x-diff
Expires
Fri, Feb 6, 7:14 PM (1 d, 14 h)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3620163

Event Timeline