new-api/relay/common/relay_utils.go
2026-04-03 15:28:08 +08:00

225 lines
5.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package common
import (
"fmt"
"net/http"
"strconv"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
)
type HasPrompt interface {
GetPrompt() string
}
type HasImage interface {
HasImage() bool
}
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
switch channelType {
case constant.ChannelTypeOpenAI:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
case constant.ChannelTypeAzure:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
}
}
return fullRequestURL
}
func GetAPIVersion(c *gin.Context) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
return apiVersion
}
func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError {
return &dto.TaskError{
Code: code,
Message: err.Error(),
StatusCode: statusCode,
LocalError: localError,
Error: err,
}
}
func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj TaskSubmitReq) {
info.Action = action
c.Set("task_request", requestObj)
}
func GetTaskRequest(c *gin.Context) (TaskSubmitReq, error) {
v, exists := c.Get("task_request")
if !exists {
return TaskSubmitReq{}, fmt.Errorf("request not found in context")
}
req, ok := v.(TaskSubmitReq)
if !ok {
return TaskSubmitReq{}, fmt.Errorf("invalid task request type")
}
return req, nil
}
func validatePrompt(prompt string) *dto.TaskError {
if strings.TrimSpace(prompt) == "" {
return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true)
}
return nil
}
func validateMultipartTaskRequest(c *gin.Context, info *RelayInfo, action string) (TaskSubmitReq, error) {
var req TaskSubmitReq
if _, err := c.MultipartForm(); err != nil {
return req, err
}
formData := c.Request.PostForm
req = TaskSubmitReq{
Prompt: formData.Get("prompt"),
Model: formData.Get("model"),
Mode: formData.Get("mode"),
Image: formData.Get("image"),
Size: formData.Get("size"),
Metadata: make(map[string]interface{}),
}
if durationStr := formData.Get("seconds"); durationStr != "" {
if duration, err := strconv.Atoi(durationStr); err == nil {
req.Duration = duration
}
}
if images := formData["images"]; len(images) > 0 {
req.Images = images
}
for key, values := range formData {
if len(values) > 0 && !isKnownTaskField(key) {
if intVal, err := strconv.Atoi(values[0]); err == nil {
req.Metadata[key] = intVal
} else if floatVal, err := strconv.ParseFloat(values[0], 64); err == nil {
req.Metadata[key] = floatVal
} else {
req.Metadata[key] = values[0]
}
}
}
return req, nil
}
func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
var prompt string
var model string
var seconds int
var size string
var hasInputReference bool
var req TaskSubmitReq
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
return createTaskError(err, "invalid_json", http.StatusBadRequest, true)
}
prompt = req.Prompt
model = req.Model
size = req.Size
seconds, _ = strconv.Atoi(req.Seconds)
if seconds == 0 {
seconds = req.Duration
}
if req.InputReference != "" {
req.Images = []string{req.InputReference}
}
if strings.TrimSpace(req.Model) == "" {
return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
}
if req.HasImage() {
hasInputReference = true
}
if taskErr := validatePrompt(prompt); taskErr != nil {
return taskErr
}
action := constant.TaskActionTextGenerate
if hasInputReference {
action = constant.TaskActionGenerate
}
if strings.HasPrefix(model, "sora-2") {
if size == "" {
size = "720x1280"
}
if seconds <= 0 {
seconds = 4
}
if model == "sora-2" && !lo.Contains([]string{"720x1280", "1280x720"}, size) {
return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
}
if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) {
return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
}
// OtherRatios 已移到 Sora adaptor 的 EstimateBilling 中设置
}
storeTaskRequest(c, info, action, req)
return nil
}
func isKnownTaskField(field string) bool {
knownFields := map[string]bool{
"prompt": true,
"model": true,
"mode": true,
"image": true,
"images": true,
"size": true,
"duration": true,
"input_reference": true, // Sora 特有字段
}
return knownFields[field]
}
func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
var err error
contentType := c.GetHeader("Content-Type")
var req TaskSubmitReq
if strings.HasPrefix(contentType, "multipart/form-data") {
req, err = validateMultipartTaskRequest(c, info, action)
if err != nil {
return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
}
}
// 为了metadata字段的兼容性统一UnmarshalBodyReusable
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
}
if taskErr := validatePrompt(req.Prompt); taskErr != nil {
return taskErr
}
if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" {
// 兼容单图上传
req.Images = []string{req.Image}
}
storeTaskRequest(c, info, action, req)
return nil
}