This repository has been archived on 2024-01-19. You can view files and clone it, but cannot push or open issues or pull requests.
image_fun/lib/prodia/prodia.go

161 lines
3.4 KiB
Go

package prodia
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"os"
"time"
"danlegt.com/stablediffusion-friends/lib/rtypes"
)
const (
ProdiaURL = "https://api.prodia.com"
)
func getProdiaKey() string {
return os.Getenv("PRODIA_KEY")
}
func RequestGeneration(prompt string, negative string) (*rtypes.ProdiaGenerateResponse, error) {
req := rtypes.ProdiaGenerateRequest{
Model: rtypes.Deliberate_v2,
Prompt: prompt,
NegativePrompt: negative,
Steps: 35,
CFGScale: 6.5,
Seed: -1,
Upscale: false,
Sampler: rtypes.DPMpp2MK,
AspectRatio: rtypes.Square,
}
body, err := json.Marshal(req)
if err != nil {
return nil, err
}
r, err := http.NewRequest("POST", fmt.Sprintf("%s/v1/job", ProdiaURL), bytes.NewBuffer(body))
if err != nil {
return nil, err
}
var prodiaKey string = getProdiaKey()
if prodiaKey == "" {
panic("Missing Prodia key, please set the PRODIA_KEY variable in the env")
}
r.Header.Add("Accept", "application/json")
r.Header.Add("Content-Type", "application/json")
r.Header.Add("X-Prodia-Key", prodiaKey)
client := &http.Client{}
res, err := client.Do(r)
if err != nil {
return nil, err
}
defer res.Body.Close()
responseJson := &rtypes.ProdiaGenerateResponse{}
err = json.NewDecoder(res.Body).Decode(responseJson)
if err != nil {
return nil, err
}
return responseJson, nil
}
func FetchGeneration(jobId string) (*rtypes.ProdiaRetrieveResponseX, error) {
// We need to wait for this, so we will go ahead and do a while loop
poolingLimit := 30
r, err := http.NewRequest("GET", fmt.Sprintf("%s/v1/job/%s", ProdiaURL, jobId), nil)
if err != nil {
return nil, err
}
var prodiaKey string = getProdiaKey()
if prodiaKey == "" {
panic("Missing Prodia key, please set the PRODIA_KEY variable in the env")
}
r.Header.Add("Accept", "application/json")
r.Header.Add("Content-Type", "application/json")
r.Header.Add("X-Prodia-Key", prodiaKey)
client := &http.Client{}
responseJson := &rtypes.ProdiaRetrieveResponse{}
for poolingLimit > 0 {
res, err := client.Do(r)
if err != nil {
return nil, err
}
defer res.Body.Close()
err = json.NewDecoder(res.Body).Decode(responseJson)
if err != nil {
return nil, err
}
if responseJson.Status == "succeeded" {
// Download the image and turn it into a base64
b64Image, err := imageUrlToBase64(responseJson.ImageUrl)
if err != nil {
return nil, err
}
newResp := &rtypes.ProdiaRetrieveResponseX{
Images: []string{b64Image},
Info: "Generated with Prodia",
}
return newResp, nil
} else if responseJson.Status == "error" {
return nil, errors.New("API Error")
}
time.Sleep(time.Second)
}
return nil, errors.New("timed out")
}
func imageUrlToBase64(imageUrl string) (string, error) {
r, err := http.NewRequest("GET", imageUrl, nil)
if err != nil {
return "", err
}
var prodiaKey string = getProdiaKey()
if prodiaKey == "" {
panic("Missing Prodia key, please set the PRODIA_KEY variable in the env")
}
r.Header.Add("Accept", "application/json")
r.Header.Add("Content-Type", "application/json")
r.Header.Add("X-Prodia-Key", prodiaKey)
client := &http.Client{}
res, err := client.Do(r)
if err != nil {
return "", err
}
imageData, err := ioutil.ReadAll(res.Body)
if err != nil {
return "", err
}
var base64Encoding string = base64.StdEncoding.EncodeToString(imageData)
return base64Encoding, nil
}