package main import ( "bytes" "encoding/json" "fmt" "net/http" "danlegt.com/stablediffusion-friends/lib/prodia" "danlegt.com/stablediffusion-friends/lib/rtypes" goaway "github.com/TwiN/go-away" "github.com/gin-gonic/gin" ) const ( NEGATIVE_PROMPT = "(nsfw:1.1), (porn:1.1), (naked:1.1), (nude:1.1), (nipple:1.1), (penis:1.1), (dick:1.1), (vagina:1.1), (asshole:1.1), visible nipple, nsfl, not safe for work, nudity, artifact, deformed, multiple limbs, ugly, gore, blood, sex, pornography, penis, dick, genitalia, male genitalia, anus, penetration, double penetration, cock" ) var nodes [1]rtypes.SDNode = [...]rtypes.SDNode{ rtypes.SDNode{ URL: "127.0.0.1", Port: 7860, Status: false, }, } var imagesToDisplay []string func main() { r := gin.Default() runLoaders(r) r.GET("/", func(c *gin.Context) { c.HTML(http.StatusOK, "index.html", nil) }) r.POST("/api/image/generate", generateProdiaImage) r.POST("/api/image/submit", func(c *gin.Context) { imageData := c.PostForm("image_data") imagesToDisplay = append(imagesToDisplay[:], imageData) fmt.Printf("Got a new image to display, currently %v in queue\n", len(imagesToDisplay)) c.JSON(200, gin.H{ "queue_position": len(imagesToDisplay), }) }) r.GET("/api/image", func(c *gin.Context) { if len(imagesToDisplay) <= 0 { c.JSON(200, [0]string{}) return } newImage, newImagesToDisplay := imagesToDisplay[0], imagesToDisplay[1:] // Actually pop it if 3 > 2 { imagesToDisplay = newImagesToDisplay } c.JSON(200, gin.H{ "image": newImage, "queue_left": len(imagesToDisplay), }) }) r.GET("/api/image/queue", func(c *gin.Context) { c.JSON(200, imagesToDisplay) }) r.Run() } func generateProdiaImage(c *gin.Context) { imageDescriptior := c.PostForm("image_description") if goaway.IsProfane(imageDescriptior) { imageDescriptior = goaway.Censor(imageDescriptior) fmt.Printf("Found profanity, censoring to: %s\n", imageDescriptior) } resp, err := prodia.RequestGeneration(imageDescriptior, NEGATIVE_PROMPT) if err != nil { c.AbortWithError(500, err) return } fmt.Println(resp.Job) imgResult, err := prodia.FetchGeneration(resp.Job) if err != nil { c.AbortWithError(500, err) return } c.JSON(http.StatusOK, imgResult) } func generateSDImage(c *gin.Context) { imageDescriptior := c.PostForm("image_description") if goaway.IsProfane(imageDescriptior) { imageDescriptior = goaway.Censor(imageDescriptior) fmt.Printf("Found profanity, censoring to: %s\n", imageDescriptior) } // Build the request request := rtypes.SDTextToImageRequest{ Prompt: imageDescriptior, SamplerName: "DPM2", BatchSize: 1, Steps: 25, CfgScale: 8, Width: 512, Height: 512, NegativePrompt: NEGATIVE_PROMPT, SamplerIndex: "Euler", SendImages: true, SaveImages: false, } body, err := json.Marshal(request) if err != nil { c.AbortWithError(500, err) return } r, err := http.NewRequest("POST", nodes[0].GetTextToImageLink(), bytes.NewBuffer(body)) if err != nil { c.AbortWithError(500, err) return } r.Header.Add("Content-Type", "application/json") client := &http.Client{} res, err := client.Do(r) if err != nil { c.AbortWithError(500, err) return } defer res.Body.Close() responseJson := &rtypes.SDTextToImageResponse{} err = json.NewDecoder(res.Body).Decode(responseJson) if err != nil { fmt.Println("Failed to decode the JSON") c.AbortWithError(500, err) return } c.JSON(http.StatusOK, responseJson) } func runLoaders(r *gin.Engine) { r.Static("/public", "./public") r.LoadHTMLGlob("templates/**/*.html") }