feat: doubao-seedream support image edit

This commit is contained in:
feitianbubu 2025-10-23 21:18:11 +08:00
parent a4c46e999e
commit fe9b092b0b
3 changed files with 170 additions and 106 deletions

View File

@ -2,9 +2,11 @@ package common
import ( import (
"bytes" "bytes"
"encoding/json"
"io" "io"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"net/url"
"strings" "strings"
"time" "time"
@ -40,6 +42,10 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
contentType := c.Request.Header.Get("Content-Type") contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") { if strings.HasPrefix(contentType, "application/json") {
err = Unmarshal(requestBody, &v) err = Unmarshal(requestBody, &v)
} else if strings.Contains(contentType, gin.MIMEPOSTForm) {
err = parseFormData(requestBody, &v)
} else if strings.Contains(contentType, gin.MIMEMultipartPOSTForm) {
err = parseMultipartFormData(c, requestBody, &v)
} else { } else {
// skip for now // skip for now
// TODO: someday non json request have variant model, we will need to implementation this // TODO: someday non json request have variant model, we will need to implementation this
@ -138,3 +144,57 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return form, nil return form, nil
} }
func parseFormData(data []byte, v any) error {
values, err := url.ParseQuery(string(data))
if err != nil {
return err
}
formMap := make(map[string]any)
for key, vals := range values {
if len(vals) == 1 {
formMap[key] = vals[0]
} else {
formMap[key] = vals
}
}
jsonData, err := json.Marshal(formMap)
if err != nil {
return err
}
return json.Unmarshal(jsonData, v)
}
func parseMultipartFormData(c *gin.Context, data []byte, v any) error {
contentType := c.Request.Header.Get("Content-Type")
boundary := ""
if idx := strings.Index(contentType, "boundary="); idx != -1 {
boundary = contentType[idx+9:]
}
if boundary == "" {
return json.Unmarshal(data, v) // Fallback to JSON
}
reader := multipart.NewReader(bytes.NewReader(data), boundary)
form, err := reader.ReadForm(32 << 20) // 32 MB max memory
if err != nil {
return err
}
defer form.RemoveAll()
formMap := make(map[string]any)
for key, vals := range form.Value {
if len(vals) == 1 {
formMap[key] = vals[0]
} else {
formMap[key] = vals
}
}
jsonData, err := json.Marshal(formMap)
if err != nil {
return err
}
return json.Unmarshal(jsonData, v)
}

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"slices"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -245,7 +246,8 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e") modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") { } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
//modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1") //modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1")
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") { contentType := c.Request.Header.Get("Content-Type")
if slices.Contains([]string{gin.MIMEPOSTForm, gin.MIMEMultipartPOSTForm}, contentType) {
modelRequest.Model = c.PostForm("model") modelRequest.Model = c.PostForm("model")
} }
} }

View File

@ -6,9 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"mime/multipart"
"net/http" "net/http"
"net/textproto"
"path/filepath" "path/filepath"
"strings" "strings"
@ -104,106 +102,107 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
switch info.RelayMode { switch info.RelayMode {
case constant.RelayModeImagesGenerations: case constant.RelayModeImagesGenerations:
return request, nil return request, nil
case constant.RelayModeImagesEdits: // 根据官方文档,并没有发现豆包生图支持表单请求:https://www.volcengine.com/docs/82379/1824121
//case constant.RelayModeImagesEdits:
var requestBody bytes.Buffer //
writer := multipart.NewWriter(&requestBody) // var requestBody bytes.Buffer
// writer := multipart.NewWriter(&requestBody)
writer.WriteField("model", request.Model) //
// writer.WriteField("model", request.Model)
formData := c.Request.PostForm //
for key, values := range formData { // formData := c.Request.PostForm
if key == "model" { // for key, values := range formData {
continue // if key == "model" {
} // continue
for _, value := range values { // }
writer.WriteField(key, value) // for _, value := range values {
} // writer.WriteField(key, value)
} // }
// }
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { //
return nil, errors.New("failed to parse multipart form") // if err := c.Request.ParseMultipartForm(32 << 20); err != nil {
} // return nil, errors.New("failed to parse multipart form")
// }
if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil { //
var imageFiles []*multipart.FileHeader // if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
var exists bool // var imageFiles []*multipart.FileHeader
// var exists bool
if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 { //
if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 { // if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
foundArrayImages := false // if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
for fieldName, files := range c.Request.MultipartForm.File { // foundArrayImages := false
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 { // for fieldName, files := range c.Request.MultipartForm.File {
foundArrayImages = true // if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
for _, file := range files { // foundArrayImages = true
imageFiles = append(imageFiles, file) // for _, file := range files {
} // imageFiles = append(imageFiles, file)
} // }
} // }
// }
if !foundArrayImages && (len(imageFiles) == 0) { //
return nil, errors.New("image is required") // if !foundArrayImages && (len(imageFiles) == 0) {
} // return nil, errors.New("image is required")
} // }
} // }
// }
for i, fileHeader := range imageFiles { //
file, err := fileHeader.Open() // for i, fileHeader := range imageFiles {
if err != nil { // file, err := fileHeader.Open()
return nil, fmt.Errorf("failed to open image file %d: %w", i, err) // if err != nil {
} // return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
defer file.Close() // }
// defer file.Close()
fieldName := "image" //
if len(imageFiles) > 1 { // fieldName := "image"
fieldName = "image[]" // if len(imageFiles) > 1 {
} // fieldName = "image[]"
// }
mimeType := detectImageMimeType(fileHeader.Filename) //
// mimeType := detectImageMimeType(fileHeader.Filename)
h := make(textproto.MIMEHeader) //
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename)) // h := make(textproto.MIMEHeader)
h.Set("Content-Type", mimeType) // h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
// h.Set("Content-Type", mimeType)
part, err := writer.CreatePart(h) //
if err != nil { // part, err := writer.CreatePart(h)
return nil, fmt.Errorf("create form part failed for image %d: %w", i, err) // if err != nil {
} // return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
// }
if _, err := io.Copy(part, file); err != nil { //
return nil, fmt.Errorf("copy file failed for image %d: %w", i, err) // if _, err := io.Copy(part, file); err != nil {
} // return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
} // }
// }
if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 { //
maskFile, err := maskFiles[0].Open() // if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
if err != nil { // maskFile, err := maskFiles[0].Open()
return nil, errors.New("failed to open mask file") // if err != nil {
} // return nil, errors.New("failed to open mask file")
defer maskFile.Close() // }
// defer maskFile.Close()
mimeType := detectImageMimeType(maskFiles[0].Filename) //
// mimeType := detectImageMimeType(maskFiles[0].Filename)
h := make(textproto.MIMEHeader) //
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename)) // h := make(textproto.MIMEHeader)
h.Set("Content-Type", mimeType) // h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
// h.Set("Content-Type", mimeType)
maskPart, err := writer.CreatePart(h) //
if err != nil { // maskPart, err := writer.CreatePart(h)
return nil, errors.New("create form file failed for mask") // if err != nil {
} // return nil, errors.New("create form file failed for mask")
// }
if _, err := io.Copy(maskPart, maskFile); err != nil { //
return nil, errors.New("copy mask file failed") // if _, err := io.Copy(maskPart, maskFile); err != nil {
} // return nil, errors.New("copy mask file failed")
} // }
} else { // }
return nil, errors.New("no multipart form data found") // } else {
} // return nil, errors.New("no multipart form data found")
// }
writer.Close() //
c.Request.Header.Set("Content-Type", writer.FormDataContentType()) // writer.Close()
return bytes.NewReader(requestBody.Bytes()), nil // c.Request.Header.Set("Content-Type", writer.FormDataContentType())
// return bytes.NewReader(requestBody.Bytes()), nil
default: default:
return request, nil return request, nil
@ -251,10 +250,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil
case constant.RelayModeEmbeddings: case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/api/v3/embeddings", baseUrl), nil return fmt.Sprintf("%s/api/v3/embeddings", baseUrl), nil
case constant.RelayModeImagesGenerations: //豆包的图生图也走generations接口: https://www.volcengine.com/docs/82379/1824121
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
return fmt.Sprintf("%s/api/v3/images/generations", baseUrl), nil return fmt.Sprintf("%s/api/v3/images/generations", baseUrl), nil
case constant.RelayModeImagesEdits: //case constant.RelayModeImagesEdits:
return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil // return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
case constant.RelayModeRerank: case constant.RelayModeRerank:
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
case constant.RelayModeAudioSpeech: case constant.RelayModeAudioSpeech:
@ -278,6 +278,8 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
} }
req.Set("Content-Type", "application/json") req.Set("Content-Type", "application/json")
return nil return nil
} else if info.RelayMode == constant.RelayModeImagesEdits {
req.Set("Content-Type", gin.MIMEJSON)
} }
req.Set("Authorization", "Bearer "+info.ApiKey) req.Set("Authorization", "Bearer "+info.ApiKey)