Page MenuHomeMusing Studio

No OneTemporary

diff --git a/oauth.go b/oauth.go
index e28e21a..2958721 100644
--- a/oauth.go
+++ b/oauth.go
@@ -1,467 +1,467 @@
/*
* Copyright © 2019-2021 A Bunch Tell LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package writefreely
import (
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
"time"
"github.com/gorilla/mux"
"github.com/gorilla/sessions"
"github.com/writeas/impart"
"github.com/writeas/web-core/log"
"github.com/writefreely/writefreely/config"
)
// OAuthButtons holds display information for different OAuth providers we support.
type OAuthButtons struct {
SlackEnabled bool
WriteAsEnabled bool
GitLabEnabled bool
GitLabDisplayName string
GiteaEnabled bool
GiteaDisplayName string
GenericEnabled bool
GenericDisplayName string
}
// NewOAuthButtons creates a new OAuthButtons struct based on our app configuration.
func NewOAuthButtons(cfg *config.Config) *OAuthButtons {
return &OAuthButtons{
SlackEnabled: cfg.SlackOauth.ClientID != "",
WriteAsEnabled: cfg.WriteAsOauth.ClientID != "",
GitLabEnabled: cfg.GitlabOauth.ClientID != "",
GitLabDisplayName: config.OrDefaultString(cfg.GitlabOauth.DisplayName, gitlabDisplayName),
GiteaEnabled: cfg.GiteaOauth.ClientID != "",
GiteaDisplayName: config.OrDefaultString(cfg.GiteaOauth.DisplayName, giteaDisplayName),
GenericEnabled: cfg.GenericOauth.ClientID != "",
GenericDisplayName: config.OrDefaultString(cfg.GenericOauth.DisplayName, genericOauthDisplayName),
}
}
// TokenResponse contains data returned when a token is created either
// through a code exchange or using a refresh token.
type TokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
Error string `json:"error"`
}
// InspectResponse contains data returned when an access token is inspected.
type InspectResponse struct {
ClientID string `json:"client_id"`
UserID string `json:"user_id"`
ExpiresAt time.Time `json:"expires_at"`
Username string `json:"username"`
DisplayName string `json:"-"`
Email string `json:"email"`
Error string `json:"error"`
}
// tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token
// endpoint. One megabyte is plenty.
const tokenRequestMaxLen = 1000000
// infoRequestMaxLen is the most bytes that we'll read from the
// /oauth/inspect endpoint.
const infoRequestMaxLen = 1000000
// OAuthDatastoreProvider provides a minimal interface of data store, config,
// and session store for use with the oauth handlers.
type OAuthDatastoreProvider interface {
DB() OAuthDatastore
Config() *config.Config
SessionStore() sessions.Store
}
// OAuthDatastore provides a minimal interface of data store methods used in
// oauth functionality.
type OAuthDatastore interface {
GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
RecordRemoteUserID(context.Context, int64, string, string, string, string) error
ValidateOAuthState(context.Context, string) (string, string, int64, string, error)
GenerateOAuthState(context.Context, string, string, int64, string) (string, error)
- CreateUser(*config.Config, *User, string) error
+ CreateUser(*config.Config, *User, string, string) error
GetUserByID(int64) (*User, error)
}
type HttpClient interface {
Do(req *http.Request) (*http.Response, error)
}
type oauthClient interface {
GetProvider() string
GetClientID() string
GetCallbackLocation() string
buildLoginURL(state string) (string, error)
exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error)
inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error)
}
type callbackProxyClient struct {
server string
callbackLocation string
httpClient HttpClient
}
type oauthHandler struct {
Config *config.Config
DB OAuthDatastore
Store sessions.Store
EmailKey []byte
oauthClient oauthClient
callbackProxy *callbackProxyClient
}
func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
var attachUser int64
if attach := r.URL.Query().Get("attach"); attach == "t" {
user, _ := getUserAndSession(app, r)
if user == nil {
return impart.HTTPError{http.StatusInternalServerError, "cannot attach auth to user: user not found in session"}
}
attachUser = user.ID
}
state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID(), attachUser, r.FormValue("invite_code"))
if err != nil {
log.Error("viewOauthInit error: %s", err)
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
}
if h.callbackProxy != nil {
if err := h.callbackProxy.register(ctx, state); err != nil {
log.Error("viewOauthInit error: %s", err)
return impart.HTTPError{http.StatusInternalServerError, "could not register state server"}
}
}
location, err := h.oauthClient.buildLoginURL(state)
if err != nil {
log.Error("viewOauthInit error: %s", err)
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
}
return impart.HTTPError{http.StatusTemporaryRedirect, location}
}
func configureSlackOauth(parentHandler *Handler, r *mux.Router, app *App) {
if app.Config().SlackOauth.ClientID != "" {
callbackLocation := app.Config().App.Host + "/oauth/callback/slack"
var stateRegisterClient *callbackProxyClient = nil
if app.Config().SlackOauth.CallbackProxyAPI != "" {
stateRegisterClient = &callbackProxyClient{
server: app.Config().SlackOauth.CallbackProxyAPI,
callbackLocation: app.Config().App.Host + "/oauth/callback/slack",
httpClient: config.DefaultHTTPClient(),
}
callbackLocation = app.Config().SlackOauth.CallbackProxy
}
oauthClient := slackOauthClient{
ClientID: app.Config().SlackOauth.ClientID,
ClientSecret: app.Config().SlackOauth.ClientSecret,
TeamID: app.Config().SlackOauth.TeamID,
HttpClient: config.DefaultHTTPClient(),
CallbackLocation: callbackLocation,
}
configureOauthRoutes(parentHandler, r, app, oauthClient, stateRegisterClient)
}
}
func configureWriteAsOauth(parentHandler *Handler, r *mux.Router, app *App) {
if app.Config().WriteAsOauth.ClientID != "" {
callbackLocation := app.Config().App.Host + "/oauth/callback/write.as"
var callbackProxy *callbackProxyClient = nil
if app.Config().WriteAsOauth.CallbackProxy != "" {
callbackProxy = &callbackProxyClient{
server: app.Config().WriteAsOauth.CallbackProxyAPI,
callbackLocation: app.Config().App.Host + "/oauth/callback/write.as",
httpClient: config.DefaultHTTPClient(),
}
callbackLocation = app.Config().WriteAsOauth.CallbackProxy
}
oauthClient := writeAsOauthClient{
ClientID: app.Config().WriteAsOauth.ClientID,
ClientSecret: app.Config().WriteAsOauth.ClientSecret,
ExchangeLocation: config.OrDefaultString(app.Config().WriteAsOauth.TokenLocation, writeAsExchangeLocation),
InspectLocation: config.OrDefaultString(app.Config().WriteAsOauth.InspectLocation, writeAsIdentityLocation),
AuthLocation: config.OrDefaultString(app.Config().WriteAsOauth.AuthLocation, writeAsAuthLocation),
HttpClient: config.DefaultHTTPClient(),
CallbackLocation: callbackLocation,
}
configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
}
}
func configureGitlabOauth(parentHandler *Handler, r *mux.Router, app *App) {
if app.Config().GitlabOauth.ClientID != "" {
callbackLocation := app.Config().App.Host + "/oauth/callback/gitlab"
var callbackProxy *callbackProxyClient = nil
if app.Config().GitlabOauth.CallbackProxy != "" {
callbackProxy = &callbackProxyClient{
server: app.Config().GitlabOauth.CallbackProxyAPI,
callbackLocation: app.Config().App.Host + "/oauth/callback/gitlab",
httpClient: config.DefaultHTTPClient(),
}
callbackLocation = app.Config().GitlabOauth.CallbackProxy
}
address := config.OrDefaultString(app.Config().GitlabOauth.Host, gitlabHost)
oauthClient := gitlabOauthClient{
ClientID: app.Config().GitlabOauth.ClientID,
ClientSecret: app.Config().GitlabOauth.ClientSecret,
ExchangeLocation: address + "/oauth/token",
InspectLocation: address + "/api/v4/user",
AuthLocation: address + "/oauth/authorize",
HttpClient: config.DefaultHTTPClient(),
CallbackLocation: callbackLocation,
}
configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
}
}
func configureGenericOauth(parentHandler *Handler, r *mux.Router, app *App) {
if app.Config().GenericOauth.ClientID != "" {
callbackLocation := app.Config().App.Host + "/oauth/callback/generic"
var callbackProxy *callbackProxyClient = nil
if app.Config().GenericOauth.CallbackProxy != "" {
callbackProxy = &callbackProxyClient{
server: app.Config().GenericOauth.CallbackProxyAPI,
callbackLocation: app.Config().App.Host + "/oauth/callback/generic",
httpClient: config.DefaultHTTPClient(),
}
callbackLocation = app.Config().GenericOauth.CallbackProxy
}
oauthClient := genericOauthClient{
ClientID: app.Config().GenericOauth.ClientID,
ClientSecret: app.Config().GenericOauth.ClientSecret,
ExchangeLocation: app.Config().GenericOauth.Host + app.Config().GenericOauth.TokenEndpoint,
InspectLocation: app.Config().GenericOauth.Host + app.Config().GenericOauth.InspectEndpoint,
AuthLocation: app.Config().GenericOauth.Host + app.Config().GenericOauth.AuthEndpoint,
HttpClient: config.DefaultHTTPClient(),
CallbackLocation: callbackLocation,
Scope: config.OrDefaultString(app.Config().GenericOauth.Scope, "read_user"),
MapUserID: config.OrDefaultString(app.Config().GenericOauth.MapUserID, "user_id"),
MapUsername: config.OrDefaultString(app.Config().GenericOauth.MapUsername, "username"),
MapDisplayName: config.OrDefaultString(app.Config().GenericOauth.MapDisplayName, "-"),
MapEmail: config.OrDefaultString(app.Config().GenericOauth.MapEmail, "email"),
}
configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
}
}
func configureGiteaOauth(parentHandler *Handler, r *mux.Router, app *App) {
if app.Config().GiteaOauth.ClientID != "" {
callbackLocation := app.Config().App.Host + "/oauth/callback/gitea"
var callbackProxy *callbackProxyClient = nil
if app.Config().GiteaOauth.CallbackProxy != "" {
callbackProxy = &callbackProxyClient{
server: app.Config().GiteaOauth.CallbackProxyAPI,
callbackLocation: app.Config().App.Host + "/oauth/callback/gitea",
httpClient: config.DefaultHTTPClient(),
}
callbackLocation = app.Config().GiteaOauth.CallbackProxy
}
oauthClient := giteaOauthClient{
ClientID: app.Config().GiteaOauth.ClientID,
ClientSecret: app.Config().GiteaOauth.ClientSecret,
ExchangeLocation: app.Config().GiteaOauth.Host + "/login/oauth/access_token",
InspectLocation: app.Config().GiteaOauth.Host + "/api/v1/user",
AuthLocation: app.Config().GiteaOauth.Host + "/login/oauth/authorize",
HttpClient: config.DefaultHTTPClient(),
CallbackLocation: callbackLocation,
}
configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
}
}
func configureOauthRoutes(parentHandler *Handler, r *mux.Router, app *App, oauthClient oauthClient, callbackProxy *callbackProxyClient) {
handler := &oauthHandler{
Config: app.Config(),
DB: app.DB(),
Store: app.SessionStore(),
oauthClient: oauthClient,
EmailKey: app.keys.EmailKey,
callbackProxy: callbackProxy,
}
r.HandleFunc("/oauth/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthInit)).Methods("GET")
r.HandleFunc("/oauth/callback/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthCallback)).Methods("GET")
r.HandleFunc("/oauth/signup", parentHandler.OAuth(handler.viewOauthSignup)).Methods("POST")
}
func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
code := r.FormValue("code")
state := r.FormValue("state")
provider, clientID, attachUserID, inviteCode, err := h.DB.ValidateOAuthState(ctx, state)
if err != nil {
log.Error("Unable to ValidateOAuthState: %s", err)
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
}
tokenResponse, err := h.oauthClient.exchangeOauthCode(ctx, code)
if err != nil {
log.Error("Unable to exchangeOauthCode: %s", err)
// TODO: show user friendly message if needed
// TODO: show NO message for cases like user pressing "Cancel" on authorize step
addSessionFlash(app, w, r, err.Error(), nil)
if attachUserID > 0 {
return impart.HTTPError{http.StatusFound, "/me/settings"}
}
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
}
// Now that we have the access token, let's use it real quick to make sure
// it really really works.
tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
if err != nil {
log.Error("Unable to inspectOauthAccessToken: %s", err)
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
}
localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID, provider, clientID)
if err != nil {
log.Error("Unable to GetIDForRemoteUser: %s", err)
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
}
if localUserID != -1 && attachUserID > 0 {
if err = addSessionFlash(app, w, r, "This Slack account is already attached to another user.", nil); err != nil {
return impart.HTTPError{Status: http.StatusInternalServerError, Message: err.Error()}
}
return impart.HTTPError{http.StatusFound, "/me/settings"}
}
if localUserID != -1 {
// Existing user, so log in now
user, err := h.DB.GetUserByID(localUserID)
if err != nil {
log.Error("Unable to GetUserByID %d: %s", localUserID, err)
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
}
if err = loginOrFail(h.Store, w, r, user); err != nil {
log.Error("Unable to loginOrFail %d: %s", localUserID, err)
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
}
return nil
}
if attachUserID > 0 {
log.Info("attaching to user %d", attachUserID)
err = h.DB.RecordRemoteUserID(r.Context(), attachUserID, tokenInfo.UserID, provider, clientID, tokenResponse.AccessToken)
if err != nil {
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
}
return impart.HTTPError{http.StatusFound, "/me/settings"}
}
// New user registration below.
// First, verify that user is allowed to register
if inviteCode != "" {
// Verify invite code is valid
i, err := app.db.GetUserInvite(inviteCode)
if err != nil {
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
}
if !i.Active(app.db) {
return impart.HTTPError{http.StatusNotFound, "Invite link has expired."}
}
} else if !app.cfg.App.OpenRegistration {
addSessionFlash(app, w, r, ErrUserNotFound.Error(), nil)
return impart.HTTPError{http.StatusFound, "/login"}
}
displayName := tokenInfo.DisplayName
if len(displayName) == 0 {
displayName = tokenInfo.Username
}
tp := &oauthSignupPageParams{
AccessToken: tokenResponse.AccessToken,
TokenUsername: tokenInfo.Username,
TokenAlias: tokenInfo.DisplayName,
TokenEmail: tokenInfo.Email,
TokenRemoteUser: tokenInfo.UserID,
Provider: provider,
ClientID: clientID,
InviteCode: inviteCode,
}
tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed)
return h.showOauthSignupPage(app, w, r, tp, nil)
}
func (r *callbackProxyClient) register(ctx context.Context, state string) error {
form := url.Values{}
form.Add("state", state)
form.Add("location", r.callbackLocation)
req, err := http.NewRequestWithContext(ctx, "POST", r.server, strings.NewReader(form.Encode()))
if err != nil {
return err
}
req.Header.Set("User-Agent", ServerUserAgent(""))
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := r.httpClient.Do(req)
if err != nil {
return err
}
if resp.StatusCode != http.StatusCreated {
return fmt.Errorf("unable register state location: %d", resp.StatusCode)
}
return nil
}
func limitedJsonUnmarshal(body io.ReadCloser, n int, thing interface{}) error {
lr := io.LimitReader(body, int64(n+1))
data, err := ioutil.ReadAll(lr)
if err != nil {
return err
}
if len(data) == n+1 {
return fmt.Errorf("content larger than max read allowance: %d", n)
}
return json.Unmarshal(data, thing)
}
func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error {
// An error may be returned, but a valid session should always be returned.
session, _ := store.Get(r, cookieName)
session.Values[cookieUserVal] = user.Cookie()
if err := session.Save(r, w); err != nil {
fmt.Println("error saving session", err)
return err
}
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
return nil
}
diff --git a/oauth_signup.go b/oauth_signup.go
index b1256be..8dff416 100644
--- a/oauth_signup.go
+++ b/oauth_signup.go
@@ -1,231 +1,231 @@
/*
* Copyright © 2020-2021 A Bunch Tell LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package writefreely
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"github.com/writeas/impart"
"github.com/writeas/web-core/auth"
"github.com/writeas/web-core/log"
"github.com/writefreely/writefreely/page"
"html/template"
"net/http"
"strings"
"time"
)
type viewOauthSignupVars struct {
page.StaticPage
To string
Message template.HTML
Flashes []template.HTML
AccessToken string
TokenUsername string
TokenAlias string // TODO: rename this to match the data it represents: the collection title
TokenEmail string
TokenRemoteUser string
Provider string
ClientID string
TokenHash string
InviteCode string
LoginUsername string
Alias string // TODO: rename this to match the data it represents: the collection title
Email string
}
const (
oauthParamAccessToken = "access_token"
oauthParamTokenUsername = "token_username"
oauthParamTokenAlias = "token_alias"
oauthParamTokenEmail = "token_email"
oauthParamTokenRemoteUserID = "token_remote_user"
oauthParamClientID = "client_id"
oauthParamProvider = "provider"
oauthParamHash = "signature"
oauthParamUsername = "username"
oauthParamAlias = "alias"
oauthParamEmail = "email"
oauthParamPassword = "password"
oauthParamInviteCode = "invite_code"
)
type oauthSignupPageParams struct {
AccessToken string
TokenUsername string
TokenAlias string // TODO: rename this to match the data it represents: the collection title
TokenEmail string
TokenRemoteUser string
ClientID string
Provider string
TokenHash string
InviteCode string
}
func (p oauthSignupPageParams) HashTokenParams(key string) string {
hasher := sha256.New()
hasher.Write([]byte(key))
hasher.Write([]byte(p.AccessToken))
hasher.Write([]byte(p.TokenUsername))
hasher.Write([]byte(p.TokenAlias))
hasher.Write([]byte(p.TokenEmail))
hasher.Write([]byte(p.TokenRemoteUser))
hasher.Write([]byte(p.ClientID))
hasher.Write([]byte(p.Provider))
return hex.EncodeToString(hasher.Sum(nil))
}
func (h oauthHandler) viewOauthSignup(app *App, w http.ResponseWriter, r *http.Request) error {
tp := &oauthSignupPageParams{
AccessToken: r.FormValue(oauthParamAccessToken),
TokenUsername: r.FormValue(oauthParamTokenUsername),
TokenAlias: r.FormValue(oauthParamTokenAlias),
TokenEmail: r.FormValue(oauthParamTokenEmail),
TokenRemoteUser: r.FormValue(oauthParamTokenRemoteUserID),
ClientID: r.FormValue(oauthParamClientID),
Provider: r.FormValue(oauthParamProvider),
InviteCode: r.FormValue(oauthParamInviteCode),
}
if tp.HashTokenParams(h.Config.Server.HashSeed) != r.FormValue(oauthParamHash) {
return impart.HTTPError{Status: http.StatusBadRequest, Message: "Request has been tampered with."}
}
tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed)
if err := h.validateOauthSignup(r); err != nil {
return h.showOauthSignupPage(app, w, r, tp, err)
}
var err error
hashedPass := []byte{}
clearPass := r.FormValue(oauthParamPassword)
hasPass := clearPass != ""
if hasPass {
hashedPass, err = auth.HashPass([]byte(clearPass))
if err != nil {
return h.showOauthSignupPage(app, w, r, tp, fmt.Errorf("unable to hash password"))
}
}
newUser := &User{
Username: r.FormValue(oauthParamUsername),
HashedPass: hashedPass,
HasPass: hasPass,
Email: prepareUserEmail(r.FormValue(oauthParamEmail), h.EmailKey),
Created: time.Now().Truncate(time.Second).UTC(),
}
displayName := r.FormValue(oauthParamAlias)
if len(displayName) == 0 {
displayName = r.FormValue(oauthParamUsername)
}
- err = h.DB.CreateUser(h.Config, newUser, displayName)
+ err = h.DB.CreateUser(h.Config, newUser, displayName, "")
if err != nil {
return h.showOauthSignupPage(app, w, r, tp, err)
}
// Log invite if needed
if tp.InviteCode != "" {
err = app.db.CreateInvitedUser(tp.InviteCode, newUser.ID)
if err != nil {
return err
}
}
err = h.DB.RecordRemoteUserID(r.Context(), newUser.ID, r.FormValue(oauthParamTokenRemoteUserID), r.FormValue(oauthParamProvider), r.FormValue(oauthParamClientID), r.FormValue(oauthParamAccessToken))
if err != nil {
return h.showOauthSignupPage(app, w, r, tp, err)
}
if err := loginOrFail(h.Store, w, r, newUser); err != nil {
return h.showOauthSignupPage(app, w, r, tp, err)
}
return nil
}
func (h oauthHandler) validateOauthSignup(r *http.Request) error {
username := r.FormValue(oauthParamUsername)
if len(username) < h.Config.App.MinUsernameLen {
return impart.HTTPError{Status: http.StatusBadRequest, Message: "Username is too short."}
}
if len(username) > 100 {
return impart.HTTPError{Status: http.StatusBadRequest, Message: "Username is too long."}
}
collTitle := r.FormValue(oauthParamAlias)
if len(collTitle) == 0 {
collTitle = username
}
email := r.FormValue(oauthParamEmail)
if len(email) > 0 {
parts := strings.Split(email, "@")
if len(parts) != 2 || (len(parts[0]) < 1 || len(parts[1]) < 1) {
return impart.HTTPError{Status: http.StatusBadRequest, Message: "Invalid email address"}
}
}
return nil
}
func (h oauthHandler) showOauthSignupPage(app *App, w http.ResponseWriter, r *http.Request, tp *oauthSignupPageParams, errMsg error) error {
username := tp.TokenUsername
collTitle := tp.TokenAlias
email := tp.TokenEmail
session, err := app.sessionStore.Get(r, cookieName)
if err != nil {
// Ignore this
log.Error("Unable to get session; ignoring: %v", err)
}
if tmpValue := r.FormValue(oauthParamUsername); len(tmpValue) > 0 {
username = tmpValue
}
if tmpValue := r.FormValue(oauthParamAlias); len(tmpValue) > 0 {
collTitle = tmpValue
}
if tmpValue := r.FormValue(oauthParamEmail); len(tmpValue) > 0 {
email = tmpValue
}
p := &viewOauthSignupVars{
StaticPage: pageForReq(app, r),
To: r.FormValue("to"),
Flashes: []template.HTML{},
AccessToken: tp.AccessToken,
TokenUsername: tp.TokenUsername,
TokenAlias: tp.TokenAlias,
TokenEmail: tp.TokenEmail,
TokenRemoteUser: tp.TokenRemoteUser,
Provider: tp.Provider,
ClientID: tp.ClientID,
TokenHash: tp.TokenHash,
InviteCode: tp.InviteCode,
LoginUsername: username,
Alias: collTitle,
Email: email,
}
// Display any error messages
flashes, _ := getSessionFlashes(app, w, r, session)
for _, flash := range flashes {
p.Flashes = append(p.Flashes, template.HTML(flash))
}
if errMsg != nil {
p.Flashes = append(p.Flashes, template.HTML(errMsg.Error()))
}
err = pages["signup-oauth.tmpl"].ExecuteTemplate(w, "base", p)
if err != nil {
log.Error("Unable to render signup-oauth: %v", err)
return err
}
return nil
}
diff --git a/oauth_test.go b/oauth_test.go
index 5694416..553c1cb 100644
--- a/oauth_test.go
+++ b/oauth_test.go
@@ -1,261 +1,261 @@
/*
* Copyright © 2019-2021 A Bunch Tell LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package writefreely
import (
"context"
"fmt"
"github.com/gorilla/sessions"
"github.com/stretchr/testify/assert"
"github.com/writeas/impart"
"github.com/writeas/web-core/id"
"github.com/writefreely/writefreely/config"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
type MockOAuthDatastoreProvider struct {
DoDB func() OAuthDatastore
DoConfig func() *config.Config
DoSessionStore func() sessions.Store
}
type MockOAuthDatastore struct {
DoGenerateOAuthState func(context.Context, string, string, int64, string) (string, error)
DoValidateOAuthState func(context.Context, string) (string, string, int64, string, error)
DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error)
DoCreateUser func(*config.Config, *User, string) error
DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error
DoGetUserByID func(int64) (*User, error)
}
var _ OAuthDatastore = &MockOAuthDatastore{}
type StringReadCloser struct {
*strings.Reader
}
func (src *StringReadCloser) Close() error {
return nil
}
type MockHTTPClient struct {
DoDo func(req *http.Request) (*http.Response, error)
}
func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) {
if m.DoDo != nil {
return m.DoDo(req)
}
return &http.Response{}, nil
}
func (m *MockOAuthDatastoreProvider) SessionStore() sessions.Store {
if m.DoSessionStore != nil {
return m.DoSessionStore()
}
return sessions.NewCookieStore([]byte("secret-key"))
}
func (m *MockOAuthDatastoreProvider) DB() OAuthDatastore {
if m.DoDB != nil {
return m.DoDB()
}
return &MockOAuthDatastore{}
}
func (m *MockOAuthDatastoreProvider) Config() *config.Config {
if m.DoConfig != nil {
return m.DoConfig()
}
cfg := config.New()
cfg.UseSQLite(true)
cfg.WriteAsOauth = config.WriteAsOauthCfg{
ClientID: "development",
ClientSecret: "development",
AuthLocation: "https://write.as/oauth/login",
TokenLocation: "https://write.as/oauth/token",
InspectLocation: "https://write.as/oauth/inspect",
}
cfg.SlackOauth = config.SlackOauthCfg{
ClientID: "development",
ClientSecret: "development",
TeamID: "development",
}
return cfg
}
func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, string, error) {
if m.DoValidateOAuthState != nil {
return m.DoValidateOAuthState(ctx, state)
}
return "", "", 0, "", nil
}
func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) {
if m.DoGetIDForRemoteUser != nil {
return m.DoGetIDForRemoteUser(ctx, remoteUserID, provider, clientID)
}
return -1, nil
}
-func (m *MockOAuthDatastore) CreateUser(cfg *config.Config, u *User, username string) error {
+func (m *MockOAuthDatastore) CreateUser(cfg *config.Config, u *User, username, description string) error {
if m.DoCreateUser != nil {
return m.DoCreateUser(cfg, u, username)
}
u.ID = 1
return nil
}
func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error {
if m.DoRecordRemoteUserID != nil {
return m.DoRecordRemoteUserID(ctx, localUserID, remoteUserID, provider, clientID, accessToken)
}
return nil
}
func (m *MockOAuthDatastore) GetUserByID(userID int64) (*User, error) {
if m.DoGetUserByID != nil {
return m.DoGetUserByID(userID)
}
user := &User{}
return user, nil
}
func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUserID int64, inviteCode string) (string, error) {
if m.DoGenerateOAuthState != nil {
return m.DoGenerateOAuthState(ctx, provider, clientID, attachUserID, inviteCode)
}
return id.Generate62RandomString(14), nil
}
func TestViewOauthInit(t *testing.T) {
t.Run("success", func(t *testing.T) {
app := &MockOAuthDatastoreProvider{}
h := oauthHandler{
Config: app.Config(),
DB: app.DB(),
Store: app.SessionStore(),
EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd},
oauthClient: writeAsOauthClient{
ClientID: app.Config().WriteAsOauth.ClientID,
ClientSecret: app.Config().WriteAsOauth.ClientSecret,
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
InspectLocation: app.Config().WriteAsOauth.InspectLocation,
AuthLocation: app.Config().WriteAsOauth.AuthLocation,
CallbackLocation: "http://localhost/oauth/callback",
HttpClient: nil,
},
}
req, err := http.NewRequest("GET", "/oauth/client", nil)
assert.NoError(t, err)
rr := httptest.NewRecorder()
err = h.viewOauthInit(nil, rr, req)
assert.NotNil(t, err)
httpErr, ok := err.(impart.HTTPError)
assert.True(t, ok)
assert.Equal(t, http.StatusTemporaryRedirect, httpErr.Status)
assert.NotEmpty(t, httpErr.Message)
locURI, err := url.Parse(httpErr.Message)
assert.NoError(t, err)
assert.Equal(t, "/oauth/login", locURI.Path)
assert.Equal(t, "development", locURI.Query().Get("client_id"))
assert.Equal(t, "http://localhost/oauth/callback", locURI.Query().Get("redirect_uri"))
assert.Equal(t, "code", locURI.Query().Get("response_type"))
assert.NotEmpty(t, locURI.Query().Get("state"))
})
t.Run("state failure", func(t *testing.T) {
app := &MockOAuthDatastoreProvider{
DoDB: func() OAuthDatastore {
return &MockOAuthDatastore{
DoGenerateOAuthState: func(ctx context.Context, provider, clientID string, attachUserID int64, inviteCode string) (string, error) {
return "", fmt.Errorf("pretend unable to write state error")
},
}
},
}
h := oauthHandler{
Config: app.Config(),
DB: app.DB(),
Store: app.SessionStore(),
EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd},
oauthClient: writeAsOauthClient{
ClientID: app.Config().WriteAsOauth.ClientID,
ClientSecret: app.Config().WriteAsOauth.ClientSecret,
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
InspectLocation: app.Config().WriteAsOauth.InspectLocation,
AuthLocation: app.Config().WriteAsOauth.AuthLocation,
CallbackLocation: "http://localhost/oauth/callback",
HttpClient: nil,
},
}
req, err := http.NewRequest("GET", "/oauth/client", nil)
assert.NoError(t, err)
rr := httptest.NewRecorder()
err = h.viewOauthInit(nil, rr, req)
httpErr, ok := err.(impart.HTTPError)
assert.True(t, ok)
assert.NotEmpty(t, httpErr.Message)
assert.Equal(t, http.StatusInternalServerError, httpErr.Status)
assert.Equal(t, "could not prepare oauth redirect url", httpErr.Message)
})
}
func TestViewOauthCallback(t *testing.T) {
t.Run("success", func(t *testing.T) {
app := &MockOAuthDatastoreProvider{}
h := oauthHandler{
Config: app.Config(),
DB: app.DB(),
Store: app.SessionStore(),
EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd},
oauthClient: writeAsOauthClient{
ClientID: app.Config().WriteAsOauth.ClientID,
ClientSecret: app.Config().WriteAsOauth.ClientSecret,
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
InspectLocation: app.Config().WriteAsOauth.InspectLocation,
AuthLocation: app.Config().WriteAsOauth.AuthLocation,
CallbackLocation: "http://localhost/oauth/callback",
HttpClient: &MockHTTPClient{
DoDo: func(req *http.Request) (*http.Response, error) {
switch req.URL.String() {
case "https://write.as/oauth/token":
return &http.Response{
StatusCode: 200,
Body: &StringReadCloser{strings.NewReader(`{"access_token": "access_token", "expires_in": 1000, "refresh_token": "refresh_token", "token_type": "access"}`)},
}, nil
case "https://write.as/oauth/inspect":
return &http.Response{
StatusCode: 200,
Body: &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": "1", "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)},
}, nil
}
return &http.Response{
StatusCode: http.StatusNotFound,
}, nil
},
},
},
}
req, err := http.NewRequest("GET", "/oauth/callback", nil)
assert.NoError(t, err)
rr := httptest.NewRecorder()
err = h.viewOauthCallback(&App{cfg: app.Config(), sessionStore: app.SessionStore()}, rr, req)
assert.NoError(t, err)
assert.Equal(t, http.StatusTemporaryRedirect, rr.Code)
})
}

File Metadata

Mime Type
text/x-diff
Expires
Sat, Nov 23, 3:33 PM (1 d, 4 h)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3104583

Event Timeline