Prodia
This commit is contained in:
161
lib/prodia/prodia.go
Normal file
161
lib/prodia/prodia.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package prodia
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"danlegt.com/stablediffusion-friends/lib/rtypes"
|
||||
)
|
||||
|
||||
const (
|
||||
ProdiaURL = "https://api.prodia.com"
|
||||
)
|
||||
|
||||
func getProdiaKey() string {
|
||||
return os.Getenv("PRODIA_KEY")
|
||||
}
|
||||
|
||||
func RequestGeneration() (*rtypes.ProdiaGenerateResponse, error) {
|
||||
req := rtypes.ProdiaGenerateRequest{
|
||||
Model: rtypes.Deliberate_v2,
|
||||
Prompt: "beautiful woman wearing military clothing, woman, girl, beautiful, masterpiece, pretty, blue eyes, military, uniform, army",
|
||||
NegativePrompt: "man, ugly, destroyed, nsfw, nudity",
|
||||
Steps: 25,
|
||||
CFGScale: 6.5,
|
||||
Seed: -1,
|
||||
Upscale: true,
|
||||
Sampler: rtypes.DPMpp2MK,
|
||||
AspectRatio: rtypes.Square,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, err := http.NewRequest("POST", fmt.Sprintf("%s/v1/job", ProdiaURL), bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var prodiaKey string = getProdiaKey()
|
||||
if prodiaKey == "" {
|
||||
panic("Missing Prodia key, please set the PRODIA_KEY variable in the env")
|
||||
}
|
||||
|
||||
r.Header.Add("Accept", "application/json")
|
||||
r.Header.Add("Content-Type", "application/json")
|
||||
r.Header.Add("X-Prodia-Key", prodiaKey)
|
||||
client := &http.Client{}
|
||||
res, err := client.Do(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
|
||||
responseJson := &rtypes.ProdiaGenerateResponse{}
|
||||
err = json.NewDecoder(res.Body).Decode(responseJson)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return responseJson, nil
|
||||
}
|
||||
|
||||
func FetchGeneration(jobId string) (*rtypes.ProdiaRetrieveResponseX, error) {
|
||||
|
||||
// We need to wait for this, so we will go ahead and do a while loop
|
||||
poolingLimit := 30
|
||||
|
||||
r, err := http.NewRequest("GET", fmt.Sprintf("%s/v1/job/%s", ProdiaURL, jobId), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var prodiaKey string = getProdiaKey()
|
||||
if prodiaKey == "" {
|
||||
panic("Missing Prodia key, please set the PRODIA_KEY variable in the env")
|
||||
}
|
||||
|
||||
r.Header.Add("Accept", "application/json")
|
||||
r.Header.Add("Content-Type", "application/json")
|
||||
r.Header.Add("X-Prodia-Key", prodiaKey)
|
||||
client := &http.Client{}
|
||||
|
||||
responseJson := &rtypes.ProdiaRetrieveResponse{}
|
||||
|
||||
for poolingLimit > 0 {
|
||||
res, err := client.Do(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
err = json.NewDecoder(res.Body).Decode(responseJson)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if responseJson.Status == "succeeded" {
|
||||
// Download the image and turn it into a base64
|
||||
b64Image, err := imageUrlToBase64(responseJson.ImageUrl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newResp := &rtypes.ProdiaRetrieveResponseX{
|
||||
Images: []string{b64Image},
|
||||
Info: "Generated with Prodia",
|
||||
}
|
||||
|
||||
return newResp, nil
|
||||
} else if responseJson.Status == "error" {
|
||||
return nil, errors.New("API Error")
|
||||
}
|
||||
|
||||
fmt.Println("Not done yet, waiting... ::" + responseJson.Status)
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
return nil, errors.New("timed out")
|
||||
|
||||
}
|
||||
|
||||
func imageUrlToBase64(imageUrl string) (string, error) {
|
||||
|
||||
r, err := http.NewRequest("GET", imageUrl, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var prodiaKey string = getProdiaKey()
|
||||
if prodiaKey == "" {
|
||||
panic("Missing Prodia key, please set the PRODIA_KEY variable in the env")
|
||||
}
|
||||
|
||||
r.Header.Add("Accept", "application/json")
|
||||
r.Header.Add("Content-Type", "application/json")
|
||||
r.Header.Add("X-Prodia-Key", prodiaKey)
|
||||
client := &http.Client{}
|
||||
res, err := client.Do(r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
imageData, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var base64Encoding string = base64.StdEncoding.EncodeToString(imageData)
|
||||
|
||||
return base64Encoding, nil
|
||||
}
|
||||
Reference in New Issue
Block a user