Prodia
This commit is contained in:
parent
ebf0812920
commit
c538b35342
|
@ -0,0 +1,161 @@
|
||||||
|
package prodia
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"danlegt.com/stablediffusion-friends/lib/rtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ProdiaURL = "https://api.prodia.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getProdiaKey() string {
|
||||||
|
return os.Getenv("PRODIA_KEY")
|
||||||
|
}
|
||||||
|
|
||||||
|
func RequestGeneration() (*rtypes.ProdiaGenerateResponse, error) {
|
||||||
|
req := rtypes.ProdiaGenerateRequest{
|
||||||
|
Model: rtypes.Deliberate_v2,
|
||||||
|
Prompt: "beautiful woman wearing military clothing, woman, girl, beautiful, masterpiece, pretty, blue eyes, military, uniform, army",
|
||||||
|
NegativePrompt: "man, ugly, destroyed, nsfw, nudity",
|
||||||
|
Steps: 25,
|
||||||
|
CFGScale: 6.5,
|
||||||
|
Seed: -1,
|
||||||
|
Upscale: true,
|
||||||
|
Sampler: rtypes.DPMpp2MK,
|
||||||
|
AspectRatio: rtypes.Square,
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := http.NewRequest("POST", fmt.Sprintf("%s/v1/job", ProdiaURL), bytes.NewBuffer(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var prodiaKey string = getProdiaKey()
|
||||||
|
if prodiaKey == "" {
|
||||||
|
panic("Missing Prodia key, please set the PRODIA_KEY variable in the env")
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Header.Add("Accept", "application/json")
|
||||||
|
r.Header.Add("Content-Type", "application/json")
|
||||||
|
r.Header.Add("X-Prodia-Key", prodiaKey)
|
||||||
|
client := &http.Client{}
|
||||||
|
res, err := client.Do(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
responseJson := &rtypes.ProdiaGenerateResponse{}
|
||||||
|
err = json.NewDecoder(res.Body).Decode(responseJson)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return responseJson, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func FetchGeneration(jobId string) (*rtypes.ProdiaRetrieveResponseX, error) {
|
||||||
|
|
||||||
|
// We need to wait for this, so we will go ahead and do a while loop
|
||||||
|
poolingLimit := 30
|
||||||
|
|
||||||
|
r, err := http.NewRequest("GET", fmt.Sprintf("%s/v1/job/%s", ProdiaURL, jobId), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var prodiaKey string = getProdiaKey()
|
||||||
|
if prodiaKey == "" {
|
||||||
|
panic("Missing Prodia key, please set the PRODIA_KEY variable in the env")
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Header.Add("Accept", "application/json")
|
||||||
|
r.Header.Add("Content-Type", "application/json")
|
||||||
|
r.Header.Add("X-Prodia-Key", prodiaKey)
|
||||||
|
client := &http.Client{}
|
||||||
|
|
||||||
|
responseJson := &rtypes.ProdiaRetrieveResponse{}
|
||||||
|
|
||||||
|
for poolingLimit > 0 {
|
||||||
|
res, err := client.Do(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer res.Body.Close()
|
||||||
|
err = json.NewDecoder(res.Body).Decode(responseJson)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if responseJson.Status == "succeeded" {
|
||||||
|
// Download the image and turn it into a base64
|
||||||
|
b64Image, err := imageUrlToBase64(responseJson.ImageUrl)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
newResp := &rtypes.ProdiaRetrieveResponseX{
|
||||||
|
Images: []string{b64Image},
|
||||||
|
Info: "Generated with Prodia",
|
||||||
|
}
|
||||||
|
|
||||||
|
return newResp, nil
|
||||||
|
} else if responseJson.Status == "error" {
|
||||||
|
return nil, errors.New("API Error")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Not done yet, waiting... ::" + responseJson.Status)
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.New("timed out")
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func imageUrlToBase64(imageUrl string) (string, error) {
|
||||||
|
|
||||||
|
r, err := http.NewRequest("GET", imageUrl, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
var prodiaKey string = getProdiaKey()
|
||||||
|
if prodiaKey == "" {
|
||||||
|
panic("Missing Prodia key, please set the PRODIA_KEY variable in the env")
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Header.Add("Accept", "application/json")
|
||||||
|
r.Header.Add("Content-Type", "application/json")
|
||||||
|
r.Header.Add("X-Prodia-Key", prodiaKey)
|
||||||
|
client := &http.Client{}
|
||||||
|
res, err := client.Do(r)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
imageData, err := ioutil.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
var base64Encoding string = base64.StdEncoding.EncodeToString(imageData)
|
||||||
|
|
||||||
|
return base64Encoding, nil
|
||||||
|
}
|
|
@ -0,0 +1,92 @@
|
||||||
|
package rtypes
|
||||||
|
|
||||||
|
type ProdiaAspectRatio string
|
||||||
|
|
||||||
|
const (
|
||||||
|
Square = "square"
|
||||||
|
Portrait = "portrait"
|
||||||
|
Landscape = "landscape"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ProdiaModel string
|
||||||
|
|
||||||
|
const (
|
||||||
|
SDV1_4 = "sdv1_4.ckpt [7460a6fa]"
|
||||||
|
Pruned15 = "v1-5-pruned-emaonly.ckpt [81761151]"
|
||||||
|
Anythingv3_0 = "anythingv3_0-pruned.ckpt [2700c435]"
|
||||||
|
Anything = "anything-v4.5-pruned.ckpt [65745d25]"
|
||||||
|
Analog = "analog-diffusion-1.0.ckpt [9ca13f02]"
|
||||||
|
Theallys = "theallys-mix-ii-churned.safetensors [5d9225a4]"
|
||||||
|
Elldreths = "elldreths-vivid-mix.safetensors [342d9d26]"
|
||||||
|
Deliberate_v2 = "deliberate_v2.safetensors [10ec4b29]"
|
||||||
|
Openjourney_V4 = "openjourney_V4.ckpt [ca2f377f]"
|
||||||
|
Dreamlike1 = "dreamlike-diffusion-1.0.safetensors [5c9fd6e0]"
|
||||||
|
Dreamlike2 = "dreamlike-diffusion-2.0.safetensors [fdcf65e7]"
|
||||||
|
Portrait1 = "portrait+1.0.safetensors [1400e684]"
|
||||||
|
Riffusion = "riffusion-model-v1.ckpt [3aafa6fe]"
|
||||||
|
Timeless = "timeless-1.0.ckpt [7c4971d4]"
|
||||||
|
Dreamshaper_5BakedVae = "dreamshaper_5BakedVae.safetensors [a3fbf318]"
|
||||||
|
RevAnimated_v122 = "revAnimated_v122.safetensors [3f4fefd9]"
|
||||||
|
Meinamix_meinaV9 = "meinamix_meinaV9.safetensors [2ec66ab0]"
|
||||||
|
Lyriel_v15 = "lyriel_v15.safetensors [65d547c5]"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ProdiaSampler string
|
||||||
|
|
||||||
|
const (
|
||||||
|
Euler = "Euler"
|
||||||
|
EulerA = "Euler a"
|
||||||
|
Heun = "Heun"
|
||||||
|
DPMpp2MK = "DPM++ 2M Karras"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ProdiaGenerateRequest struct {
|
||||||
|
Model ProdiaModel `json:"model"`
|
||||||
|
Prompt ProdiaSampler `json:"prompt"`
|
||||||
|
NegativePrompt string `json:"negative_prompt"`
|
||||||
|
Steps int `json:"steps"`
|
||||||
|
CFGScale float32 `json:"cfg_scale"`
|
||||||
|
Seed int64 `json:"seed"`
|
||||||
|
Upscale bool `json:"upscale"`
|
||||||
|
Sampler string `json:"sampler"`
|
||||||
|
AspectRatio ProdiaAspectRatio `json:"aspect_ratio"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProdiaGenerateResponseParamOptions struct {
|
||||||
|
SDModelCheckpoint string `json:"sd_model_checkpoint"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProdiaGenerateResponseParams struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
NegativePrompt string `json:"negative_prompt"`
|
||||||
|
Steps int `json:"steps"`
|
||||||
|
CFGScale float32 `json:"cfg_scale"`
|
||||||
|
Seed int64 `json:"seed"`
|
||||||
|
Upscale bool `json:"upscale"`
|
||||||
|
Sampler string `json:"sampler"`
|
||||||
|
AspectRatio string `json:"aspect_ratio"`
|
||||||
|
Options ProdiaGenerateResponseParamOptions `json:"options"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProdiaGenerateResponse struct {
|
||||||
|
Job string `json:"job"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Params ProdiaGenerateResponseParams `json:"params"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProdiaRetrieveRequest struct {
|
||||||
|
JobID string `json:"jobid"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProdiaRetrieveResponse struct {
|
||||||
|
Job string `json:"job"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
ImageUrl string `json:"imageUrl"`
|
||||||
|
Params ProdiaGenerateResponseParams `json:"params"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProdiaRetrieveResponseX struct {
|
||||||
|
Images []string `json:"images"`
|
||||||
|
Info string `json:"info"`
|
||||||
|
}
|
134
main.go
134
main.go
|
@ -6,6 +6,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"danlegt.com/stablediffusion-friends/lib/prodia"
|
||||||
"danlegt.com/stablediffusion-friends/lib/rtypes"
|
"danlegt.com/stablediffusion-friends/lib/rtypes"
|
||||||
goaway "github.com/TwiN/go-away"
|
goaway "github.com/TwiN/go-away"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
@ -22,6 +23,7 @@ var nodes [1]rtypes.SDNode = [...]rtypes.SDNode{
|
||||||
var imagesToDisplay []string
|
var imagesToDisplay []string
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|
||||||
r := gin.Default()
|
r := gin.Default()
|
||||||
runLoaders(r)
|
runLoaders(r)
|
||||||
|
|
||||||
|
@ -29,62 +31,7 @@ func main() {
|
||||||
c.HTML(http.StatusOK, "index.html", nil)
|
c.HTML(http.StatusOK, "index.html", nil)
|
||||||
})
|
})
|
||||||
|
|
||||||
r.POST("/api/image/generate", func(c *gin.Context) {
|
r.POST("/api/image/generate", generateProdiaImage)
|
||||||
|
|
||||||
imageDescriptior := c.PostForm("image_description")
|
|
||||||
|
|
||||||
if goaway.IsProfane(imageDescriptior) {
|
|
||||||
imageDescriptior = goaway.Censor(imageDescriptior)
|
|
||||||
fmt.Printf("Found profanity, censoring to: %s\n", imageDescriptior)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build the request
|
|
||||||
request := rtypes.SDTextToImageRequest{
|
|
||||||
Prompt: imageDescriptior,
|
|
||||||
SamplerName: "DPM2",
|
|
||||||
BatchSize: 1,
|
|
||||||
Steps: 25,
|
|
||||||
CfgScale: 8,
|
|
||||||
Width: 512,
|
|
||||||
Height: 512,
|
|
||||||
NegativePrompt: "nsfw, porn, naked, nude, nipple, penis, dick, vagina, asshole, visible nipple, nsfl, not safe for work, nudity, artifact, deformed, multiple limbs, ugly, gore, blood, sex, pornography, penis, dick, genitalia, male genitalia, anus, penetration, double penetration, cock",
|
|
||||||
SamplerIndex: "Euler",
|
|
||||||
SendImages: true,
|
|
||||||
SaveImages: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := json.Marshal(request)
|
|
||||||
if err != nil {
|
|
||||||
c.AbortWithError(500, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r, err := http.NewRequest("POST", nodes[0].GetTextToImageLink(), bytes.NewBuffer(body))
|
|
||||||
if err != nil {
|
|
||||||
c.AbortWithError(500, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Header.Add("Content-Type", "application/json")
|
|
||||||
client := &http.Client{}
|
|
||||||
res, err := client.Do(r)
|
|
||||||
if err != nil {
|
|
||||||
c.AbortWithError(500, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
defer res.Body.Close()
|
|
||||||
|
|
||||||
responseJson := &rtypes.SDTextToImageResponse{}
|
|
||||||
err = json.NewDecoder(res.Body).Decode(responseJson)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("Failed to decode the JSON")
|
|
||||||
c.AbortWithError(500, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, responseJson)
|
|
||||||
})
|
|
||||||
|
|
||||||
r.POST("/api/image/submit", func(c *gin.Context) {
|
r.POST("/api/image/submit", func(c *gin.Context) {
|
||||||
|
|
||||||
|
@ -127,6 +74,81 @@ func main() {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateProdiaImage(c *gin.Context) {
|
||||||
|
resp, err := prodia.RequestGeneration()
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithError(500, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(resp.Job)
|
||||||
|
|
||||||
|
imgResult, err := prodia.FetchGeneration(resp.Job)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithError(500, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, imgResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateSDImage(c *gin.Context) {
|
||||||
|
|
||||||
|
imageDescriptior := c.PostForm("image_description")
|
||||||
|
|
||||||
|
if goaway.IsProfane(imageDescriptior) {
|
||||||
|
imageDescriptior = goaway.Censor(imageDescriptior)
|
||||||
|
fmt.Printf("Found profanity, censoring to: %s\n", imageDescriptior)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the request
|
||||||
|
request := rtypes.SDTextToImageRequest{
|
||||||
|
Prompt: imageDescriptior,
|
||||||
|
SamplerName: "DPM2",
|
||||||
|
BatchSize: 1,
|
||||||
|
Steps: 25,
|
||||||
|
CfgScale: 8,
|
||||||
|
Width: 512,
|
||||||
|
Height: 512,
|
||||||
|
NegativePrompt: "(nsfw:1.1), (porn:1.1), (naked:1.1), (nude:1.1), (nipple:1.1), (penis:1.1), (dick:1.1), (vagina:1.1), (asshole:1.1), visible nipple, nsfl, not safe for work, nudity, artifact, deformed, multiple limbs, ugly, gore, blood, sex, pornography, penis, dick, genitalia, male genitalia, anus, penetration, double penetration, cock",
|
||||||
|
SamplerIndex: "Euler",
|
||||||
|
SendImages: true,
|
||||||
|
SaveImages: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithError(500, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := http.NewRequest("POST", nodes[0].GetTextToImageLink(), bytes.NewBuffer(body))
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithError(500, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Header.Add("Content-Type", "application/json")
|
||||||
|
client := &http.Client{}
|
||||||
|
res, err := client.Do(r)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithError(500, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
responseJson := &rtypes.SDTextToImageResponse{}
|
||||||
|
err = json.NewDecoder(res.Body).Decode(responseJson)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Failed to decode the JSON")
|
||||||
|
c.AbortWithError(500, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, responseJson)
|
||||||
|
}
|
||||||
|
|
||||||
func runLoaders(r *gin.Engine) {
|
func runLoaders(r *gin.Engine) {
|
||||||
r.Static("/public", "./public")
|
r.Static("/public", "./public")
|
||||||
r.LoadHTMLGlob("templates/**/*.html")
|
r.LoadHTMLGlob("templates/**/*.html")
|
||||||
|
|
Reference in New Issue