feat: add ali wan video (#2141)
Some checks failed
Publish Docker image (Multi Registries, native amd64+arm64) / Build & push (amd64) [native] (push) Has been cancelled
Publish Docker image (Multi Registries, native amd64+arm64) / Build & push (arm64) [native] (push) Has been cancelled
Publish Docker image (Multi Registries, native amd64+arm64) / Create multi-arch manifests (Docker Hub) (push) Has been cancelled
Build Electron App / build (windows-latest) (push) Has been cancelled
Build Electron App / release (push) Has been cancelled
Release (Linux, macOS, Windows) / Linux Release (push) Has been cancelled
Release (Linux, macOS, Windows) / macOS Release (push) Has been cancelled
Release (Linux, macOS, Windows) / Windows Release (push) Has been cancelled

* feat: add ali wan video

* refactor: use same UnmarshalBodyReusable

* feat: enhance request body metadata

* feat: opt wan convertToOpenAIVideo

* feat: add wan support other param via json metadata

* refactor: remove unused code

* fix ali

---------

Co-authored-by: feitianbubu <feitianbubu@qq.com>
This commit is contained in:
Seefs 2025-10-31 16:51:05 +08:00 committed by GitHub
parent 36b712437d
commit a98e207ef7
9 changed files with 475 additions and 76 deletions

View File

@ -2,7 +2,6 @@ package common
import ( import (
"bytes" "bytes"
"encoding/json"
"io" "io"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
@ -41,11 +40,11 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
//} //}
contentType := c.Request.Header.Get("Content-Type") contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") { if strings.HasPrefix(contentType, "application/json") {
err = Unmarshal(requestBody, &v) err = Unmarshal(requestBody, v)
} else if strings.Contains(contentType, gin.MIMEPOSTForm) { } else if strings.Contains(contentType, gin.MIMEPOSTForm) {
err = parseFormData(requestBody, &v) err = parseFormData(requestBody, v)
} else if strings.Contains(contentType, gin.MIMEMultipartPOSTForm) { } else if strings.Contains(contentType, gin.MIMEMultipartPOSTForm) {
err = parseMultipartFormData(c, requestBody, &v) err = parseMultipartFormData(c, requestBody, v)
} else { } else {
// skip for now // skip for now
// TODO: someday non json request have variant model, we will need to implementation this // TODO: someday non json request have variant model, we will need to implementation this
@ -145,6 +144,20 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
return form, nil return form, nil
} }
func processFormMap(formMap map[string]any, v any) error {
jsonData, err := Marshal(formMap)
if err != nil {
return err
}
err = Unmarshal(jsonData, v)
if err != nil {
return err
}
return nil
}
func parseFormData(data []byte, v any) error { func parseFormData(data []byte, v any) error {
values, err := url.ParseQuery(string(data)) values, err := url.ParseQuery(string(data))
if err != nil { if err != nil {
@ -158,12 +171,8 @@ func parseFormData(data []byte, v any) error {
formMap[key] = vals formMap[key] = vals
} }
} }
jsonData, err := json.Marshal(formMap)
if err != nil {
return err
}
return Unmarshal(jsonData, v) return processFormMap(formMap, v)
} }
func parseMultipartFormData(c *gin.Context, data []byte, v any) error { func parseMultipartFormData(c *gin.Context, data []byte, v any) error {
@ -191,10 +200,6 @@ func parseMultipartFormData(c *gin.Context, data []byte, v any) error {
formMap[key] = vals formMap[key] = vals
} }
} }
jsonData, err := Marshal(formMap)
if err != nil {
return err
}
return Unmarshal(jsonData, v) return processFormMap(formMap, v)
} }

View File

@ -91,7 +91,8 @@ func VideoProxy(c *gin.Context) {
return return
} }
if channel.Type == constant.ChannelTypeGemini { switch channel.Type {
case constant.ChannelTypeGemini:
apiKey := task.PrivateData.Key apiKey := task.PrivateData.Key
if apiKey == "" { if apiKey == "" {
logger.LogError(c.Request.Context(), fmt.Sprintf("Missing stored API key for Gemini task %s", taskID)) logger.LogError(c.Request.Context(), fmt.Sprintf("Missing stored API key for Gemini task %s", taskID))
@ -116,7 +117,10 @@ func VideoProxy(c *gin.Context) {
return return
} }
req.Header.Set("x-goog-api-key", apiKey) req.Header.Set("x-goog-api-key", apiKey)
} else { case constant.ChannelTypeAli:
// Video URL is directly in task.FailReason
videoURL = task.FailReason
default:
// Default (Sora, etc.): Use original logic // Default (Sora, etc.): Use original logic
videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID) videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID)
req.Header.Set("Authorization", "Bearer "+channel.Key) req.Header.Set("Authorization", "Bearer "+channel.Key)

View File

@ -27,7 +27,7 @@ type OpenAIVideo struct {
Size string `json:"size,omitempty"` Size string `json:"size,omitempty"`
RemixedFromVideoID string `json:"remixed_from_video_id,omitempty"` RemixedFromVideoID string `json:"remixed_from_video_id,omitempty"`
Error *OpenAIVideoError `json:"error,omitempty"` Error *OpenAIVideoError `json:"error,omitempty"`
Metadata map[string]any `json:"meta_data,omitempty"` Metadata map[string]any `json:"metadata,omitempty"`
} }
func (m *OpenAIVideo) SetProgressStr(progress string) { func (m *OpenAIVideo) SetProgressStr(progress string) {

View File

@ -73,20 +73,22 @@ func (t *Task) GetData(v any) error {
} }
type Properties struct { type Properties struct {
Input string `json:"input"` Input string `json:"input"`
UpstreamModelName string `json:"upstream_model_name,omitempty"`
OriginModelName string `json:"origin_model_name,omitempty"`
} }
func (m *Properties) Scan(val interface{}) error { func (m *Properties) Scan(val interface{}) error {
bytesValue, _ := val.([]byte) bytesValue, _ := val.([]byte)
if len(bytesValue) == 0 { if len(bytesValue) == 0 {
m.Input = "" *m = Properties{}
return nil return nil
} }
return json.Unmarshal(bytesValue, m) return json.Unmarshal(bytesValue, m)
} }
func (m Properties) Value() (driver.Value, error) { func (m Properties) Value() (driver.Value, error) {
if m.Input == "" { if m == (Properties{}) {
return nil, nil return nil, nil
} }
return json.Marshal(m) return json.Marshal(m)
@ -127,8 +129,16 @@ type SyncTaskQueryParams struct {
func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo) *Task { func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo) *Task {
properties := Properties{} properties := Properties{}
privateData := TaskPrivateData{} privateData := TaskPrivateData{}
if relayInfo != nil && relayInfo.ChannelMeta != nil && relayInfo.ChannelMeta.ChannelType == constant.ChannelTypeGemini { if relayInfo != nil && relayInfo.ChannelMeta != nil {
privateData.Key = relayInfo.ChannelMeta.ApiKey if relayInfo.ChannelMeta.ChannelType == constant.ChannelTypeGemini {
privateData.Key = relayInfo.ChannelMeta.ApiKey
}
if relayInfo.UpstreamModelName != "" {
properties.UpstreamModelName = relayInfo.UpstreamModelName
}
if relayInfo.OriginModelName != "" {
properties.OriginModelName = relayInfo.OriginModelName
}
} }
t := &Task{ t := &Task{

View File

@ -0,0 +1,360 @@
package ali
import (
"bytes"
"fmt"
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
)
// ============================
// Request / Response structures
// ============================
// AliVideoRequest 阿里通义万相视频生成请求
type AliVideoRequest struct {
Model string `json:"model"`
Input AliVideoInput `json:"input"`
Parameters *AliVideoParameters `json:"parameters,omitempty"`
}
// AliVideoInput 视频输入参数
type AliVideoInput struct {
Prompt string `json:"prompt,omitempty"` // 文本提示词
ImgURL string `json:"img_url,omitempty"` // 首帧图像URL或Base64图生视频
FirstFrameURL string `json:"first_frame_url,omitempty"` // 首帧图片URL首尾帧生视频
LastFrameURL string `json:"last_frame_url,omitempty"` // 尾帧图片URL首尾帧生视频
AudioURL string `json:"audio_url,omitempty"` // 音频URLwan2.5支持)
NegativePrompt string `json:"negative_prompt,omitempty"` // 反向提示词
Template string `json:"template,omitempty"` // 视频特效模板
}
// AliVideoParameters 视频参数
type AliVideoParameters struct {
Resolution string `json:"resolution,omitempty"` // 分辨率: 480P/720P/1080P图生视频、首尾帧生视频
Size string `json:"size,omitempty"` // 尺寸: 如 "832*480"(文生视频)
Duration int `json:"duration,omitempty"` // 时长: 3-10秒
PromptExtend bool `json:"prompt_extend,omitempty"` // 是否开启prompt智能改写
Watermark bool `json:"watermark,omitempty"` // 是否添加水印
Audio *bool `json:"audio,omitempty"` // 是否添加音频wan2.5
Seed int `json:"seed,omitempty"` // 随机数种子
}
// AliVideoResponse 阿里通义万相响应
type AliVideoResponse struct {
Output AliVideoOutput `json:"output"`
RequestID string `json:"request_id"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
Usage *AliUsage `json:"usage,omitempty"`
}
// AliVideoOutput 输出信息
type AliVideoOutput struct {
TaskID string `json:"task_id"`
TaskStatus string `json:"task_status"`
SubmitTime string `json:"submit_time,omitempty"`
ScheduledTime string `json:"scheduled_time,omitempty"`
EndTime string `json:"end_time,omitempty"`
OrigPrompt string `json:"orig_prompt,omitempty"`
ActualPrompt string `json:"actual_prompt,omitempty"`
VideoURL string `json:"video_url,omitempty"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
// AliUsage 使用统计
type AliUsage struct {
Duration int `json:"duration,omitempty"`
VideoCount int `json:"video_count,omitempty"`
SR int `json:"SR,omitempty"`
}
type AliMetadata struct {
// Input 相关
AudioURL string `json:"audio_url,omitempty"` // 音频URL
ImgURL string `json:"img_url,omitempty"` // 图片URL图生视频
FirstFrameURL string `json:"first_frame_url,omitempty"` // 首帧图片URL首尾帧生视频
LastFrameURL string `json:"last_frame_url,omitempty"` // 尾帧图片URL首尾帧生视频
NegativePrompt string `json:"negative_prompt,omitempty"` // 反向提示词
Template string `json:"template,omitempty"` // 视频特效模板
// Parameters 相关
Resolution *string `json:"resolution,omitempty"` // 分辨率: 480P/720P/1080P
Size *string `json:"size,omitempty"` // 尺寸: 如 "832*480"
Duration *int `json:"duration,omitempty"` // 时长
PromptExtend *bool `json:"prompt_extend,omitempty"` // 是否开启prompt智能改写
Watermark *bool `json:"watermark,omitempty"` // 是否添加水印
Audio *bool `json:"audio,omitempty"` // 是否添加音频
Seed *int `json:"seed,omitempty"` // 随机数种子
}
// ============================
// Adaptor implementation
// ============================
type TaskAdaptor struct {
ChannelType int
apiKey string
baseURL string
}
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
a.ChannelType = info.ChannelType
a.baseURL = info.ChannelBaseUrl
a.apiKey = info.ApiKey
}
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
// 阿里通义万相支持 JSON 格式,不使用 multipart
return relaycommon.ValidateMultipartDirect(c, info)
}
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/v1/services/aigc/video-generation/video-synthesis", a.baseURL), nil
}
// BuildRequestHeader sets required headers for Ali API
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
req.Header.Set("Authorization", "Bearer "+a.apiKey)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-DashScope-Async", "enable") // 阿里异步任务必须设置
return nil
}
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
var taskReq relaycommon.TaskSubmitReq
if err := common.UnmarshalBodyReusable(c, &taskReq); err != nil {
return nil, errors.Wrap(err, "unmarshal_task_request_failed")
}
aliReq := a.convertToAliRequest(taskReq)
bodyBytes, err := common.Marshal(aliReq)
if err != nil {
return nil, errors.Wrap(err, "marshal_ali_request_failed")
}
return bytes.NewReader(bodyBytes), nil
}
func (a *TaskAdaptor) convertToAliRequest(req relaycommon.TaskSubmitReq) *AliVideoRequest {
aliReq := &AliVideoRequest{
Model: req.Model,
Input: AliVideoInput{
Prompt: req.Prompt,
ImgURL: req.InputReference,
},
Parameters: &AliVideoParameters{
PromptExtend: true, // 默认开启智能改写
Watermark: false,
},
}
// 处理分辨率映射
if req.Size != "" {
resolution := strings.ToUpper(req.Size)
// 支持 480p, 720p, 1080p 或 480P, 720P, 1080P
if !strings.HasSuffix(resolution, "P") {
resolution = resolution + "P"
}
aliReq.Parameters.Resolution = resolution
} else {
// 根据模型设置默认分辨率
if strings.HasPrefix(req.Model, "wan2.5") {
aliReq.Parameters.Resolution = "1080P"
} else if strings.HasPrefix(req.Model, "wan2.2-i2v-flash") {
aliReq.Parameters.Resolution = "720P"
} else if strings.HasPrefix(req.Model, "wan2.2-i2v-plus") {
aliReq.Parameters.Resolution = "1080P"
} else {
aliReq.Parameters.Resolution = "720P"
}
}
// 处理时长
if req.Duration > 0 {
aliReq.Parameters.Duration = req.Duration
} else {
aliReq.Parameters.Duration = 5 // 默认5秒
}
// 从 metadata 中提取额外参数
if req.Metadata != nil {
if metadataBytes, err := common.Marshal(req.Metadata); err == nil {
_ = common.Unmarshal(metadataBytes, aliReq)
}
}
return aliReq
}
// DoRequest delegates to common helper
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoTaskApiRequest(a, c, info, requestBody)
}
// DoResponse handles upstream response
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
return
}
_ = resp.Body.Close()
// 解析阿里响应
var aliResp AliVideoResponse
if err := common.Unmarshal(responseBody, &aliResp); err != nil {
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
return
}
// 检查错误
if aliResp.Code != "" {
taskErr = service.TaskErrorWrapper(fmt.Errorf("%s: %s", aliResp.Code, aliResp.Message), "ali_api_error", resp.StatusCode)
return
}
if aliResp.Output.TaskID == "" {
taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
return
}
// 转换为 OpenAI 格式响应
openAIResp := dto.NewOpenAIVideo()
openAIResp.ID = aliResp.Output.TaskID
openAIResp.Model = c.GetString("model")
if openAIResp.Model == "" && info != nil {
openAIResp.Model = info.OriginModelName
}
openAIResp.Status = convertAliStatus(aliResp.Output.TaskStatus)
openAIResp.CreatedAt = common.GetTimestamp()
// 返回 OpenAI 格式
c.JSON(http.StatusOK, openAIResp)
return aliResp.Output.TaskID, responseBody, nil
}
// FetchTask 查询任务状态
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
taskID, ok := body["task_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid task_id")
}
uri := fmt.Sprintf("%s/api/v1/tasks/%s", baseUrl, taskID)
req, err := http.NewRequest(http.MethodGet, uri, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+key)
return service.GetHttpClient().Do(req)
}
func (a *TaskAdaptor) GetModelList() []string {
return ModelList
}
func (a *TaskAdaptor) GetChannelName() string {
return ChannelName
}
// ParseTaskResult 解析任务结果
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
var aliResp AliVideoResponse
if err := common.Unmarshal(respBody, &aliResp); err != nil {
return nil, errors.Wrap(err, "unmarshal task result failed")
}
taskResult := relaycommon.TaskInfo{
Code: 0,
}
// 状态映射
switch aliResp.Output.TaskStatus {
case "PENDING":
taskResult.Status = model.TaskStatusQueued
case "RUNNING":
taskResult.Status = model.TaskStatusInProgress
case "SUCCEEDED":
taskResult.Status = model.TaskStatusSuccess
// 阿里直接返回视频URL不需要额外的代理端点
taskResult.Url = aliResp.Output.VideoURL
case "FAILED", "CANCELED", "UNKNOWN":
taskResult.Status = model.TaskStatusFailure
if aliResp.Message != "" {
taskResult.Reason = aliResp.Message
} else if aliResp.Output.Message != "" {
taskResult.Reason = fmt.Sprintf("task failed, code: %s , message: %s", aliResp.Output.Code, aliResp.Output.Message)
} else {
taskResult.Reason = "task failed"
}
default:
taskResult.Status = model.TaskStatusQueued
}
return &taskResult, nil
}
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
var aliResp AliVideoResponse
if err := common.Unmarshal(task.Data, &aliResp); err != nil {
return nil, errors.Wrap(err, "unmarshal ali response failed")
}
openAIResp := dto.NewOpenAIVideo()
openAIResp.ID = task.TaskID
openAIResp.Status = convertAliStatus(aliResp.Output.TaskStatus)
openAIResp.Model = task.Properties.OriginModelName
openAIResp.SetProgressStr(task.Progress)
openAIResp.CreatedAt = task.CreatedAt
openAIResp.CompletedAt = task.UpdatedAt
// 设置视频URL核心字段
openAIResp.SetMetadata("url", aliResp.Output.VideoURL)
// 错误处理
if aliResp.Code != "" {
openAIResp.Error = &dto.OpenAIVideoError{
Code: aliResp.Code,
Message: aliResp.Message,
}
} else if aliResp.Output.Code != "" {
openAIResp.Error = &dto.OpenAIVideoError{
Code: aliResp.Output.Code,
Message: aliResp.Output.Message,
}
}
return common.Marshal(openAIResp)
}
func convertAliStatus(aliStatus string) string {
switch aliStatus {
case "PENDING":
return dto.VideoStatusQueued
case "RUNNING":
return dto.VideoStatusInProgress
case "SUCCEEDED":
return dto.VideoStatusCompleted
case "FAILED", "CANCELED", "UNKNOWN":
return dto.VideoStatusFailed
default:
return dto.VideoStatusUnknown
}
}

View File

@ -0,0 +1,11 @@
package ali
var ModelList = []string{
"wan2.5-i2v-preview", // 万相2.5 preview有声视频推荐
"wan2.2-i2v-flash", // 万相2.2极速版(无声视频)
"wan2.2-i2v-plus", // 万相2.2专业版(无声视频)
"wanx2.1-i2v-plus", // 万相2.1专业版(无声视频)
"wanx2.1-i2v-turbo", // 万相2.1极速版(无声视频)
}
var ChannelName = "ali"

View File

@ -1,6 +1,7 @@
package common package common
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
@ -485,14 +486,16 @@ type TaskRelayInfo struct {
} }
type TaskSubmitReq struct { type TaskSubmitReq struct {
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Mode string `json:"mode,omitempty"` Mode string `json:"mode,omitempty"`
Image string `json:"image,omitempty"` Image string `json:"image,omitempty"`
Images []string `json:"images,omitempty"` Images []string `json:"images,omitempty"`
Size string `json:"size,omitempty"` Size string `json:"size,omitempty"`
Duration int `json:"duration,omitempty"` Duration int `json:"duration,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"` Seconds string `json:"seconds,omitempty"`
InputReference string `json:"input_reference,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
} }
func (t TaskSubmitReq) GetPrompt() string { func (t TaskSubmitReq) GetPrompt() string {
@ -503,6 +506,38 @@ func (t TaskSubmitReq) HasImage() bool {
return len(t.Images) > 0 return len(t.Images) > 0
} }
func (t *TaskSubmitReq) UnmarshalJSON(data []byte) error {
type Alias TaskSubmitReq
aux := &struct {
Metadata json.RawMessage `json:"metadata,omitempty"`
*Alias
}{
Alias: (*Alias)(t),
}
if err := common.Unmarshal(data, &aux); err != nil {
return err
}
if len(aux.Metadata) > 0 {
var metadataStr string
if err := common.Unmarshal(aux.Metadata, &metadataStr); err == nil && metadataStr != "" {
var metadataObj map[string]interface{}
if err := common.Unmarshal([]byte(metadataStr), &metadataObj); err == nil {
t.Metadata = metadataObj
return nil
}
}
var metadataObj map[string]interface{}
if err := common.Unmarshal(aux.Metadata, &metadataObj); err == nil {
t.Metadata = metadataObj
}
}
return nil
}
type TaskInfo struct { type TaskInfo struct {
Code int `json:"code"` Code int `json:"code"`
TaskID string `json:"task_id"` TaskID string `json:"task_id"`

View File

@ -108,62 +108,33 @@ func validateMultipartTaskRequest(c *gin.Context, info *RelayInfo, action string
} }
func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError { func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
contentType := c.GetHeader("Content-Type")
var prompt string var prompt string
var model string var model string
var seconds int var seconds int
var size string var size string
var hasInputReference bool var hasInputReference bool
if strings.HasPrefix(contentType, "multipart/form-data") { var req TaskSubmitReq
form, err := common.ParseMultipartFormReusable(c) if err := common.UnmarshalBodyReusable(c, &req); err != nil {
if err != nil { return createTaskError(err, "invalid_json", http.StatusBadRequest, true)
return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true) }
}
defer form.RemoveAll()
prompts, ok := form.Value["prompt"] prompt = req.Prompt
if !ok || len(prompts) == 0 { model = req.Model
return createTaskError(fmt.Errorf("prompt field is required"), "missing_prompt", http.StatusBadRequest, true) seconds, _ = strconv.Atoi(req.Seconds)
} if seconds == 0 {
prompt = prompts[0]
if _, ok := form.Value["model"]; !ok {
return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
}
model = form.Value["model"][0]
if _, ok := form.File["input_reference"]; ok {
hasInputReference = true
}
if ss, ok := form.Value["seconds"]; ok {
sInt := common.String2Int(ss[0])
if sInt > seconds {
seconds = common.String2Int(ss[0])
}
}
if sz, ok := form.Value["size"]; ok {
size = sz[0]
}
} else {
var req TaskSubmitReq
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
return createTaskError(err, "invalid_json", http.StatusBadRequest, true)
}
prompt = req.Prompt
model = req.Model
seconds = req.Duration seconds = req.Duration
}
if req.InputReference != "" {
req.Images = []string{req.InputReference}
}
if strings.TrimSpace(req.Model) == "" { if strings.TrimSpace(req.Model) == "" {
return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true) return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
} }
if req.HasImage() { if req.HasImage() {
hasInputReference = true hasInputReference = true
}
} }
if taskErr := validatePrompt(prompt); taskErr != nil { if taskErr := validatePrompt(prompt); taskErr != nil {

View File

@ -28,6 +28,7 @@ import (
"github.com/QuantumNous/new-api/relay/channel/perplexity" "github.com/QuantumNous/new-api/relay/channel/perplexity"
"github.com/QuantumNous/new-api/relay/channel/siliconflow" "github.com/QuantumNous/new-api/relay/channel/siliconflow"
"github.com/QuantumNous/new-api/relay/channel/submodel" "github.com/QuantumNous/new-api/relay/channel/submodel"
taskali "github.com/QuantumNous/new-api/relay/channel/task/ali"
taskdoubao "github.com/QuantumNous/new-api/relay/channel/task/doubao" taskdoubao "github.com/QuantumNous/new-api/relay/channel/task/doubao"
taskGemini "github.com/QuantumNous/new-api/relay/channel/task/gemini" taskGemini "github.com/QuantumNous/new-api/relay/channel/task/gemini"
taskjimeng "github.com/QuantumNous/new-api/relay/channel/task/jimeng" taskjimeng "github.com/QuantumNous/new-api/relay/channel/task/jimeng"
@ -133,6 +134,8 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
} }
if channelType, err := strconv.ParseInt(string(platform), 10, 64); err == nil { if channelType, err := strconv.ParseInt(string(platform), 10, 64); err == nil {
switch channelType { switch channelType {
case constant.ChannelTypeAli:
return &taskali.TaskAdaptor{}
case constant.ChannelTypeKling: case constant.ChannelTypeKling:
return &kling.TaskAdaptor{} return &kling.TaskAdaptor{}
case constant.ChannelTypeJimeng: case constant.ChannelTypeJimeng: