package services import ( "context" "fmt" "io" "os" "path" "sort" "strconv" "strings" "time" "github.com/pkg/sftp" "golang.org/x/crypto/ssh" ) type sftpStorageBackend struct { cfg StorageBackendConfig } func (b sftpStorageBackend) ID() string { return b.cfg.ID } func (b sftpStorageBackend) Type() string { return StorageBackendSFTP } func (b sftpStorageBackend) Put(ctx context.Context, key string, body io.Reader, _ int64, _ string) error { client, closer, err := b.client() if err != nil { return err } defer closer() if err := ctx.Err(); err != nil { return err } remotePath := b.remotePath(key) if err := client.MkdirAll(path.Dir(remotePath)); err != nil { return err } target, err := client.OpenFile(remotePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY) if err != nil { return err } defer target.Close() _, err = io.Copy(target, body) return err } func (b sftpStorageBackend) Get(ctx context.Context, key string) (StorageObject, error) { client, closer, err := b.client() if err != nil { return StorageObject{}, err } if err := ctx.Err(); err != nil { closer() return StorageObject{}, err } remotePath := b.remotePath(key) source, err := client.Open(remotePath) if err != nil { closer() return StorageObject{}, err } stat, err := source.Stat() if err != nil { source.Close() closer() return StorageObject{}, err } return StorageObject{Key: key, Size: stat.Size(), ModTime: stat.ModTime(), Body: closeWith(source, closer)}, nil } func (b sftpStorageBackend) Delete(ctx context.Context, key string) error { client, closer, err := b.client() if err != nil { return err } defer closer() if err := ctx.Err(); err != nil { return err } if err := client.Remove(b.remotePath(key)); err != nil && !os.IsNotExist(err) { return err } return nil } func (b sftpStorageBackend) DeletePrefix(ctx context.Context, prefix string) error { client, closer, err := b.client() if err != nil { return err } defer closer() if err := ctx.Err(); err != nil { return err } remotePath := b.remotePath(prefix) if err := client.RemoveDirectory(remotePath); err == nil || os.IsNotExist(err) { return nil } walker := client.Walk(remotePath) paths := make([]string, 0) for walker.Step() { if walker.Err() != nil { return walker.Err() } paths = append(paths, walker.Path()) } sort.Slice(paths, func(i, j int) bool { return len(paths[i]) > len(paths[j]) }) for _, item := range paths { if err := client.Remove(item); err != nil { _ = client.RemoveDirectory(item) } } _ = client.RemoveDirectory(remotePath) return nil } func (b sftpStorageBackend) Usage(ctx context.Context) (int64, error) { client, closer, err := b.client() if err != nil { return 0, err } defer closer() if err := ctx.Err(); err != nil { return 0, err } var total int64 walker := client.Walk(cleanRemoteRoot(b.cfg.RemotePath)) for walker.Step() { if walker.Err() != nil { return 0, walker.Err() } info := walker.Stat() if info != nil && !info.IsDir() { total += info.Size() } } return total, nil } func (b sftpStorageBackend) Test(ctx context.Context) error { key := ".warpbox-storage-test-" + randomID(6) if err := b.Put(ctx, key, strings.NewReader("ok"), 2, "text/plain"); err != nil { return err } return b.Delete(ctx, key) } func (b sftpStorageBackend) client() (*sftp.Client, func(), error) { auth := make([]ssh.AuthMethod, 0, 2) if b.cfg.PrivateKey != "" { signer, err := ssh.ParsePrivateKey([]byte(b.cfg.PrivateKey)) if err != nil { return nil, nil, err } auth = append(auth, ssh.PublicKeys(signer)) } if b.cfg.Password != "" { auth = append(auth, ssh.Password(b.cfg.Password)) } if len(auth) == 0 { return nil, nil, fmt.Errorf("sftp password or private key is required") } hostKeyCallback, err := b.hostKeyCallback() if err != nil { return nil, nil, err } sshClient, err := ssh.Dial("tcp", b.cfg.Host+":"+strconv.Itoa(b.cfg.Port), &ssh.ClientConfig{ User: b.cfg.Username, Auth: auth, HostKeyCallback: hostKeyCallback, Timeout: 15 * time.Second, }) if err != nil { return nil, nil, err } client, err := sftp.NewClient(sshClient) if err != nil { sshClient.Close() return nil, nil, err } return client, func() { client.Close() sshClient.Close() }, nil } func (b sftpStorageBackend) hostKeyCallback() (ssh.HostKeyCallback, error) { if strings.TrimSpace(b.cfg.HostKey) == "" { return ssh.InsecureIgnoreHostKey(), nil } key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(strings.TrimSpace(b.cfg.HostKey))) if err != nil { return nil, fmt.Errorf("invalid sftp host public key: %w", err) } return ssh.FixedHostKey(key), nil } func (b sftpStorageBackend) remotePath(key string) string { return path.Join(cleanRemoteRoot(b.cfg.RemotePath), cleanObjectKey(key)) }