670 lines
15 KiB
Go
670 lines
15 KiB
Go
package metastore
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/dgraph-io/badger/v4"
|
|
"golang.org/x/crypto/bcrypt"
|
|
|
|
"warpbox/lib/helpers"
|
|
)
|
|
|
|
var (
|
|
ErrNotFound = errors.New("not found")
|
|
ErrDuplicate = errors.New("duplicate")
|
|
ErrInvalid = errors.New("invalid")
|
|
)
|
|
|
|
type Store struct {
|
|
db *badger.DB
|
|
}
|
|
|
|
func Open(path string) (*Store, error) {
|
|
opts := badger.DefaultOptions(path).WithLogger(nil)
|
|
db, err := badger.Open(opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &Store{db: db}, nil
|
|
}
|
|
|
|
func (store *Store) Close() error {
|
|
if store == nil || store.db == nil {
|
|
return nil
|
|
}
|
|
return store.db.Close()
|
|
}
|
|
|
|
func (store *Store) SetSetting(name string, value string) error {
|
|
name = strings.TrimSpace(name)
|
|
if name == "" {
|
|
return fmt.Errorf("%w: setting name cannot be empty", ErrInvalid)
|
|
}
|
|
|
|
return store.db.Update(func(txn *badger.Txn) error {
|
|
return txn.Set(settingKey(name), []byte(value))
|
|
})
|
|
}
|
|
|
|
func (store *Store) DeleteSetting(name string) error {
|
|
return store.db.Update(func(txn *badger.Txn) error {
|
|
return txn.Delete(settingKey(name))
|
|
})
|
|
}
|
|
|
|
func (store *Store) GetSetting(name string) (string, bool, error) {
|
|
var value string
|
|
err := store.db.View(func(txn *badger.Txn) error {
|
|
item, err := txn.Get(settingKey(name))
|
|
if errors.Is(err, badger.ErrKeyNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return item.Value(func(data []byte) error {
|
|
value = string(data)
|
|
return nil
|
|
})
|
|
})
|
|
if errors.Is(err, ErrNotFound) {
|
|
return "", false, nil
|
|
}
|
|
return value, err == nil, err
|
|
}
|
|
|
|
func (store *Store) ListSettings() (map[string]string, error) {
|
|
settings := make(map[string]string)
|
|
err := store.db.View(func(txn *badger.Txn) error {
|
|
opts := badger.DefaultIteratorOptions
|
|
opts.Prefix = []byte("setting/")
|
|
it := txn.NewIterator(opts)
|
|
defer it.Close()
|
|
|
|
for it.Rewind(); it.Valid(); it.Next() {
|
|
item := it.Item()
|
|
name := strings.TrimPrefix(string(item.Key()), "setting/")
|
|
if err := item.Value(func(data []byte) error {
|
|
settings[name] = string(data)
|
|
return nil
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
return settings, err
|
|
}
|
|
|
|
func HashPassword(password string) (string, error) {
|
|
if strings.TrimSpace(password) == "" {
|
|
return "", fmt.Errorf("%w: password cannot be empty", ErrInvalid)
|
|
}
|
|
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(hash), nil
|
|
}
|
|
|
|
func VerifyPassword(hash string, password string) bool {
|
|
if hash == "" || password == "" {
|
|
return false
|
|
}
|
|
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) == nil
|
|
}
|
|
|
|
func (store *Store) CreateUserWithPassword(username string, email string, password string, tagIDs []string) (User, error) {
|
|
hash, err := HashPassword(password)
|
|
if err != nil {
|
|
return User{}, err
|
|
}
|
|
user := User{
|
|
Username: username,
|
|
Email: email,
|
|
PasswordHash: hash,
|
|
TagIDs: uniqueStrings(tagIDs),
|
|
}
|
|
if err := store.CreateUser(&user); err != nil {
|
|
return User{}, err
|
|
}
|
|
return user, nil
|
|
}
|
|
|
|
func (store *Store) CreateUser(user *User) error {
|
|
if user == nil {
|
|
return fmt.Errorf("%w: user cannot be nil", ErrInvalid)
|
|
}
|
|
username := strings.TrimSpace(user.Username)
|
|
if username == "" {
|
|
return fmt.Errorf("%w: username cannot be empty", ErrInvalid)
|
|
}
|
|
email := strings.TrimSpace(user.Email)
|
|
if user.PasswordHash == "" {
|
|
return fmt.Errorf("%w: password hash cannot be empty", ErrInvalid)
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
if user.ID == "" {
|
|
id, err := helpers.RandomHexID(16)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
user.ID = id
|
|
}
|
|
user.Username = username
|
|
user.Email = email
|
|
user.TagIDs = uniqueStrings(user.TagIDs)
|
|
user.CreatedAt = now
|
|
user.UpdatedAt = now
|
|
|
|
return store.db.Update(func(txn *badger.Txn) error {
|
|
if exists, err := keyExists(txn, usernameKey(username)); err != nil || exists {
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return fmt.Errorf("%w: username already exists", ErrDuplicate)
|
|
}
|
|
if email != "" {
|
|
if exists, err := keyExists(txn, emailKey(email)); err != nil || exists {
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return fmt.Errorf("%w: email already exists", ErrDuplicate)
|
|
}
|
|
}
|
|
if err := putJSON(txn, userKey(user.ID), user); err != nil {
|
|
return err
|
|
}
|
|
if err := txn.Set(usernameKey(username), []byte(user.ID)); err != nil {
|
|
return err
|
|
}
|
|
if email != "" {
|
|
return txn.Set(emailKey(email), []byte(user.ID))
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (store *Store) UpdateUser(user User) error {
|
|
if strings.TrimSpace(user.ID) == "" {
|
|
return fmt.Errorf("%w: user id cannot be empty", ErrInvalid)
|
|
}
|
|
user.Username = strings.TrimSpace(user.Username)
|
|
user.Email = strings.TrimSpace(user.Email)
|
|
if user.Username == "" {
|
|
return fmt.Errorf("%w: username cannot be empty", ErrInvalid)
|
|
}
|
|
user.TagIDs = uniqueStrings(user.TagIDs)
|
|
user.UpdatedAt = time.Now().UTC()
|
|
|
|
return store.db.Update(func(txn *badger.Txn) error {
|
|
var existing User
|
|
if err := getJSON(txn, userKey(user.ID), &existing); err != nil {
|
|
return err
|
|
}
|
|
|
|
oldUsername := normalizeIndex(existing.Username)
|
|
newUsername := normalizeIndex(user.Username)
|
|
if oldUsername != newUsername {
|
|
if exists, err := keyExists(txn, usernameKey(user.Username)); err != nil || exists {
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return fmt.Errorf("%w: username already exists", ErrDuplicate)
|
|
}
|
|
if err := txn.Delete(usernameKey(existing.Username)); err != nil && !errors.Is(err, badger.ErrKeyNotFound) {
|
|
return err
|
|
}
|
|
if err := txn.Set(usernameKey(user.Username), []byte(user.ID)); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
oldEmail := normalizeIndex(existing.Email)
|
|
newEmail := normalizeIndex(user.Email)
|
|
if oldEmail != newEmail {
|
|
if newEmail != "" {
|
|
if exists, err := keyExists(txn, emailKey(user.Email)); err != nil || exists {
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return fmt.Errorf("%w: email already exists", ErrDuplicate)
|
|
}
|
|
if err := txn.Set(emailKey(user.Email), []byte(user.ID)); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if oldEmail != "" {
|
|
if err := txn.Delete(emailKey(existing.Email)); err != nil && !errors.Is(err, badger.ErrKeyNotFound) {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
return putJSON(txn, userKey(user.ID), user)
|
|
})
|
|
}
|
|
|
|
func (store *Store) GetUser(id string) (User, bool, error) {
|
|
var user User
|
|
err := store.db.View(func(txn *badger.Txn) error {
|
|
return getJSON(txn, userKey(id), &user)
|
|
})
|
|
if errors.Is(err, ErrNotFound) {
|
|
return User{}, false, nil
|
|
}
|
|
return user, err == nil, err
|
|
}
|
|
|
|
func (store *Store) GetUserByUsername(username string) (User, bool, error) {
|
|
return store.getUserByIndex(usernameKey(username))
|
|
}
|
|
|
|
func (store *Store) GetUserByEmail(email string) (User, bool, error) {
|
|
return store.getUserByIndex(emailKey(email))
|
|
}
|
|
|
|
func (store *Store) ListUsers() ([]User, error) {
|
|
users := []User{}
|
|
err := store.db.View(func(txn *badger.Txn) error {
|
|
opts := badger.DefaultIteratorOptions
|
|
opts.Prefix = []byte("user/")
|
|
it := txn.NewIterator(opts)
|
|
defer it.Close()
|
|
|
|
for it.Rewind(); it.Valid(); it.Next() {
|
|
var user User
|
|
if err := it.Item().Value(func(data []byte) error {
|
|
return json.Unmarshal(data, &user)
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
users = append(users, user)
|
|
}
|
|
return nil
|
|
})
|
|
return users, err
|
|
}
|
|
|
|
func (store *Store) ListUsersPaginated(filters UserFilters, pageReq UserPageRequest) (UserPage, error) {
|
|
users, err := store.ListUsers()
|
|
if err != nil {
|
|
return UserPage{}, err
|
|
}
|
|
|
|
tags, err := store.ListTags()
|
|
if err != nil {
|
|
return UserPage{}, err
|
|
}
|
|
tagMap := make(map[string]Tag, len(tags))
|
|
for _, tag := range tags {
|
|
tagMap[tag.ID] = tag
|
|
}
|
|
|
|
query := strings.ToLower(strings.TrimSpace(filters.Query))
|
|
filtered := make([]User, 0, len(users))
|
|
for _, user := range users {
|
|
if query != "" {
|
|
if !strings.Contains(strings.ToLower(user.Username), query) &&
|
|
!strings.Contains(strings.ToLower(user.Email), query) {
|
|
continue
|
|
}
|
|
}
|
|
switch filters.Status {
|
|
case "active":
|
|
if user.Disabled || strings.HasPrefix(user.PasswordHash, "invite/") {
|
|
continue
|
|
}
|
|
case "disabled":
|
|
if !user.Disabled || strings.HasPrefix(user.PasswordHash, "invite/") {
|
|
continue
|
|
}
|
|
case "pending":
|
|
if !strings.HasPrefix(user.PasswordHash, "invite/") {
|
|
continue
|
|
}
|
|
}
|
|
if filters.Role != "" && filters.Role != "all" {
|
|
match := false
|
|
for _, tagID := range user.TagIDs {
|
|
if tag, ok := tagMap[tagID]; ok && strings.EqualFold(tag.Name, filters.Role) {
|
|
match = true
|
|
break
|
|
}
|
|
}
|
|
if !match {
|
|
continue
|
|
}
|
|
}
|
|
filtered = append(filtered, user)
|
|
}
|
|
|
|
switch filters.Sort {
|
|
case "createdDesc":
|
|
sort.Slice(filtered, func(i, j int) bool {
|
|
return filtered[i].CreatedAt.After(filtered[j].CreatedAt)
|
|
})
|
|
case "username":
|
|
fallthrough
|
|
default:
|
|
sort.Slice(filtered, func(i, j int) bool {
|
|
return strings.ToLower(filtered[i].Username) < strings.ToLower(filtered[j].Username)
|
|
})
|
|
}
|
|
|
|
total := len(filtered)
|
|
pageSize := pageReq.PageSize
|
|
if pageSize <= 0 {
|
|
pageSize = 12
|
|
}
|
|
if pageSize > 100 {
|
|
pageSize = 100
|
|
}
|
|
totalPages := (total + pageSize - 1) / pageSize
|
|
if totalPages < 1 {
|
|
totalPages = 1
|
|
}
|
|
page := pageReq.Page
|
|
if page < 1 {
|
|
page = 1
|
|
}
|
|
if page > totalPages {
|
|
page = totalPages
|
|
}
|
|
|
|
start := (page - 1) * pageSize
|
|
end := start + pageSize
|
|
if end > total {
|
|
end = total
|
|
}
|
|
pageUsers := filtered[start:end]
|
|
|
|
stats := UserPageStats{TotalUsers: len(users)}
|
|
for _, user := range users {
|
|
if strings.HasPrefix(user.PasswordHash, "invite/") {
|
|
stats.PendingInvites++
|
|
} else if user.Disabled {
|
|
stats.DisabledUsers++
|
|
} else {
|
|
stats.ActiveUsers++
|
|
}
|
|
}
|
|
|
|
rows := make([]UserRow, len(pageUsers))
|
|
for i, user := range pageUsers {
|
|
role := ""
|
|
tagNames := make([]string, 0, len(user.TagIDs))
|
|
for _, tagID := range user.TagIDs {
|
|
if tag, ok := tagMap[tagID]; ok {
|
|
tagNames = append(tagNames, tag.Name)
|
|
if tag.Permissions.AdminAccess && role == "" {
|
|
role = tag.Name
|
|
} else if role == "" {
|
|
role = tag.Name
|
|
}
|
|
}
|
|
}
|
|
if role == "" {
|
|
role = "user"
|
|
}
|
|
|
|
plan := "standard"
|
|
for _, tagID := range user.TagIDs {
|
|
if tag, ok := tagMap[tagID]; ok && strings.EqualFold(tag.Name, "admin") {
|
|
plan = "unlimited"
|
|
break
|
|
}
|
|
}
|
|
|
|
isInvite := strings.HasPrefix(user.PasswordHash, "invite/")
|
|
status := userStatus(user.Disabled)
|
|
if isInvite {
|
|
status = "pending"
|
|
}
|
|
|
|
rows[i] = UserRow{
|
|
ID: user.ID,
|
|
Username: user.Username,
|
|
Email: user.Email,
|
|
Status: status,
|
|
Role: role,
|
|
TagIDs: user.TagIDs,
|
|
Tags: strings.Join(tagNames, ", "),
|
|
Plan: plan,
|
|
PolicySummary: "system default",
|
|
BoxCount: 0,
|
|
APIKeyCount: 0,
|
|
CreatedAt: formatTime(user.CreatedAt),
|
|
LastSeen: "-",
|
|
Disabled: user.Disabled,
|
|
IsCurrent: false,
|
|
IsInvite: isInvite,
|
|
}
|
|
}
|
|
|
|
return UserPage{
|
|
Rows: rows,
|
|
Page: page,
|
|
PageSize: pageSize,
|
|
Total: total,
|
|
HasPrev: page > 1,
|
|
HasNext: page < totalPages,
|
|
PrevPage: page - 1,
|
|
NextPage: page + 1,
|
|
TotalPages: totalPages,
|
|
Stats: stats,
|
|
}, nil
|
|
}
|
|
|
|
func userStatus(disabled bool) string {
|
|
if disabled {
|
|
return "disabled"
|
|
}
|
|
return "active"
|
|
}
|
|
|
|
func formatTime(t time.Time) string {
|
|
if t.IsZero() {
|
|
return "-"
|
|
}
|
|
return t.UTC().Format("2006-01-02 15:04")
|
|
}
|
|
|
|
func (store *Store) BulkSetUsersDisabled(ids []string, disabled bool) error {
|
|
return store.db.Update(func(txn *badger.Txn) error {
|
|
for _, id := range ids {
|
|
var user User
|
|
if err := getJSON(txn, userKey(id), &user); err != nil {
|
|
if errors.Is(err, ErrNotFound) {
|
|
continue
|
|
}
|
|
return err
|
|
}
|
|
user.Disabled = disabled
|
|
user.UpdatedAt = time.Now().UTC()
|
|
if err := putJSON(txn, userKey(id), user); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (store *Store) RevokeUserSessions(userID string) error {
|
|
tokens, err := store.sessionTokensForUser(userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return store.db.Update(func(txn *badger.Txn) error {
|
|
for _, token := range tokens {
|
|
if err := txn.Delete(sessionKey(token)); err != nil && !errors.Is(err, badger.ErrKeyNotFound) {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (store *Store) BulkRevokeUserSessions(ids []string) error {
|
|
for _, id := range ids {
|
|
if err := store.RevokeUserSessions(id); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (store *Store) sessionTokensForUser(userID string) ([]string, error) {
|
|
tokens := []string{}
|
|
err := store.db.View(func(txn *badger.Txn) error {
|
|
opts := badger.DefaultIteratorOptions
|
|
opts.Prefix = []byte("session/")
|
|
it := txn.NewIterator(opts)
|
|
defer it.Close()
|
|
for it.Rewind(); it.Valid(); it.Next() {
|
|
var session Session
|
|
if err := it.Item().Value(func(data []byte) error {
|
|
return json.Unmarshal(data, &session)
|
|
}); err != nil {
|
|
continue
|
|
}
|
|
if session.UserID == userID {
|
|
tokens = append(tokens, session.Token)
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
return tokens, err
|
|
}
|
|
|
|
func (store *Store) CountAdminUsers(adminTagID string) (int, error) {
|
|
users, err := store.ListUsers()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
count := 0
|
|
for _, user := range users {
|
|
if user.Disabled {
|
|
continue
|
|
}
|
|
for _, tagID := range user.TagIDs {
|
|
if tagID == adminTagID {
|
|
count++
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return count, nil
|
|
}
|
|
|
|
func (store *Store) CreateUserWithoutPassword(username string, email string, tagIDs []string) (User, error) {
|
|
hash, err := helpers.RandomHexID(32)
|
|
if err != nil {
|
|
return User{}, err
|
|
}
|
|
user := User{
|
|
Username: username,
|
|
Email: email,
|
|
PasswordHash: "invite/" + hash,
|
|
TagIDs: uniqueStrings(tagIDs),
|
|
Disabled: true,
|
|
}
|
|
if err := store.CreateUser(&user); err != nil {
|
|
return User{}, err
|
|
}
|
|
return user, nil
|
|
}
|
|
|
|
func (store *Store) getUserByIndex(key []byte) (User, bool, error) {
|
|
var id string
|
|
err := store.db.View(func(txn *badger.Txn) error {
|
|
item, err := txn.Get(key)
|
|
if errors.Is(err, badger.ErrKeyNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return item.Value(func(data []byte) error {
|
|
id = string(data)
|
|
return nil
|
|
})
|
|
})
|
|
if errors.Is(err, ErrNotFound) {
|
|
return User{}, false, nil
|
|
}
|
|
if err != nil {
|
|
return User{}, false, err
|
|
}
|
|
return store.GetUser(id)
|
|
}
|
|
|
|
func putJSON(txn *badger.Txn, key []byte, value any) error {
|
|
data, err := json.Marshal(value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return txn.Set(key, data)
|
|
}
|
|
|
|
func getJSON(txn *badger.Txn, key []byte, value any) error {
|
|
item, err := txn.Get(key)
|
|
if errors.Is(err, badger.ErrKeyNotFound) {
|
|
return ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return item.Value(func(data []byte) error {
|
|
return json.Unmarshal(data, value)
|
|
})
|
|
}
|
|
|
|
func keyExists(txn *badger.Txn, key []byte) (bool, error) {
|
|
_, err := txn.Get(key)
|
|
if errors.Is(err, badger.ErrKeyNotFound) {
|
|
return false, nil
|
|
}
|
|
return err == nil, err
|
|
}
|
|
|
|
func settingKey(name string) []byte {
|
|
return []byte("setting/" + strings.TrimSpace(name))
|
|
}
|
|
|
|
func userKey(id string) []byte {
|
|
return []byte("user/" + strings.TrimSpace(id))
|
|
}
|
|
|
|
func usernameKey(username string) []byte {
|
|
return []byte("user_by_name/" + normalizeIndex(username))
|
|
}
|
|
|
|
func emailKey(email string) []byte {
|
|
return []byte("user_by_email/" + normalizeIndex(email))
|
|
}
|
|
|
|
func normalizeIndex(value string) string {
|
|
return strings.ToLower(strings.TrimSpace(value))
|
|
}
|
|
|
|
func uniqueStrings(values []string) []string {
|
|
seen := make(map[string]bool, len(values))
|
|
out := make([]string, 0, len(values))
|
|
for _, value := range values {
|
|
value = strings.TrimSpace(value)
|
|
if value == "" || seen[value] {
|
|
continue
|
|
}
|
|
seen[value] = true
|
|
out = append(out, value)
|
|
}
|
|
return out
|
|
}
|