feat: doubao-seedream support image edit
This commit is contained in:
parent
a4c46e999e
commit
fe9b092b0b
@ -2,9 +2,11 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -40,6 +42,10 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
|||||||
contentType := c.Request.Header.Get("Content-Type")
|
contentType := c.Request.Header.Get("Content-Type")
|
||||||
if strings.HasPrefix(contentType, "application/json") {
|
if strings.HasPrefix(contentType, "application/json") {
|
||||||
err = Unmarshal(requestBody, &v)
|
err = Unmarshal(requestBody, &v)
|
||||||
|
} else if strings.Contains(contentType, gin.MIMEPOSTForm) {
|
||||||
|
err = parseFormData(requestBody, &v)
|
||||||
|
} else if strings.Contains(contentType, gin.MIMEMultipartPOSTForm) {
|
||||||
|
err = parseMultipartFormData(c, requestBody, &v)
|
||||||
} else {
|
} else {
|
||||||
// skip for now
|
// skip for now
|
||||||
// TODO: someday non json request have variant model, we will need to implementation this
|
// TODO: someday non json request have variant model, we will need to implementation this
|
||||||
@ -138,3 +144,57 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
|
|||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
return form, nil
|
return form, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseFormData(data []byte, v any) error {
|
||||||
|
values, err := url.ParseQuery(string(data))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
formMap := make(map[string]any)
|
||||||
|
for key, vals := range values {
|
||||||
|
if len(vals) == 1 {
|
||||||
|
formMap[key] = vals[0]
|
||||||
|
} else {
|
||||||
|
formMap[key] = vals
|
||||||
|
}
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(formMap)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Unmarshal(jsonData, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMultipartFormData(c *gin.Context, data []byte, v any) error {
|
||||||
|
contentType := c.Request.Header.Get("Content-Type")
|
||||||
|
boundary := ""
|
||||||
|
if idx := strings.Index(contentType, "boundary="); idx != -1 {
|
||||||
|
boundary = contentType[idx+9:]
|
||||||
|
}
|
||||||
|
|
||||||
|
if boundary == "" {
|
||||||
|
return json.Unmarshal(data, v) // Fallback to JSON
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := multipart.NewReader(bytes.NewReader(data), boundary)
|
||||||
|
form, err := reader.ReadForm(32 << 20) // 32 MB max memory
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer form.RemoveAll()
|
||||||
|
formMap := make(map[string]any)
|
||||||
|
for key, vals := range form.Value {
|
||||||
|
if len(vals) == 1 {
|
||||||
|
formMap[key] = vals[0]
|
||||||
|
} else {
|
||||||
|
formMap[key] = vals
|
||||||
|
}
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(formMap)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Unmarshal(jsonData, v)
|
||||||
|
}
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -245,7 +246,8 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
|
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
|
||||||
//modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1")
|
//modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1")
|
||||||
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
contentType := c.Request.Header.Get("Content-Type")
|
||||||
|
if slices.Contains([]string{gin.MIMEPOSTForm, gin.MIMEMultipartPOSTForm}, contentType) {
|
||||||
modelRequest.Model = c.PostForm("model")
|
modelRequest.Model = c.PostForm("model")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,9 +6,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/textproto"
|
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -104,106 +102,107 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeImagesGenerations:
|
case constant.RelayModeImagesGenerations:
|
||||||
return request, nil
|
return request, nil
|
||||||
case constant.RelayModeImagesEdits:
|
// 根据官方文档,并没有发现豆包生图支持表单请求:https://www.volcengine.com/docs/82379/1824121
|
||||||
|
//case constant.RelayModeImagesEdits:
|
||||||
var requestBody bytes.Buffer
|
//
|
||||||
writer := multipart.NewWriter(&requestBody)
|
// var requestBody bytes.Buffer
|
||||||
|
// writer := multipart.NewWriter(&requestBody)
|
||||||
writer.WriteField("model", request.Model)
|
//
|
||||||
|
// writer.WriteField("model", request.Model)
|
||||||
formData := c.Request.PostForm
|
//
|
||||||
for key, values := range formData {
|
// formData := c.Request.PostForm
|
||||||
if key == "model" {
|
// for key, values := range formData {
|
||||||
continue
|
// if key == "model" {
|
||||||
}
|
// continue
|
||||||
for _, value := range values {
|
// }
|
||||||
writer.WriteField(key, value)
|
// for _, value := range values {
|
||||||
}
|
// writer.WriteField(key, value)
|
||||||
}
|
// }
|
||||||
|
// }
|
||||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil {
|
//
|
||||||
return nil, errors.New("failed to parse multipart form")
|
// if err := c.Request.ParseMultipartForm(32 << 20); err != nil {
|
||||||
}
|
// return nil, errors.New("failed to parse multipart form")
|
||||||
|
// }
|
||||||
if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
|
//
|
||||||
var imageFiles []*multipart.FileHeader
|
// if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
|
||||||
var exists bool
|
// var imageFiles []*multipart.FileHeader
|
||||||
|
// var exists bool
|
||||||
if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
|
//
|
||||||
if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
|
// if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
|
||||||
foundArrayImages := false
|
// if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
|
||||||
for fieldName, files := range c.Request.MultipartForm.File {
|
// foundArrayImages := false
|
||||||
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
|
// for fieldName, files := range c.Request.MultipartForm.File {
|
||||||
foundArrayImages = true
|
// if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
|
||||||
for _, file := range files {
|
// foundArrayImages = true
|
||||||
imageFiles = append(imageFiles, file)
|
// for _, file := range files {
|
||||||
}
|
// imageFiles = append(imageFiles, file)
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
// }
|
||||||
if !foundArrayImages && (len(imageFiles) == 0) {
|
//
|
||||||
return nil, errors.New("image is required")
|
// if !foundArrayImages && (len(imageFiles) == 0) {
|
||||||
}
|
// return nil, errors.New("image is required")
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
// }
|
||||||
for i, fileHeader := range imageFiles {
|
//
|
||||||
file, err := fileHeader.Open()
|
// for i, fileHeader := range imageFiles {
|
||||||
if err != nil {
|
// file, err := fileHeader.Open()
|
||||||
return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
|
// if err != nil {
|
||||||
}
|
// return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
|
||||||
defer file.Close()
|
// }
|
||||||
|
// defer file.Close()
|
||||||
fieldName := "image"
|
//
|
||||||
if len(imageFiles) > 1 {
|
// fieldName := "image"
|
||||||
fieldName = "image[]"
|
// if len(imageFiles) > 1 {
|
||||||
}
|
// fieldName = "image[]"
|
||||||
|
// }
|
||||||
mimeType := detectImageMimeType(fileHeader.Filename)
|
//
|
||||||
|
// mimeType := detectImageMimeType(fileHeader.Filename)
|
||||||
h := make(textproto.MIMEHeader)
|
//
|
||||||
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
|
// h := make(textproto.MIMEHeader)
|
||||||
h.Set("Content-Type", mimeType)
|
// h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
|
||||||
|
// h.Set("Content-Type", mimeType)
|
||||||
part, err := writer.CreatePart(h)
|
//
|
||||||
if err != nil {
|
// part, err := writer.CreatePart(h)
|
||||||
return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
|
// if err != nil {
|
||||||
}
|
// return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
|
||||||
|
// }
|
||||||
if _, err := io.Copy(part, file); err != nil {
|
//
|
||||||
return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
|
// if _, err := io.Copy(part, file); err != nil {
|
||||||
}
|
// return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
|
||||||
}
|
// }
|
||||||
|
// }
|
||||||
if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
|
//
|
||||||
maskFile, err := maskFiles[0].Open()
|
// if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
|
||||||
if err != nil {
|
// maskFile, err := maskFiles[0].Open()
|
||||||
return nil, errors.New("failed to open mask file")
|
// if err != nil {
|
||||||
}
|
// return nil, errors.New("failed to open mask file")
|
||||||
defer maskFile.Close()
|
// }
|
||||||
|
// defer maskFile.Close()
|
||||||
mimeType := detectImageMimeType(maskFiles[0].Filename)
|
//
|
||||||
|
// mimeType := detectImageMimeType(maskFiles[0].Filename)
|
||||||
h := make(textproto.MIMEHeader)
|
//
|
||||||
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
|
// h := make(textproto.MIMEHeader)
|
||||||
h.Set("Content-Type", mimeType)
|
// h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
|
||||||
|
// h.Set("Content-Type", mimeType)
|
||||||
maskPart, err := writer.CreatePart(h)
|
//
|
||||||
if err != nil {
|
// maskPart, err := writer.CreatePart(h)
|
||||||
return nil, errors.New("create form file failed for mask")
|
// if err != nil {
|
||||||
}
|
// return nil, errors.New("create form file failed for mask")
|
||||||
|
// }
|
||||||
if _, err := io.Copy(maskPart, maskFile); err != nil {
|
//
|
||||||
return nil, errors.New("copy mask file failed")
|
// if _, err := io.Copy(maskPart, maskFile); err != nil {
|
||||||
}
|
// return nil, errors.New("copy mask file failed")
|
||||||
}
|
// }
|
||||||
} else {
|
// }
|
||||||
return nil, errors.New("no multipart form data found")
|
// } else {
|
||||||
}
|
// return nil, errors.New("no multipart form data found")
|
||||||
|
// }
|
||||||
writer.Close()
|
//
|
||||||
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
// writer.Close()
|
||||||
return bytes.NewReader(requestBody.Bytes()), nil
|
// c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
// return bytes.NewReader(requestBody.Bytes()), nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return request, nil
|
return request, nil
|
||||||
@ -251,10 +250,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil
|
return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil
|
||||||
case constant.RelayModeEmbeddings:
|
case constant.RelayModeEmbeddings:
|
||||||
return fmt.Sprintf("%s/api/v3/embeddings", baseUrl), nil
|
return fmt.Sprintf("%s/api/v3/embeddings", baseUrl), nil
|
||||||
case constant.RelayModeImagesGenerations:
|
//豆包的图生图也走generations接口: https://www.volcengine.com/docs/82379/1824121
|
||||||
|
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
|
||||||
return fmt.Sprintf("%s/api/v3/images/generations", baseUrl), nil
|
return fmt.Sprintf("%s/api/v3/images/generations", baseUrl), nil
|
||||||
case constant.RelayModeImagesEdits:
|
//case constant.RelayModeImagesEdits:
|
||||||
return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
|
// return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
|
||||||
case constant.RelayModeRerank:
|
case constant.RelayModeRerank:
|
||||||
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
|
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
|
||||||
case constant.RelayModeAudioSpeech:
|
case constant.RelayModeAudioSpeech:
|
||||||
@ -278,6 +278,8 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
}
|
}
|
||||||
req.Set("Content-Type", "application/json")
|
req.Set("Content-Type", "application/json")
|
||||||
return nil
|
return nil
|
||||||
|
} else if info.RelayMode == constant.RelayModeImagesEdits {
|
||||||
|
req.Set("Content-Type", gin.MIMEJSON)
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Set("Authorization", "Bearer "+info.ApiKey)
|
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user