feat(accounts): implement user accounts, sessions, and dashboards
All checks were successful
Build and Publish Docker Image / deploy (push) Successful in 1m8s
All checks were successful
Build and Publish Docker Image / deploy (push) Successful in 1m8s
Introduce Stage 4 features to support multi-user accounts, cookie-based web sessions, and personal dashboards. Changes include: - Adding `/register` to bootstrap the first admin account and `/login`/`/logout` for session management. - Creating a personal dashboard (`/app`) to display owned boxes, storage usage, and upload history. - Implementing admin user management (`/admin/users`) for generating invite links and managing user states. - Updating the bbolt database schema to store users, sessions, invites, and collections. - Adding `golang.org/x/crypto` for password hashing and introducing unit tests for account handlers.
This commit is contained in:
579
backend/libs/services/auth.go
Normal file
579
backend/libs/services/auth.go
Normal file
@@ -0,0 +1,579 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.etcd.io/bbolt"
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
var (
|
||||
usersBucket = []byte("users")
|
||||
userEmailsBucket = []byte("user_emails")
|
||||
sessionsBucket = []byte("sessions")
|
||||
invitesBucket = []byte("invites")
|
||||
collectionsBucket = []byte("collections")
|
||||
)
|
||||
|
||||
const (
|
||||
UserRoleAdmin = "admin"
|
||||
UserRoleUser = "user"
|
||||
|
||||
UserStatusActive = "active"
|
||||
UserStatusDisabled = "disabled"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidCredentials = errors.New("invalid credentials")
|
||||
ErrRegistrationClosed = errors.New("registration is closed")
|
||||
ErrInviteInvalid = errors.New("invite is invalid")
|
||||
ErrUserDisabled = errors.New("user is disabled")
|
||||
)
|
||||
|
||||
type AuthService struct {
|
||||
db *bbolt.DB
|
||||
baseURL string
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
PasswordHash string `json:"passwordHash"`
|
||||
Role string `json:"role"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
type PublicUser struct {
|
||||
ID string
|
||||
Username string
|
||||
Email string
|
||||
Role string
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"userId"`
|
||||
TokenHash string `json:"tokenHash"`
|
||||
ExpiresAt time.Time `json:"expiresAt"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
type Invite struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"userId,omitempty"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
TokenHash string `json:"tokenHash"`
|
||||
CreatedBy string `json:"createdBy"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
ExpiresAt time.Time `json:"expiresAt"`
|
||||
UsedAt *time.Time `json:"usedAt,omitempty"`
|
||||
UsedByUserID string `json:"usedByUserId,omitempty"`
|
||||
}
|
||||
|
||||
type Collection struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"userId"`
|
||||
Name string `json:"name"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
type InviteResult struct {
|
||||
Invite Invite
|
||||
URL string
|
||||
Token string
|
||||
}
|
||||
|
||||
func NewAuthService(db *bbolt.DB, baseURL string) (*AuthService, error) {
|
||||
service := &AuthService{db: db, baseURL: strings.TrimRight(baseURL, "/")}
|
||||
err := db.Update(func(tx *bbolt.Tx) error {
|
||||
for _, bucket := range [][]byte{usersBucket, userEmailsBucket, sessionsBucket, invitesBucket, collectionsBucket} {
|
||||
if _, err := tx.CreateBucketIfNotExists(bucket); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) BootstrapAvailable() (bool, error) {
|
||||
count := 0
|
||||
err := s.db.View(func(tx *bbolt.Tx) error {
|
||||
return tx.Bucket(usersBucket).ForEach(func(_, _ []byte) error {
|
||||
count++
|
||||
return nil
|
||||
})
|
||||
})
|
||||
return count == 0, err
|
||||
}
|
||||
|
||||
func (s *AuthService) CreateBootstrapUser(username, email, password string) (User, error) {
|
||||
available, err := s.BootstrapAvailable()
|
||||
if err != nil {
|
||||
return User{}, err
|
||||
}
|
||||
if !available {
|
||||
return User{}, ErrRegistrationClosed
|
||||
}
|
||||
return s.createUser(username, email, password, UserRoleAdmin)
|
||||
}
|
||||
|
||||
func (s *AuthService) Login(email, password string) (User, string, error) {
|
||||
user, err := s.UserByEmail(email)
|
||||
if err != nil {
|
||||
return User{}, "", ErrInvalidCredentials
|
||||
}
|
||||
if user.Status != UserStatusActive {
|
||||
return User{}, "", ErrUserDisabled
|
||||
}
|
||||
if !VerifyPasswordHash(user.PasswordHash, password) {
|
||||
return User{}, "", ErrInvalidCredentials
|
||||
}
|
||||
|
||||
token := randomID(32)
|
||||
session := Session{
|
||||
ID: randomID(12),
|
||||
UserID: user.ID,
|
||||
TokenHash: tokenHash(token),
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().UTC().Add(30 * 24 * time.Hour),
|
||||
}
|
||||
err = s.db.Update(func(tx *bbolt.Tx) error {
|
||||
data, err := json.Marshal(session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Bucket(sessionsBucket).Put([]byte(session.ID), data)
|
||||
})
|
||||
return user, session.ID + "." + token, err
|
||||
}
|
||||
|
||||
func (s *AuthService) UserForSession(raw string) (User, Session, error) {
|
||||
sessionID, token, ok := strings.Cut(raw, ".")
|
||||
if !ok || sessionID == "" || token == "" {
|
||||
return User{}, Session{}, os.ErrNotExist
|
||||
}
|
||||
|
||||
var session Session
|
||||
err := s.db.View(func(tx *bbolt.Tx) error {
|
||||
data := tx.Bucket(sessionsBucket).Get([]byte(sessionID))
|
||||
if data == nil {
|
||||
return os.ErrNotExist
|
||||
}
|
||||
return json.Unmarshal(data, &session)
|
||||
})
|
||||
if err != nil {
|
||||
return User{}, Session{}, err
|
||||
}
|
||||
if time.Now().UTC().After(session.ExpiresAt) || subtle.ConstantTimeCompare([]byte(tokenHash(token)), []byte(session.TokenHash)) != 1 {
|
||||
return User{}, Session{}, os.ErrPermission
|
||||
}
|
||||
user, err := s.UserByID(session.UserID)
|
||||
if err != nil {
|
||||
return User{}, Session{}, err
|
||||
}
|
||||
if user.Status != UserStatusActive {
|
||||
return User{}, Session{}, ErrUserDisabled
|
||||
}
|
||||
return user, session, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) Logout(raw string) error {
|
||||
sessionID, _, ok := strings.Cut(raw, ".")
|
||||
if !ok || sessionID == "" {
|
||||
return nil
|
||||
}
|
||||
return s.db.Update(func(tx *bbolt.Tx) error {
|
||||
return tx.Bucket(sessionsBucket).Delete([]byte(sessionID))
|
||||
})
|
||||
}
|
||||
|
||||
func (s *AuthService) CreateInvite(email, role, createdBy string, expiresIn time.Duration) (InviteResult, error) {
|
||||
email, err := normalizeEmail(email)
|
||||
if err != nil {
|
||||
return InviteResult{}, err
|
||||
}
|
||||
if role == "" {
|
||||
role = UserRoleUser
|
||||
}
|
||||
if role != UserRoleAdmin && role != UserRoleUser {
|
||||
role = UserRoleUser
|
||||
}
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 7 * 24 * time.Hour
|
||||
}
|
||||
|
||||
token := randomID(32)
|
||||
invite := Invite{
|
||||
ID: randomID(12),
|
||||
Email: email,
|
||||
Role: role,
|
||||
TokenHash: tokenHash(token),
|
||||
CreatedBy: createdBy,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().UTC().Add(expiresIn),
|
||||
}
|
||||
err = s.saveInvite(invite)
|
||||
if err != nil {
|
||||
return InviteResult{}, err
|
||||
}
|
||||
return InviteResult{
|
||||
Invite: invite,
|
||||
Token: token,
|
||||
URL: fmt.Sprintf("%s/invite/%s", s.baseURL, token),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) AcceptInvite(token, username, password string) (User, error) {
|
||||
invite, err := s.InviteByToken(token)
|
||||
if err != nil {
|
||||
return User{}, err
|
||||
}
|
||||
if invite.UsedAt != nil || time.Now().UTC().After(invite.ExpiresAt) {
|
||||
return User{}, ErrInviteInvalid
|
||||
}
|
||||
|
||||
var user User
|
||||
if invite.UserID != "" {
|
||||
user, err = s.UserByID(invite.UserID)
|
||||
if err != nil {
|
||||
return User{}, err
|
||||
}
|
||||
if err := s.SetPassword(user.ID, password); err != nil {
|
||||
return User{}, err
|
||||
}
|
||||
user, _ = s.UserByID(user.ID)
|
||||
} else {
|
||||
user, err = s.createUser(username, invite.Email, password, invite.Role)
|
||||
if err != nil {
|
||||
return User{}, err
|
||||
}
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
invite.UsedAt = &now
|
||||
invite.UsedByUserID = user.ID
|
||||
if err := s.saveInvite(invite); err != nil {
|
||||
return User{}, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) InviteByToken(token string) (Invite, error) {
|
||||
hash := tokenHash(token)
|
||||
var match Invite
|
||||
err := s.db.View(func(tx *bbolt.Tx) error {
|
||||
return tx.Bucket(invitesBucket).ForEach(func(_, value []byte) error {
|
||||
var invite Invite
|
||||
if err := json.Unmarshal(value, &invite); err != nil {
|
||||
return err
|
||||
}
|
||||
if subtle.ConstantTimeCompare([]byte(hash), []byte(invite.TokenHash)) == 1 {
|
||||
match = invite
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
return Invite{}, err
|
||||
}
|
||||
if match.ID == "" {
|
||||
return Invite{}, ErrInviteInvalid
|
||||
}
|
||||
return match, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) CreatePasswordResetInvite(userID, createdBy string) (InviteResult, error) {
|
||||
user, err := s.UserByID(userID)
|
||||
if err != nil {
|
||||
return InviteResult{}, err
|
||||
}
|
||||
result, err := s.CreateInvite(user.Email, user.Role, createdBy, 24*time.Hour)
|
||||
if err != nil {
|
||||
return InviteResult{}, err
|
||||
}
|
||||
result.Invite.UserID = user.ID
|
||||
if err := s.saveInvite(result.Invite); err != nil {
|
||||
return InviteResult{}, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) ListUsers() ([]User, error) {
|
||||
users := make([]User, 0)
|
||||
err := s.db.View(func(tx *bbolt.Tx) error {
|
||||
return tx.Bucket(usersBucket).ForEach(func(_, value []byte) error {
|
||||
var user User
|
||||
if err := json.Unmarshal(value, &user); err != nil {
|
||||
return err
|
||||
}
|
||||
users = append(users, user)
|
||||
return nil
|
||||
})
|
||||
})
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
return users[i].CreatedAt.After(users[j].CreatedAt)
|
||||
})
|
||||
return users, err
|
||||
}
|
||||
|
||||
func (s *AuthService) DisableUser(userID string, disabled bool) error {
|
||||
user, err := s.UserByID(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if disabled {
|
||||
user.Status = UserStatusDisabled
|
||||
} else {
|
||||
user.Status = UserStatusActive
|
||||
}
|
||||
user.UpdatedAt = time.Now().UTC()
|
||||
return s.saveUser(user)
|
||||
}
|
||||
|
||||
func (s *AuthService) SetPassword(userID, password string) error {
|
||||
if len(password) < 8 {
|
||||
return fmt.Errorf("password must be at least 8 characters")
|
||||
}
|
||||
user, err := s.UserByID(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user.PasswordHash = HashPassword(password)
|
||||
user.UpdatedAt = time.Now().UTC()
|
||||
return s.saveUser(user)
|
||||
}
|
||||
|
||||
func (s *AuthService) UserByID(id string) (User, error) {
|
||||
var user User
|
||||
err := s.db.View(func(tx *bbolt.Tx) error {
|
||||
data := tx.Bucket(usersBucket).Get([]byte(id))
|
||||
if data == nil {
|
||||
return os.ErrNotExist
|
||||
}
|
||||
return json.Unmarshal(data, &user)
|
||||
})
|
||||
return user, err
|
||||
}
|
||||
|
||||
func (s *AuthService) UserByEmail(email string) (User, error) {
|
||||
email, err := normalizeEmail(email)
|
||||
if err != nil {
|
||||
return User{}, err
|
||||
}
|
||||
var userID string
|
||||
err = s.db.View(func(tx *bbolt.Tx) error {
|
||||
data := tx.Bucket(userEmailsBucket).Get([]byte(email))
|
||||
if data == nil {
|
||||
return os.ErrNotExist
|
||||
}
|
||||
userID = string(data)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return User{}, err
|
||||
}
|
||||
return s.UserByID(userID)
|
||||
}
|
||||
|
||||
func (s *AuthService) CreateCollection(userID, name string) (Collection, error) {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return Collection{}, fmt.Errorf("collection name is required")
|
||||
}
|
||||
collection := Collection{
|
||||
ID: randomID(10),
|
||||
UserID: userID,
|
||||
Name: name,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
return collection, s.saveCollection(collection)
|
||||
}
|
||||
|
||||
func (s *AuthService) ListCollections(userID string) ([]Collection, error) {
|
||||
collections := make([]Collection, 0)
|
||||
err := s.db.View(func(tx *bbolt.Tx) error {
|
||||
return tx.Bucket(collectionsBucket).ForEach(func(_, value []byte) error {
|
||||
var collection Collection
|
||||
if err := json.Unmarshal(value, &collection); err != nil {
|
||||
return err
|
||||
}
|
||||
if collection.UserID == userID {
|
||||
collections = append(collections, collection)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
sort.Slice(collections, func(i, j int) bool {
|
||||
return strings.ToLower(collections[i].Name) < strings.ToLower(collections[j].Name)
|
||||
})
|
||||
return collections, err
|
||||
}
|
||||
|
||||
func (s *AuthService) CollectionOwnedBy(collectionID, userID string) bool {
|
||||
if collectionID == "" {
|
||||
return true
|
||||
}
|
||||
collection, err := s.CollectionByID(collectionID)
|
||||
return err == nil && collection.UserID == userID
|
||||
}
|
||||
|
||||
func (s *AuthService) CollectionByID(id string) (Collection, error) {
|
||||
var collection Collection
|
||||
err := s.db.View(func(tx *bbolt.Tx) error {
|
||||
data := tx.Bucket(collectionsBucket).Get([]byte(id))
|
||||
if data == nil {
|
||||
return os.ErrNotExist
|
||||
}
|
||||
return json.Unmarshal(data, &collection)
|
||||
})
|
||||
return collection, err
|
||||
}
|
||||
|
||||
func (s *AuthService) PublicUser(user User) PublicUser {
|
||||
return PublicUser{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
Status: user.Status,
|
||||
CreatedAt: user.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) createUser(username, email, password, role string) (User, error) {
|
||||
username = strings.TrimSpace(username)
|
||||
if username == "" {
|
||||
return User{}, fmt.Errorf("username is required")
|
||||
}
|
||||
email, err := normalizeEmail(email)
|
||||
if err != nil {
|
||||
return User{}, err
|
||||
}
|
||||
if len(password) < 8 {
|
||||
return User{}, fmt.Errorf("password must be at least 8 characters")
|
||||
}
|
||||
if role != UserRoleAdmin && role != UserRoleUser {
|
||||
role = UserRoleUser
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
user := User{
|
||||
ID: randomID(12),
|
||||
Username: username,
|
||||
Email: email,
|
||||
PasswordHash: HashPassword(password),
|
||||
Role: role,
|
||||
Status: UserStatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
return user, s.db.Update(func(tx *bbolt.Tx) error {
|
||||
if existing := tx.Bucket(userEmailsBucket).Get([]byte(email)); existing != nil {
|
||||
return fmt.Errorf("email is already registered")
|
||||
}
|
||||
data, err := json.Marshal(user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Bucket(usersBucket).Put([]byte(user.ID), data); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Bucket(userEmailsBucket).Put([]byte(email), []byte(user.ID))
|
||||
})
|
||||
}
|
||||
|
||||
func (s *AuthService) saveUser(user User) error {
|
||||
data, err := json.Marshal(user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.db.Update(func(tx *bbolt.Tx) error {
|
||||
return tx.Bucket(usersBucket).Put([]byte(user.ID), data)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *AuthService) saveInvite(invite Invite) error {
|
||||
data, err := json.Marshal(invite)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.db.Update(func(tx *bbolt.Tx) error {
|
||||
return tx.Bucket(invitesBucket).Put([]byte(invite.ID), data)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *AuthService) saveCollection(collection Collection) error {
|
||||
data, err := json.Marshal(collection)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.db.Update(func(tx *bbolt.Tx) error {
|
||||
return tx.Bucket(collectionsBucket).Put([]byte(collection.ID), data)
|
||||
})
|
||||
}
|
||||
|
||||
func normalizeEmail(email string) (string, error) {
|
||||
email = strings.ToLower(strings.TrimSpace(email))
|
||||
if email == "" {
|
||||
return "", fmt.Errorf("email is required")
|
||||
}
|
||||
if _, err := mail.ParseAddress(email); err != nil {
|
||||
return "", fmt.Errorf("email is invalid")
|
||||
}
|
||||
return email, nil
|
||||
}
|
||||
|
||||
func tokenHash(token string) string {
|
||||
sum := sha256.Sum256([]byte("warpbox-session:" + token))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func HashPassword(password string) string {
|
||||
salt := make([]byte, 16)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
salt = []byte(randomID(16))[:16]
|
||||
}
|
||||
hash := argon2.IDKey([]byte(password), salt, 1, 64*1024, 4, 32)
|
||||
return "argon2id$v=19$m=65536,t=1,p=4$" + base64.RawStdEncoding.EncodeToString(salt) + "$" + base64.RawStdEncoding.EncodeToString(hash)
|
||||
}
|
||||
|
||||
func VerifyPasswordHash(encoded, password string) bool {
|
||||
parts := strings.Split(encoded, "$")
|
||||
if len(parts) != 5 || parts[0] != "argon2id" {
|
||||
return false
|
||||
}
|
||||
salt, err := base64.RawStdEncoding.DecodeString(parts[3])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
expected, err := base64.RawStdEncoding.DecodeString(parts[4])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
actual := argon2.IDKey([]byte(password), salt, 1, 64*1024, 4, uint32(len(expected)))
|
||||
return subtle.ConstantTimeCompare(actual, expected) == 1
|
||||
}
|
||||
123
backend/libs/services/auth_test.go
Normal file
123
backend/libs/services/auth_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPasswordHashVerification(t *testing.T) {
|
||||
hash := HashPassword("correct-horse")
|
||||
if !VerifyPasswordHash(hash, "correct-horse") {
|
||||
t.Fatalf("VerifyPasswordHash rejected the correct password")
|
||||
}
|
||||
if VerifyPasswordHash(hash, "wrong-password") {
|
||||
t.Fatalf("VerifyPasswordHash accepted the wrong password")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBootstrapCreatesAdminAndClosesRegistration(t *testing.T) {
|
||||
auth := newTestAuthService(t)
|
||||
available, err := auth.BootstrapAvailable()
|
||||
if err != nil {
|
||||
t.Fatalf("BootstrapAvailable returned error: %v", err)
|
||||
}
|
||||
if !available {
|
||||
t.Fatalf("BootstrapAvailable = false, want true")
|
||||
}
|
||||
|
||||
user, err := auth.CreateBootstrapUser("daniel", "daniel@example.test", "password123")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateBootstrapUser returned error: %v", err)
|
||||
}
|
||||
if user.Role != UserRoleAdmin {
|
||||
t.Fatalf("role = %q, want admin", user.Role)
|
||||
}
|
||||
|
||||
if _, err := auth.CreateBootstrapUser("other", "other@example.test", "password123"); err == nil {
|
||||
t.Fatalf("second bootstrap unexpectedly succeeded")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginSessionAndDisabledUser(t *testing.T) {
|
||||
auth := newTestAuthService(t)
|
||||
user, err := auth.CreateBootstrapUser("daniel", "daniel@example.test", "password123")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateBootstrapUser returned error: %v", err)
|
||||
}
|
||||
if _, _, err := auth.Login("daniel@example.test", "wrong"); err == nil {
|
||||
t.Fatalf("Login accepted wrong password")
|
||||
}
|
||||
|
||||
_, token, err := auth.Login("daniel@example.test", "password123")
|
||||
if err != nil {
|
||||
t.Fatalf("Login returned error: %v", err)
|
||||
}
|
||||
sessionUser, _, err := auth.UserForSession(token)
|
||||
if err != nil {
|
||||
t.Fatalf("UserForSession returned error: %v", err)
|
||||
}
|
||||
if sessionUser.ID != user.ID {
|
||||
t.Fatalf("session user = %q, want %q", sessionUser.ID, user.ID)
|
||||
}
|
||||
|
||||
if err := auth.DisableUser(user.ID, true); err != nil {
|
||||
t.Fatalf("DisableUser returned error: %v", err)
|
||||
}
|
||||
if _, _, err := auth.UserForSession(token); err == nil {
|
||||
t.Fatalf("disabled user session still resolved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInviteAcceptsOnceAndResetChangesPassword(t *testing.T) {
|
||||
auth := newTestAuthService(t)
|
||||
admin, err := auth.CreateBootstrapUser("admin", "admin@example.test", "password123")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateBootstrapUser returned error: %v", err)
|
||||
}
|
||||
invite, err := auth.CreateInvite("friend@example.test", UserRoleUser, admin.ID, time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInvite returned error: %v", err)
|
||||
}
|
||||
user, err := auth.AcceptInvite(invite.Token, "friend", "password123")
|
||||
if err != nil {
|
||||
t.Fatalf("AcceptInvite returned error: %v", err)
|
||||
}
|
||||
if user.Email != "friend@example.test" {
|
||||
t.Fatalf("email = %q, want friend@example.test", user.Email)
|
||||
}
|
||||
if _, err := auth.AcceptInvite(invite.Token, "friend", "password123"); err == nil {
|
||||
t.Fatalf("AcceptInvite allowed token reuse")
|
||||
}
|
||||
|
||||
reset, err := auth.CreatePasswordResetInvite(user.ID, admin.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("CreatePasswordResetInvite returned error: %v", err)
|
||||
}
|
||||
if _, err := auth.AcceptInvite(reset.Token, "", "newpassword123"); err != nil {
|
||||
t.Fatalf("AcceptInvite reset returned error: %v", err)
|
||||
}
|
||||
if _, _, err := auth.Login("friend@example.test", "newpassword123"); err != nil {
|
||||
t.Fatalf("Login with reset password returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func newTestAuthService(t *testing.T) *AuthService {
|
||||
t.Helper()
|
||||
root := t.TempDir()
|
||||
upload, err := NewUploadService(1024*1024, filepath.Join(root, "data"), "http://example.test", slog.Default())
|
||||
if err != nil {
|
||||
t.Fatalf("NewUploadService returned error: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := upload.Close(); err != nil {
|
||||
t.Fatalf("Close returned error: %v", err)
|
||||
}
|
||||
})
|
||||
auth, err := NewAuthService(upload.DB(), "http://example.test")
|
||||
if err != nil {
|
||||
t.Fatalf("NewAuthService returned error: %v", err)
|
||||
}
|
||||
return auth
|
||||
}
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -37,10 +38,16 @@ type UploadOptions struct {
|
||||
MaxDownloads int
|
||||
Password string
|
||||
ObfuscateMetadata bool
|
||||
OwnerID string
|
||||
CollectionID string
|
||||
SkipSizeLimit bool
|
||||
}
|
||||
|
||||
type Box struct {
|
||||
ID string `json:"id"`
|
||||
OwnerID string `json:"ownerId,omitempty"`
|
||||
CollectionID string `json:"collectionId,omitempty"`
|
||||
Title string `json:"title,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
ExpiresAt time.Time `json:"expiresAt"`
|
||||
MaxDownloads int `json:"maxDownloads"`
|
||||
@@ -93,6 +100,7 @@ type AdminStats struct {
|
||||
|
||||
type AdminBox struct {
|
||||
ID string
|
||||
OwnerID string
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
FileCount int
|
||||
@@ -104,6 +112,12 @@ type AdminBox struct {
|
||||
Expired bool
|
||||
}
|
||||
|
||||
type UserBox struct {
|
||||
Box Box
|
||||
CollectionName string
|
||||
TotalSizeLabel string
|
||||
}
|
||||
|
||||
func NewUploadService(maxUploadSize int64, dataDir, baseURL string, logger *slog.Logger) (*UploadService, error) {
|
||||
filesDir := filepath.Join(dataDir, "files")
|
||||
dbDir := filepath.Join(dataDir, "db")
|
||||
@@ -141,6 +155,10 @@ func (s *UploadService) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
func (s *UploadService) DB() *bbolt.DB {
|
||||
return s.db
|
||||
}
|
||||
|
||||
func (s *UploadService) MaxUploadSize() int64 {
|
||||
return s.maxUploadSize
|
||||
}
|
||||
@@ -166,6 +184,8 @@ func (s *UploadService) CreateBox(files []*multipart.FileHeader, opts UploadOpti
|
||||
|
||||
box := Box{
|
||||
ID: randomID(10),
|
||||
OwnerID: strings.TrimSpace(opts.OwnerID),
|
||||
CollectionID: strings.TrimSpace(opts.CollectionID),
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().UTC().Add(time.Duration(opts.MaxDays) * 24 * time.Hour),
|
||||
MaxDownloads: opts.MaxDownloads,
|
||||
@@ -186,8 +206,15 @@ func (s *UploadService) CreateBox(files []*multipart.FileHeader, opts UploadOpti
|
||||
}
|
||||
|
||||
for _, header := range files {
|
||||
if err := s.ValidateSize(header.Size); err != nil {
|
||||
return UploadResult{}, err
|
||||
if !opts.SkipSizeLimit {
|
||||
if err := s.ValidateSize(header.Size); err != nil {
|
||||
return UploadResult{}, err
|
||||
}
|
||||
}
|
||||
|
||||
maxSize := s.maxUploadSize
|
||||
if opts.SkipSizeLimit {
|
||||
maxSize = 0
|
||||
}
|
||||
|
||||
file, err := header.Open()
|
||||
@@ -203,7 +230,7 @@ func (s *UploadService) CreateBox(files []*multipart.FileHeader, opts UploadOpti
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
|
||||
if err := writeUploadedFile(storedPath, file, s.maxUploadSize); err != nil {
|
||||
if err := writeUploadedFile(storedPath, file, maxSize); err != nil {
|
||||
file.Close()
|
||||
return UploadResult{}, err
|
||||
}
|
||||
@@ -314,6 +341,7 @@ func (s *UploadService) AdminBoxes(limit int) ([]AdminBox, error) {
|
||||
}
|
||||
rows = append(rows, AdminBox{
|
||||
ID: box.ID,
|
||||
OwnerID: box.OwnerID,
|
||||
CreatedAt: box.CreatedAt,
|
||||
ExpiresAt: box.ExpiresAt,
|
||||
FileCount: len(box.Files),
|
||||
@@ -328,6 +356,85 @@ func (s *UploadService) AdminBoxes(limit int) ([]AdminBox, error) {
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func (s *UploadService) UserBoxes(userID string, collectionNames map[string]string) ([]UserBox, error) {
|
||||
boxes, err := s.ListBoxes(0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows := make([]UserBox, 0)
|
||||
for _, box := range boxes {
|
||||
if box.OwnerID != userID {
|
||||
continue
|
||||
}
|
||||
var size int64
|
||||
for _, file := range box.Files {
|
||||
size += file.Size
|
||||
}
|
||||
rows = append(rows, UserBox{
|
||||
Box: box,
|
||||
CollectionName: collectionNames[box.CollectionID],
|
||||
TotalSizeLabel: helpers.FormatBytes(size),
|
||||
})
|
||||
}
|
||||
sort.Slice(rows, func(i, j int) bool {
|
||||
return rows[i].Box.CreatedAt.After(rows[j].Box.CreatedAt)
|
||||
})
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func (s *UploadService) UserStorageUsed(userID string) (int64, error) {
|
||||
boxes, err := s.ListBoxes(0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var total int64
|
||||
for _, box := range boxes {
|
||||
if box.OwnerID != userID {
|
||||
continue
|
||||
}
|
||||
for _, file := range box.Files {
|
||||
total += file.Size
|
||||
}
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (s *UploadService) RenameOwnedBox(boxID, userID, title string) error {
|
||||
box, err := s.GetBox(boxID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if box.OwnerID != userID {
|
||||
return os.ErrPermission
|
||||
}
|
||||
box.Title = strings.TrimSpace(title)
|
||||
return s.SaveBox(box)
|
||||
}
|
||||
|
||||
func (s *UploadService) MoveOwnedBox(boxID, userID, collectionID string) error {
|
||||
box, err := s.GetBox(boxID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if box.OwnerID != userID {
|
||||
return os.ErrPermission
|
||||
}
|
||||
box.CollectionID = strings.TrimSpace(collectionID)
|
||||
return s.SaveBox(box)
|
||||
}
|
||||
|
||||
func (s *UploadService) DeleteOwnedBox(boxID, userID string) error {
|
||||
box, err := s.GetBox(boxID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if box.OwnerID != userID {
|
||||
return os.ErrPermission
|
||||
}
|
||||
return s.DeleteBoxWithSource(boxID, "user-delete")
|
||||
}
|
||||
|
||||
func (s *UploadService) DeleteBox(boxID string) error {
|
||||
return s.DeleteBoxWithSource(boxID, "admin")
|
||||
}
|
||||
@@ -518,12 +625,17 @@ func writeUploadedFile(path string, source multipart.File, maxSize int64) error
|
||||
}
|
||||
defer target.Close()
|
||||
|
||||
written, err := io.Copy(target, io.LimitReader(source, maxSize+1))
|
||||
var written int64
|
||||
if maxSize <= 0 {
|
||||
written, err = io.Copy(target, source)
|
||||
} else {
|
||||
written, err = io.Copy(target, io.LimitReader(source, maxSize+1))
|
||||
}
|
||||
if err != nil {
|
||||
os.Remove(path)
|
||||
return err
|
||||
}
|
||||
if written > maxSize {
|
||||
if maxSize > 0 && written > maxSize {
|
||||
os.Remove(path)
|
||||
return fmt.Errorf("file exceeds max upload size")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user