package prodia import ( "bytes" "encoding/base64" "encoding/json" "errors" "fmt" "io/ioutil" "net/http" "os" "time" "danlegt.com/stablediffusion-friends/lib/rtypes" ) const ( ProdiaURL = "https://api.prodia.com" ) func getProdiaKey() string { return os.Getenv("PRODIA_KEY") } func RequestGeneration(prompt string, negative string) (*rtypes.ProdiaGenerateResponse, error) { req := rtypes.ProdiaGenerateRequest{ Model: rtypes.Deliberate_v2, Prompt: prompt, NegativePrompt: negative, Steps: 35, CFGScale: 6.5, Seed: -1, Upscale: false, Sampler: rtypes.DPMpp2MK, AspectRatio: rtypes.Square, } body, err := json.Marshal(req) if err != nil { return nil, err } r, err := http.NewRequest("POST", fmt.Sprintf("%s/v1/job", ProdiaURL), bytes.NewBuffer(body)) if err != nil { return nil, err } var prodiaKey string = getProdiaKey() if prodiaKey == "" { panic("Missing Prodia key, please set the PRODIA_KEY variable in the env") } r.Header.Add("Accept", "application/json") r.Header.Add("Content-Type", "application/json") r.Header.Add("X-Prodia-Key", prodiaKey) client := &http.Client{} res, err := client.Do(r) if err != nil { return nil, err } defer res.Body.Close() responseJson := &rtypes.ProdiaGenerateResponse{} err = json.NewDecoder(res.Body).Decode(responseJson) if err != nil { return nil, err } return responseJson, nil } func FetchGeneration(jobId string) (*rtypes.ProdiaRetrieveResponseX, error) { // We need to wait for this, so we will go ahead and do a while loop poolingLimit := 30 r, err := http.NewRequest("GET", fmt.Sprintf("%s/v1/job/%s", ProdiaURL, jobId), nil) if err != nil { return nil, err } var prodiaKey string = getProdiaKey() if prodiaKey == "" { panic("Missing Prodia key, please set the PRODIA_KEY variable in the env") } r.Header.Add("Accept", "application/json") r.Header.Add("Content-Type", "application/json") r.Header.Add("X-Prodia-Key", prodiaKey) client := &http.Client{} responseJson := &rtypes.ProdiaRetrieveResponse{} for poolingLimit > 0 { res, err := client.Do(r) if err != nil { return nil, err } defer res.Body.Close() err = json.NewDecoder(res.Body).Decode(responseJson) if err != nil { return nil, err } if responseJson.Status == "succeeded" { // Download the image and turn it into a base64 b64Image, err := imageUrlToBase64(responseJson.ImageUrl) if err != nil { return nil, err } newResp := &rtypes.ProdiaRetrieveResponseX{ Images: []string{b64Image}, Info: "Generated with Prodia", } return newResp, nil } else if responseJson.Status == "error" { return nil, errors.New("API Error") } time.Sleep(time.Second) } return nil, errors.New("timed out") } func imageUrlToBase64(imageUrl string) (string, error) { r, err := http.NewRequest("GET", imageUrl, nil) if err != nil { return "", err } var prodiaKey string = getProdiaKey() if prodiaKey == "" { panic("Missing Prodia key, please set the PRODIA_KEY variable in the env") } r.Header.Add("Accept", "application/json") r.Header.Add("Content-Type", "application/json") r.Header.Add("X-Prodia-Key", prodiaKey) client := &http.Client{} res, err := client.Do(r) if err != nil { return "", err } imageData, err := ioutil.ReadAll(res.Body) if err != nil { return "", err } var base64Encoding string = base64.StdEncoding.EncodeToString(imageData) return base64Encoding, nil }