Merge pull request #2190 from Sh1n3zZ/support-replicate-channel
Some checks failed
Publish Docker image (Multi Registries, native amd64+arm64) / Build & push (amd64) [native] (push) Has been cancelled
Publish Docker image (Multi Registries, native amd64+arm64) / Build & push (arm64) [native] (push) Has been cancelled
Publish Docker image (Multi Registries, native amd64+arm64) / Create multi-arch manifests (Docker Hub) (push) Has been cancelled
Build Electron App / build (windows-latest) (push) Has been cancelled
Build Electron App / release (push) Has been cancelled
Release (Linux, macOS, Windows) / Linux Release (push) Has been cancelled
Release (Linux, macOS, Windows) / macOS Release (push) Has been cancelled
Release (Linux, macOS, Windows) / Windows Release (push) Has been cancelled
Some checks failed
Publish Docker image (Multi Registries, native amd64+arm64) / Build & push (amd64) [native] (push) Has been cancelled
Publish Docker image (Multi Registries, native amd64+arm64) / Build & push (arm64) [native] (push) Has been cancelled
Publish Docker image (Multi Registries, native amd64+arm64) / Create multi-arch manifests (Docker Hub) (push) Has been cancelled
Build Electron App / build (windows-latest) (push) Has been cancelled
Build Electron App / release (push) Has been cancelled
Release (Linux, macOS, Windows) / Linux Release (push) Has been cancelled
Release (Linux, macOS, Windows) / macOS Release (push) Has been cancelled
Release (Linux, macOS, Windows) / Windows Release (push) Has been cancelled
feat: replicate channel flux model
This commit is contained in:
commit
ae8b09d45f
@ -71,6 +71,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
|
|||||||
apiType = constant.APITypeSubmodel
|
apiType = constant.APITypeSubmodel
|
||||||
case constant.ChannelTypeMiniMax:
|
case constant.ChannelTypeMiniMax:
|
||||||
apiType = constant.APITypeMiniMax
|
apiType = constant.APITypeMiniMax
|
||||||
|
case constant.ChannelTypeReplicate:
|
||||||
|
apiType = constant.APITypeReplicate
|
||||||
}
|
}
|
||||||
if apiType == -1 {
|
if apiType == -1 {
|
||||||
return constant.APITypeOpenAI, false
|
return constant.APITypeOpenAI, false
|
||||||
|
|||||||
@ -34,5 +34,6 @@ const (
|
|||||||
APITypeMoonshot
|
APITypeMoonshot
|
||||||
APITypeSubmodel
|
APITypeSubmodel
|
||||||
APITypeMiniMax
|
APITypeMiniMax
|
||||||
|
APITypeReplicate
|
||||||
APITypeDummy // this one is only for count, do not add any channel after this
|
APITypeDummy // this one is only for count, do not add any channel after this
|
||||||
)
|
)
|
||||||
|
|||||||
@ -53,6 +53,7 @@ const (
|
|||||||
ChannelTypeSubmodel = 53
|
ChannelTypeSubmodel = 53
|
||||||
ChannelTypeDoubaoVideo = 54
|
ChannelTypeDoubaoVideo = 54
|
||||||
ChannelTypeSora = 55
|
ChannelTypeSora = 55
|
||||||
|
ChannelTypeReplicate = 56
|
||||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||||
|
|
||||||
)
|
)
|
||||||
@ -114,6 +115,7 @@ var ChannelBaseURLs = []string{
|
|||||||
"https://llm.submodel.ai", //53
|
"https://llm.submodel.ai", //53
|
||||||
"https://ark.cn-beijing.volces.com", //54
|
"https://ark.cn-beijing.volces.com", //54
|
||||||
"https://api.openai.com", //55
|
"https://api.openai.com", //55
|
||||||
|
"https://api.replicate.com", //56
|
||||||
}
|
}
|
||||||
|
|
||||||
var ChannelTypeNames = map[int]string{
|
var ChannelTypeNames = map[int]string{
|
||||||
@ -169,6 +171,7 @@ var ChannelTypeNames = map[int]string{
|
|||||||
ChannelTypeSubmodel: "Submodel",
|
ChannelTypeSubmodel: "Submodel",
|
||||||
ChannelTypeDoubaoVideo: "DoubaoVideo",
|
ChannelTypeDoubaoVideo: "DoubaoVideo",
|
||||||
ChannelTypeSora: "Sora",
|
ChannelTypeSora: "Sora",
|
||||||
|
ChannelTypeReplicate: "Replicate",
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetChannelTypeName(channelType int) string {
|
func GetChannelTypeName(channelType int) string {
|
||||||
|
|||||||
530
relay/channel/replicate/adaptor.go
Normal file
530
relay/channel/replicate/adaptor.go
Normal file
@ -0,0 +1,530 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/textproto"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/QuantumNous/new-api/common"
|
||||||
|
"github.com/QuantumNous/new-api/constant"
|
||||||
|
"github.com/QuantumNous/new-api/dto"
|
||||||
|
"github.com/QuantumNous/new-api/relay/channel"
|
||||||
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
|
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||||
|
"github.com/QuantumNous/new-api/service"
|
||||||
|
"github.com/QuantumNous/new-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
if info == nil {
|
||||||
|
return "", errors.New("replicate adaptor: relay info is nil")
|
||||||
|
}
|
||||||
|
if info.ChannelBaseUrl == "" {
|
||||||
|
info.ChannelBaseUrl = constant.ChannelBaseURLs[constant.ChannelTypeReplicate]
|
||||||
|
}
|
||||||
|
requestPath := info.RequestURLPath
|
||||||
|
if requestPath == "" {
|
||||||
|
return info.ChannelBaseUrl, nil
|
||||||
|
}
|
||||||
|
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestPath, info.ChannelType), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
|
if info == nil {
|
||||||
|
return errors.New("replicate adaptor: relay info is nil")
|
||||||
|
}
|
||||||
|
if info.ApiKey == "" {
|
||||||
|
return errors.New("replicate adaptor: api key is required")
|
||||||
|
}
|
||||||
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
|
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||||
|
req.Set("Prefer", "wait")
|
||||||
|
if req.Get("Content-Type") == "" {
|
||||||
|
req.Set("Content-Type", "application/json")
|
||||||
|
}
|
||||||
|
if req.Get("Accept") == "" {
|
||||||
|
req.Set("Accept", "application/json")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
|
if info == nil {
|
||||||
|
return nil, errors.New("replicate adaptor: relay info is nil")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(request.Prompt) == "" {
|
||||||
|
if v := c.PostForm("prompt"); strings.TrimSpace(v) != "" {
|
||||||
|
request.Prompt = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(request.Prompt) == "" {
|
||||||
|
return nil, errors.New("replicate adaptor: prompt is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
modelName := strings.TrimSpace(info.UpstreamModelName)
|
||||||
|
if modelName == "" {
|
||||||
|
modelName = strings.TrimSpace(request.Model)
|
||||||
|
}
|
||||||
|
if modelName == "" {
|
||||||
|
modelName = ModelFlux11Pro
|
||||||
|
}
|
||||||
|
info.UpstreamModelName = modelName
|
||||||
|
|
||||||
|
info.RequestURLPath = fmt.Sprintf("/v1/models/%s/predictions", modelName)
|
||||||
|
|
||||||
|
inputPayload := make(map[string]any)
|
||||||
|
inputPayload["prompt"] = request.Prompt
|
||||||
|
|
||||||
|
if size := strings.TrimSpace(request.Size); size != "" {
|
||||||
|
if aspect, width, height, ok := mapOpenAISizeToFlux(size); ok {
|
||||||
|
if aspect != "" {
|
||||||
|
if aspect == "custom" {
|
||||||
|
inputPayload["aspect_ratio"] = "custom"
|
||||||
|
if width > 0 {
|
||||||
|
inputPayload["width"] = width
|
||||||
|
}
|
||||||
|
if height > 0 {
|
||||||
|
inputPayload["height"] = height
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
inputPayload["aspect_ratio"] = aspect
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(request.OutputFormat) > 0 {
|
||||||
|
var outputFormat string
|
||||||
|
if err := json.Unmarshal(request.OutputFormat, &outputFormat); err == nil && strings.TrimSpace(outputFormat) != "" {
|
||||||
|
inputPayload["output_format"] = outputFormat
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.N > 0 {
|
||||||
|
inputPayload["num_outputs"] = int(request.N)
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.EqualFold(request.Quality, "hd") || strings.EqualFold(request.Quality, "high") {
|
||||||
|
inputPayload["prompt_upsampling"] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.RelayMode == relayconstant.RelayModeImagesEdits {
|
||||||
|
imageURL, err := uploadFileFromForm(c, info, "image", "image[]", "image_prompt")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if imageURL == "" {
|
||||||
|
return nil, errors.New("replicate adaptor: image file is required for edits")
|
||||||
|
}
|
||||||
|
inputPayload["image_prompt"] = imageURL
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(request.ExtraFields) > 0 {
|
||||||
|
var extra map[string]any
|
||||||
|
if err := common.Unmarshal(request.ExtraFields, &extra); err != nil {
|
||||||
|
return nil, fmt.Errorf("replicate adaptor: failed to decode extra_fields: %w", err)
|
||||||
|
}
|
||||||
|
for key, val := range extra {
|
||||||
|
inputPayload[key] = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, raw := range request.Extra {
|
||||||
|
if strings.EqualFold(key, "input") {
|
||||||
|
var extraInput map[string]any
|
||||||
|
if err := common.Unmarshal(raw, &extraInput); err != nil {
|
||||||
|
return nil, fmt.Errorf("replicate adaptor: failed to decode extra input: %w", err)
|
||||||
|
}
|
||||||
|
for k, v := range extraInput {
|
||||||
|
inputPayload[k] = v
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if raw == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var val any
|
||||||
|
if err := common.Unmarshal(raw, &val); err != nil {
|
||||||
|
return nil, fmt.Errorf("replicate adaptor: failed to decode extra field %s: %w", key, err)
|
||||||
|
}
|
||||||
|
inputPayload[key] = val
|
||||||
|
}
|
||||||
|
|
||||||
|
return map[string]any{
|
||||||
|
"input": inputPayload,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (any, *types.NewAPIError) {
|
||||||
|
if resp == nil {
|
||||||
|
return nil, types.NewError(errors.New("replicate adaptor: empty response"), types.ErrorCodeBadResponse)
|
||||||
|
}
|
||||||
|
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||||
|
}
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
|
var prediction PredictionResponse
|
||||||
|
if err := common.Unmarshal(responseBody, &prediction); err != nil {
|
||||||
|
return nil, types.NewError(fmt.Errorf("replicate adaptor: failed to decode response: %w", err), types.ErrorCodeBadResponseBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
if prediction.Error != nil {
|
||||||
|
errMsg := prediction.Error.Message
|
||||||
|
if errMsg == "" {
|
||||||
|
errMsg = prediction.Error.Detail
|
||||||
|
}
|
||||||
|
if errMsg == "" {
|
||||||
|
errMsg = prediction.Error.Code
|
||||||
|
}
|
||||||
|
if errMsg == "" {
|
||||||
|
errMsg = "replicate adaptor: prediction error"
|
||||||
|
}
|
||||||
|
return nil, types.NewError(errors.New(errMsg), types.ErrorCodeBadResponse)
|
||||||
|
}
|
||||||
|
|
||||||
|
if prediction.Status != "" && !strings.EqualFold(prediction.Status, "succeeded") {
|
||||||
|
return nil, types.NewError(fmt.Errorf("replicate adaptor: prediction status %q", prediction.Status), types.ErrorCodeBadResponse)
|
||||||
|
}
|
||||||
|
|
||||||
|
var urls []string
|
||||||
|
|
||||||
|
appendOutput := func(value string) {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
if value == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
urls = append(urls, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch output := prediction.Output.(type) {
|
||||||
|
case string:
|
||||||
|
appendOutput(output)
|
||||||
|
case []any:
|
||||||
|
for _, item := range output {
|
||||||
|
if str, ok := item.(string); ok {
|
||||||
|
appendOutput(str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case nil:
|
||||||
|
// no output
|
||||||
|
default:
|
||||||
|
if str, ok := output.(fmt.Stringer); ok {
|
||||||
|
appendOutput(str.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(urls) == 0 {
|
||||||
|
return nil, types.NewError(errors.New("replicate adaptor: empty prediction output"), types.ErrorCodeBadResponseBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
var imageReq *dto.ImageRequest
|
||||||
|
if info != nil {
|
||||||
|
if req, ok := info.Request.(*dto.ImageRequest); ok {
|
||||||
|
imageReq = req
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wantsBase64 := imageReq != nil && strings.EqualFold(imageReq.ResponseFormat, "b64_json")
|
||||||
|
|
||||||
|
imageResponse := dto.ImageResponse{
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Data: make([]dto.ImageData, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
if wantsBase64 {
|
||||||
|
converted, convErr := downloadImagesToBase64(urls)
|
||||||
|
if convErr != nil {
|
||||||
|
return nil, types.NewError(convErr, types.ErrorCodeBadResponse)
|
||||||
|
}
|
||||||
|
for _, content := range converted {
|
||||||
|
if content == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
imageResponse.Data = append(imageResponse.Data, dto.ImageData{B64Json: content})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _, url := range urls {
|
||||||
|
if url == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
imageResponse.Data = append(imageResponse.Data, dto.ImageData{Url: url})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(imageResponse.Data) == 0 {
|
||||||
|
return nil, types.NewError(errors.New("replicate adaptor: no usable image data"), types.ErrorCodeBadResponse)
|
||||||
|
}
|
||||||
|
|
||||||
|
responseBytes, err := common.Marshal(imageResponse)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.NewError(fmt.Errorf("replicate adaptor: encode response failed: %w", err), types.ErrorCodeBadResponseBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = c.Writer.Write(responseBytes)
|
||||||
|
|
||||||
|
usage := &dto.Usage{}
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return ChannelName
|
||||||
|
}
|
||||||
|
|
||||||
|
func downloadImagesToBase64(urls []string) ([]string, error) {
|
||||||
|
results := make([]string, 0, len(urls))
|
||||||
|
for _, url := range urls {
|
||||||
|
if strings.TrimSpace(url) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, data, err := service.GetImageFromUrl(url)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("replicate adaptor: failed to download image from %s: %w", url, err)
|
||||||
|
}
|
||||||
|
results = append(results, data)
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapOpenAISizeToFlux(size string) (aspect string, width int, height int, ok bool) {
|
||||||
|
parts := strings.Split(size, "x")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return "", 0, 0, false
|
||||||
|
}
|
||||||
|
w, err1 := strconv.Atoi(strings.TrimSpace(parts[0]))
|
||||||
|
h, err2 := strconv.Atoi(strings.TrimSpace(parts[1]))
|
||||||
|
if err1 != nil || err2 != nil || w <= 0 || h <= 0 {
|
||||||
|
return "", 0, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case w == h:
|
||||||
|
return "1:1", 0, 0, true
|
||||||
|
case w == 1792 && h == 1024:
|
||||||
|
return "16:9", 0, 0, true
|
||||||
|
case w == 1024 && h == 1792:
|
||||||
|
return "9:16", 0, 0, true
|
||||||
|
case w == 1536 && h == 1024:
|
||||||
|
return "3:2", 0, 0, true
|
||||||
|
case w == 1024 && h == 1536:
|
||||||
|
return "2:3", 0, 0, true
|
||||||
|
}
|
||||||
|
|
||||||
|
rw, rh := reduceRatio(w, h)
|
||||||
|
ratioStr := fmt.Sprintf("%d:%d", rw, rh)
|
||||||
|
switch ratioStr {
|
||||||
|
case "1:1", "16:9", "9:16", "3:2", "2:3", "4:5", "5:4", "3:4", "4:3":
|
||||||
|
return ratioStr, 0, 0, true
|
||||||
|
}
|
||||||
|
|
||||||
|
width = normalizeFluxDimension(w)
|
||||||
|
height = normalizeFluxDimension(h)
|
||||||
|
return "custom", width, height, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func reduceRatio(w, h int) (int, int) {
|
||||||
|
g := gcd(w, h)
|
||||||
|
if g == 0 {
|
||||||
|
return w, h
|
||||||
|
}
|
||||||
|
return w / g, h / g
|
||||||
|
}
|
||||||
|
|
||||||
|
func gcd(a, b int) int {
|
||||||
|
for b != 0 {
|
||||||
|
a, b = b, a%b
|
||||||
|
}
|
||||||
|
if a < 0 {
|
||||||
|
return -a
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeFluxDimension(value int) int {
|
||||||
|
const (
|
||||||
|
minDim = 256
|
||||||
|
maxDim = 1440
|
||||||
|
step = 32
|
||||||
|
)
|
||||||
|
if value < minDim {
|
||||||
|
value = minDim
|
||||||
|
}
|
||||||
|
if value > maxDim {
|
||||||
|
value = maxDim
|
||||||
|
}
|
||||||
|
remainder := value % step
|
||||||
|
if remainder != 0 {
|
||||||
|
if remainder >= step/2 {
|
||||||
|
value += step - remainder
|
||||||
|
} else {
|
||||||
|
value -= remainder
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if value < minDim {
|
||||||
|
value = minDim
|
||||||
|
}
|
||||||
|
if value > maxDim {
|
||||||
|
value = maxDim
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
func uploadFileFromForm(c *gin.Context, info *relaycommon.RelayInfo, fieldCandidates ...string) (string, error) {
|
||||||
|
if info == nil {
|
||||||
|
return "", errors.New("replicate adaptor: relay info is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
mf := c.Request.MultipartForm
|
||||||
|
if mf == nil {
|
||||||
|
if _, err := c.MultipartForm(); err != nil {
|
||||||
|
return "", fmt.Errorf("replicate adaptor: parse multipart form failed: %w", err)
|
||||||
|
}
|
||||||
|
mf = c.Request.MultipartForm
|
||||||
|
}
|
||||||
|
if mf == nil || len(mf.File) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(fieldCandidates) == 0 {
|
||||||
|
fieldCandidates = []string{"image", "image[]", "image_prompt"}
|
||||||
|
}
|
||||||
|
|
||||||
|
var fileHeader *multipart.FileHeader
|
||||||
|
for _, key := range fieldCandidates {
|
||||||
|
if files := mf.File[key]; len(files) > 0 {
|
||||||
|
fileHeader = files[0]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if fileHeader == nil {
|
||||||
|
for _, files := range mf.File {
|
||||||
|
if len(files) > 0 {
|
||||||
|
fileHeader = files[0]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if fileHeader == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := fileHeader.Open()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("replicate adaptor: failed to open image file: %w", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
var body bytes.Buffer
|
||||||
|
writer := multipart.NewWriter(&body)
|
||||||
|
|
||||||
|
hdr := make(textproto.MIMEHeader)
|
||||||
|
hdr.Set("Content-Disposition", fmt.Sprintf("form-data; name=\"content\"; filename=\"%s\"", fileHeader.Filename))
|
||||||
|
contentType := fileHeader.Header.Get("Content-Type")
|
||||||
|
if contentType == "" {
|
||||||
|
contentType = "application/octet-stream"
|
||||||
|
}
|
||||||
|
hdr.Set("Content-Type", contentType)
|
||||||
|
|
||||||
|
part, err := writer.CreatePart(hdr)
|
||||||
|
if err != nil {
|
||||||
|
writer.Close()
|
||||||
|
return "", fmt.Errorf("replicate adaptor: create upload form failed: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := io.Copy(part, file); err != nil {
|
||||||
|
writer.Close()
|
||||||
|
return "", fmt.Errorf("replicate adaptor: copy image content failed: %w", err)
|
||||||
|
}
|
||||||
|
formContentType := writer.FormDataContentType()
|
||||||
|
writer.Close()
|
||||||
|
|
||||||
|
baseURL := info.ChannelBaseUrl
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = constant.ChannelBaseURLs[constant.ChannelTypeReplicate]
|
||||||
|
}
|
||||||
|
uploadURL := relaycommon.GetFullRequestURL(baseURL, "/v1/files", info.ChannelType)
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodPost, uploadURL, &body)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("replicate adaptor: create upload request failed: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", formContentType)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||||
|
|
||||||
|
resp, err := service.GetHttpClient().Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("replicate adaptor: upload image failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("replicate adaptor: read upload response failed: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||||
|
return "", fmt.Errorf("replicate adaptor: upload image failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var uploadResp FileUploadResponse
|
||||||
|
if err := common.Unmarshal(respBody, &uploadResp); err != nil {
|
||||||
|
return "", fmt.Errorf("replicate adaptor: decode upload response failed: %w", err)
|
||||||
|
}
|
||||||
|
if uploadResp.Urls.Get == "" {
|
||||||
|
return "", errors.New("replicate adaptor: upload response missing url")
|
||||||
|
}
|
||||||
|
return uploadResp.Urls.Get, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertOpenAIRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
|
return nil, errors.New("replicate adaptor: ConvertOpenAIRequest is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRerankRequest(*gin.Context, int, dto.RerankRequest) (any, error) {
|
||||||
|
return nil, errors.New("replicate adaptor: ConvertRerankRequest is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertEmbeddingRequest(*gin.Context, *relaycommon.RelayInfo, dto.EmbeddingRequest) (any, error) {
|
||||||
|
return nil, errors.New("replicate adaptor: ConvertEmbeddingRequest is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertAudioRequest(*gin.Context, *relaycommon.RelayInfo, dto.AudioRequest) (io.Reader, error) {
|
||||||
|
return nil, errors.New("replicate adaptor: ConvertAudioRequest is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertOpenAIResponsesRequest(*gin.Context, *relaycommon.RelayInfo, dto.OpenAIResponsesRequest) (any, error) {
|
||||||
|
return nil, errors.New("replicate adaptor: ConvertOpenAIResponsesRequest is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||||
|
return nil, errors.New("replicate adaptor: ConvertClaudeRequest is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
|
return nil, errors.New("replicate adaptor: ConvertGeminiRequest is not implemented")
|
||||||
|
}
|
||||||
12
relay/channel/replicate/constants.go
Normal file
12
relay/channel/replicate/constants.go
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ChannelName identifies the replicate channel.
|
||||||
|
ChannelName = "replicate"
|
||||||
|
// ModelFlux11Pro is the default image generation model supported by this channel.
|
||||||
|
ModelFlux11Pro = "black-forest-labs/flux-1.1-pro"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
ModelFlux11Pro,
|
||||||
|
}
|
||||||
19
relay/channel/replicate/dto.go
Normal file
19
relay/channel/replicate/dto.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package replicate
|
||||||
|
|
||||||
|
type PredictionResponse struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
Output any `json:"output"`
|
||||||
|
Error *PredictionError `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PredictionError struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Detail string `json:"detail"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FileUploadResponse struct {
|
||||||
|
Urls struct {
|
||||||
|
Get string `json:"get"`
|
||||||
|
} `json:"urls"`
|
||||||
|
}
|
||||||
@ -8,6 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/QuantumNous/new-api/common"
|
"github.com/QuantumNous/new-api/common"
|
||||||
|
"github.com/QuantumNous/new-api/constant"
|
||||||
"github.com/QuantumNous/new-api/dto"
|
"github.com/QuantumNous/new-api/dto"
|
||||||
"github.com/QuantumNous/new-api/logger"
|
"github.com/QuantumNous/new-api/logger"
|
||||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
@ -92,10 +93,15 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
|||||||
httpResp = resp.(*http.Response)
|
httpResp = resp.(*http.Response)
|
||||||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||||
if httpResp.StatusCode != http.StatusOK {
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
if httpResp.StatusCode == http.StatusCreated && info.ApiType == constant.APITypeReplicate {
|
||||||
// reset status code 重置状态码
|
// replicate channel returns 201 Created when using Prefer: wait, treat it as success.
|
||||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
httpResp.StatusCode = http.StatusOK
|
||||||
return newAPIError
|
} else {
|
||||||
|
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||||
|
// reset status code 重置状态码
|
||||||
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
|
return newAPIError
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -26,6 +26,7 @@ import (
|
|||||||
"github.com/QuantumNous/new-api/relay/channel/openai"
|
"github.com/QuantumNous/new-api/relay/channel/openai"
|
||||||
"github.com/QuantumNous/new-api/relay/channel/palm"
|
"github.com/QuantumNous/new-api/relay/channel/palm"
|
||||||
"github.com/QuantumNous/new-api/relay/channel/perplexity"
|
"github.com/QuantumNous/new-api/relay/channel/perplexity"
|
||||||
|
"github.com/QuantumNous/new-api/relay/channel/replicate"
|
||||||
"github.com/QuantumNous/new-api/relay/channel/siliconflow"
|
"github.com/QuantumNous/new-api/relay/channel/siliconflow"
|
||||||
"github.com/QuantumNous/new-api/relay/channel/submodel"
|
"github.com/QuantumNous/new-api/relay/channel/submodel"
|
||||||
taskali "github.com/QuantumNous/new-api/relay/channel/task/ali"
|
taskali "github.com/QuantumNous/new-api/relay/channel/task/ali"
|
||||||
@ -113,6 +114,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
|||||||
return &submodel.Adaptor{}
|
return &submodel.Adaptor{}
|
||||||
case constant.APITypeMiniMax:
|
case constant.APITypeMiniMax:
|
||||||
return &minimax.Adaptor{}
|
return &minimax.Adaptor{}
|
||||||
|
case constant.APITypeReplicate:
|
||||||
|
return &replicate.Adaptor{}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -268,31 +268,32 @@ var defaultModelRatio = map[string]float64{
|
|||||||
}
|
}
|
||||||
|
|
||||||
var defaultModelPrice = map[string]float64{
|
var defaultModelPrice = map[string]float64{
|
||||||
"suno_music": 0.1,
|
"suno_music": 0.1,
|
||||||
"suno_lyrics": 0.01,
|
"suno_lyrics": 0.01,
|
||||||
"dall-e-3": 0.04,
|
"dall-e-3": 0.04,
|
||||||
"imagen-3.0-generate-002": 0.03,
|
"imagen-3.0-generate-002": 0.03,
|
||||||
"gpt-4-gizmo-*": 0.1,
|
"black-forest-labs/flux-1.1-pro": 0.04,
|
||||||
"mj_video": 0.8,
|
"gpt-4-gizmo-*": 0.1,
|
||||||
"mj_imagine": 0.1,
|
"mj_video": 0.8,
|
||||||
"mj_edits": 0.1,
|
"mj_imagine": 0.1,
|
||||||
"mj_variation": 0.1,
|
"mj_edits": 0.1,
|
||||||
"mj_reroll": 0.1,
|
"mj_variation": 0.1,
|
||||||
"mj_blend": 0.1,
|
"mj_reroll": 0.1,
|
||||||
"mj_modal": 0.1,
|
"mj_blend": 0.1,
|
||||||
"mj_zoom": 0.1,
|
"mj_modal": 0.1,
|
||||||
"mj_shorten": 0.1,
|
"mj_zoom": 0.1,
|
||||||
"mj_high_variation": 0.1,
|
"mj_shorten": 0.1,
|
||||||
"mj_low_variation": 0.1,
|
"mj_high_variation": 0.1,
|
||||||
"mj_pan": 0.1,
|
"mj_low_variation": 0.1,
|
||||||
"mj_inpaint": 0,
|
"mj_pan": 0.1,
|
||||||
"mj_custom_zoom": 0,
|
"mj_inpaint": 0,
|
||||||
"mj_describe": 0.05,
|
"mj_custom_zoom": 0,
|
||||||
"mj_upscale": 0.05,
|
"mj_describe": 0.05,
|
||||||
"swap_face": 0.05,
|
"mj_upscale": 0.05,
|
||||||
"mj_upload": 0.05,
|
"swap_face": 0.05,
|
||||||
"sora-2": 0.3,
|
"mj_upload": 0.05,
|
||||||
"sora-2-pro": 0.5,
|
"sora-2": 0.3,
|
||||||
|
"sora-2-pro": 0.5,
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultAudioRatio = map[string]float64{
|
var defaultAudioRatio = map[string]float64{
|
||||||
|
|||||||
@ -179,6 +179,11 @@ export const CHANNEL_OPTIONS = [
|
|||||||
color: 'green',
|
color: 'green',
|
||||||
label: 'Sora',
|
label: 'Sora',
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
value: 56,
|
||||||
|
color: 'blue',
|
||||||
|
label: 'Replicate',
|
||||||
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
export const MODEL_TABLE_PAGE_SIZE = 10;
|
export const MODEL_TABLE_PAGE_SIZE = 10;
|
||||||
|
|||||||
@ -55,6 +55,7 @@ import {
|
|||||||
Kling,
|
Kling,
|
||||||
Jimeng,
|
Jimeng,
|
||||||
Perplexity,
|
Perplexity,
|
||||||
|
Replicate,
|
||||||
} from '@lobehub/icons';
|
} from '@lobehub/icons';
|
||||||
|
|
||||||
import {
|
import {
|
||||||
@ -342,6 +343,8 @@ export function getChannelIcon(channelType) {
|
|||||||
return <Jimeng.Color size={iconSize} />;
|
return <Jimeng.Color size={iconSize} />;
|
||||||
case 54: // 豆包视频 Doubao Video
|
case 54: // 豆包视频 Doubao Video
|
||||||
return <Doubao.Color size={iconSize} />;
|
return <Doubao.Color size={iconSize} />;
|
||||||
|
case 56: // Replicate
|
||||||
|
return <Replicate size={iconSize} />;
|
||||||
case 8: // 自定义渠道
|
case 8: // 自定义渠道
|
||||||
case 22: // 知识库:FastGPT
|
case 22: // 知识库:FastGPT
|
||||||
return <FastGPT.Color size={iconSize} />;
|
return <FastGPT.Color size={iconSize} />;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user