diff --git a/lib/boxstore/store.go b/lib/boxstore/store.go index 27e7010..26c39eb 100644 --- a/lib/boxstore/store.go +++ b/lib/boxstore/store.go @@ -18,6 +18,8 @@ import ( "sync" "time" + "golang.org/x/crypto/bcrypt" + "warpbox/lib/helpers" "warpbox/lib/models" ) @@ -76,16 +78,39 @@ func BoxPath(boxID string) string { return filepath.Join(uploadRoot, boxID) } +func safeBoxPath(boxID string) (string, bool) { + if !ValidBoxID(boxID) { + return "", false + } + return helpers.SafeChildPath(uploadRoot, boxID) +} + func ManifestPath(boxID string) string { return filepath.Join(BoxPath(boxID), manifestFile) } func SafeBoxFilePath(boxID string, filename string) (string, bool) { - return helpers.SafeChildPath(BoxPath(boxID), filename) + boxPath, ok := safeBoxPath(boxID) + if !ok { + return "", false + } + return helpers.SafeChildPath(boxPath, filename) +} + +func IsSafeRegularBoxFile(boxID string, filename string) bool { + path, ok := SafeBoxFilePath(boxID, filename) + if !ok { + return false + } + return ensureRegularFile(path) == nil } func DeleteBox(boxID string) error { - return os.RemoveAll(BoxPath(boxID)) + boxPath, ok := safeBoxPath(boxID) + if !ok { + return fmt.Errorf("Invalid box id") + } + return os.RemoveAll(boxPath) } func ListBoxSummaries() ([]models.BoxSummary, error) { @@ -218,18 +243,17 @@ func CreateManifest(boxID string, request models.CreateBoxRequest) ([]models.Box } if password := strings.TrimSpace(request.Password); password != "" { - salt, err := helpers.RandomHexID(16) - if err != nil { - return nil, fmt.Errorf("Could not secure upload box") - } - authToken, err := helpers.RandomHexID(16) if err != nil { return nil, fmt.Errorf("Could not secure upload box") } + passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return nil, fmt.Errorf("Could not secure upload box") + } - manifest.PasswordSalt = salt - manifest.PasswordHash = passwordHash(salt, password) + manifest.PasswordHash = string(passwordHash) + manifest.PasswordHashAlg = "bcrypt" manifest.AuthToken = authToken } @@ -250,7 +274,7 @@ func IsExpired(manifest models.BoxManifest) bool { } func IsPasswordProtected(manifest models.BoxManifest) bool { - return manifest.PasswordSalt != "" && manifest.PasswordHash != "" && manifest.AuthToken != "" + return manifest.PasswordHash != "" && manifest.AuthToken != "" } func VerifyPassword(manifest models.BoxManifest, password string) bool { @@ -259,7 +283,11 @@ func VerifyPassword(manifest models.BoxManifest, password string) bool { } expected := manifest.PasswordHash - actual := passwordHash(manifest.PasswordSalt, password) + if manifest.PasswordHashAlg == "bcrypt" || strings.HasPrefix(expected, "$2") { + return bcrypt.CompareHashAndPassword([]byte(expected), []byte(password)) == nil + } + + actual := legacyPasswordHash(manifest.PasswordSalt, password) return subtle.ConstantTimeCompare([]byte(expected), []byte(actual)) == 1 } @@ -335,13 +363,25 @@ func RenewManifest(boxID string, seconds int64) (models.BoxManifest, error) { } func AddFileToZip(zipWriter *zip.Writer, boxID string, filename string) error { - source, err := os.Open(filepath.Join(BoxPath(boxID), filename)) + path, ok := SafeBoxFilePath(boxID, filename) + if !ok { + return fmt.Errorf("Invalid file") + } + if err := ensureRegularFile(path); err != nil { + return err + } + zipName, ok := safeZipEntryName(filename) + if !ok { + return fmt.Errorf("Invalid zip entry") + } + + source, err := os.Open(path) if err != nil { return err } defer source.Close() - destination, err := zipWriter.Create(filename) + destination, err := zipWriter.Create(zipName) if err != nil { return err } @@ -358,6 +398,9 @@ func SaveManifestUpload(boxID string, fileID string, file *multipart.FileHeader) if err != nil { return models.BoxFile{}, err } + if IsExpired(manifest) { + return models.BoxFile{}, fmt.Errorf("Box expired") + } fileIndex := -1 for index, manifestFile := range manifest.Files { @@ -376,7 +419,10 @@ func SaveManifestUpload(boxID string, fileID string, file *multipart.FileHeader) return models.BoxFile{}, fmt.Errorf("Could not prepare upload box") } - destination := filepath.Join(BoxPath(boxID), filename) + destination, ok := SafeBoxFilePath(boxID, filename) + if !ok { + return models.BoxFile{}, fmt.Errorf("Invalid filename") + } if err := saveMultipartFile(file, destination); err != nil { manifest.Files[fileIndex].Status = models.FileStatusFailed startRetentionIfTerminalUnlocked(&manifest) @@ -407,7 +453,10 @@ func SaveUpload(boxID string, file *multipart.FileHeader) (models.BoxFile, error } filename = helpers.UniqueFilename(boxPath, filename) - destination := filepath.Join(boxPath, filename) + destination, ok := SafeBoxFilePath(boxID, filename) + if !ok { + return models.BoxFile{}, fmt.Errorf("Invalid filename") + } if err := saveMultipartFile(file, destination); err != nil { return models.BoxFile{}, fmt.Errorf("Could not save uploaded file") } @@ -423,7 +472,9 @@ func SaveUpload(boxID string, file *multipart.FileHeader) (models.BoxFile, error func DecorateFile(boxID string, file models.BoxFile) models.BoxFile { if file.MimeType == "" { - file.MimeType = helpers.MimeTypeForFile(filepath.Join(BoxPath(boxID), file.Name), file.Name) + if path, ok := SafeBoxFilePath(boxID, file.Name); ok { + file.MimeType = helpers.MimeTypeForFile(path, file.Name) + } } if file.SizeLabel == "" { @@ -495,9 +546,12 @@ func reconcileManifest(boxID string) (models.BoxManifest, error) { changed := false for index, file := range manifest.Files { - path := filepath.Join(BoxPath(boxID), file.Name) + path, ok := SafeBoxFilePath(boxID, file.Name) + if !ok || ensureRegularFile(path) != nil { + continue + } info, err := os.Stat(path) - if err != nil { + if err != nil || !info.Mode().IsRegular() { continue } @@ -531,7 +585,7 @@ func listCompletedFilesFromDisk(boxID string) ([]models.BoxFile, error) { files := make([]models.BoxFile, 0, len(entries)) for _, entry := range entries { - if entry.IsDir() || entry.Name() == manifestFile { + if entry.IsDir() || entry.Name() == manifestFile || entry.Type()&os.ModeSymlink != 0 { continue } @@ -539,6 +593,9 @@ func listCompletedFilesFromDisk(boxID string) ([]models.BoxFile, error) { if err != nil { return nil, err } + if !info.Mode().IsRegular() { + continue + } name := entry.Name() files = append(files, DecorateFile(boxID, models.BoxFile{ @@ -601,7 +658,7 @@ func startRetentionIfTerminalUnlocked(manifest *models.BoxManifest) { manifest.ExpiresAt = time.Now().UTC().Add(time.Duration(seconds) * time.Second) } -func passwordHash(salt string, password string) string { +func legacyPasswordHash(salt string, password string) string { sum := sha256.Sum256([]byte(salt + ":" + password)) return hex.EncodeToString(sum[:]) } @@ -624,12 +681,64 @@ func saveMultipartFile(file *multipart.FileHeader, destination string) error { } defer source.Close() - target, err := os.Create(destination) + target, tempPath, err := createTempSibling(destination) if err != nil { return err } - defer target.Close() + committed := false + defer func() { + target.Close() + if !committed { + os.Remove(tempPath) + } + }() - _, err = io.Copy(target, source) - return err + if _, err := io.Copy(target, source); err != nil { + return err + } + if err := target.Close(); err != nil { + return err + } + if err := os.Rename(tempPath, destination); err != nil { + return err + } + committed = true + return nil +} + +func createTempSibling(destination string) (*os.File, string, error) { + directory := filepath.Dir(destination) + if err := os.MkdirAll(directory, 0755); err != nil { + return nil, "", err + } + + target, err := os.CreateTemp(directory, ".warpbox-upload-*") + if err != nil { + return nil, "", err + } + return target, target.Name(), nil +} + +func safeZipEntryName(filename string) (string, bool) { + filename = strings.TrimSpace(filename) + if filename == "" || filepath.IsAbs(filename) { + return "", false + } + + cleaned := filepath.ToSlash(filepath.Clean(filename)) + if cleaned == "." || cleaned == ".." || strings.HasPrefix(cleaned, "../") || strings.HasPrefix(cleaned, "/") { + return "", false + } + return cleaned, true +} + +func ensureRegularFile(path string) error { + info, err := os.Lstat(path) + if err != nil { + return err + } + if info.Mode()&os.ModeSymlink != 0 || !info.Mode().IsRegular() { + return fmt.Errorf("Invalid file") + } + return nil } diff --git a/lib/boxstore/store_test.go b/lib/boxstore/store_test.go index ba003f4..e761c46 100644 --- a/lib/boxstore/store_test.go +++ b/lib/boxstore/store_test.go @@ -1,6 +1,10 @@ package boxstore import ( + "archive/zip" + "bytes" + "os" + "path/filepath" "testing" "time" @@ -59,3 +63,87 @@ func TestStartRetentionSkipsOneTimeDownload(t *testing.T) { t.Fatalf("expected one-time download box to avoid retention expiry, got %s", manifest.ExpiresAt) } } + +func TestSafeBoxFilePathRejectsTraversal(t *testing.T) { + restoreUploadRoot := UploadRoot() + defer SetUploadRoot(restoreUploadRoot) + SetUploadRoot(t.TempDir()) + + boxID := "0123456789abcdef0123456789abcdef" + if _, ok := SafeBoxFilePath(boxID, "../outside.txt"); ok { + t.Fatal("expected traversal to be rejected") + } + if _, ok := SafeBoxFilePath("../bad", "file.txt"); ok { + t.Fatal("expected invalid box id to be rejected") + } +} + +func TestAddFileToZipRejectsUnsafeManifestName(t *testing.T) { + restoreUploadRoot := UploadRoot() + defer SetUploadRoot(restoreUploadRoot) + SetUploadRoot(t.TempDir()) + + var buffer bytes.Buffer + zipWriter := zip.NewWriter(&buffer) + if err := AddFileToZip(zipWriter, "0123456789abcdef0123456789abcdef", "../outside.txt"); err == nil { + t.Fatal("expected unsafe zip filename to be rejected") + } +} + +func TestListFilesSkipsSymlinks(t *testing.T) { + restoreUploadRoot := UploadRoot() + defer SetUploadRoot(restoreUploadRoot) + SetUploadRoot(t.TempDir()) + + boxID := "0123456789abcdef0123456789abcdef" + if err := os.MkdirAll(BoxPath(boxID), 0755); err != nil { + t.Fatalf("MkdirAll returned error: %v", err) + } + if err := os.WriteFile(filepath.Join(BoxPath(boxID), "safe.txt"), []byte("safe"), 0644); err != nil { + t.Fatalf("WriteFile returned error: %v", err) + } + if err := os.Symlink(filepath.Join(BoxPath(boxID), "safe.txt"), filepath.Join(BoxPath(boxID), "link.txt")); err != nil { + t.Skipf("symlink unavailable: %v", err) + } + + files, err := ListFiles(boxID) + if err != nil { + t.Fatalf("ListFiles returned error: %v", err) + } + if len(files) != 1 || files[0].Name != "safe.txt" { + t.Fatalf("expected only regular file, got %#v", files) + } +} + +func TestBoxPasswordUsesBcryptAndVerifiesLegacy(t *testing.T) { + restoreUploadRoot := UploadRoot() + defer SetUploadRoot(restoreUploadRoot) + SetUploadRoot(t.TempDir()) + + boxID := "0123456789abcdef0123456789abcdef" + if err := os.MkdirAll(BoxPath(boxID), 0755); err != nil { + t.Fatalf("MkdirAll returned error: %v", err) + } + if _, err := CreateManifest(boxID, models.CreateBoxRequest{Password: "secret"}); err != nil { + t.Fatalf("CreateManifest returned error: %v", err) + } + manifest, err := ReadManifest(boxID) + if err != nil { + t.Fatalf("ReadManifest returned error: %v", err) + } + if manifest.PasswordHashAlg != "bcrypt" { + t.Fatalf("expected bcrypt password hash, got %q", manifest.PasswordHashAlg) + } + if !VerifyPassword(manifest, "secret") { + t.Fatal("expected bcrypt password to verify") + } + + legacy := models.BoxManifest{ + PasswordSalt: "salt", + PasswordHash: legacyPasswordHash("salt", "secret"), + AuthToken: "token", + } + if !VerifyPassword(legacy, "secret") { + t.Fatal("expected legacy password hash to verify") + } +} diff --git a/lib/boxstore/thumbnails.go b/lib/boxstore/thumbnails.go index 2299c52..d7dd289 100644 --- a/lib/boxstore/thumbnails.go +++ b/lib/boxstore/thumbnails.go @@ -140,7 +140,15 @@ func canGenerateThumbnail(file models.BoxFile) bool { } func generateThumbnail(task thumbnailTask) error { - source, err := os.Open(filepath.Join(BoxPath(task.BoxID), task.Name)) + sourcePath, ok := SafeBoxFilePath(task.BoxID, task.Name) + if !ok { + return os.ErrInvalid + } + if err := ensureRegularFile(sourcePath); err != nil { + return err + } + + source, err := os.Open(sourcePath) if err != nil { return err } @@ -161,15 +169,28 @@ func generateThumbnail(task thumbnailTask) error { return os.ErrInvalid } - target, err := os.Create(path) + target, tempPath, err := createTempSibling(path) if err != nil { return err } - defer target.Close() + committed := false + defer func() { + target.Close() + if !committed { + os.Remove(tempPath) + } + }() if err := jpeg.Encode(target, thumb, &jpeg.Options{Quality: 82}); err != nil { return err } + if err := target.Close(); err != nil { + return err + } + if err := os.Rename(tempPath, path); err != nil { + return err + } + committed = true return markThumbnailReady(task.BoxID, task.FileID) } diff --git a/lib/helpers/paths.go b/lib/helpers/paths.go index a1b2d45..af95f99 100644 --- a/lib/helpers/paths.go +++ b/lib/helpers/paths.go @@ -14,8 +14,19 @@ func SafeFilename(name string) (string, bool) { } func SafeChildPath(parent string, filename string) (string, bool) { - path := filepath.Join(parent, filename) - return path, strings.HasPrefix(path, parent+string(filepath.Separator)) + parent = filepath.Clean(parent) + filename = strings.TrimSpace(filename) + if parent == "" || filename == "" || filepath.IsAbs(filename) { + return "", false + } + + path := filepath.Clean(filepath.Join(parent, filename)) + relative, err := filepath.Rel(parent, path) + if err != nil || relative == "." || strings.HasPrefix(relative, ".."+string(filepath.Separator)) || relative == ".." { + return "", false + } + + return path, true } func UniqueFilename(directory string, filename string) string { diff --git a/lib/helpers/paths_test.go b/lib/helpers/paths_test.go new file mode 100644 index 0000000..dc3bf69 --- /dev/null +++ b/lib/helpers/paths_test.go @@ -0,0 +1,20 @@ +package helpers + +import ( + "path/filepath" + "testing" +) + +func TestSafeChildPathRejectsTraversalAndAbsolutePaths(t *testing.T) { + parent := filepath.Join(t.TempDir(), "parent") + + if _, ok := SafeChildPath(parent, "../outside.txt"); ok { + t.Fatal("expected traversal to be rejected") + } + if _, ok := SafeChildPath(parent, filepath.Join(string(filepath.Separator), "tmp", "outside.txt")); ok { + t.Fatal("expected absolute path to be rejected") + } + if path, ok := SafeChildPath(parent, "inside.txt"); !ok || path != filepath.Join(parent, "inside.txt") { + t.Fatalf("expected safe child path, got path=%q ok=%v", path, ok) + } +} diff --git a/lib/metastore/models.go b/lib/metastore/models.go index dc32ecc..1b39098 100644 --- a/lib/metastore/models.go +++ b/lib/metastore/models.go @@ -46,6 +46,7 @@ type TagPermissions struct { type Session struct { Token string `json:"token"` + CSRFToken string `json:"csrf_token"` UserID string `json:"user_id"` CreatedAt time.Time `json:"created_at"` ExpiresAt time.Time `json:"expires_at"` diff --git a/lib/metastore/sessions.go b/lib/metastore/sessions.go index 87a767b..b4684d5 100644 --- a/lib/metastore/sessions.go +++ b/lib/metastore/sessions.go @@ -23,9 +23,14 @@ func (store *Store) CreateSession(userID string, ttl time.Duration) (Session, er if err != nil { return Session{}, err } + csrfToken, err := helpers.RandomHexID(32) + if err != nil { + return Session{}, err + } now := time.Now().UTC() session := Session{ Token: token, + CSRFToken: csrfToken, UserID: userID, CreatedAt: now, ExpiresAt: now.Add(ttl), diff --git a/lib/models/models.go b/lib/models/models.go index 9789d38..1df17d4 100644 --- a/lib/models/models.go +++ b/lib/models/models.go @@ -49,6 +49,7 @@ type BoxManifest struct { RetentionSecs int64 `json:"retention_seconds"` PasswordSalt string `json:"password_salt,omitempty"` PasswordHash string `json:"password_hash,omitempty"` + PasswordHashAlg string `json:"password_hash_alg,omitempty"` AuthToken string `json:"auth_token,omitempty"` DisableZip bool `json:"disable_zip,omitempty"` OneTimeDownload bool `json:"one_time_download,omitempty"` diff --git a/lib/server/admin.go b/lib/server/admin.go index ce6eecd..eb70857 100644 --- a/lib/server/admin.go +++ b/lib/server/admin.go @@ -1,6 +1,7 @@ package server import ( + "crypto/subtle" "errors" "fmt" "net/http" @@ -57,6 +58,7 @@ type adminBoxRow struct { func (app *App) registerAdminRoutes(router *gin.Engine) { admin := router.Group("/admin") + admin.Use(noStoreAdminHeaders) admin.GET("/login", app.handleAdminLogin) admin.POST("/login", app.handleAdminLoginPost) @@ -132,6 +134,7 @@ func (app *App) handleAdminLogout(ctx *gin.Context) { func (app *App) handleAdminDashboard(ctx *gin.Context) { ctx.HTML(http.StatusOK, "admin.html", gin.H{ "CurrentUser": app.currentAdminUsername(ctx), + "CSRFToken": app.currentCSRFToken(ctx), }) } @@ -267,6 +270,7 @@ func (app *App) renderAdminUsers(ctx *gin.Context, errorMessage string) { ctx.HTML(http.StatusOK, "admin_users.html", gin.H{ "CurrentUser": app.currentAdminUsername(ctx), + "CSRFToken": app.currentCSRFToken(ctx), "Users": rows, "Tags": tags, "Error": errorMessage, @@ -330,6 +334,7 @@ func (app *App) renderAdminTags(ctx *gin.Context, errorMessage string) { } ctx.HTML(http.StatusOK, "admin_tags.html", gin.H{ "CurrentUser": app.currentAdminUsername(ctx), + "CSRFToken": app.currentCSRFToken(ctx), "Tags": rows, "Error": errorMessage, }) @@ -374,6 +379,7 @@ func (app *App) handleAdminSettingsPost(ctx *gin.Context) { func (app *App) renderAdminSettings(ctx *gin.Context, errorMessage string) { ctx.HTML(http.StatusOK, "admin_settings.html", gin.H{ "CurrentUser": app.currentAdminUsername(ctx), + "CSRFToken": app.currentCSRFToken(ctx), "Rows": app.config.SettingRows(), "OverridesAllowed": app.config.AllowAdminSettingsOverride, "Error": errorMessage, @@ -393,6 +399,11 @@ func (app *App) requireAdminSession(ctx *gin.Context) { ctx.Abort() return } + if !validAdminCSRF(ctx, session) { + ctx.String(http.StatusForbidden, "Permission denied") + ctx.Abort() + return + } user, ok, err := app.store.GetUser(session.UserID) if err != nil || !ok || user.Disabled { ctx.Redirect(http.StatusSeeOther, "/admin/login") @@ -407,6 +418,7 @@ func (app *App) requireAdminSession(ctx *gin.Context) { } ctx.Set("adminUser", user) ctx.Set("adminPerms", perms) + ctx.Set("adminCSRFToken", session.CSRFToken) ctx.Next() } @@ -458,6 +470,15 @@ func (app *App) currentAdminUsername(ctx *gin.Context) string { return "" } +func (app *App) currentCSRFToken(ctx *gin.Context) string { + if value, ok := ctx.Get("adminCSRFToken"); ok { + if token, ok := value.(string); ok { + return token + } + } + return "" +} + func (app *App) renderAdminLogin(ctx *gin.Context, errorMessage string) { ctx.HTML(http.StatusOK, "admin_login.html", gin.H{ "AdminLoginEnabled": app.adminLoginEnabled, @@ -465,6 +486,30 @@ func (app *App) renderAdminLogin(ctx *gin.Context, errorMessage string) { }) } +func noStoreAdminHeaders(ctx *gin.Context) { + ctx.Header("Cache-Control", "no-store") + ctx.Header("Pragma", "no-cache") + ctx.Header("X-Content-Type-Options", "nosniff") + ctx.Next() +} + +func validAdminCSRF(ctx *gin.Context, session metastore.Session) bool { + switch ctx.Request.Method { + case http.MethodGet, http.MethodHead, http.MethodOptions: + return true + } + + token := ctx.PostForm("csrf_token") + return token != "" && subtleConstantTimeEqual(token, session.CSRFToken) +} + +func subtleConstantTimeEqual(a string, b string) bool { + if len(a) != len(b) { + return false + } + return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1 +} + func parseTagPermissions(ctx *gin.Context) (metastore.TagPermissions, error) { maxFileSize, err := parseOptionalInt64(ctx.PostForm("max_file_size_bytes")) if err != nil { diff --git a/lib/server/handlers.go b/lib/server/handlers.go index 967ee16..f6210d5 100644 --- a/lib/server/handlers.go +++ b/lib/server/handlers.go @@ -7,6 +7,7 @@ import ( "net/http" "os" "strings" + "sync" "time" "github.com/gin-gonic/gin" @@ -18,6 +19,8 @@ import ( const boxAuthCookiePrefix = "warpbox_box_" +var oneTimeDownloadLocks sync.Map + func (app *App) handleIndex(ctx *gin.Context) { ctx.HTML(http.StatusOK, "index.html", gin.H{ "RetentionOptions": app.retentionOptions(), @@ -166,6 +169,10 @@ func (app *App) handleDownloadBox(ctx *gin.Context) { if !ok { return } + if hasManifest && manifest.OneTimeDownload { + app.handleOneTimeDownloadBox(ctx, boxID) + return + } if hasManifest && manifest.DisableZip { ctx.String(http.StatusForbidden, "Zip download disabled for this box") @@ -177,11 +184,45 @@ func (app *App) handleDownloadBox(ctx *gin.Context) { ctx.String(http.StatusNotFound, "Box not found") return } - if hasManifest && manifest.OneTimeDownload && !allFilesComplete(files) { - ctx.String(http.StatusConflict, "Box is not ready yet") + if !app.writeBoxZip(ctx, boxID, files) { + return + } + if hasManifest && app.config.RenewOnDownloadEnabled { + boxstore.RenewManifest(boxID, manifest.RetentionSecs) + } +} + +func (app *App) handleOneTimeDownloadBox(ctx *gin.Context, boxID string) { + lock := oneTimeDownloadLock(boxID) + lock.Lock() + defer lock.Unlock() + defer oneTimeDownloadLocks.Delete(boxID) + + manifest, hasManifest, ok := app.authorizeBoxRequest(ctx, boxID, true) + if !ok { + return + } + if !hasManifest || !manifest.OneTimeDownload { + ctx.String(http.StatusNotFound, "Box not found") return } + files, err := boxstore.ListFiles(boxID) + if err != nil { + ctx.String(http.StatusNotFound, "Box not found") + return + } + if !allFilesComplete(files) { + ctx.String(http.StatusConflict, "Box is not ready yet") + return + } + if !app.writeBoxZip(ctx, boxID, files) { + return + } + boxstore.DeleteBox(boxID) +} + +func (app *App) writeBoxZip(ctx *gin.Context, boxID string, files []models.BoxFile) bool { ctx.Header("Content-Type", "application/zip") ctx.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="warpbox-%s.zip"`, boxID)) @@ -197,25 +238,24 @@ func (app *App) handleDownloadBox(ctx *gin.Context) { if !file.IsComplete { continue } - if err := boxstore.AddFileToZip(zipWriter, boxID, file.Name); err != nil { ctx.Status(http.StatusInternalServerError) - return + return false } } if err := zipWriter.Close(); err != nil { zipClosed = true ctx.Status(http.StatusInternalServerError) - return + return false } zipClosed = true + return true +} - if hasManifest && manifest.OneTimeDownload { - boxstore.DeleteBox(boxID) - } else if hasManifest && app.config.RenewOnDownloadEnabled { - boxstore.RenewManifest(boxID, manifest.RetentionSecs) - } +func oneTimeDownloadLock(boxID string) *sync.Mutex { + lock, _ := oneTimeDownloadLocks.LoadOrStore(boxID, &sync.Mutex{}) + return lock.(*sync.Mutex) } func allFilesComplete(files []models.BoxFile) bool { @@ -259,6 +299,10 @@ func (app *App) handleDownloadFile(ctx *gin.Context) { ctx.String(http.StatusNotFound, "File not found") return } + if !boxstore.IsSafeRegularBoxFile(boxID, filename) { + ctx.String(http.StatusBadRequest, "Invalid file") + return + } ctx.FileAttachment(path, filename) if hasManifest && app.config.RenewOnDownloadEnabled { @@ -297,6 +341,7 @@ func (app *App) handleCreateBox(ctx *gin.Context) { if !app.requireAPI(ctx) || !app.requireGuestUploads(ctx) { return } + app.limitRequestBody(ctx) boxID, err := boxstore.NewBoxID() if err != nil { @@ -332,6 +377,7 @@ func (app *App) handleManifestFileUpload(ctx *gin.Context) { if !app.requireAPI(ctx) || !app.requireGuestUploads(ctx) { return } + app.limitRequestBody(ctx) boxID := ctx.Param("id") fileID := ctx.Param("file_id") @@ -366,11 +412,12 @@ func (app *App) handleFileStatusUpdate(ctx *gin.Context) { if !app.requireAPI(ctx) { return } + app.limitRequestBody(ctx) boxID := ctx.Param("id") fileID := ctx.Param("file_id") - if !boxstore.ValidBoxID(boxID) { - ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid box id"}) + if !boxstore.ValidBoxID(boxID) || !helpers.ValidLowerHexID(fileID, 16) { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid file"}) return } @@ -379,6 +426,14 @@ func (app *App) handleFileStatusUpdate(ctx *gin.Context) { ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid status payload"}) return } + if request.Status == models.FileStatusReady { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "Uploads must complete through the upload endpoint"}) + return + } + if err := app.rejectExpiredManifestBox(boxID); err != nil { + ctx.JSON(http.StatusGone, gin.H{"error": err.Error()}) + return + } file, err := boxstore.MarkFileStatus(boxID, fileID, request.Status) if err != nil { @@ -393,6 +448,7 @@ func (app *App) handleDirectBoxUpload(ctx *gin.Context) { if !app.requireAPI(ctx) || !app.requireGuestUploads(ctx) { return } + app.limitRequestBody(ctx) boxID := ctx.Param("id") if !boxstore.ValidBoxID(boxID) { @@ -423,6 +479,7 @@ func (app *App) handleLegacyUpload(ctx *gin.Context) { if !app.requireAPI(ctx) || !app.requireGuestUploads(ctx) { return } + app.limitRequestBody(ctx) form, err := ctx.MultipartForm() if err != nil { @@ -580,14 +637,18 @@ func (app *App) validateManifestFileUpload(boxID string, fileID string, size int if err := app.validateFileSize(size); err != nil { return err } - if app.config.GlobalMaxBoxSizeBytes <= 0 { - return nil - } manifest, err := boxstore.ReadManifest(boxID) if err != nil { return app.validateIncomingFile(boxID, size) } + if boxstore.IsExpired(manifest) { + _ = boxstore.DeleteBox(boxID) + return fmt.Errorf("Box expired") + } + if app.config.GlobalMaxBoxSizeBytes <= 0 { + return nil + } totalSize := int64(0) found := false for _, file := range manifest.Files { @@ -624,6 +685,37 @@ func (app *App) validateBoxSize(size int64) error { return nil } +func (app *App) rejectExpiredManifestBox(boxID string) error { + manifest, err := boxstore.ReadManifest(boxID) + if err != nil { + return nil + } + if !boxstore.IsExpired(manifest) { + return nil + } + _ = boxstore.DeleteBox(boxID) + return fmt.Errorf("Box expired") +} + +func (app *App) limitRequestBody(ctx *gin.Context) { + limit := app.maxRequestBodyBytes() + if limit <= 0 { + return + } + ctx.Request.Body = http.MaxBytesReader(ctx.Writer, ctx.Request.Body, limit) +} + +func (app *App) maxRequestBodyBytes() int64 { + limit := app.config.GlobalMaxBoxSizeBytes + if limit <= 0 || app.config.GlobalMaxFileSizeBytes > limit { + limit = app.config.GlobalMaxFileSizeBytes + } + if limit <= 0 { + return 0 + } + return limit + 10*1024*1024 +} + func (app *App) retentionAllowed(key string) bool { key = strings.TrimSpace(key) if key == "" { diff --git a/lib/server/security_test.go b/lib/server/security_test.go new file mode 100644 index 0000000..3a0d985 --- /dev/null +++ b/lib/server/security_test.go @@ -0,0 +1,79 @@ +package server + +import ( + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/gin-gonic/gin" + + "warpbox/lib/boxstore" + "warpbox/lib/config" + "warpbox/lib/metastore" + "warpbox/lib/models" +) + +func TestValidateManifestFileUploadRejectsExpiredBox(t *testing.T) { + restoreUploadRoot := boxstore.UploadRoot() + defer boxstore.SetUploadRoot(restoreUploadRoot) + boxstore.SetUploadRoot(t.TempDir()) + + boxID := "0123456789abcdef0123456789abcdef" + if err := os.MkdirAll(boxstore.BoxPath(boxID), 0755); err != nil { + t.Fatalf("MkdirAll returned error: %v", err) + } + manifest := models.BoxManifest{ + Files: []models.BoxFile{{ID: "0123456789abcdef", Name: "file.txt", Status: models.FileStatusWait}}, + ExpiresAt: time.Now().UTC().Add(-time.Second), + } + if err := boxstore.WriteManifest(boxID, manifest); err != nil { + t.Fatalf("WriteManifest returned error: %v", err) + } + + app := &App{config: &config.Config{}} + if err := app.validateManifestFileUpload(boxID, "0123456789abcdef", 1); err == nil { + t.Fatal("expected expired box upload to be rejected") + } + if _, err := os.Stat(boxstore.BoxPath(boxID)); !os.IsNotExist(err) { + t.Fatalf("expected expired box to be deleted, stat err=%v", err) + } +} + +func TestAdminProtectedPostRequiresCSRF(t *testing.T) { + gin.SetMode(gin.TestMode) + + store, err := metastore.Open(t.TempDir()) + if err != nil { + t.Fatalf("Open returned error: %v", err) + } + defer store.Close() + + adminTag, err := store.EnsureAdminTag() + if err != nil { + t.Fatalf("EnsureAdminTag returned error: %v", err) + } + user, err := store.CreateUserWithPassword("admin", "", "secret", []string{adminTag.ID}) + if err != nil { + t.Fatalf("CreateUserWithPassword returned error: %v", err) + } + session, err := store.CreateSession(user.ID, time.Hour) + if err != nil { + t.Fatalf("CreateSession returned error: %v", err) + } + + app := &App{config: &config.Config{}, store: store} + router := gin.New() + router.POST("/admin/test", app.requireAdminSession, func(ctx *gin.Context) { + ctx.Status(http.StatusNoContent) + }) + + request := httptest.NewRequest(http.MethodPost, "/admin/test", nil) + request.AddCookie(&http.Cookie{Name: adminSessionCookie, Value: session.Token}) + response := httptest.NewRecorder() + router.ServeHTTP(response, request) + if response.Code != http.StatusForbidden { + t.Fatalf("expected missing CSRF token to be forbidden, got %d", response.Code) + } +} diff --git a/templates/admin.html b/templates/admin.html index 3a398c8..bf7044a 100644 --- a/templates/admin.html +++ b/templates/admin.html @@ -20,12 +20,13 @@