From ba25ba88fe342cf14db61c90acf102b304342587 Mon Sep 17 00:00:00 2001 From: CaIon Date: Tue, 10 Feb 2026 20:40:33 +0800 Subject: [PATCH 01/10] refactor(task): extract billing and polling logic from controller to service layer Restructure the task relay system for better separation of concerns: - Extract task billing into service/task_billing.go with unified settlement flow - Move task polling loop from controller to service/task_polling.go (supports Suno + video platforms) - Split RelayTask into fetch/submit paths with dedicated retry logic (taskSubmitWithRetry) - Add TaskDto, TaskResponse generics, and FetchReq to dto/task.go - Add taskcommon/helpers.go for shared task adaptor utilities - Remove controller/task_video.go (logic consolidated into service layer) - Update all task adaptors (ali, doubao, gemini, hailuo, jimeng, kling, sora, suno, vertex, vidu) - Simplify frontend task logs to use new TaskDto response format --- controller/relay.go | 122 +++- controller/task.go | 228 +------ controller/task_video.go | 313 ---------- controller/video_proxy.go | 111 +--- controller/video_proxy_gemini.go | 8 +- dto/suno.go | 32 - dto/task.go | 47 ++ main.go | 10 + middleware/auth.go | 18 + model/task.go | 57 +- model/token.go | 6 +- relay/channel/task/ali/adaptor.go | 3 +- relay/channel/task/doubao/adaptor.go | 24 +- relay/channel/task/gemini/adaptor.go | 47 +- relay/channel/task/hailuo/adaptor.go | 15 +- relay/channel/task/jimeng/adaptor.go | 27 +- relay/channel/task/kling/adaptor.go | 43 +- relay/channel/task/sora/adaptor.go | 24 +- relay/channel/task/suno/adaptor.go | 29 +- relay/channel/task/taskcommon/helpers.go | 70 +++ relay/channel/task/vertex/adaptor.go | 50 +- relay/channel/task/vidu/adaptor.go | 45 +- relay/common/relay_info.go | 15 +- relay/helper/price.go | 15 +- relay/relay_task.go | 576 +++++++++--------- router/video-router.go | 8 +- service/billing_session.go | 5 + service/error.go | 13 + service/log_info_generate.go | 2 +- service/task_billing.go | 227 +++++++ service/task_polling.go | 446 ++++++++++++++ types/price_data.go | 9 +- .../table/task-logs/TaskLogsColumnDefs.jsx | 9 +- .../table/task-logs/modals/ContentModal.jsx | 2 - 34 files changed, 1465 insertions(+), 1191 deletions(-) delete mode 100644 controller/task_video.go create mode 100644 relay/channel/task/taskcommon/helpers.go create mode 100644 service/task_billing.go create mode 100644 service/task_polling.go diff --git a/controller/relay.go b/controller/relay.go index 0b30e6e9..132fee9b 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -451,17 +451,102 @@ func RelayNotFound(c *gin.Context) { } func RelayTask(c *gin.Context) { - retryTimes := common.RetryTimes channelId := c.GetInt("channel_id") c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)}) relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil) if err != nil { + c.JSON(http.StatusInternalServerError, &dto.TaskError{ + Code: "gen_relay_info_failed", + Message: err.Error(), + StatusCode: http.StatusInternalServerError, + }) return } - taskErr := taskRelayHandler(c, relayInfo) - if taskErr == nil { - retryTimes = 0 + + // Fetch 操作是纯 DB 查询(或 task 自带 channelId 的上游查询),不依赖上下文 channel,无需重试 + // TODO: 在video-route层面优化,避免无谓的 channel 选择和上下文设置,也没必要吧代码放到这里来写这么多屎山 + switch relayInfo.RelayMode { + case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID: + if taskErr := relay.RelayTaskFetch(c, relayInfo.RelayMode); taskErr != nil { + respondTaskError(c, taskErr) + } + return } + + // ── Submit 路径 ───────────────────────────────────────────────── + + // 1. 解析原始任务(remix / continuation),一次性,可能锁定渠道并禁止重试 + if taskErr := relay.ResolveOriginTask(c, relayInfo); taskErr != nil { + respondTaskError(c, taskErr) + return + } + + // 2. defer Refund(全部失败时回滚预扣费) + var result *relay.TaskSubmitResult + var taskErr *dto.TaskError + defer func() { + if taskErr != nil && relayInfo.Billing != nil { + relayInfo.Billing.Refund(c) + } + }() + + // 3. 执行 + 重试(RelayTaskSubmit 内部在首次调用时自动预扣费) + taskErr = taskSubmitWithRetry(c, relayInfo, channelId, common.RetryTimes, func() *dto.TaskError { + var te *dto.TaskError + result, te = relay.RelayTaskSubmit(c, relayInfo) + return te + }) + + // 4. 成功:结算 + 日志 + 插入任务 + if taskErr == nil { + if settleErr := service.SettleBilling(c, relayInfo, result.Quota); settleErr != nil { + common.SysError("settle task billing error: " + settleErr.Error()) + } + service.LogTaskConsumption(c, relayInfo, result.ModelName) + + task := model.InitTask(result.Platform, relayInfo) + task.PrivateData.UpstreamTaskID = result.UpstreamTaskID + task.PrivateData.BillingSource = relayInfo.BillingSource + task.PrivateData.SubscriptionId = relayInfo.SubscriptionId + task.PrivateData.TokenId = relayInfo.TokenId + task.Quota = result.Quota + task.Data = result.TaskData + task.Action = relayInfo.Action + if insertErr := task.Insert(); insertErr != nil { + //taskErr = service.TaskErrorWrapper(insertErr, "insert_task_failed", http.StatusInternalServerError) + common.SysError("insert task error: " + insertErr.Error()) + } + } + + if taskErr != nil { + respondTaskError(c, taskErr) + } +} + +// respondTaskError 统一输出 Task 错误响应(含 429 限流提示改写) +func respondTaskError(c *gin.Context, taskErr *dto.TaskError) { + if taskErr.StatusCode == http.StatusTooManyRequests { + taskErr.Message = "当前分组上游负载已饱和,请稍后再试" + } + c.JSON(taskErr.StatusCode, taskErr) +} + +// taskSubmitWithRetry 执行首次尝试并在失败时切换渠道重试,返回最终的 taskErr。 +// attempt 闭包负责实际的上游请求,不涉及计费。 +func taskSubmitWithRetry(c *gin.Context, relayInfo *relaycommon.RelayInfo, + channelId int, retryTimes int, attempt func() *dto.TaskError) *dto.TaskError { + + taskErr := attempt() + if taskErr == nil { + return nil + } + if !taskErr.LocalError { + processChannelError(c, + *types.NewChannelError(channelId, c.GetInt("channel_type"), c.GetString("channel_name"), common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey), + common.GetContextKeyString(c, constant.ContextKeyChannelKey), common.GetContextKeyBool(c, constant.ContextKeyChannelAutoBan)), + types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode)) + } + retryParam := &service.RetryParam{ Ctx: c, TokenGroup: relayInfo.TokenGroup, @@ -480,7 +565,7 @@ func RelayTask(c *gin.Context) { useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) c.Set("use_channel", useChannel) logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry())) - //middleware.SetupContextForSelectedChannel(c, channel, originalModel) + middleware.SetupContextForSelectedChannel(c, channel, c.GetString("original_model")) bodyStorage, err := common.GetBodyStorage(c) if err != nil { @@ -492,30 +577,21 @@ func RelayTask(c *gin.Context) { break } c.Request.Body = io.NopCloser(bodyStorage) - taskErr = taskRelayHandler(c, relayInfo) + taskErr = attempt() + if taskErr != nil && !taskErr.LocalError { + processChannelError(c, + *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, + common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), + types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode)) + } } + useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) logger.LogInfo(c, retryLogStr) } - if taskErr != nil { - if taskErr.StatusCode == http.StatusTooManyRequests { - taskErr.Message = "当前分组上游负载已饱和,请稍后再试" - } - c.JSON(taskErr.StatusCode, taskErr) - } -} - -func taskRelayHandler(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.TaskError { - var err *dto.TaskError - switch relayInfo.RelayMode { - case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID: - err = relay.RelayTaskFetch(c, relayInfo.RelayMode) - default: - err = relay.RelayTaskSubmit(c, relayInfo) - } - return err + return taskErr } func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool { diff --git a/controller/task.go b/controller/task.go index 244f9161..ec713c5d 100644 --- a/controller/task.go +++ b/controller/task.go @@ -1,231 +1,21 @@ package controller import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "sort" "strconv" - "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" - "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay" + "github.com/QuantumNous/new-api/service" "github.com/gin-gonic/gin" - "github.com/samber/lo" ) +// UpdateTaskBulk 薄入口,实际轮询逻辑在 service 层 func UpdateTaskBulk() { - //revocer - //imageModel := "midjourney" - for { - time.Sleep(time.Duration(15) * time.Second) - common.SysLog("任务进度轮询开始") - ctx := context.TODO() - allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit) - platformTask := make(map[constant.TaskPlatform][]*model.Task) - for _, t := range allTasks { - platformTask[t.Platform] = append(platformTask[t.Platform], t) - } - for platform, tasks := range platformTask { - if len(tasks) == 0 { - continue - } - taskChannelM := make(map[int][]string) - taskM := make(map[string]*model.Task) - nullTaskIds := make([]int64, 0) - for _, task := range tasks { - if task.TaskID == "" { - // 统计失败的未完成任务 - nullTaskIds = append(nullTaskIds, task.ID) - continue - } - taskM[task.TaskID] = task - taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID) - } - if len(nullTaskIds) > 0 { - err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{ - "status": "FAILURE", - "progress": "100%", - }) - if err != nil { - logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err)) - } else { - logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds)) - } - } - if len(taskChannelM) == 0 { - continue - } - - UpdateTaskByPlatform(platform, taskChannelM, taskM) - } - common.SysLog("任务进度轮询完成") - } -} - -func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) { - switch platform { - case constant.TaskPlatformMidjourney: - //_ = UpdateMidjourneyTaskAll(context.Background(), tasks) - case constant.TaskPlatformSuno: - _ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM) - default: - if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil { - common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err)) - } - } -} - -func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error { - for channelId, taskIds := range taskChannelM { - err := updateSunoTaskAll(ctx, channelId, taskIds, taskM) - if err != nil { - logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error())) - } - } - return nil -} - -func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error { - logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) - if len(taskIds) == 0 { - return nil - } - channel, err := model.CacheGetChannel(channelId) - if err != nil { - common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err)) - err = model.TaskBulkUpdate(taskIds, map[string]any{ - "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId), - "status": "FAILURE", - "progress": "100%", - }) - if err != nil { - common.SysLog(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err)) - } - return err - } - adaptor := relay.GetTaskAdaptor(constant.TaskPlatformSuno) - if adaptor == nil { - return errors.New("adaptor not found") - } - proxy := channel.GetSetting().Proxy - resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{ - "ids": taskIds, - }, proxy) - if err != nil { - common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err)) - return err - } - if resp.StatusCode != http.StatusOK { - logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) - return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) - } - defer resp.Body.Close() - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err)) - return err - } - var responseItems dto.TaskResponse[[]dto.SunoDataResponse] - err = json.Unmarshal(responseBody, &responseItems) - if err != nil { - logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) - return err - } - if !responseItems.IsSuccess() { - common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody))) - return err - } - - for _, responseItem := range responseItems.Data { - task := taskM[responseItem.TaskID] - if !checkTaskNeedUpdate(task, responseItem) { - continue - } - - task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status) - task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason) - task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime) - task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime) - task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime) - if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure { - logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) - task.Progress = "100%" - //err = model.CacheUpdateUserQuota(task.UserId) ? - if err != nil { - logger.LogError(ctx, "error update user quota cache: "+err.Error()) - } else { - quota := task.Quota - if quota != 0 { - err = model.IncreaseUserQuota(task.UserId, quota, false) - if err != nil { - logger.LogError(ctx, "fail to increase user quota: "+err.Error()) - } - logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, logger.LogQuota(quota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) - } - } - } - if responseItem.Status == model.TaskStatusSuccess { - task.Progress = "100%" - } - task.Data = responseItem.Data - - err = task.Update() - if err != nil { - common.SysLog("UpdateMidjourneyTask task error: " + err.Error()) - } - } - return nil -} - -func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool { - - if oldTask.SubmitTime != newTask.SubmitTime { - return true - } - if oldTask.StartTime != newTask.StartTime { - return true - } - if oldTask.FinishTime != newTask.FinishTime { - return true - } - if string(oldTask.Status) != newTask.Status { - return true - } - if oldTask.FailReason != newTask.FailReason { - return true - } - if oldTask.FinishTime != newTask.FinishTime { - return true - } - - if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" { - return true - } - - oldData, _ := json.Marshal(oldTask.Data) - newData, _ := json.Marshal(newTask.Data) - - sort.Slice(oldData, func(i, j int) bool { - return oldData[i] < oldData[j] - }) - sort.Slice(newData, func(i, j int) bool { - return newData[i] < newData[j] - }) - - if string(oldData) != string(newData) { - return true - } - return false + service.TaskPollingLoop() } func GetAllTask(c *gin.Context) { @@ -247,7 +37,7 @@ func GetAllTask(c *gin.Context) { items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams) total := model.TaskCountAllTasks(queryParams) pageInfo.SetTotal(int(total)) - pageInfo.SetItems(items) + pageInfo.SetItems(tasksToDto(items)) common.ApiSuccess(c, pageInfo) } @@ -271,6 +61,14 @@ func GetUserTask(c *gin.Context) { items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams) total := model.TaskCountAllUserTask(userId, queryParams) pageInfo.SetTotal(int(total)) - pageInfo.SetItems(items) + pageInfo.SetItems(tasksToDto(items)) common.ApiSuccess(c, pageInfo) } + +func tasksToDto(tasks []*model.Task) []*dto.TaskDto { + result := make([]*dto.TaskDto, len(tasks)) + for i, task := range tasks { + result[i] = relay.TaskModel2Dto(task) + } + return result +} diff --git a/controller/task_video.go b/controller/task_video.go deleted file mode 100644 index d7c19e62..00000000 --- a/controller/task_video.go +++ /dev/null @@ -1,313 +0,0 @@ -package controller - -import ( - "context" - "encoding/json" - "fmt" - "io" - "time" - - "github.com/QuantumNous/new-api/common" - "github.com/QuantumNous/new-api/constant" - "github.com/QuantumNous/new-api/dto" - "github.com/QuantumNous/new-api/logger" - "github.com/QuantumNous/new-api/model" - "github.com/QuantumNous/new-api/relay" - "github.com/QuantumNous/new-api/relay/channel" - relaycommon "github.com/QuantumNous/new-api/relay/common" - "github.com/QuantumNous/new-api/setting/ratio_setting" -) - -func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error { - for channelId, taskIds := range taskChannelM { - if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil { - logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) - } - } - return nil -} - -func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error { - logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) - if len(taskIds) == 0 { - return nil - } - cacheGetChannel, err := model.CacheGetChannel(channelId) - if err != nil { - errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{ - "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId), - "status": "FAILURE", - "progress": "100%", - }) - if errUpdate != nil { - common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) - } - return fmt.Errorf("CacheGetChannel failed: %w", err) - } - adaptor := relay.GetTaskAdaptor(platform) - if adaptor == nil { - return fmt.Errorf("video adaptor not found") - } - info := &relaycommon.RelayInfo{} - info.ChannelMeta = &relaycommon.ChannelMeta{ - ChannelBaseUrl: cacheGetChannel.GetBaseURL(), - } - info.ApiKey = cacheGetChannel.Key - adaptor.Init(info) - for _, taskId := range taskIds { - if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil { - logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) - } - } - return nil -} - -func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error { - baseURL := constant.ChannelBaseURLs[channel.Type] - if channel.GetBaseURL() != "" { - baseURL = channel.GetBaseURL() - } - proxy := channel.GetSetting().Proxy - - task := taskM[taskId] - if task == nil { - logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) - return fmt.Errorf("task %s not found", taskId) - } - key := channel.Key - - privateData := task.PrivateData - if privateData.Key != "" { - key = privateData.Key - } - resp, err := adaptor.FetchTask(baseURL, key, map[string]any{ - "task_id": taskId, - "action": task.Action, - }, proxy) - if err != nil { - return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err) - } - //if resp.StatusCode != http.StatusOK { - //return fmt.Errorf("get Video Task status code: %d", resp.StatusCode) - //} - defer resp.Body.Close() - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("readAll failed for task %s: %w", taskId, err) - } - - logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask response: %s", string(responseBody))) - - taskResult := &relaycommon.TaskInfo{} - // try parse as New API response format - var responseItems dto.TaskResponse[model.Task] - if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() { - logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask parsed as new api response format: %+v", responseItems)) - t := responseItems.Data - taskResult.TaskID = t.TaskID - taskResult.Status = string(t.Status) - taskResult.Url = t.FailReason - taskResult.Progress = t.Progress - taskResult.Reason = t.FailReason - task.Data = t.Data - } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil { - return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err) - } else { - task.Data = redactVideoResponseBody(responseBody) - } - - logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask taskResult: %+v", taskResult)) - - now := time.Now().Unix() - if taskResult.Status == "" { - //return fmt.Errorf("task %s status is empty", taskId) - taskResult = relaycommon.FailTaskInfo("upstream returned empty status") - } - - // 记录原本的状态,防止重复退款 - shouldRefund := false - quota := task.Quota - preStatus := task.Status - - task.Status = model.TaskStatus(taskResult.Status) - switch taskResult.Status { - case model.TaskStatusSubmitted: - task.Progress = "10%" - case model.TaskStatusQueued: - task.Progress = "20%" - case model.TaskStatusInProgress: - task.Progress = "30%" - if task.StartTime == 0 { - task.StartTime = now - } - case model.TaskStatusSuccess: - task.Progress = "100%" - if task.FinishTime == 0 { - task.FinishTime = now - } - if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") { - task.FailReason = taskResult.Url - } - - // 如果返回了 total_tokens 并且配置了模型倍率(非固定价格),则重新计费 - if taskResult.TotalTokens > 0 { - // 获取模型名称 - var taskData map[string]interface{} - if err := json.Unmarshal(task.Data, &taskData); err == nil { - if modelName, ok := taskData["model"].(string); ok && modelName != "" { - // 获取模型价格和倍率 - modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName) - // 只有配置了倍率(非固定价格)时才按 token 重新计费 - if hasRatioSetting && modelRatio > 0 { - // 获取用户和组的倍率信息 - group := task.Group - if group == "" { - user, err := model.GetUserById(task.UserId, false) - if err == nil { - group = user.Group - } - } - if group != "" { - groupRatio := ratio_setting.GetGroupRatio(group) - userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group) - - var finalGroupRatio float64 - if hasUserGroupRatio { - finalGroupRatio = userGroupRatio - } else { - finalGroupRatio = groupRatio - } - - // 计算实际应扣费额度: totalTokens * modelRatio * groupRatio - actualQuota := int(float64(taskResult.TotalTokens) * modelRatio * finalGroupRatio) - - // 计算差额 - preConsumedQuota := task.Quota - quotaDelta := actualQuota - preConsumedQuota - - if quotaDelta > 0 { - // 需要补扣费 - logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后补扣费:%s(实际消耗:%s,预扣费:%s,tokens:%d)", - task.TaskID, - logger.LogQuota(quotaDelta), - logger.LogQuota(actualQuota), - logger.LogQuota(preConsumedQuota), - taskResult.TotalTokens, - )) - if err := model.DecreaseUserQuota(task.UserId, quotaDelta); err != nil { - logger.LogError(ctx, fmt.Sprintf("补扣费失败: %s", err.Error())) - } else { - model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta) - model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta) - task.Quota = actualQuota // 更新任务记录的实际扣费额度 - - // 记录消费日志 - logContent := fmt.Sprintf("视频任务成功补扣费,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,补扣费 %s", - modelRatio, finalGroupRatio, taskResult.TotalTokens, - logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(quotaDelta)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) - } - } else if quotaDelta < 0 { - // 需要退还多扣的费用 - refundQuota := -quotaDelta - logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后返还:%s(实际消耗:%s,预扣费:%s,tokens:%d)", - task.TaskID, - logger.LogQuota(refundQuota), - logger.LogQuota(actualQuota), - logger.LogQuota(preConsumedQuota), - taskResult.TotalTokens, - )) - if err := model.IncreaseUserQuota(task.UserId, refundQuota, false); err != nil { - logger.LogError(ctx, fmt.Sprintf("退还预扣费失败: %s", err.Error())) - } else { - task.Quota = actualQuota // 更新任务记录的实际扣费额度 - - // 记录退款日志 - logContent := fmt.Sprintf("视频任务成功退还多扣费用,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,退还 %s", - modelRatio, finalGroupRatio, taskResult.TotalTokens, - logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(refundQuota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) - } - } else { - // quotaDelta == 0, 预扣费刚好准确 - logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)", - task.TaskID, logger.LogQuota(actualQuota), taskResult.TotalTokens)) - } - } - } - } - } - } - case model.TaskStatusFailure: - logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task) - task.Status = model.TaskStatusFailure - task.Progress = "100%" - if task.FinishTime == 0 { - task.FinishTime = now - } - task.FailReason = taskResult.Reason - logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) - taskResult.Progress = "100%" - if quota != 0 { - if preStatus != model.TaskStatusFailure { - shouldRefund = true - } else { - logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID)) - } - } - default: - return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId) - } - if taskResult.Progress != "" { - task.Progress = taskResult.Progress - } - if err := task.Update(); err != nil { - common.SysLog("UpdateVideoTask task error: " + err.Error()) - shouldRefund = false - } - - if shouldRefund { - // 任务失败且之前状态不是失败才退还额度,防止重复退还 - if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil { - logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error()) - } - logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) - } - - return nil -} - -func redactVideoResponseBody(body []byte) []byte { - var m map[string]any - if err := json.Unmarshal(body, &m); err != nil { - return body - } - resp, _ := m["response"].(map[string]any) - if resp != nil { - delete(resp, "bytesBase64Encoded") - if v, ok := resp["video"].(string); ok { - resp["video"] = truncateBase64(v) - } - if vs, ok := resp["videos"].([]any); ok { - for i := range vs { - if vm, ok := vs[i].(map[string]any); ok { - delete(vm, "bytesBase64Encoded") - } - } - } - } - b, err := json.Marshal(m) - if err != nil { - return body - } - return b -} - -func truncateBase64(s string) string { - const maxKeep = 256 - if len(s) <= maxKeep { - return s - } - return s[:maxKeep] + "..." -} diff --git a/controller/video_proxy.go b/controller/video_proxy.go index f102baae..f1dd2bc9 100644 --- a/controller/video_proxy.go +++ b/controller/video_proxy.go @@ -16,59 +16,44 @@ import ( "github.com/gin-gonic/gin" ) +// videoProxyError returns a standardized OpenAI-style error response. +func videoProxyError(c *gin.Context, status int, errType, message string) { + c.JSON(status, gin.H{ + "error": gin.H{ + "message": message, + "type": errType, + }, + }) +} + func VideoProxy(c *gin.Context) { taskID := c.Param("task_id") if taskID == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": "task_id is required", - "type": "invalid_request_error", - }, - }) + videoProxyError(c, http.StatusBadRequest, "invalid_request_error", "task_id is required") return } task, exists, err := model.GetByOnlyTaskId(taskID) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to query task %s: %s", taskID, err.Error())) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to query task", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to query task") return } if !exists || task == nil { - logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: %v", taskID, err)) - c.JSON(http.StatusNotFound, gin.H{ - "error": gin.H{ - "message": "Task not found", - "type": "invalid_request_error", - }, - }) + videoProxyError(c, http.StatusNotFound, "invalid_request_error", "Task not found") return } if task.Status != model.TaskStatusSuccess { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": fmt.Sprintf("Task is not completed yet, current status: %s", task.Status), - "type": "invalid_request_error", - }, - }) + videoProxyError(c, http.StatusBadRequest, "invalid_request_error", + fmt.Sprintf("Task is not completed yet, current status: %s", task.Status)) return } channel, err := model.CacheGetChannel(task.ChannelId) if err != nil { - logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: not found", taskID)) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to retrieve channel information", - "type": "server_error", - }, - }) + logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get channel for task %s: %s", taskID, err.Error())) + videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to retrieve channel information") return } baseURL := channel.GetBaseURL() @@ -81,12 +66,7 @@ func VideoProxy(c *gin.Context) { client, err := service.GetHttpClientWithProxy(proxy) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create proxy client for task %s: %s", taskID, err.Error())) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to create proxy client", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy client") return } @@ -95,12 +75,7 @@ func VideoProxy(c *gin.Context) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error())) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to create proxy request", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request") return } @@ -109,68 +84,43 @@ func VideoProxy(c *gin.Context) { apiKey := task.PrivateData.Key if apiKey == "" { logger.LogError(c.Request.Context(), fmt.Sprintf("Missing stored API key for Gemini task %s", taskID)) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "API key not stored for task", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusInternalServerError, "server_error", "API key not stored for task") return } - videoURL, err = getGeminiVideoURL(channel, task, apiKey) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to resolve Gemini video URL for task %s: %s", taskID, err.Error())) - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "message": "Failed to resolve Gemini video URL", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to resolve Gemini video URL") return } req.Header.Set("x-goog-api-key", apiKey) case constant.ChannelTypeOpenAI, constant.ChannelTypeSora: - videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID) + videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.GetUpstreamTaskID()) req.Header.Set("Authorization", "Bearer "+channel.Key) default: - // Video URL is directly in task.FailReason - videoURL = task.FailReason + // Video URL is stored in PrivateData.ResultURL (fallback to FailReason for old data) + videoURL = task.GetResultURL() } req.URL, err = url.Parse(videoURL) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to parse URL %s: %s", videoURL, err.Error())) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": "Failed to create proxy request", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request") return } resp, err := client.Do(req) if err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to fetch video from %s: %s", videoURL, err.Error())) - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "message": "Failed to fetch video content", - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to fetch video content") return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { logger.LogError(c.Request.Context(), fmt.Sprintf("Upstream returned status %d for %s", resp.StatusCode, videoURL)) - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "message": fmt.Sprintf("Upstream service returned status %d", resp.StatusCode), - "type": "server_error", - }, - }) + videoProxyError(c, http.StatusBadGateway, "server_error", + fmt.Sprintf("Upstream service returned status %d", resp.StatusCode)) return } @@ -180,10 +130,9 @@ func VideoProxy(c *gin.Context) { } } - c.Writer.Header().Set("Cache-Control", "public, max-age=86400") // Cache for 24 hours + c.Writer.Header().Set("Cache-Control", "public, max-age=86400") c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { + if _, err = io.Copy(c.Writer, resp.Body); err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error())) } } diff --git a/controller/video_proxy_gemini.go b/controller/video_proxy_gemini.go index 053ac651..a63a2a5c 100644 --- a/controller/video_proxy_gemini.go +++ b/controller/video_proxy_gemini.go @@ -1,12 +1,12 @@ package controller import ( - "encoding/json" "fmt" "io" "strconv" "strings" + "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay" @@ -37,7 +37,7 @@ func getGeminiVideoURL(channel *model.Channel, task *model.Task, apiKey string) proxy := channel.GetSetting().Proxy resp, err := adaptor.FetchTask(baseURL, apiKey, map[string]any{ - "task_id": task.TaskID, + "task_id": task.GetUpstreamTaskID(), "action": task.Action, }, proxy) if err != nil { @@ -71,7 +71,7 @@ func extractGeminiVideoURLFromTaskData(task *model.Task) string { return "" } var payload map[string]any - if err := json.Unmarshal(task.Data, &payload); err != nil { + if err := common.Unmarshal(task.Data, &payload); err != nil { return "" } return extractGeminiVideoURLFromMap(payload) @@ -79,7 +79,7 @@ func extractGeminiVideoURLFromTaskData(task *model.Task) string { func extractGeminiVideoURLFromPayload(body []byte) string { var payload map[string]any - if err := json.Unmarshal(body, &payload); err != nil { + if err := common.Unmarshal(body, &payload); err != nil { return "" } return extractGeminiVideoURLFromMap(payload) diff --git a/dto/suno.go b/dto/suno.go index a6bb3eba..90e11b81 100644 --- a/dto/suno.go +++ b/dto/suno.go @@ -4,10 +4,6 @@ import ( "encoding/json" ) -type TaskData interface { - SunoDataResponse | []SunoDataResponse | string | any -} - type SunoSubmitReq struct { GptDescriptionPrompt string `json:"gpt_description_prompt,omitempty"` Prompt string `json:"prompt,omitempty"` @@ -20,10 +16,6 @@ type SunoSubmitReq struct { MakeInstrumental bool `json:"make_instrumental"` } -type FetchReq struct { - IDs []string `json:"ids"` -} - type SunoDataResponse struct { TaskID string `json:"task_id" gorm:"type:varchar(50);index"` Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode @@ -66,30 +58,6 @@ type SunoLyrics struct { Text string `json:"text"` } -const TaskSuccessCode = "success" - -type TaskResponse[T TaskData] struct { - Code string `json:"code"` - Message string `json:"message"` - Data T `json:"data"` -} - -func (t *TaskResponse[T]) IsSuccess() bool { - return t.Code == TaskSuccessCode -} - -type TaskDto struct { - TaskID string `json:"task_id"` // 第三方id,不一定有/ song id\ Task id - Action string `json:"action"` // 任务类型, song, lyrics, description-mode - Status string `json:"status"` // 任务状态, submitted, queueing, processing, success, failed - FailReason string `json:"fail_reason"` - SubmitTime int64 `json:"submit_time"` - StartTime int64 `json:"start_time"` - FinishTime int64 `json:"finish_time"` - Progress string `json:"progress"` - Data json.RawMessage `json:"data"` -} - type SunoGoAPISubmitReq struct { CustomMode bool `json:"custom_mode"` diff --git a/dto/task.go b/dto/task.go index afc186b4..4a9a8e2e 100644 --- a/dto/task.go +++ b/dto/task.go @@ -1,5 +1,9 @@ package dto +import ( + "encoding/json" +) + type TaskError struct { Code string `json:"code"` Message string `json:"message"` @@ -8,3 +12,46 @@ type TaskError struct { LocalError bool `json:"-"` Error error `json:"-"` } + +type TaskData interface { + SunoDataResponse | []SunoDataResponse | string | any +} + +const TaskSuccessCode = "success" + +type TaskResponse[T TaskData] struct { + Code string `json:"code"` + Message string `json:"message"` + Data T `json:"data"` +} + +func (t *TaskResponse[T]) IsSuccess() bool { + return t.Code == TaskSuccessCode +} + +type TaskDto struct { + ID int64 `json:"id"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + TaskID string `json:"task_id"` + Platform string `json:"platform"` + UserId int `json:"user_id"` + Group string `json:"group"` + ChannelId int `json:"channel_id"` + Quota int `json:"quota"` + Action string `json:"action"` + Status string `json:"status"` + FailReason string `json:"fail_reason"` + ResultURL string `json:"result_url,omitempty"` // 任务结果 URL(视频地址等) + SubmitTime int64 `json:"submit_time"` + StartTime int64 `json:"start_time"` + FinishTime int64 `json:"finish_time"` + Progress string `json:"progress"` + Properties any `json:"properties"` + Username string `json:"username,omitempty"` + Data json.RawMessage `json:"data"` +} + +type FetchReq struct { + IDs []string `json:"ids"` +} diff --git a/main.go b/main.go index 852e1a0a..476a2ed2 100644 --- a/main.go +++ b/main.go @@ -19,6 +19,7 @@ import ( "github.com/QuantumNous/new-api/middleware" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/oauth" + "github.com/QuantumNous/new-api/relay" "github.com/QuantumNous/new-api/router" "github.com/QuantumNous/new-api/service" _ "github.com/QuantumNous/new-api/setting/performance_setting" @@ -111,6 +112,15 @@ func main() { // Subscription quota reset task (daily/weekly/monthly/custom) service.StartSubscriptionQuotaResetTask() + // Wire task polling adaptor factory (breaks service -> relay import cycle) + service.GetTaskAdaptorFunc = func(platform constant.TaskPlatform) service.TaskPollingAdaptor { + a := relay.GetTaskAdaptor(platform) + if a == nil { + return nil + } + return a + } + if common.IsMasterNode && constant.UpdateTask { gopool.Go(func() { controller.UpdateMidjourneyTaskBulk() diff --git a/middleware/auth.go b/middleware/auth.go index cf184351..342e7f49 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -170,6 +170,24 @@ func WssAuth(c *gin.Context) { } +// TokenOrUserAuth allows either session-based user auth or API token auth. +// Used for endpoints that need to be accessible from both the dashboard and API clients. +func TokenOrUserAuth() func(c *gin.Context) { + return func(c *gin.Context) { + // Try session auth first (dashboard users) + session := sessions.Default(c) + if id := session.Get("id"); id != nil { + if status, ok := session.Get("status").(int); ok && status == common.UserStatusEnabled { + c.Set("id", id) + c.Next() + return + } + } + // Fall back to token auth (API clients) + TokenAuth()(c) + } +} + // TokenAuthReadOnly 宽松版本的令牌认证中间件,用于只读查询接口。 // 只验证令牌 key 是否存在,不检查令牌状态、过期时间和额度。 // 即使令牌已过期、已耗尽或已禁用,也允许访问。 diff --git a/model/task.go b/model/task.go index 82c2e978..38bb4d05 100644 --- a/model/task.go +++ b/model/task.go @@ -5,6 +5,7 @@ import ( "encoding/json" "time" + "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" commonRelay "github.com/QuantumNous/new-api/relay/common" @@ -64,13 +65,12 @@ type Task struct { } func (t *Task) SetData(data any) { - b, _ := json.Marshal(data) + b, _ := common.Marshal(data) t.Data = json.RawMessage(b) } func (t *Task) GetData(v any) error { - err := json.Unmarshal(t.Data, &v) - return err + return common.Unmarshal(t.Data, &v) } type Properties struct { @@ -85,18 +85,48 @@ func (m *Properties) Scan(val interface{}) error { *m = Properties{} return nil } - return json.Unmarshal(bytesValue, m) + return common.Unmarshal(bytesValue, m) } func (m Properties) Value() (driver.Value, error) { if m == (Properties{}) { return nil, nil } - return json.Marshal(m) + return common.Marshal(m) } type TaskPrivateData struct { - Key string `json:"key,omitempty"` + Key string `json:"key,omitempty"` + UpstreamTaskID string `json:"upstream_task_id,omitempty"` // 上游真实 task ID + ResultURL string `json:"result_url,omitempty"` // 任务成功后的结果 URL(视频地址等) + // 计费上下文:用于异步退款/差额结算(轮询阶段读取) + BillingSource string `json:"billing_source,omitempty"` // "wallet" 或 "subscription" + SubscriptionId int `json:"subscription_id,omitempty"` // 订阅 ID,用于订阅退款 + TokenId int `json:"token_id,omitempty"` // 令牌 ID,用于令牌额度退款 +} + +// GetUpstreamTaskID 获取上游真实 task ID(用于与 provider 通信) +// 旧数据没有 UpstreamTaskID 时,TaskID 本身就是上游 ID +func (t *Task) GetUpstreamTaskID() string { + if t.PrivateData.UpstreamTaskID != "" { + return t.PrivateData.UpstreamTaskID + } + return t.TaskID +} + +// GetResultURL 获取任务结果 URL(视频地址等) +// 新数据存在 PrivateData.ResultURL 中;旧数据回退到 FailReason(历史兼容) +func (t *Task) GetResultURL() string { + if t.PrivateData.ResultURL != "" { + return t.PrivateData.ResultURL + } + return t.FailReason +} + +// GenerateTaskID 生成对外暴露的 task_xxxx 格式 ID +func GenerateTaskID() string { + key, _ := common.GenerateRandomCharsKey(32) + return "task_" + key } func (p *TaskPrivateData) Scan(val interface{}) error { @@ -104,14 +134,14 @@ func (p *TaskPrivateData) Scan(val interface{}) error { if len(bytesValue) == 0 { return nil } - return json.Unmarshal(bytesValue, p) + return common.Unmarshal(bytesValue, p) } func (p TaskPrivateData) Value() (driver.Value, error) { if (p == TaskPrivateData{}) { return nil, nil } - return json.Marshal(p) + return common.Marshal(p) } // SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段 @@ -142,7 +172,16 @@ func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo) } } + // 使用预生成的公开 ID(如果有),否则新生成 + taskID := "" + if relayInfo.TaskRelayInfo != nil && relayInfo.TaskRelayInfo.PublicTaskID != "" { + taskID = relayInfo.TaskRelayInfo.PublicTaskID + } else { + taskID = GenerateTaskID() + } + t := &Task{ + TaskID: taskID, UserId: relayInfo.UserId, Group: relayInfo.UsingGroup, SubmitTime: time.Now().Unix(), @@ -438,6 +477,6 @@ func (t *Task) ToOpenAIVideo() *dto.OpenAIVideo { openAIVideo.SetProgressStr(t.Progress) openAIVideo.CreatedAt = t.CreatedAt openAIVideo.CompletedAt = t.UpdatedAt - openAIVideo.SetMetadata("url", t.FailReason) + openAIVideo.SetMetadata("url", t.GetResultURL()) return openAIVideo } diff --git a/model/token.go b/model/token.go index 9e05b63c..773b2d79 100644 --- a/model/token.go +++ b/model/token.go @@ -360,7 +360,7 @@ func DeleteTokenById(id int, userId int) (err error) { return token.Delete() } -func IncreaseTokenQuota(id int, key string, quota int) (err error) { +func IncreaseTokenQuota(tokenId int, key string, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -373,10 +373,10 @@ func IncreaseTokenQuota(id int, key string, quota int) (err error) { }) } if common.BatchUpdateEnabled { - addNewRecord(BatchUpdateTypeTokenQuota, id, quota) + addNewRecord(BatchUpdateTypeTokenQuota, tokenId, quota) return nil } - return increaseTokenQuota(id, quota) + return increaseTokenQuota(tokenId, quota) } func increaseTokenQuota(id int, quota int) (err error) { diff --git a/relay/channel/task/ali/adaptor.go b/relay/channel/task/ali/adaptor.go index d55452c0..5d14ff65 100644 --- a/relay/channel/task/ali/adaptor.go +++ b/relay/channel/task/ali/adaptor.go @@ -384,7 +384,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela // 转换为 OpenAI 格式响应 openAIResp := dto.NewOpenAIVideo() - openAIResp.ID = aliResp.Output.TaskID + openAIResp.ID = info.PublicTaskID + openAIResp.TaskID = info.PublicTaskID openAIResp.Model = c.GetString("model") if openAIResp.Model == "" && info != nil { openAIResp.Model = info.OriginModelName diff --git a/relay/channel/task/doubao/adaptor.go b/relay/channel/task/doubao/adaptor.go index 6ebecb3c..3da125af 100644 --- a/relay/channel/task/doubao/adaptor.go +++ b/relay/channel/task/doubao/adaptor.go @@ -2,7 +2,6 @@ package doubao import ( "bytes" - "encoding/json" "fmt" "io" "net/http" @@ -14,6 +13,7 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" @@ -131,7 +131,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn return nil, errors.Wrap(err, "convert request payload failed") } info.UpstreamModelName = body.Model - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -154,7 +154,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela // Parse Doubao response var dResp responsePayload - if err := json.Unmarshal(responseBody, &dResp); err != nil { + if err := common.Unmarshal(responseBody, &dResp); err != nil { taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) return } @@ -165,8 +165,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } ov := dto.NewOpenAIVideo() - ov.ID = dResp.ID - ov.TaskID = dResp.ID + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName @@ -234,12 +234,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* } metadata := req.Metadata - medaBytes, err := json.Marshal(metadata) - if err != nil { - return nil, errors.Wrap(err, "metadata marshal metadata failed") - } - err = json.Unmarshal(medaBytes, &r) - if err != nil { + if err := taskcommon.UnmarshalMetadata(metadata, &r); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } @@ -248,7 +243,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { resTask := responseTask{} - if err := json.Unmarshal(respBody, &resTask); err != nil { + if err := common.Unmarshal(respBody, &resTask); err != nil { return nil, errors.Wrap(err, "unmarshal task result failed") } @@ -286,7 +281,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) { var dResp responseTask - if err := json.Unmarshal(originTask.Data, &dResp); err != nil { + if err := common.Unmarshal(originTask.Data, &dResp); err != nil { return nil, errors.Wrap(err, "unmarshal doubao task data failed") } @@ -307,6 +302,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro } } - jsonData, _ := common.Marshal(openAIVideo) - return jsonData, nil + return common.Marshal(openAIVideo) } diff --git a/relay/channel/task/gemini/adaptor.go b/relay/channel/task/gemini/adaptor.go index 16c6919b..a863ea85 100644 --- a/relay/channel/task/gemini/adaptor.go +++ b/relay/channel/task/gemini/adaptor.go @@ -2,8 +2,6 @@ package gemini import ( "bytes" - "encoding/base64" - "encoding/json" "fmt" "io" "net/http" @@ -16,10 +14,10 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" "github.com/QuantumNous/new-api/setting/model_setting" - "github.com/QuantumNous/new-api/setting/system_setting" "github.com/gin-gonic/gin" "github.com/pkg/errors" ) @@ -145,16 +143,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn } metadata := req.Metadata - medaBytes, err := json.Marshal(metadata) - if err != nil { - return nil, errors.Wrap(err, "metadata marshal metadata failed") - } - err = json.Unmarshal(medaBytes, &body.Parameters) - if err != nil { + if err := taskcommon.UnmarshalMetadata(metadata, &body.Parameters); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -175,16 +168,16 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela _ = resp.Body.Close() var s submitResponse - if err := json.Unmarshal(responseBody, &s); err != nil { + if err := common.Unmarshal(responseBody, &s); err != nil { return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError) } if strings.TrimSpace(s.Name) == "" { return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError) } - taskID = encodeLocalTaskID(s.Name) + taskID = taskcommon.EncodeLocalTaskID(s.Name) ov := dto.NewOpenAIVideo() - ov.ID = taskID - ov.TaskID = taskID + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName c.JSON(http.StatusOK, ov) @@ -206,7 +199,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy return nil, fmt.Errorf("invalid task_id") } - upstreamName, err := decodeLocalTaskID(taskID) + upstreamName, err := taskcommon.DecodeLocalTaskID(taskID) if err != nil { return nil, fmt.Errorf("decode task_id failed: %w", err) } @@ -232,7 +225,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { var op operationResponse - if err := json.Unmarshal(respBody, &op); err != nil { + if err := common.Unmarshal(respBody, &op); err != nil { return nil, fmt.Errorf("unmarshal operation response failed: %w", err) } @@ -254,9 +247,8 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e ti.Status = model.TaskStatusSuccess ti.Progress = "100%" - taskID := encodeLocalTaskID(op.Name) - ti.TaskID = taskID - ti.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID) + ti.TaskID = taskcommon.EncodeLocalTaskID(op.Name) + // Url intentionally left empty — the caller constructs the proxy URL using the public task ID // Extract URL from generateVideoResponse if available if len(op.Response.GenerateVideoResponse.GeneratedSamples) > 0 { @@ -269,7 +261,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e } func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { - upstreamName, err := decodeLocalTaskID(task.TaskID) + // Use GetUpstreamTaskID() to get the real upstream operation name for model extraction. + // task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name. + upstreamTaskID := task.GetUpstreamTaskID() + upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID) if err != nil { upstreamName = "" } @@ -297,18 +292,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { // helpers // ============================ -func encodeLocalTaskID(name string) string { - return base64.RawURLEncoding.EncodeToString([]byte(name)) -} - -func decodeLocalTaskID(local string) (string, error) { - b, err := base64.RawURLEncoding.DecodeString(local) - if err != nil { - return "", err - } - return string(b), nil -} - var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`) func extractModelFromOperationName(name string) string { diff --git a/relay/channel/task/hailuo/adaptor.go b/relay/channel/task/hailuo/adaptor.go index c77905bf..67a68a10 100644 --- a/relay/channel/task/hailuo/adaptor.go +++ b/relay/channel/task/hailuo/adaptor.go @@ -2,7 +2,6 @@ package hailuo import ( "bytes" - "encoding/json" "fmt" "io" "net/http" @@ -65,7 +64,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn return nil, errors.Wrap(err, "convert request payload failed") } - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -86,7 +85,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela _ = resp.Body.Close() var hResp VideoResponse - if err := json.Unmarshal(responseBody, &hResp); err != nil { + if err := common.Unmarshal(responseBody, &hResp); err != nil { taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) return } @@ -101,8 +100,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } ov := dto.NewOpenAIVideo() - ov.ID = hResp.TaskID - ov.TaskID = hResp.TaskID + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName @@ -182,7 +181,7 @@ func (a *TaskAdaptor) parseResolutionFromSize(size string, modelConfig ModelConf func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { resTask := QueryTaskResponse{} - if err := json.Unmarshal(respBody, &resTask); err != nil { + if err := common.Unmarshal(respBody, &resTask); err != nil { return nil, errors.Wrap(err, "unmarshal task result failed") } @@ -224,7 +223,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) { var hailuoResp QueryTaskResponse - if err := json.Unmarshal(originTask.Data, &hailuoResp); err != nil { + if err := common.Unmarshal(originTask.Data, &hailuoResp); err != nil { return nil, errors.Wrap(err, "unmarshal hailuo task data failed") } @@ -271,7 +270,7 @@ func (a *TaskAdaptor) buildVideoURL(_, fileID string) string { } var retrieveResp RetrieveFileResponse - if err := json.Unmarshal(responseBody, &retrieveResp); err != nil { + if err := common.Unmarshal(responseBody, &retrieveResp); err != nil { return "" } diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index 1522a967..7f88be24 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -6,7 +6,6 @@ import ( "crypto/sha256" "encoding/base64" "encoding/hex" - "encoding/json" "fmt" "io" "net/http" @@ -25,6 +24,7 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" ) @@ -168,7 +168,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn if err != nil { return nil, errors.Wrap(err, "convert request payload failed") } - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -191,7 +191,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela // Parse Jimeng response var jResp responsePayload - if err := json.Unmarshal(responseBody, &jResp); err != nil { + if err := common.Unmarshal(responseBody, &jResp); err != nil { taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) return } @@ -202,8 +202,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } ov := dto.NewOpenAIVideo() - ov.ID = jResp.Data.TaskID - ov.TaskID = jResp.Data.TaskID + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName c.JSON(http.StatusOK, ov) @@ -225,7 +225,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy "req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774 "task_id": taskID, } - payloadBytes, err := json.Marshal(payload) + payloadBytes, err := common.Marshal(payload) if err != nil { return nil, errors.Wrap(err, "marshal fetch task payload failed") } @@ -398,13 +398,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* r.BinaryDataBase64 = req.Images } } - metadata := req.Metadata - medaBytes, err := json.Marshal(metadata) - if err != nil { - return nil, errors.Wrap(err, "metadata marshal metadata failed") - } - err = json.Unmarshal(medaBytes, &r) - if err != nil { + if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } @@ -432,7 +426,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { resTask := responseTask{} - if err := json.Unmarshal(respBody, &resTask); err != nil { + if err := common.Unmarshal(respBody, &resTask); err != nil { return nil, errors.Wrap(err, "unmarshal task result failed") } taskResult := relaycommon.TaskInfo{} @@ -458,7 +452,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) { var jimengResp responseTask - if err := json.Unmarshal(originTask.Data, &jimengResp); err != nil { + if err := common.Unmarshal(originTask.Data, &jimengResp); err != nil { return nil, errors.Wrap(err, "unmarshal jimeng task data failed") } @@ -477,8 +471,7 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro } } - jsonData, _ := common.Marshal(openAIVideo) - return jsonData, nil + return common.Marshal(openAIVideo) } func isNewAPIRelay(apiKey string) bool { diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index 5fb85348..4458626b 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -2,7 +2,6 @@ package kling import ( "bytes" - "encoding/json" "fmt" "io" "net/http" @@ -21,6 +20,7 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" ) @@ -156,7 +156,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn if body.Image == "" && body.ImageTail == "" { c.Set("action", constant.TaskActionTextGenerate) } - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -180,7 +180,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } var kResp responsePayload - err = json.Unmarshal(responseBody, &kResp) + err = common.Unmarshal(responseBody, &kResp) if err != nil { taskErr = service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError) return @@ -190,8 +190,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela return } ov := dto.NewOpenAIVideo() - ov.ID = kResp.Data.TaskId - ov.TaskID = kResp.Data.TaskId + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName c.JSON(http.StatusOK, ov) @@ -251,8 +251,8 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* r := requestPayload{ Prompt: req.Prompt, Image: req.Image, - Mode: defaultString(req.Mode, "std"), - Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)), + Mode: taskcommon.DefaultString(req.Mode, "std"), + Duration: fmt.Sprintf("%d", taskcommon.DefaultInt(req.Duration, 5)), AspectRatio: a.getAspectRatio(req.Size), ModelName: req.Model, Model: req.Model, // Keep consistent with model_name, double writing improves compatibility @@ -266,13 +266,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* if r.ModelName == "" { r.ModelName = "kling-v1" } - metadata := req.Metadata - medaBytes, err := json.Marshal(metadata) - if err != nil { - return nil, errors.Wrap(err, "metadata marshal metadata failed") - } - err = json.Unmarshal(medaBytes, &r) - if err != nil { + if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } return &r, nil @@ -291,20 +285,6 @@ func (a *TaskAdaptor) getAspectRatio(size string) string { } } -func defaultString(s, def string) string { - if strings.TrimSpace(s) == "" { - return def - } - return s -} - -func defaultInt(v int, def int) int { - if v == 0 { - return def - } - return v -} - // ============================ // JWT helpers // ============================ @@ -340,7 +320,7 @@ func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) { func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { taskInfo := &relaycommon.TaskInfo{} resPayload := responsePayload{} - err := json.Unmarshal(respBody, &resPayload) + err := common.Unmarshal(respBody, &resPayload) if err != nil { return nil, errors.Wrap(err, "failed to unmarshal response body") } @@ -374,7 +354,7 @@ func isNewAPIRelay(apiKey string) bool { func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) { var klingResp responsePayload - if err := json.Unmarshal(originTask.Data, &klingResp); err != nil { + if err := common.Unmarshal(originTask.Data, &klingResp); err != nil { return nil, errors.Wrap(err, "unmarshal kling task data failed") } @@ -401,6 +381,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro Code: fmt.Sprintf("%d", klingResp.Code), } } - jsonData, _ := common.Marshal(openAIVideo) - return jsonData, nil + return common.Marshal(openAIVideo) } diff --git a/relay/channel/task/sora/adaptor.go b/relay/channel/task/sora/adaptor.go index c149f966..ee69a3e4 100644 --- a/relay/channel/task/sora/adaptor.go +++ b/relay/channel/task/sora/adaptor.go @@ -13,7 +13,6 @@ import ( "github.com/QuantumNous/new-api/relay/channel" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" - "github.com/QuantumNous/new-api/setting/system_setting" "github.com/gin-gonic/gin" "github.com/pkg/errors" @@ -116,7 +115,7 @@ func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, req } // DoResponse handles upstream response, returns taskID etc. -func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) @@ -131,17 +130,20 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco return } - if dResp.ID == "" { - if dResp.TaskID == "" { - taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError) - return - } - dResp.ID = dResp.TaskID - dResp.TaskID = "" + upstreamID := dResp.ID + if upstreamID == "" { + upstreamID = dResp.TaskID + } + if upstreamID == "" { + taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError) + return } + // 使用公开 task_xxxx ID 返回给客户端 + dResp.ID = info.PublicTaskID + dResp.TaskID = info.PublicTaskID c.JSON(http.StatusOK, dResp) - return dResp.ID, responseBody, nil + return upstreamID, responseBody, nil } // FetchTask fetch task status @@ -192,7 +194,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e taskResult.Status = model.TaskStatusInProgress case "completed": taskResult.Status = model.TaskStatusSuccess - taskResult.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, resTask.ID) + // Url intentionally left empty — the caller constructs the proxy URL using the public task ID case "failed", "cancelled": taskResult.Status = model.TaskStatusFailure if resTask.Error != nil { diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go index 8ea9a1c7..5dd62a70 100644 --- a/relay/channel/task/suno/adaptor.go +++ b/relay/channel/task/suno/adaptor.go @@ -3,7 +3,6 @@ package suno import ( "bytes" "context" - "encoding/json" "fmt" "io" "net/http" @@ -24,8 +23,12 @@ type TaskAdaptor struct { ChannelType int } +// ParseTaskResult is not used for Suno tasks. +// Suno polling uses a dedicated batch-fetch path (service.UpdateSunoTasks) that +// receives dto.TaskResponse[[]dto.SunoDataResponse] from the upstream /fetch API. +// This differs from the per-task polling used by video adaptors. func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { - return nil, fmt.Errorf("not implement") // todo implement this method if needed + return nil, fmt.Errorf("suno uses batch polling via UpdateSunoTasks, ParseTaskResult is not applicable") } func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { @@ -81,7 +84,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn return nil, err } } - data, err := json.Marshal(sunoRequest) + data, err := common.Marshal(sunoRequest) if err != nil { return nil, err } @@ -99,7 +102,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela return } var sunoResponse dto.TaskResponse[string] - err = json.Unmarshal(responseBody, &sunoResponse) + err = common.Unmarshal(responseBody, &sunoResponse) if err != nil { taskErr = service.TaskErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) return @@ -109,17 +112,13 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela return } - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - - _, err = io.Copy(c.Writer, bytes.NewBuffer(responseBody)) - if err != nil { - taskErr = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) - return + // 使用公开 task_xxxx ID 替换上游 ID 返回给客户端 + publicResponse := dto.TaskResponse[string]{ + Code: sunoResponse.Code, + Message: sunoResponse.Message, + Data: info.PublicTaskID, } + c.JSON(http.StatusOK, publicResponse) return sunoResponse.Data, nil, nil } @@ -134,7 +133,7 @@ func (a *TaskAdaptor) GetChannelName() string { func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) { requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl) - byteBody, err := json.Marshal(body) + byteBody, err := common.Marshal(body) if err != nil { return nil, err } diff --git a/relay/channel/task/taskcommon/helpers.go b/relay/channel/task/taskcommon/helpers.go new file mode 100644 index 00000000..b1dde998 --- /dev/null +++ b/relay/channel/task/taskcommon/helpers.go @@ -0,0 +1,70 @@ +package taskcommon + +import ( + "encoding/base64" + "fmt" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/setting/system_setting" +) + +// UnmarshalMetadata converts a map[string]any metadata to a typed struct via JSON round-trip. +// This replaces the repeated pattern: json.Marshal(metadata) → json.Unmarshal(bytes, &target). +func UnmarshalMetadata(metadata map[string]any, target any) error { + if metadata == nil { + return nil + } + metaBytes, err := common.Marshal(metadata) + if err != nil { + return fmt.Errorf("marshal metadata failed: %w", err) + } + if err := common.Unmarshal(metaBytes, target); err != nil { + return fmt.Errorf("unmarshal metadata failed: %w", err) + } + return nil +} + +// DefaultString returns val if non-empty, otherwise fallback. +func DefaultString(val, fallback string) string { + if val == "" { + return fallback + } + return val +} + +// DefaultInt returns val if non-zero, otherwise fallback. +func DefaultInt(val, fallback int) int { + if val == 0 { + return fallback + } + return val +} + +// EncodeLocalTaskID encodes an upstream operation name to a URL-safe base64 string. +// Used by Gemini/Vertex to store upstream names as task IDs. +func EncodeLocalTaskID(name string) string { + return base64.RawURLEncoding.EncodeToString([]byte(name)) +} + +// DecodeLocalTaskID decodes a base64-encoded upstream operation name. +func DecodeLocalTaskID(id string) (string, error) { + b, err := base64.RawURLEncoding.DecodeString(id) + if err != nil { + return "", err + } + return string(b), nil +} + +// BuildProxyURL constructs the video proxy URL using the public task ID. +// e.g., "https://your-server.com/v1/videos/task_xxxx/content" +func BuildProxyURL(taskID string) string { + return fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID) +} + +// Status-to-progress mapping constants for polling updates. +const ( + ProgressSubmitted = "10%" + ProgressQueued = "20%" + ProgressInProgress = "30%" + ProgressComplete = "100%" +) diff --git a/relay/channel/task/vertex/adaptor.go b/relay/channel/task/vertex/adaptor.go index 8ec77266..fb3a313f 100644 --- a/relay/channel/task/vertex/adaptor.go +++ b/relay/channel/task/vertex/adaptor.go @@ -2,13 +2,12 @@ package vertex import ( "bytes" - "encoding/base64" - "encoding/json" "fmt" "io" "net/http" "regexp" "strings" + "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/model" @@ -17,6 +16,7 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" vertexcore "github.com/QuantumNous/new-api/relay/channel/vertex" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" @@ -82,7 +82,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom // BuildRequestURL constructs the upstream URL. func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { adc := &vertexcore.Credentials{} - if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil { + if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil { return "", fmt.Errorf("failed to decode credentials: %w", err) } modelName := info.OriginModelName @@ -116,7 +116,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info req.Header.Set("Accept", "application/json") adc := &vertexcore.Credentials{} - if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil { + if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil { return fmt.Errorf("failed to decode credentials: %w", err) } @@ -184,7 +184,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn // info.PriceData.OtherRatios["durationSeconds"] = float64(v.(int)) // } - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -205,14 +205,19 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela _ = resp.Body.Close() var s submitResponse - if err := json.Unmarshal(responseBody, &s); err != nil { + if err := common.Unmarshal(responseBody, &s); err != nil { return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError) } if strings.TrimSpace(s.Name) == "" { return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError) } - localID := encodeLocalTaskID(s.Name) - c.JSON(http.StatusOK, gin.H{"task_id": localID}) + localID := taskcommon.EncodeLocalTaskID(s.Name) + ov := dto.NewOpenAIVideo() + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID + ov.CreatedAt = time.Now().Unix() + ov.Model = info.OriginModelName + c.JSON(http.StatusOK, ov) return localID, responseBody, nil } @@ -225,7 +230,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy if !ok { return nil, fmt.Errorf("invalid task_id") } - upstreamName, err := decodeLocalTaskID(taskID) + upstreamName, err := taskcommon.DecodeLocalTaskID(taskID) if err != nil { return nil, fmt.Errorf("decode task_id failed: %w", err) } @@ -245,12 +250,12 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName) } payload := map[string]string{"operationName": upstreamName} - data, err := json.Marshal(payload) + data, err := common.Marshal(payload) if err != nil { return nil, err } adc := &vertexcore.Credentials{} - if err := json.Unmarshal([]byte(key), adc); err != nil { + if err := common.Unmarshal([]byte(key), adc); err != nil { return nil, fmt.Errorf("failed to decode credentials: %w", err) } token, err := vertexcore.AcquireAccessToken(*adc, proxy) @@ -274,7 +279,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { var op operationResponse - if err := json.Unmarshal(respBody, &op); err != nil { + if err := common.Unmarshal(respBody, &op); err != nil { return nil, fmt.Errorf("unmarshal operation response failed: %w", err) } ti := &relaycommon.TaskInfo{} @@ -338,7 +343,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e } func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { - upstreamName, err := decodeLocalTaskID(task.TaskID) + // Use GetUpstreamTaskID() to get the real upstream operation name for model extraction. + // task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name. + upstreamTaskID := task.GetUpstreamTaskID() + upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID) if err != nil { upstreamName = "" } @@ -353,8 +361,8 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { v.SetProgressStr(task.Progress) v.CreatedAt = task.CreatedAt v.CompletedAt = task.UpdatedAt - if strings.HasPrefix(task.FailReason, "data:") && len(task.FailReason) > 0 { - v.SetMetadata("url", task.FailReason) + if resultURL := task.GetResultURL(); strings.HasPrefix(resultURL, "data:") && len(resultURL) > 0 { + v.SetMetadata("url", resultURL) } return common.Marshal(v) @@ -364,18 +372,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { // helpers // ============================ -func encodeLocalTaskID(name string) string { - return base64.RawURLEncoding.EncodeToString([]byte(name)) -} - -func decodeLocalTaskID(local string) (string, error) { - b, err := base64.RawURLEncoding.DecodeString(local) - if err != nil { - return "", err - } - return string(b), nil -} - var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`) func extractRegionFromOperationName(name string) string { diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index 3657161c..1bab12f0 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -2,7 +2,6 @@ package vidu import ( "bytes" - "encoding/json" "fmt" "io" "net/http" @@ -16,6 +15,7 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" @@ -127,7 +127,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn } } - data, err := json.Marshal(body) + data, err := common.Marshal(body) if err != nil { return nil, err } @@ -168,7 +168,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } var vResp responsePayload - err = json.Unmarshal(responseBody, &vResp) + err = common.Unmarshal(responseBody, &vResp) if err != nil { taskErr = service.TaskErrorWrapper(errors.Wrap(err, fmt.Sprintf("%s", responseBody)), "unmarshal_response_failed", http.StatusInternalServerError) return @@ -180,8 +180,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela } ov := dto.NewOpenAIVideo() - ov.ID = vResp.TaskId - ov.TaskID = vResp.TaskId + ov.ID = info.PublicTaskID + ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName c.JSON(http.StatusOK, ov) @@ -225,45 +225,25 @@ func (a *TaskAdaptor) GetChannelName() string { func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { r := requestPayload{ - Model: defaultString(req.Model, "viduq1"), + Model: taskcommon.DefaultString(req.Model, "viduq1"), Images: req.Images, Prompt: req.Prompt, - Duration: defaultInt(req.Duration, 5), - Resolution: defaultString(req.Size, "1080p"), + Duration: taskcommon.DefaultInt(req.Duration, 5), + Resolution: taskcommon.DefaultString(req.Size, "1080p"), MovementAmplitude: "auto", Bgm: false, } - metadata := req.Metadata - medaBytes, err := json.Marshal(metadata) - if err != nil { - return nil, errors.Wrap(err, "metadata marshal metadata failed") - } - err = json.Unmarshal(medaBytes, &r) - if err != nil { + if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } return &r, nil } -func defaultString(value, defaultValue string) string { - if value == "" { - return defaultValue - } - return value -} - -func defaultInt(value, defaultValue int) int { - if value == 0 { - return defaultValue - } - return value -} - func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { taskInfo := &relaycommon.TaskInfo{} var taskResp taskResultResponse - err := json.Unmarshal(respBody, &taskResp) + err := common.Unmarshal(respBody, &taskResp) if err != nil { return nil, errors.Wrap(err, "failed to unmarshal response body") } @@ -293,7 +273,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) { var viduResp taskResultResponse - if err := json.Unmarshal(originTask.Data, &viduResp); err != nil { + if err := common.Unmarshal(originTask.Data, &viduResp); err != nil { return nil, errors.Wrap(err, "unmarshal vidu task data failed") } @@ -315,6 +295,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro } } - jsonData, _ := common.Marshal(openAIVideo) - return jsonData, nil + return common.Marshal(openAIVideo) } diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 81b7d21d..b6882681 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -118,8 +118,12 @@ type RelayInfo struct { SendResponseCount int ReceivedResponseCount int FinalPreConsumedQuota int // 最终预消耗的配额 + // ForcePreConsume 为 true 时禁用 BillingSession 的信任额度旁路, + // 强制预扣全额。用于异步任务(视频/音乐生成等),因为请求返回后任务仍在运行, + // 必须在提交前锁定全额。 + ForcePreConsume bool // Billing 是计费会话,封装了预扣费/结算/退款的统一生命周期。 - // 免费模型和按次计费(MJ/Task)时为 nil。 + // 免费模型时为 nil。 Billing BillingSettler // BillingSource indicates whether this request is billed from wallet quota or subscription. // "" or "wallet" => wallet; "subscription" => subscription @@ -525,8 +529,10 @@ func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Req return nil, errors.New("request is not a OpenAIResponsesCompactionRequest") case types.RelayFormatTask: info = genBaseRelayInfo(c, nil) + info.TaskRelayInfo = &TaskRelayInfo{} case types.RelayFormatMjProxy: info = genBaseRelayInfo(c, nil) + info.TaskRelayInfo = &TaskRelayInfo{} default: err = errors.New("invalid relay format") } @@ -608,6 +614,9 @@ func (info *RelayInfo) HasSendResponse() bool { type TaskRelayInfo struct { Action string OriginTaskID string + // PublicTaskID 是提交时预生成的 task_xxxx 格式公开 ID, + // 供 DoResponse 在返回给客户端时使用(避免暴露上游真实 ID)。 + PublicTaskID string ConsumeQuota bool } @@ -667,11 +676,11 @@ func (t *TaskSubmitReq) UnmarshalJSON(data []byte) error { func (t *TaskSubmitReq) UnmarshalMetadata(v any) error { metadata := t.Metadata if metadata != nil { - metadataBytes, err := json.Marshal(metadata) + metadataBytes, err := common.Marshal(metadata) if err != nil { return fmt.Errorf("marshal metadata failed: %w", err) } - err = json.Unmarshal(metadataBytes, v) + err = common.Unmarshal(metadataBytes, v) if err != nil { return fmt.Errorf("unmarshal metadata to target failed: %w", err) } diff --git a/relay/helper/price.go b/relay/helper/price.go index c310220f..1cb04166 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -140,7 +140,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens } // ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task) -func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData { +func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PriceData { groupRatioInfo := HandleGroupRatio(c, info) modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true) @@ -154,7 +154,18 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types. } } quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio) - priceData := types.PerCallPriceData{ + + // 免费模型检测(与 ModelPriceHelper 对齐) + freeModel := false + if !operation_setting.GetQuotaSetting().EnableFreeModelPreConsume { + if groupRatioInfo.GroupRatio == 0 || modelPrice == 0 { + quota = 0 + freeModel = true + } + } + + priceData := types.PriceData{ + FreeModel: freeModel, ModelPrice: modelPrice, Quota: quota, GroupRatioInfo: groupRatioInfo, diff --git a/relay/relay_task.go b/relay/relay_task.go index ebbd1f65..d372ca2e 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -2,7 +2,6 @@ package relay import ( "bytes" - "encoding/json" "errors" "fmt" "io" @@ -15,29 +14,33 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" relayconstant "github.com/QuantumNous/new-api/relay/constant" + "github.com/QuantumNous/new-api/relay/helper" "github.com/QuantumNous/new-api/service" - "github.com/QuantumNous/new-api/setting/ratio_setting" - "github.com/gin-gonic/gin" ) -/* -Task 任务通过平台、Action 区分任务 -*/ -func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { - info.InitChannelMeta(c) - // ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields - if info.TaskRelayInfo == nil { - info.TaskRelayInfo = &relaycommon.TaskRelayInfo{} - } +type TaskSubmitResult struct { + UpstreamTaskID string + TaskData []byte + Platform constant.TaskPlatform + ModelName string + Quota int + //PerCallPrice types.PriceData +} + +// ResolveOriginTask 处理基于已有任务的提交(remix / continuation): +// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道(并通过 +// specific_channel_id 禁止重试),以及提取 OtherRatios(时长、分辨率)。 +// 该函数在控制器的重试循环之前调用一次,其结果通过 info 字段和上下文持久化。 +func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError { + // 检测 remix action path := c.Request.URL.Path if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") { info.Action = constant.TaskActionRemix } - - // 提取 remix 任务的 video_id if info.Action == constant.TaskActionRemix { videoID := c.Param("video_id") if strings.TrimSpace(videoID) == "" { @@ -46,241 +49,164 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto. info.OriginTaskID = videoID } - platform := constant.TaskPlatform(c.GetString("platform")) + if info.OriginTaskID == "" { + return nil + } - // 获取原始任务信息 - if info.OriginTaskID != "" { - originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID) - if err != nil { - taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError) - return - } - if !exist { - taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest) - return - } - if info.OriginModelName == "" { - if originTask.Properties.OriginModelName != "" { - info.OriginModelName = originTask.Properties.OriginModelName - } else if originTask.Properties.UpstreamModelName != "" { - info.OriginModelName = originTask.Properties.UpstreamModelName - } else { - var taskData map[string]interface{} - _ = json.Unmarshal(originTask.Data, &taskData) - if m, ok := taskData["model"].(string); ok && m != "" { - info.OriginModelName = m - platform = originTask.Platform - } - } - } - if originTask.ChannelId != info.ChannelId { - channel, err := model.GetChannelById(originTask.ChannelId, true) - if err != nil { - taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) - return - } - if channel.Status != common.ChannelStatusEnabled { - taskErr = service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest) - return - } - key, _, newAPIError := channel.GetNextEnabledKey() - if newAPIError != nil { - taskErr = service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode) - return - } - common.SetContextKey(c, constant.ContextKeyChannelKey, key) - common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type) - common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL()) - common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId) + // 查找原始任务 + originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID) + if err != nil { + return service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError) + } + if !exist { + return service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest) + } - info.ChannelBaseUrl = channel.GetBaseURL() - info.ChannelId = originTask.ChannelId - info.ChannelType = channel.Type - info.ApiKey = key - platform = originTask.Platform - } - - // 使用原始任务的参数 - if info.Action == constant.TaskActionRemix { + // 从原始任务推导模型名称 + if info.OriginModelName == "" { + if originTask.Properties.OriginModelName != "" { + info.OriginModelName = originTask.Properties.OriginModelName + } else if originTask.Properties.UpstreamModelName != "" { + info.OriginModelName = originTask.Properties.UpstreamModelName + } else { var taskData map[string]interface{} - _ = json.Unmarshal(originTask.Data, &taskData) - secondsStr, _ := taskData["seconds"].(string) - seconds, _ := strconv.Atoi(secondsStr) - if seconds <= 0 { - seconds = 4 - } - sizeStr, _ := taskData["size"].(string) - if info.PriceData.OtherRatios == nil { - info.PriceData.OtherRatios = map[string]float64{} - } - info.PriceData.OtherRatios["seconds"] = float64(seconds) - info.PriceData.OtherRatios["size"] = 1 - if sizeStr == "1792x1024" || sizeStr == "1024x1792" { - info.PriceData.OtherRatios["size"] = 1.666667 + _ = common.Unmarshal(originTask.Data, &taskData) + if m, ok := taskData["model"].(string); ok && m != "" { + info.OriginModelName = m } } } + + // 锁定到原始任务的渠道(如果与当前选中的不同) + if originTask.ChannelId != info.ChannelId { + ch, err := model.GetChannelById(originTask.ChannelId, true) + if err != nil { + return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) + } + if ch.Status != common.ChannelStatusEnabled { + return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest) + } + key, _, newAPIError := ch.GetNextEnabledKey() + if newAPIError != nil { + return service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode) + } + common.SetContextKey(c, constant.ContextKeyChannelKey, key) + common.SetContextKey(c, constant.ContextKeyChannelType, ch.Type) + common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, ch.GetBaseURL()) + common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId) + + info.ChannelBaseUrl = ch.GetBaseURL() + info.ChannelId = originTask.ChannelId + info.ChannelType = ch.Type + info.ApiKey = key + } + + // 渠道已锁定到原始任务 → 禁止重试切换到其他渠道 + c.Set("specific_channel_id", fmt.Sprintf("%d", originTask.ChannelId)) + + // 提取 remix 参数(时长、分辨率 → OtherRatios) + if info.Action == constant.TaskActionRemix { + var taskData map[string]interface{} + _ = common.Unmarshal(originTask.Data, &taskData) + secondsStr, _ := taskData["seconds"].(string) + seconds, _ := strconv.Atoi(secondsStr) + if seconds <= 0 { + seconds = 4 + } + sizeStr, _ := taskData["size"].(string) + if info.PriceData.OtherRatios == nil { + info.PriceData.OtherRatios = map[string]float64{} + } + info.PriceData.OtherRatios["seconds"] = float64(seconds) + info.PriceData.OtherRatios["size"] = 1 + if sizeStr == "1792x1024" || sizeStr == "1024x1792" { + info.PriceData.OtherRatios["size"] = 1.666667 + } + } + + return nil +} + +// RelayTaskSubmit 完成 task 提交的全部流程(每次尝试调用一次): +// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 → 计算价格 → +// 预扣费(仅首次,通过 info.Billing==nil 守卫)→ 构建/发送/解析上游请求。 +// 控制器负责 defer Refund 和成功后 Settle。 +func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) { + info.InitChannelMeta(c) + + // 1. 确定 platform → 创建适配器 → 验证请求 + platform := constant.TaskPlatform(c.GetString("platform")) if platform == "" { platform = GetTaskPlatform(c) } - - info.InitChannelMeta(c) adaptor := GetTaskAdaptor(platform) if adaptor == nil { - return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest) + return nil, service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest) } adaptor.Init(info) - // get & validate taskRequest 获取并验证文本请求 - taskErr = adaptor.ValidateRequestAndSetAction(c, info) - if taskErr != nil { - return + if taskErr := adaptor.ValidateRequestAndSetAction(c, info); taskErr != nil { + return nil, taskErr } + // 2. 确定模型名称 modelName := info.OriginModelName if modelName == "" { modelName = service.CoverTaskActionToModelName(platform, info.Action) } - modelPrice, success := ratio_setting.GetModelPrice(modelName, true) - if !success { - defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[modelName] - if !ok { - modelPrice = float64(common.PreConsumedQuota) / common.QuotaPerUnit - } else { - modelPrice = defaultPrice - } + + // 3. 预生成公开 task ID(仅首次) + if info.PublicTaskID == "" { + info.PublicTaskID = model.GenerateTaskID() } - // 处理 auto 分组:从 context 获取实际选中的分组 - // 当使用 auto 分组时,Distribute 中间件会将实际选中的分组存储在 ContextKeyAutoGroup 中 - if autoGroup, exists := common.GetContextKey(c, constant.ContextKeyAutoGroup); exists { - if groupStr, ok := autoGroup.(string); ok && groupStr != "" { - info.UsingGroup = groupStr - } - } + // 4. 价格计算 + info.OriginModelName = modelName + info.PriceData = helper.ModelPriceHelperPerCall(c, info) - // 预扣 - groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup) - var ratio float64 - userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup) - if hasUserGroupRatio { - ratio = modelPrice * userGroupRatio - } else { - ratio = modelPrice * groupRatio - } - // FIXME: 临时修补,支持任务仅按次计费 if !common.StringsContains(constant.TaskPricePatches, modelName) { - if len(info.PriceData.OtherRatios) > 0 { - for _, ra := range info.PriceData.OtherRatios { - if 1.0 != ra { - ratio *= ra - } + for _, ra := range info.PriceData.OtherRatios { + if ra != 1.0 { + info.PriceData.Quota = int(float64(info.PriceData.Quota) * ra) } } } - println(fmt.Sprintf("model: %s, model_price: %.4f, group: %s, group_ratio: %.4f, final_ratio: %.4f", modelName, modelPrice, info.UsingGroup, groupRatio, ratio)) - userQuota, err := model.GetUserQuota(info.UserId, false) - if err != nil { - taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) - return - } - quota := int(ratio * common.QuotaPerUnit) - if userQuota-quota < 0 { - taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden) - return + + // 5. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过) + if info.Billing == nil && !info.PriceData.FreeModel { + info.ForcePreConsume = true + if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil { + return nil, service.TaskErrorFromAPIError(apiErr) + } } - // build body + // 6. 构建请求体 requestBody, err := adaptor.BuildRequestBody(c, info) if err != nil { - taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError) - return + return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError) } - // do request + + // 7. 发送请求 resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { - taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) - return + return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - // handle response if resp != nil && resp.StatusCode != http.StatusOK { responseBody, _ := io.ReadAll(resp.Body) - taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode) - return + return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode) } - defer func() { - // release quota - if info.ConsumeQuota && taskErr == nil { - - err := service.PostConsumeQuota(info, quota, 0, true) - if err != nil { - common.SysLog("error consuming token remain quota: " + err.Error()) - } - if quota != 0 { - tokenName := c.GetString("token_name") - //gRatio := groupRatio - //if hasUserGroupRatio { - // gRatio = userGroupRatio - //} - logContent := fmt.Sprintf("操作 %s", info.Action) - // FIXME: 临时修补,支持任务仅按次计费 - if common.StringsContains(constant.TaskPricePatches, modelName) { - logContent = fmt.Sprintf("%s,按次计费", logContent) - } else { - if len(info.PriceData.OtherRatios) > 0 { - var contents []string - for key, ra := range info.PriceData.OtherRatios { - if 1.0 != ra { - contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra)) - } - } - if len(contents) > 0 { - logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", ")) - } - } - } - other := make(map[string]interface{}) - if c != nil && c.Request != nil && c.Request.URL != nil { - other["request_path"] = c.Request.URL.Path - } - other["model_price"] = modelPrice - other["group_ratio"] = groupRatio - if hasUserGroupRatio { - other["user_group_ratio"] = userGroupRatio - } - model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{ - ChannelId: info.ChannelId, - ModelName: modelName, - TokenName: tokenName, - Quota: quota, - Content: logContent, - TokenId: info.TokenId, - Group: info.UsingGroup, - Other: other, - }) - model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota) - model.UpdateChannelUsedQuota(info.ChannelId, quota) - } - } - }() - - taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info) + // 8. 解析响应 + upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info) if taskErr != nil { - return + return nil, taskErr } - info.ConsumeQuota = true - // insert task - task := model.InitTask(platform, info) - task.TaskID = taskID - task.Quota = quota - task.Data = taskData - task.Action = info.Action - err = task.Insert() - if err != nil { - taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError) - return - } - return nil + + return &TaskSubmitResult{ + UpstreamTaskID: upstreamTaskID, + TaskData: taskData, + Platform: platform, + ModelName: modelName, + }, nil } var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){ @@ -336,7 +262,7 @@ func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.Ta } else { tasks = make([]any, 0) } - respBody, err = json.Marshal(dto.TaskResponse[[]any]{ + respBody, err = common.Marshal(dto.TaskResponse[[]any]{ Code: "success", Data: tasks, }) @@ -357,7 +283,7 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt return } - respBody, err = json.Marshal(dto.TaskResponse[any]{ + respBody, err = common.Marshal(dto.TaskResponse[any]{ Code: "success", Data: TaskModel2Dto(originTask), }) @@ -381,97 +307,16 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d return } - func() { - channelModel, err2 := model.GetChannelById(originTask.ChannelId, true) - if err2 != nil { - return - } - if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini { - return - } - baseURL := constant.ChannelBaseURLs[channelModel.Type] - if channelModel.GetBaseURL() != "" { - baseURL = channelModel.GetBaseURL() - } - proxy := channelModel.GetSetting().Proxy - adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type))) - if adaptor == nil { - return - } - resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{ - "task_id": originTask.TaskID, - "action": originTask.Action, - }, proxy) - if err2 != nil || resp == nil { - return - } - defer resp.Body.Close() - body, err2 := io.ReadAll(resp.Body) - if err2 != nil { - return - } - ti, err2 := adaptor.ParseTaskResult(body) - if err2 == nil && ti != nil { - if ti.Status != "" { - originTask.Status = model.TaskStatus(ti.Status) - } - if ti.Progress != "" { - originTask.Progress = ti.Progress - } - if ti.Url != "" { - if strings.HasPrefix(ti.Url, "data:") { - } else { - originTask.FailReason = ti.Url - } - } - _ = originTask.Update() - var raw map[string]any - _ = json.Unmarshal(body, &raw) - format := "mp4" - if respObj, ok := raw["response"].(map[string]any); ok { - if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 { - if v0, ok := vids[0].(map[string]any); ok { - if mt, ok := v0["mimeType"].(string); ok && mt != "" { - if strings.Contains(mt, "mp4") { - format = "mp4" - } else { - format = mt - } - } - } - } - } - status := "processing" - switch originTask.Status { - case model.TaskStatusSuccess: - status = "succeeded" - case model.TaskStatusFailure: - status = "failed" - case model.TaskStatusQueued, model.TaskStatusSubmitted: - status = "queued" - } - if !strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") { - out := map[string]any{ - "error": nil, - "format": format, - "metadata": nil, - "status": status, - "task_id": originTask.TaskID, - "url": originTask.FailReason, - } - respBody, _ = json.Marshal(dto.TaskResponse[any]{ - Code: "success", - Data: out, - }) - } - } - }() + isOpenAIVideoAPI := strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") - if len(respBody) != 0 { + // Gemini/Vertex 支持实时查询:用户 fetch 时直接从上游拉取最新状态 + if realtimeResp := tryRealtimeFetch(originTask, isOpenAIVideoAPI); len(realtimeResp) > 0 { + respBody = realtimeResp return } - if strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") { + // OpenAI Video API 格式: 走各 adaptor 的 ConvertToOpenAIVideo + if isOpenAIVideoAPI { adaptor := GetTaskAdaptor(originTask.Platform) if adaptor == nil { taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest) @@ -486,10 +331,12 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d respBody = openAIVideoData return } - taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented) + taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("not_implemented:%s", originTask.Platform), "not_implemented", http.StatusNotImplemented) return } - respBody, err = json.Marshal(dto.TaskResponse[any]{ + + // 通用 TaskDto 格式 + respBody, err = common.Marshal(dto.TaskResponse[any]{ Code: "success", Data: TaskModel2Dto(originTask), }) @@ -499,16 +346,145 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d return } +// tryRealtimeFetch 尝试从上游实时拉取 Gemini/Vertex 任务状态。 +// 仅当渠道类型为 Gemini 或 Vertex 时触发;其他渠道或出错时返回 nil。 +// 当非 OpenAI Video API 时,还会构建自定义格式的响应体。 +func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte { + channelModel, err := model.GetChannelById(task.ChannelId, true) + if err != nil { + return nil + } + if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini { + return nil + } + + baseURL := constant.ChannelBaseURLs[channelModel.Type] + if channelModel.GetBaseURL() != "" { + baseURL = channelModel.GetBaseURL() + } + proxy := channelModel.GetSetting().Proxy + adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type))) + if adaptor == nil { + return nil + } + + resp, err := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{ + "task_id": task.GetUpstreamTaskID(), + "action": task.Action, + }, proxy) + if err != nil || resp == nil { + return nil + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil + } + + ti, err := adaptor.ParseTaskResult(body) + if err != nil || ti == nil { + return nil + } + + // 将上游最新状态更新到 task + if ti.Status != "" { + task.Status = model.TaskStatus(ti.Status) + } + if ti.Progress != "" { + task.Progress = ti.Progress + } + if strings.HasPrefix(ti.Url, "data:") { + // data: URI — kept in Data, not ResultURL + } else if ti.Url != "" { + task.PrivateData.ResultURL = ti.Url + } else if task.Status == model.TaskStatusSuccess { + // No URL from adaptor — construct proxy URL using public task ID + task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID) + } + _ = task.Update() + + // OpenAI Video API 由调用者的 ConvertToOpenAIVideo 分支处理 + if isOpenAIVideoAPI { + return nil + } + + // 非 OpenAI Video API: 构建自定义格式响应 + format := detectVideoFormat(body) + out := map[string]any{ + "error": nil, + "format": format, + "metadata": nil, + "status": mapTaskStatusToSimple(task.Status), + "task_id": task.TaskID, + "url": task.GetResultURL(), + } + respBody, _ := common.Marshal(dto.TaskResponse[any]{ + Code: "success", + Data: out, + }) + return respBody +} + +// detectVideoFormat 从 Gemini/Vertex 原始响应中探测视频格式 +func detectVideoFormat(rawBody []byte) string { + var raw map[string]any + if err := common.Unmarshal(rawBody, &raw); err != nil { + return "mp4" + } + respObj, ok := raw["response"].(map[string]any) + if !ok { + return "mp4" + } + vids, ok := respObj["videos"].([]any) + if !ok || len(vids) == 0 { + return "mp4" + } + v0, ok := vids[0].(map[string]any) + if !ok { + return "mp4" + } + mt, ok := v0["mimeType"].(string) + if !ok || mt == "" || strings.Contains(mt, "mp4") { + return "mp4" + } + return mt +} + +// mapTaskStatusToSimple 将内部 TaskStatus 映射为简化状态字符串 +func mapTaskStatusToSimple(status model.TaskStatus) string { + switch status { + case model.TaskStatusSuccess: + return "succeeded" + case model.TaskStatusFailure: + return "failed" + case model.TaskStatusQueued, model.TaskStatusSubmitted: + return "queued" + default: + return "processing" + } +} + func TaskModel2Dto(task *model.Task) *dto.TaskDto { return &dto.TaskDto{ + ID: task.ID, + CreatedAt: task.CreatedAt, + UpdatedAt: task.UpdatedAt, TaskID: task.TaskID, + Platform: string(task.Platform), + UserId: task.UserId, + Group: task.Group, + ChannelId: task.ChannelId, + Quota: task.Quota, Action: task.Action, Status: string(task.Status), FailReason: task.FailReason, + ResultURL: task.GetResultURL(), SubmitTime: task.SubmitTime, StartTime: task.StartTime, FinishTime: task.FinishTime, Progress: task.Progress, + Properties: task.Properties, + Username: task.Username, Data: task.Data, } } diff --git a/router/video-router.go b/router/video-router.go index d5fed1d7..d2bce42b 100644 --- a/router/video-router.go +++ b/router/video-router.go @@ -8,10 +8,16 @@ import ( ) func SetVideoRouter(router *gin.Engine) { + // Video proxy: accepts either session auth (dashboard) or token auth (API clients) + videoProxyRouter := router.Group("/v1") + videoProxyRouter.Use(middleware.TokenOrUserAuth()) + { + videoProxyRouter.GET("/videos/:task_id/content", controller.VideoProxy) + } + videoV1Router := router.Group("/v1") videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) { - videoV1Router.GET("/videos/:task_id/content", controller.VideoProxy) videoV1Router.POST("/video/generations", controller.RelayTask) videoV1Router.GET("/video/generations/:task_id", controller.RelayTask) videoV1Router.POST("/videos/:video_id/remix", controller.RelayTask) diff --git a/service/billing_session.go b/service/billing_session.go index 1a31316b..f24b68e5 100644 --- a/service/billing_session.go +++ b/service/billing_session.go @@ -193,6 +193,11 @@ func (s *BillingSession) preConsume(c *gin.Context, quota int) *types.NewAPIErro // shouldTrust 统一信任额度检查,适用于钱包和订阅。 func (s *BillingSession) shouldTrust(c *gin.Context) bool { + // 异步任务(ForcePreConsume=true)必须预扣全额,不允许信任旁路 + if s.relayInfo.ForcePreConsume { + return false + } + trustQuota := common.GetTrustQuota() if trustQuota <= 0 { return false diff --git a/service/error.go b/service/error.go index 7a9d7a81..a2ff0aad 100644 --- a/service/error.go +++ b/service/error.go @@ -206,3 +206,16 @@ func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError { return taskError } + +// TaskErrorFromAPIError 将 PreConsumeBilling 返回的 NewAPIError 转换为 TaskError。 +func TaskErrorFromAPIError(apiErr *types.NewAPIError) *dto.TaskError { + if apiErr == nil { + return nil + } + return &dto.TaskError{ + Code: string(apiErr.GetErrorCode()), + Message: apiErr.Err.Error(), + StatusCode: apiErr.StatusCode, + Error: apiErr.Err, + } +} diff --git a/service/log_info_generate.go b/service/log_info_generate.go index 771da5b7..1c440911 100644 --- a/service/log_info_generate.go +++ b/service/log_info_generate.go @@ -204,7 +204,7 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, return info } -func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.PerCallPriceData) map[string]interface{} { +func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.PriceData) map[string]interface{} { other := make(map[string]interface{}) other["model_price"] = priceData.ModelPrice other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio diff --git a/service/task_billing.go b/service/task_billing.go new file mode 100644 index 00000000..ec0094bd --- /dev/null +++ b/service/task_billing.go @@ -0,0 +1,227 @@ +package service + +import ( + "context" + "fmt" + "strings" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/logger" + "github.com/QuantumNous/new-api/model" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/setting/ratio_setting" + "github.com/gin-gonic/gin" +) + +// LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。 +// 实际扣费已由 BillingSession(PreConsumeBilling + SettleBilling)完成。 +func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo, modelName string) { + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("操作 %s", info.Action) + // 支持任务仅按次计费 + if common.StringsContains(constant.TaskPricePatches, modelName) { + logContent = fmt.Sprintf("%s,按次计费", logContent) + } else { + if len(info.PriceData.OtherRatios) > 0 { + var contents []string + for key, ra := range info.PriceData.OtherRatios { + if 1.0 != ra { + contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra)) + } + } + if len(contents) > 0 { + logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", ")) + } + } + } + other := make(map[string]interface{}) + other["request_path"] = c.Request.URL.Path + other["model_price"] = info.PriceData.ModelPrice + other["group_ratio"] = info.PriceData.GroupRatioInfo.GroupRatio + if info.PriceData.GroupRatioInfo.HasSpecialRatio { + other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio + } + model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{ + ChannelId: info.ChannelId, + ModelName: modelName, + TokenName: tokenName, + Quota: info.PriceData.Quota, + Content: logContent, + TokenId: info.TokenId, + Group: info.UsingGroup, + Other: other, + }) + model.UpdateUserUsedQuotaAndRequestCount(info.UserId, info.PriceData.Quota) + model.UpdateChannelUsedQuota(info.ChannelId, info.PriceData.Quota) +} + +// --------------------------------------------------------------------------- +// 异步任务计费辅助函数 +// --------------------------------------------------------------------------- + +// resolveTokenKey 通过 TokenId 运行时获取令牌 Key(用于 Redis 缓存操作)。 +// 如果令牌已被删除或查询失败,返回空字符串。 +func resolveTokenKey(ctx context.Context, tokenId int, taskID string) string { + token, err := model.GetTokenById(tokenId) + if err != nil { + logger.LogWarn(ctx, fmt.Sprintf("获取令牌 key 失败 (tokenId=%d, task=%s): %s", tokenId, taskID, err.Error())) + return "" + } + return token.Key +} + +// taskIsSubscription 判断任务是否通过订阅计费。 +func taskIsSubscription(task *model.Task) bool { + return task.PrivateData.BillingSource == BillingSourceSubscription && task.PrivateData.SubscriptionId > 0 +} + +// taskAdjustFunding 调整任务的资金来源(钱包或订阅),delta > 0 表示扣费,delta < 0 表示退还。 +func taskAdjustFunding(task *model.Task, delta int) error { + if taskIsSubscription(task) { + return model.PostConsumeUserSubscriptionDelta(task.PrivateData.SubscriptionId, int64(delta)) + } + if delta > 0 { + return model.DecreaseUserQuota(task.UserId, delta) + } + return model.IncreaseUserQuota(task.UserId, -delta, false) +} + +// taskAdjustTokenQuota 调整任务的令牌额度,delta > 0 表示扣费,delta < 0 表示退还。 +// 需要通过 resolveTokenKey 运行时获取 key(不从 PrivateData 中读取)。 +func taskAdjustTokenQuota(ctx context.Context, task *model.Task, delta int) { + if task.PrivateData.TokenId <= 0 || delta == 0 { + return + } + tokenKey := resolveTokenKey(ctx, task.PrivateData.TokenId, task.TaskID) + if tokenKey == "" { + return + } + var err error + if delta > 0 { + err = model.DecreaseTokenQuota(task.PrivateData.TokenId, tokenKey, delta) + } else { + err = model.IncreaseTokenQuota(task.PrivateData.TokenId, tokenKey, -delta) + } + if err != nil { + logger.LogWarn(ctx, fmt.Sprintf("调整令牌额度失败 (delta=%d, task=%s): %s", delta, task.TaskID, err.Error())) + } +} + +// RefundTaskQuota 统一的任务失败退款逻辑。 +// 当异步任务失败时,将预扣的 quota 退还给用户(支持钱包和订阅),并退还令牌额度。 +func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) { + quota := task.Quota + if quota == 0 { + return + } + + // 1. 退还资金来源(钱包或订阅) + if err := taskAdjustFunding(task, -quota); err != nil { + logger.LogWarn(ctx, fmt.Sprintf("退还资金来源失败 task %s: %s", task.TaskID, err.Error())) + return + } + + // 2. 退还令牌额度 + taskAdjustTokenQuota(ctx, task, -quota) + + // 3. 记录日志 + logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s,原因:%s", task.TaskID, logger.LogQuota(quota), reason) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) +} + +// RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。 +// 当任务成功且返回了 totalTokens 时,根据模型倍率和分组倍率重新计算实际扣费额度, +// 与预扣费的差额进行补扣或退还。支持钱包和订阅计费来源。 +func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTokens int) { + if totalTokens <= 0 { + return + } + + // 获取模型名称 + var taskData map[string]interface{} + if err := common.Unmarshal(task.Data, &taskData); err != nil { + return + } + modelName, ok := taskData["model"].(string) + if !ok || modelName == "" { + return + } + + // 获取模型价格和倍率 + modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName) + // 只有配置了倍率(非固定价格)时才按 token 重新计费 + if !hasRatioSetting || modelRatio <= 0 { + return + } + + // 获取用户和组的倍率信息 + group := task.Group + if group == "" { + user, err := model.GetUserById(task.UserId, false) + if err == nil { + group = user.Group + } + } + if group == "" { + return + } + + groupRatio := ratio_setting.GetGroupRatio(group) + userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group) + + var finalGroupRatio float64 + if hasUserGroupRatio { + finalGroupRatio = userGroupRatio + } else { + finalGroupRatio = groupRatio + } + + // 计算实际应扣费额度: totalTokens * modelRatio * groupRatio + actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio) + + // 计算差额(正数=需要补扣,负数=需要退还) + preConsumedQuota := task.Quota + quotaDelta := actualQuota - preConsumedQuota + + if quotaDelta == 0 { + logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)", + task.TaskID, logger.LogQuota(actualQuota), totalTokens)) + return + } + + logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,tokens:%d)", + task.TaskID, + logger.LogQuota(quotaDelta), + logger.LogQuota(actualQuota), + logger.LogQuota(preConsumedQuota), + totalTokens, + )) + + // 调整资金来源 + if err := taskAdjustFunding(task, quotaDelta); err != nil { + logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error())) + return + } + + // 调整令牌额度 + taskAdjustTokenQuota(ctx, task, quotaDelta) + + // 更新统计(仅补扣时更新,退还不影响已用统计) + if quotaDelta > 0 { + model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta) + model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta) + } + task.Quota = actualQuota + + var action string + if quotaDelta > 0 { + action = "补扣费" + } else { + action = "退还" + } + logContent := fmt.Sprintf("视频任务成功%s,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s", + action, modelRatio, finalGroupRatio, totalTokens, + logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota)) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) +} diff --git a/service/task_polling.go b/service/task_polling.go new file mode 100644 index 00000000..847e1659 --- /dev/null +++ b/service/task_polling.go @@ -0,0 +1,446 @@ +package service + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "sort" + "strings" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/logger" + "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" + relaycommon "github.com/QuantumNous/new-api/relay/common" + + "github.com/samber/lo" +) + +// TaskPollingAdaptor 定义轮询所需的最小适配器接口,避免 service -> relay 的循环依赖 +type TaskPollingAdaptor interface { + Init(info *relaycommon.RelayInfo) + FetchTask(baseURL string, key string, body map[string]any, proxy string) (*http.Response, error) + ParseTaskResult(body []byte) (*relaycommon.TaskInfo, error) +} + +// GetTaskAdaptorFunc 由 main 包注入,用于获取指定平台的任务适配器。 +// 打破 service -> relay -> relay/channel -> service 的循环依赖。 +var GetTaskAdaptorFunc func(platform constant.TaskPlatform) TaskPollingAdaptor + +// TaskPollingLoop 主轮询循环,每 15 秒检查一次未完成的任务 +func TaskPollingLoop() { + for { + time.Sleep(time.Duration(15) * time.Second) + common.SysLog("任务进度轮询开始") + ctx := context.TODO() + allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit) + platformTask := make(map[constant.TaskPlatform][]*model.Task) + for _, t := range allTasks { + platformTask[t.Platform] = append(platformTask[t.Platform], t) + } + for platform, tasks := range platformTask { + if len(tasks) == 0 { + continue + } + taskChannelM := make(map[int][]string) + taskM := make(map[string]*model.Task) + nullTaskIds := make([]int64, 0) + for _, task := range tasks { + upstreamID := task.GetUpstreamTaskID() + if upstreamID == "" { + // 统计失败的未完成任务 + nullTaskIds = append(nullTaskIds, task.ID) + continue + } + taskM[upstreamID] = task + taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], upstreamID) + } + if len(nullTaskIds) > 0 { + err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{ + "status": "FAILURE", + "progress": "100%", + }) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err)) + } else { + logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds)) + } + } + if len(taskChannelM) == 0 { + continue + } + + DispatchPlatformUpdate(platform, taskChannelM, taskM) + } + common.SysLog("任务进度轮询完成") + } +} + +// DispatchPlatformUpdate 按平台分发轮询更新 +func DispatchPlatformUpdate(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) { + switch platform { + case constant.TaskPlatformMidjourney: + // MJ 轮询由其自身处理,这里预留入口 + case constant.TaskPlatformSuno: + _ = UpdateSunoTasks(context.Background(), taskChannelM, taskM) + default: + if err := UpdateVideoTasks(context.Background(), platform, taskChannelM, taskM); err != nil { + common.SysLog(fmt.Sprintf("UpdateVideoTasks fail: %s", err)) + } + } +} + +// UpdateSunoTasks 按渠道更新所有 Suno 任务 +func UpdateSunoTasks(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error { + for channelId, taskIds := range taskChannelM { + err := updateSunoTasks(ctx, channelId, taskIds, taskM) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error())) + } + } + return nil +} + +func updateSunoTasks(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error { + logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) + if len(taskIds) == 0 { + return nil + } + ch, err := model.CacheGetChannel(channelId) + if err != nil { + common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err)) + // Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values) + var failedIDs []int64 + for _, upstreamID := range taskIds { + if t, ok := taskM[upstreamID]; ok { + failedIDs = append(failedIDs, t.ID) + } + } + err = model.TaskBulkUpdateByID(failedIDs, map[string]any{ + "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId), + "status": "FAILURE", + "progress": "100%", + }) + if err != nil { + common.SysLog(fmt.Sprintf("UpdateSunoTask error: %v", err)) + } + return err + } + adaptor := GetTaskAdaptorFunc(constant.TaskPlatformSuno) + if adaptor == nil { + return errors.New("adaptor not found") + } + proxy := ch.GetSetting().Proxy + resp, err := adaptor.FetchTask(*ch.BaseURL, ch.Key, map[string]any{ + "ids": taskIds, + }, proxy) + if err != nil { + common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err)) + return err + } + if resp.StatusCode != http.StatusOK { + logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) + return fmt.Errorf("Get Task status code: %d", resp.StatusCode) + } + defer resp.Body.Close() + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err)) + return err + } + var responseItems dto.TaskResponse[[]dto.SunoDataResponse] + err = common.Unmarshal(responseBody, &responseItems) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) + return err + } + if !responseItems.IsSuccess() { + common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody))) + return err + } + + for _, responseItem := range responseItems.Data { + task := taskM[responseItem.TaskID] + if !taskNeedsUpdate(task, responseItem) { + continue + } + + task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status) + task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason) + task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime) + task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime) + task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime) + if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure { + logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) + task.Progress = "100%" + RefundTaskQuota(ctx, task, task.FailReason) + } + if responseItem.Status == model.TaskStatusSuccess { + task.Progress = "100%" + } + task.Data = responseItem.Data + + err = task.Update() + if err != nil { + common.SysLog("UpdateSunoTask task error: " + err.Error()) + } + } + return nil +} + +// taskNeedsUpdate 检查 Suno 任务是否需要更新 +func taskNeedsUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool { + if oldTask.SubmitTime != newTask.SubmitTime { + return true + } + if oldTask.StartTime != newTask.StartTime { + return true + } + if oldTask.FinishTime != newTask.FinishTime { + return true + } + if string(oldTask.Status) != newTask.Status { + return true + } + if oldTask.FailReason != newTask.FailReason { + return true + } + + if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" { + return true + } + + oldData, _ := common.Marshal(oldTask.Data) + newData, _ := common.Marshal(newTask.Data) + + sort.Slice(oldData, func(i, j int) bool { + return oldData[i] < oldData[j] + }) + sort.Slice(newData, func(i, j int) bool { + return newData[i] < newData[j] + }) + + if string(oldData) != string(newData) { + return true + } + return false +} + +// UpdateVideoTasks 按渠道更新所有视频任务 +func UpdateVideoTasks(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error { + for channelId, taskIds := range taskChannelM { + if err := updateVideoTasks(ctx, platform, channelId, taskIds, taskM); err != nil { + logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) + } + } + return nil +} + +func updateVideoTasks(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error { + logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) + if len(taskIds) == 0 { + return nil + } + cacheGetChannel, err := model.CacheGetChannel(channelId) + if err != nil { + // Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values) + var failedIDs []int64 + for _, upstreamID := range taskIds { + if t, ok := taskM[upstreamID]; ok { + failedIDs = append(failedIDs, t.ID) + } + } + errUpdate := model.TaskBulkUpdateByID(failedIDs, map[string]any{ + "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId), + "status": "FAILURE", + "progress": "100%", + }) + if errUpdate != nil { + common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) + } + return fmt.Errorf("CacheGetChannel failed: %w", err) + } + adaptor := GetTaskAdaptorFunc(platform) + if adaptor == nil { + return fmt.Errorf("video adaptor not found") + } + info := &relaycommon.RelayInfo{} + info.ChannelMeta = &relaycommon.ChannelMeta{ + ChannelBaseUrl: cacheGetChannel.GetBaseURL(), + } + info.ApiKey = cacheGetChannel.Key + adaptor.Init(info) + for _, taskId := range taskIds { + if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil { + logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) + } + } + return nil +} + +func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch *model.Channel, taskId string, taskM map[string]*model.Task) error { + baseURL := constant.ChannelBaseURLs[ch.Type] + if ch.GetBaseURL() != "" { + baseURL = ch.GetBaseURL() + } + proxy := ch.GetSetting().Proxy + + task := taskM[taskId] + if task == nil { + logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) + return fmt.Errorf("task %s not found", taskId) + } + key := ch.Key + + privateData := task.PrivateData + if privateData.Key != "" { + key = privateData.Key + } + resp, err := adaptor.FetchTask(baseURL, key, map[string]any{ + "task_id": task.GetUpstreamTaskID(), + "action": task.Action, + }, proxy) + if err != nil { + return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err) + } + defer resp.Body.Close() + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("readAll failed for task %s: %w", taskId, err) + } + + logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask response: %s", string(responseBody))) + + taskResult := &relaycommon.TaskInfo{} + // try parse as New API response format + var responseItems dto.TaskResponse[model.Task] + if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() { + logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask parsed as new api response format: %+v", responseItems)) + t := responseItems.Data + taskResult.TaskID = t.TaskID + taskResult.Status = string(t.Status) + taskResult.Url = t.GetResultURL() + taskResult.Progress = t.Progress + taskResult.Reason = t.FailReason + task.Data = t.Data + } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil { + return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err) + } else { + task.Data = redactVideoResponseBody(responseBody) + } + + logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask taskResult: %+v", taskResult)) + + now := time.Now().Unix() + if taskResult.Status == "" { + taskResult = relaycommon.FailTaskInfo("upstream returned empty status") + } + + // 记录原本的状态,防止重复退款 + shouldRefund := false + quota := task.Quota + preStatus := task.Status + + task.Status = model.TaskStatus(taskResult.Status) + switch taskResult.Status { + case model.TaskStatusSubmitted: + task.Progress = taskcommon.ProgressSubmitted + case model.TaskStatusQueued: + task.Progress = taskcommon.ProgressQueued + case model.TaskStatusInProgress: + task.Progress = taskcommon.ProgressInProgress + if task.StartTime == 0 { + task.StartTime = now + } + case model.TaskStatusSuccess: + task.Progress = taskcommon.ProgressComplete + if task.FinishTime == 0 { + task.FinishTime = now + } + if strings.HasPrefix(taskResult.Url, "data:") { + // data: URI (e.g. Vertex base64 encoded video) — keep in Data, not in ResultURL + } else if taskResult.Url != "" { + // Direct upstream URL (e.g. Kling, Ali, Doubao, etc.) + task.PrivateData.ResultURL = taskResult.Url + } else { + // No URL from adaptor — construct proxy URL using public task ID + task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID) + } + + // 如果返回了 total_tokens,根据模型倍率重新计费 + if taskResult.TotalTokens > 0 { + RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens) + } + case model.TaskStatusFailure: + logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task) + task.Status = model.TaskStatusFailure + task.Progress = taskcommon.ProgressComplete + if task.FinishTime == 0 { + task.FinishTime = now + } + task.FailReason = taskResult.Reason + logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) + taskResult.Progress = taskcommon.ProgressComplete + if quota != 0 { + if preStatus != model.TaskStatusFailure { + shouldRefund = true + } else { + logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID)) + } + } + default: + return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId) + } + if taskResult.Progress != "" { + task.Progress = taskResult.Progress + } + if err := task.Update(); err != nil { + common.SysLog("UpdateVideoTask task error: " + err.Error()) + shouldRefund = false + } + + if shouldRefund { + RefundTaskQuota(ctx, task, task.FailReason) + } + + return nil +} + +func redactVideoResponseBody(body []byte) []byte { + var m map[string]any + if err := common.Unmarshal(body, &m); err != nil { + return body + } + resp, _ := m["response"].(map[string]any) + if resp != nil { + delete(resp, "bytesBase64Encoded") + if v, ok := resp["video"].(string); ok { + resp["video"] = truncateBase64(v) + } + if vs, ok := resp["videos"].([]any); ok { + for i := range vs { + if vm, ok := vs[i].(map[string]any); ok { + delete(vm, "bytesBase64Encoded") + } + } + } + } + b, err := common.Marshal(m) + if err != nil { + return body + } + return b +} + +func truncateBase64(s string) string { + const maxKeep = 256 + if len(s) <= maxKeep { + return s + } + return s[:maxKeep] + "..." +} diff --git a/types/price_data.go b/types/price_data.go index 3f7121b8..93bc6ae8 100644 --- a/types/price_data.go +++ b/types/price_data.go @@ -22,7 +22,8 @@ type PriceData struct { AudioCompletionRatio float64 OtherRatios map[string]float64 UsePrice bool - QuotaToPreConsume int // 预消耗额度 + Quota int // 按次计费的最终额度(MJ / Task) + QuotaToPreConsume int // 按量计费的预消耗额度 GroupRatioInfo GroupRatioInfo } @@ -36,12 +37,6 @@ func (p *PriceData) AddOtherRatio(key string, ratio float64) { p.OtherRatios[key] = ratio } -type PerCallPriceData struct { - ModelPrice float64 - Quota int - GroupRatioInfo GroupRatioInfo -} - func (p *PriceData) ToSetting() string { return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, CacheCreation5mRatio: %f, CacheCreation1hRatio: %f, QuotaToPreConsume: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.CacheCreation5mRatio, p.CacheCreation1hRatio, p.QuotaToPreConsume, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio) } diff --git a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx index c78d5773..4bce4525 100644 --- a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx +++ b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx @@ -396,7 +396,7 @@ export const getTaskLogsColumns = ({ dataIndex: 'fail_reason', fixed: 'right', render: (text, record, index) => { - // 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接 + // 视频预览:优先使用 result_url,兼容旧数据 fail_reason 中的 URL const isVideoTask = record.action === TASK_ACTION_GENERATE || record.action === TASK_ACTION_TEXT_GENERATE || @@ -404,14 +404,15 @@ export const getTaskLogsColumns = ({ record.action === TASK_ACTION_REFERENCE_GENERATE || record.action === TASK_ACTION_REMIX_GENERATE; const isSuccess = record.status === 'SUCCESS'; - const isUrl = typeof text === 'string' && /^https?:\/\//.test(text); - if (isSuccess && isVideoTask && isUrl) { + const resultUrl = record.result_url; + const hasResultUrl = typeof resultUrl === 'string' && /^https?:\/\//.test(resultUrl); + if (isSuccess && isVideoTask && hasResultUrl) { return ( { e.preventDefault(); - openVideoModal(text); + openVideoModal(resultUrl); }} > {t('点击预览视频')} diff --git a/web/src/components/table/task-logs/modals/ContentModal.jsx b/web/src/components/table/task-logs/modals/ContentModal.jsx index 88df4d8c..3527fd96 100644 --- a/web/src/components/table/task-logs/modals/ContentModal.jsx +++ b/web/src/components/table/task-logs/modals/ContentModal.jsx @@ -144,8 +144,6 @@ const ContentModal = ({ maxHeight: '100%', objectFit: 'contain', }} - autoPlay - crossOrigin='anonymous' onError={handleVideoError} onLoadedData={handleVideoLoaded} onLoadStart={() => setIsLoading(true)} From 8374a830844eaffde78a54b9e977dc888d8f070a Mon Sep 17 00:00:00 2001 From: CaIon Date: Tue, 10 Feb 2026 21:15:09 +0800 Subject: [PATCH 02/10] feat(task): add adaptor billing interface and async settlement framework Add three billing lifecycle methods to the TaskAdaptor interface: - EstimateBilling: compute OtherRatios from user request before pricing - AdjustBillingOnSubmit: adjust ratios from upstream submit response - AdjustBillingOnComplete: determine final quota at task terminal state Introduce BaseBilling as embeddable no-op default for adaptors without custom billing. Move Sora/Ali OtherRatios logic from shared validation into per-adaptor EstimateBilling implementations. Add TaskBillingContext to persist pricing params (model_price, group_ratio, other_ratios) in task private data for async polling settlement. Extract RecalculateTaskQuota as a general-purpose delta settlement function and unify polling billing via settleTaskBillingOnComplete (adaptor-first, then token-based fallback). --- controller/relay.go | 7 ++ logger/logger.go | 3 +- model/task.go | 16 +++- relay/channel/adapter.go | 30 +++++++- relay/channel/task/ali/adaptor.go | 57 +++++++++----- relay/channel/task/doubao/adaptor.go | 1 + relay/channel/task/gemini/adaptor.go | 1 + relay/channel/task/hailuo/adaptor.go | 2 + relay/channel/task/jimeng/adaptor.go | 1 + relay/channel/task/kling/adaptor.go | 1 + relay/channel/task/sora/adaptor.go | 44 ++++++++++- relay/channel/task/suno/adaptor.go | 7 +- relay/channel/task/taskcommon/helpers.go | 25 ++++++ relay/channel/task/vertex/adaptor.go | 41 +++++----- relay/channel/task/vidu/adaptor.go | 1 + relay/common/relay_utils.go | 10 +-- relay/relay_task.go | 64 ++++++++++++++-- service/task_billing.go | 98 +++++++++++++----------- service/task_polling.go | 28 ++++++- 19 files changed, 321 insertions(+), 116 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index 132fee9b..3d2f20e8 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -509,6 +509,13 @@ func RelayTask(c *gin.Context) { task.PrivateData.BillingSource = relayInfo.BillingSource task.PrivateData.SubscriptionId = relayInfo.SubscriptionId task.PrivateData.TokenId = relayInfo.TokenId + task.PrivateData.BillingContext = &model.TaskBillingContext{ + ModelPrice: relayInfo.PriceData.ModelPrice, + GroupRatio: relayInfo.PriceData.GroupRatioInfo.GroupRatio, + ModelRatio: relayInfo.PriceData.ModelRatio, + OtherRatios: relayInfo.PriceData.OtherRatios, + ModelName: result.ModelName, + } task.Quota = result.Quota task.Data = result.TaskData task.Action = relayInfo.Action diff --git a/logger/logger.go b/logger/logger.go index 61b1d49d..90cf5006 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -2,7 +2,6 @@ package logger import ( "context" - "encoding/json" "fmt" "io" "log" @@ -151,7 +150,7 @@ func FormatQuota(quota int) string { // LogJson 仅供测试使用 only for test func LogJson(ctx context.Context, msg string, obj any) { - jsonStr, err := json.Marshal(obj) + jsonStr, err := common.Marshal(obj) if err != nil { LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error())) return diff --git a/model/task.go b/model/task.go index 38bb4d05..592643eb 100644 --- a/model/task.go +++ b/model/task.go @@ -100,9 +100,19 @@ type TaskPrivateData struct { UpstreamTaskID string `json:"upstream_task_id,omitempty"` // 上游真实 task ID ResultURL string `json:"result_url,omitempty"` // 任务成功后的结果 URL(视频地址等) // 计费上下文:用于异步退款/差额结算(轮询阶段读取) - BillingSource string `json:"billing_source,omitempty"` // "wallet" 或 "subscription" - SubscriptionId int `json:"subscription_id,omitempty"` // 订阅 ID,用于订阅退款 - TokenId int `json:"token_id,omitempty"` // 令牌 ID,用于令牌额度退款 + BillingSource string `json:"billing_source,omitempty"` // "wallet" 或 "subscription" + SubscriptionId int `json:"subscription_id,omitempty"` // 订阅 ID,用于订阅退款 + TokenId int `json:"token_id,omitempty"` // 令牌 ID,用于令牌额度退款 + BillingContext *TaskBillingContext `json:"billing_context,omitempty"` // 计费参数快照(用于轮询阶段重新计算) +} + +// TaskBillingContext 记录任务提交时的计费参数,以便轮询阶段可以重新计算额度。 +type TaskBillingContext struct { + ModelPrice float64 `json:"model_price,omitempty"` // 模型单价 + GroupRatio float64 `json:"group_ratio,omitempty"` // 分组倍率 + ModelRatio float64 `json:"model_ratio,omitempty"` // 模型倍率 + OtherRatios map[string]float64 `json:"other_ratios,omitempty"` // 附加倍率(时长、分辨率等) + ModelName string `json:"model_name,omitempty"` // 模型名称 } // GetUpstreamTaskID 获取上游真实 task ID(用于与 provider 通信) diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index ff7606e2..d2f7c6bb 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -36,6 +36,32 @@ type TaskAdaptor interface { ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError + // ── Billing ────────────────────────────────────────────────────── + + // EstimateBilling returns OtherRatios for pre-charge based on user request. + // Called after ValidateRequestAndSetAction, before price calculation. + // Adaptors should extract duration, resolution, etc. from the parsed request + // and return them as ratio multipliers (e.g. {"seconds": 5, "size": 1.666}). + // Return nil to use the base model price without extra ratios. + EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 + + // AdjustBillingOnSubmit returns adjusted OtherRatios from the upstream + // submit response. Called after a successful DoResponse. + // If the upstream returned actual parameters that differ from the estimate + // (e.g. actual seconds), return updated ratios so the caller can recalculate + // the quota and settle the delta with the pre-charge. + // Return nil if no adjustment is needed. + AdjustBillingOnSubmit(info *relaycommon.RelayInfo, taskData []byte) map[string]float64 + + // AdjustBillingOnComplete returns the actual quota when a task reaches a + // terminal state (success/failure) during polling. + // Called by the polling loop after ParseTaskResult. + // Return a positive value to trigger delta settlement (supplement / refund). + // Return 0 to keep the pre-charged amount unchanged. + AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int + + // ── Request / Response ─────────────────────────────────────────── + BuildRequestURL(info *relaycommon.RelayInfo) (string, error) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) @@ -46,9 +72,9 @@ type TaskAdaptor interface { GetModelList() []string GetChannelName() string - // FetchTask - FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) + // ── Polling ────────────────────────────────────────────────────── + FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) } diff --git a/relay/channel/task/ali/adaptor.go b/relay/channel/task/ali/adaptor.go index 5d14ff65..f55178b3 100644 --- a/relay/channel/task/ali/adaptor.go +++ b/relay/channel/task/ali/adaptor.go @@ -13,6 +13,7 @@ import ( "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" "github.com/samber/lo" @@ -108,10 +109,10 @@ type AliMetadata struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string - aliReq *AliVideoRequest } func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { @@ -121,17 +122,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { } func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { - // 阿里通义万相支持 JSON 格式,不使用 multipart - var taskReq relaycommon.TaskSubmitReq - if err := common.UnmarshalBodyReusable(c, &taskReq); err != nil { - return service.TaskErrorWrapper(err, "unmarshal_task_request_failed", http.StatusBadRequest) - } - aliReq, err := a.convertToAliRequest(info, taskReq) - if err != nil { - return service.TaskErrorWrapper(err, "convert_to_ali_request_failed", http.StatusInternalServerError) - } - a.aliReq = aliReq - logger.LogJson(c, "ali video request body", aliReq) + // ValidateMultipartDirect 负责解析并将原始 TaskSubmitReq 存入 context return relaycommon.ValidateMultipartDirect(c, info) } @@ -148,11 +139,21 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info } func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { - bodyBytes, err := common.Marshal(a.aliReq) + taskReq, err := relaycommon.GetTaskRequest(c) + if err != nil { + return nil, errors.Wrap(err, "get_task_request_failed") + } + + aliReq, err := a.convertToAliRequest(info, taskReq) + if err != nil { + return nil, errors.Wrap(err, "convert_to_ali_request_failed") + } + logger.LogJson(c, "ali video request body", aliReq) + + bodyBytes, err := common.Marshal(aliReq) if err != nil { return nil, errors.Wrap(err, "marshal_ali_request_failed") } - return bytes.NewReader(bodyBytes), nil } @@ -335,19 +336,33 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay return nil, errors.New("can't change model with metadata") } - info.PriceData.OtherRatios = map[string]float64{ + return aliReq, nil +} + +// EstimateBilling 根据用户请求参数计算 OtherRatios(时长、分辨率等)。 +// 在 ValidateRequestAndSetAction 之后、价格计算之前调用。 +func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 { + taskReq, err := relaycommon.GetTaskRequest(c) + if err != nil { + return nil + } + + aliReq, err := a.convertToAliRequest(info, taskReq) + if err != nil { + return nil + } + + otherRatios := map[string]float64{ "seconds": float64(aliReq.Parameters.Duration), } - ratios, err := ProcessAliOtherRatios(aliReq) if err != nil { - return nil, err + return otherRatios } - for s, f := range ratios { - info.PriceData.OtherRatios[s] = f + for k, v := range ratios { + otherRatios[k] = v } - - return aliReq, nil + return otherRatios } // DoRequest delegates to common helper diff --git a/relay/channel/task/doubao/adaptor.go b/relay/channel/task/doubao/adaptor.go index 3da125af..eca421bd 100644 --- a/relay/channel/task/doubao/adaptor.go +++ b/relay/channel/task/doubao/adaptor.go @@ -89,6 +89,7 @@ type responseTask struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string diff --git a/relay/channel/task/gemini/adaptor.go b/relay/channel/task/gemini/adaptor.go index a863ea85..06c00a46 100644 --- a/relay/channel/task/gemini/adaptor.go +++ b/relay/channel/task/gemini/adaptor.go @@ -85,6 +85,7 @@ type operationResponse struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string diff --git a/relay/channel/task/hailuo/adaptor.go b/relay/channel/task/hailuo/adaptor.go index 67a68a10..ab83d659 100644 --- a/relay/channel/task/hailuo/adaptor.go +++ b/relay/channel/task/hailuo/adaptor.go @@ -17,12 +17,14 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" ) // https://platform.minimaxi.com/docs/api-reference/video-generation-intro type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index 7f88be24..b61cca41 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -77,6 +77,7 @@ const ( // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int accessKey string secretKey string diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index 4458626b..46e210f1 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -97,6 +97,7 @@ type responsePayload struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string diff --git a/relay/channel/task/sora/adaptor.go b/relay/channel/task/sora/adaptor.go index ee69a3e4..8faaf984 100644 --- a/relay/channel/task/sora/adaptor.go +++ b/relay/channel/task/sora/adaptor.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net/http" + "strconv" "strings" "github.com/QuantumNous/new-api/common" @@ -11,6 +12,7 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" @@ -56,6 +58,7 @@ type responseTask struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string @@ -68,15 +71,15 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { } func validateRemixRequest(c *gin.Context) *dto.TaskError { - var req struct { - Prompt string `json:"prompt"` - } + var req relaycommon.TaskSubmitReq if err := common.UnmarshalBodyReusable(c, &req); err != nil { return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) } if strings.TrimSpace(req.Prompt) == "" { return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest) } + // 存储原始请求到 context,与 ValidateMultipartDirect 路径保持一致 + c.Set("task_request", req) return nil } @@ -87,6 +90,41 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom return relaycommon.ValidateMultipartDirect(c, info) } +// EstimateBilling 根据用户请求的 seconds 和 size 计算 OtherRatios。 +func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 { + // remix 路径的 OtherRatios 已在 ResolveOriginTask 中设置 + if info.Action == constant.TaskActionRemix { + return nil + } + + req, err := relaycommon.GetTaskRequest(c) + if err != nil { + return nil + } + + seconds, _ := strconv.Atoi(req.Seconds) + if seconds == 0 { + seconds = req.Duration + } + if seconds <= 0 { + seconds = 4 + } + + size := req.Size + if size == "" { + size = "720x1280" + } + + ratios := map[string]float64{ + "seconds": float64(seconds), + "size": 1, + } + if size == "1792x1024" || size == "1024x1792" { + ratios["size"] = 1.666667 + } + return ratios +} + func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.Action == constant.TaskActionRemix { return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go index 5dd62a70..2dbb44f0 100644 --- a/relay/channel/task/suno/adaptor.go +++ b/relay/channel/task/suno/adaptor.go @@ -13,6 +13,7 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" @@ -20,6 +21,7 @@ import ( ) type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int } @@ -79,10 +81,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { sunoRequest, ok := c.Get("task_request") if !ok { - err := common.UnmarshalBodyReusable(c, &sunoRequest) - if err != nil { - return nil, err - } + return nil, fmt.Errorf("task_request not found in context") } data, err := common.Marshal(sunoRequest) if err != nil { diff --git a/relay/channel/task/taskcommon/helpers.go b/relay/channel/task/taskcommon/helpers.go index b1dde998..27d6612d 100644 --- a/relay/channel/task/taskcommon/helpers.go +++ b/relay/channel/task/taskcommon/helpers.go @@ -5,7 +5,10 @@ import ( "fmt" "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/model" + relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/setting/system_setting" + "github.com/gin-gonic/gin" ) // UnmarshalMetadata converts a map[string]any metadata to a typed struct via JSON round-trip. @@ -68,3 +71,25 @@ const ( ProgressInProgress = "30%" ProgressComplete = "100%" ) + +// --------------------------------------------------------------------------- +// BaseBilling — embeddable no-op implementations for TaskAdaptor billing methods. +// Adaptors that do not need custom billing can embed this struct directly. +// --------------------------------------------------------------------------- + +type BaseBilling struct{} + +// EstimateBilling returns nil (no extra ratios; use base model price). +func (BaseBilling) EstimateBilling(_ *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 { + return nil +} + +// AdjustBillingOnSubmit returns nil (no submit-time adjustment). +func (BaseBilling) AdjustBillingOnSubmit(_ *relaycommon.RelayInfo, _ []byte) map[string]float64 { + return nil +} + +// AdjustBillingOnComplete returns 0 (keep pre-charged amount). +func (BaseBilling) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int { + return 0 +} diff --git a/relay/channel/task/vertex/adaptor.go b/relay/channel/task/vertex/adaptor.go index fb3a313f..4931002d 100644 --- a/relay/channel/task/vertex/adaptor.go +++ b/relay/channel/task/vertex/adaptor.go @@ -62,6 +62,7 @@ type operationResponse struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string @@ -133,6 +134,28 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info return nil } +// EstimateBilling 根据用户请求中的 sampleCount 计算 OtherRatios。 +func (a *TaskAdaptor) EstimateBilling(c *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 { + sampleCount := 1 + v, ok := c.Get("task_request") + if ok { + req := v.(relaycommon.TaskSubmitReq) + if req.Metadata != nil { + if sc, exists := req.Metadata["sampleCount"]; exists { + if i, ok := sc.(int); ok && i > 0 { + sampleCount = i + } + if f, ok := sc.(float64); ok && int(f) > 0 { + sampleCount = int(f) + } + } + } + } + return map[string]float64{ + "sampleCount": float64(sampleCount), + } +} + // BuildRequestBody converts request into Vertex specific format. func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { v, ok := c.Get("task_request") @@ -166,24 +189,6 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn return nil, fmt.Errorf("sampleCount must be greater than 0") } - // if req.Duration > 0 { - // body.Parameters["durationSeconds"] = req.Duration - // } else if req.Seconds != "" { - // seconds, err := strconv.Atoi(req.Seconds) - // if err != nil { - // return nil, errors.Wrap(err, "convert seconds to int failed") - // } - // body.Parameters["durationSeconds"] = seconds - // } - - info.PriceData.OtherRatios = map[string]float64{ - "sampleCount": float64(body.Parameters["sampleCount"].(int)), - } - - // if v, ok := body.Parameters["durationSeconds"]; ok { - // info.PriceData.OtherRatios["durationSeconds"] = float64(v.(int)) - // } - data, err := common.Marshal(body) if err != nil { return nil, err diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index 1bab12f0..e689bf88 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -73,6 +73,7 @@ type creation struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int baseURL string } diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index b662f905..3cbb18c2 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -173,16 +173,10 @@ func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError { if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) { return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true) } - info.PriceData.OtherRatios = map[string]float64{ - "seconds": float64(seconds), - "size": 1, - } - if lo.Contains([]string{"1792x1024", "1024x1792"}, size) { - info.PriceData.OtherRatios["size"] = 1.666667 - } + // OtherRatios 已移到 Sora adaptor 的 EstimateBilling 中设置 } - info.Action = action + storeTaskRequest(c, info, action, req) return nil } diff --git a/relay/relay_task.go b/relay/relay_task.go index d372ca2e..7c6724d8 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -128,8 +128,9 @@ func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskErr } // RelayTaskSubmit 完成 task 提交的全部流程(每次尝试调用一次): -// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 → 计算价格 → -// 预扣费(仅首次,通过 info.Billing==nil 守卫)→ 构建/发送/解析上游请求。 +// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 → +// 估算计费(EstimateBilling) → 计算价格 → 预扣费(仅首次)→ +// 构建/发送/解析上游请求 → 提交后计费调整(AdjustBillingOnSubmit)。 // 控制器负责 defer Refund 和成功后 Settle。 func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) { info.InitChannelMeta(c) @@ -159,10 +160,20 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe info.PublicTaskID = model.GenerateTaskID() } - // 4. 价格计算 + // 4. 价格计算:基础模型价格 info.OriginModelName = modelName info.PriceData = helper.ModelPriceHelperPerCall(c, info) + // 5. 计费估算:让适配器根据用户请求提供 OtherRatios(时长、分辨率等) + // 必须在 ModelPriceHelperPerCall 之后调用(它会重建 PriceData)。 + // ResolveOriginTask 可能已在 remix 路径中预设了 OtherRatios,此处合并。 + if estimatedRatios := adaptor.EstimateBilling(c, info); len(estimatedRatios) > 0 { + for k, v := range estimatedRatios { + info.PriceData.AddOtherRatio(k, v) + } + } + + // 6. 将 OtherRatios 应用到基础额度 if !common.StringsContains(constant.TaskPricePatches, modelName) { for _, ra := range info.PriceData.OtherRatios { if ra != 1.0 { @@ -171,7 +182,7 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe } } - // 5. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过) + // 7. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过) if info.Billing == nil && !info.PriceData.FreeModel { info.ForcePreConsume = true if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil { @@ -179,13 +190,13 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe } } - // 6. 构建请求体 + // 8. 构建请求体 requestBody, err := adaptor.BuildRequestBody(c, info) if err != nil { return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError) } - // 7. 发送请求 + // 9. 发送请求 resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) @@ -195,20 +206,59 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode) } - // 8. 解析响应 + // 10. 返回 OtherRatios 给下游(header 必须在 DoResponse 写 body 之前设置) + otherRatios := info.PriceData.OtherRatios + if otherRatios == nil { + otherRatios = map[string]float64{} + } + ratiosJSON, _ := common.Marshal(otherRatios) + c.Header("X-New-Api-Other-Ratios", string(ratiosJSON)) + + // 11. 解析响应 upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info) if taskErr != nil { return nil, taskErr } + // 11. 提交后计费调整:让适配器根据上游实际返回调整 OtherRatios + finalQuota := info.PriceData.Quota + if adjustedRatios := adaptor.AdjustBillingOnSubmit(info, taskData); len(adjustedRatios) > 0 { + // 基于调整后的 ratios 重新计算 quota + finalQuota = recalcQuotaFromRatios(info, adjustedRatios) + info.PriceData.OtherRatios = adjustedRatios + info.PriceData.Quota = finalQuota + } + return &TaskSubmitResult{ UpstreamTaskID: upstreamTaskID, TaskData: taskData, Platform: platform, ModelName: modelName, + Quota: finalQuota, }, nil } +// recalcQuotaFromRatios 根据 adjustedRatios 重新计算 quota。 +// 公式: baseQuota × ∏(ratio) — 其中 baseQuota 是不含 OtherRatios 的基础额度。 +func recalcQuotaFromRatios(info *relaycommon.RelayInfo, ratios map[string]float64) int { + // 从 PriceData 获取不含 OtherRatios 的基础价格 + baseQuota := info.PriceData.Quota + // 先除掉原有的 OtherRatios 恢复基础额度 + for _, ra := range info.PriceData.OtherRatios { + if ra != 1.0 && ra > 0 { + baseQuota = int(float64(baseQuota) / ra) + } + } + // 应用新的 ratios + result := float64(baseQuota) + for _, ra := range ratios { + if ra != 1.0 { + result *= ra + } + } + return int(result) +} + var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){ relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder, relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder, diff --git a/service/task_billing.go b/service/task_billing.go index ec0094bd..fc44c587 100644 --- a/service/task_billing.go +++ b/service/task_billing.go @@ -130,6 +130,58 @@ func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) { model.RecordLog(task.UserId, model.LogTypeSystem, logContent) } +// RecalculateTaskQuota 通用的异步差额结算。 +// actualQuota 是任务完成后的实际应扣额度,与预扣额度 (task.Quota) 做差额结算。 +// reason 用于日志记录(例如 "token重算" 或 "adaptor调整")。 +func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int, reason string) { + if actualQuota <= 0 { + return + } + preConsumedQuota := task.Quota + quotaDelta := actualQuota - preConsumedQuota + + if quotaDelta == 0 { + logger.LogInfo(ctx, fmt.Sprintf("任务 %s 预扣费准确(%s,%s)", + task.TaskID, logger.LogQuota(actualQuota), reason)) + return + } + + logger.LogInfo(ctx, fmt.Sprintf("任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,%s)", + task.TaskID, + logger.LogQuota(quotaDelta), + logger.LogQuota(actualQuota), + logger.LogQuota(preConsumedQuota), + reason, + )) + + // 调整资金来源 + if err := taskAdjustFunding(task, quotaDelta); err != nil { + logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error())) + return + } + + // 调整令牌额度 + taskAdjustTokenQuota(ctx, task, quotaDelta) + + // 更新统计(仅补扣时更新,退还不影响已用统计) + if quotaDelta > 0 { + model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta) + model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta) + } + task.Quota = actualQuota + + var action string + if quotaDelta > 0 { + action = "补扣费" + } else { + action = "退还" + } + logContent := fmt.Sprintf("异步任务成功%s,预扣费 %s,实际扣费 %s,原因:%s", + action, + logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), reason) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) +} + // RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。 // 当任务成功且返回了 totalTokens 时,根据模型倍率和分组倍率重新计算实际扣费额度, // 与预扣费的差额进行补扣或退还。支持钱包和订阅计费来源。 @@ -180,48 +232,6 @@ func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTo // 计算实际应扣费额度: totalTokens * modelRatio * groupRatio actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio) - // 计算差额(正数=需要补扣,负数=需要退还) - preConsumedQuota := task.Quota - quotaDelta := actualQuota - preConsumedQuota - - if quotaDelta == 0 { - logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)", - task.TaskID, logger.LogQuota(actualQuota), totalTokens)) - return - } - - logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,tokens:%d)", - task.TaskID, - logger.LogQuota(quotaDelta), - logger.LogQuota(actualQuota), - logger.LogQuota(preConsumedQuota), - totalTokens, - )) - - // 调整资金来源 - if err := taskAdjustFunding(task, quotaDelta); err != nil { - logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error())) - return - } - - // 调整令牌额度 - taskAdjustTokenQuota(ctx, task, quotaDelta) - - // 更新统计(仅补扣时更新,退还不影响已用统计) - if quotaDelta > 0 { - model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta) - model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta) - } - task.Quota = actualQuota - - var action string - if quotaDelta > 0 { - action = "补扣费" - } else { - action = "退还" - } - logContent := fmt.Sprintf("视频任务成功%s,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s", - action, modelRatio, finalGroupRatio, totalTokens, - logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + reason := fmt.Sprintf("token重算:tokens=%d, modelRatio=%.2f, groupRatio=%.2f", totalTokens, modelRatio, finalGroupRatio) + RecalculateTaskQuota(ctx, task, actualQuota, reason) } diff --git a/service/task_polling.go b/service/task_polling.go index 847e1659..efbad8af 100644 --- a/service/task_polling.go +++ b/service/task_polling.go @@ -26,6 +26,9 @@ type TaskPollingAdaptor interface { Init(info *relaycommon.RelayInfo) FetchTask(baseURL string, key string, body map[string]any, proxy string) (*http.Response, error) ParseTaskResult(body []byte) (*relaycommon.TaskInfo, error) + // AdjustBillingOnComplete 在任务到达终态(成功/失败)时由轮询循环调用。 + // 返回正数触发差额结算(补扣/退还),返回 0 保持预扣费金额不变。 + AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int } // GetTaskAdaptorFunc 由 main 包注入,用于获取指定平台的任务适配器。 @@ -372,10 +375,8 @@ func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch * task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID) } - // 如果返回了 total_tokens,根据模型倍率重新计费 - if taskResult.TotalTokens > 0 { - RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens) - } + // 完成时计费调整:优先由 adaptor 计算,回退到 token 重算 + settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) case model.TaskStatusFailure: logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task) task.Status = model.TaskStatusFailure @@ -444,3 +445,22 @@ func truncateBase64(s string) string { } return s[:maxKeep] + "..." } + +// settleTaskBillingOnComplete 任务完成时的统一计费调整。 +// 优先级:1. adaptor.AdjustBillingOnComplete 返回正数 → 使用 adaptor 计算的额度 +// +// 2. taskResult.TotalTokens > 0 → 按 token 重算 +// 3. 都不满足 → 保持预扣额度不变 +func settleTaskBillingOnComplete(ctx context.Context, adaptor TaskPollingAdaptor, task *model.Task, taskResult *relaycommon.TaskInfo) { + // 1. 优先让 adaptor 决定最终额度 + if actualQuota := adaptor.AdjustBillingOnComplete(task, taskResult); actualQuota > 0 { + RecalculateTaskQuota(ctx, task, actualQuota, "adaptor计费调整") + return + } + // 2. 回退到 token 重算 + if taskResult.TotalTokens > 0 { + RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens) + return + } + // 3. 无调整,保持预扣额度 +} From 64d18a5fdf1ef072070cb7c4bc24623fdaa7c7ff Mon Sep 17 00:00:00 2001 From: CaIon Date: Sat, 21 Feb 2026 22:48:30 +0800 Subject: [PATCH 03/10] refactor(logs): add refund logging for asynchronous tasks and update translations --- controller/relay.go | 128 +++++++----------- model/log.go | 43 ++++++ service/task_billing.go | 75 ++++++++-- .../table/usage-logs/UsageLogsColumnDefs.jsx | 30 +++- .../table/usage-logs/UsageLogsFilters.jsx | 1 + web/src/hooks/usage-logs/useUsageLogsData.jsx | 24 +++- web/src/i18n/locales/en.json | 5 + web/src/i18n/locales/fr.json | 5 + web/src/i18n/locales/ja.json | 5 + web/src/i18n/locales/ru.json | 5 + web/src/i18n/locales/vi.json | 4 + web/src/i18n/locales/zh-CN.json | 5 + 12 files changed, 229 insertions(+), 101 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index 3d2f20e8..e90d6dd0 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -451,8 +451,6 @@ func RelayNotFound(c *gin.Context) { } func RelayTask(c *gin.Context) { - channelId := c.GetInt("channel_id") - c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)}) relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil) if err != nil { c.JSON(http.StatusInternalServerError, &dto.TaskError{ @@ -463,8 +461,7 @@ func RelayTask(c *gin.Context) { return } - // Fetch 操作是纯 DB 查询(或 task 自带 channelId 的上游查询),不依赖上下文 channel,无需重试 - // TODO: 在video-route层面优化,避免无谓的 channel 选择和上下文设置,也没必要吧代码放到这里来写这么多屎山 + // Fetch 路径:纯 DB 查询,不依赖上下文 channel,无需重试 switch relayInfo.RelayMode { case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID: if taskErr := relay.RelayTaskFetch(c, relayInfo.RelayMode); taskErr != nil { @@ -475,13 +472,11 @@ func RelayTask(c *gin.Context) { // ── Submit 路径 ───────────────────────────────────────────────── - // 1. 解析原始任务(remix / continuation),一次性,可能锁定渠道并禁止重试 if taskErr := relay.ResolveOriginTask(c, relayInfo); taskErr != nil { respondTaskError(c, taskErr) return } - // 2. defer Refund(全部失败时回滚预扣费) var result *relay.TaskSubmitResult var taskErr *dto.TaskError defer func() { @@ -490,14 +485,57 @@ func RelayTask(c *gin.Context) { } }() - // 3. 执行 + 重试(RelayTaskSubmit 内部在首次调用时自动预扣费) - taskErr = taskSubmitWithRetry(c, relayInfo, channelId, common.RetryTimes, func() *dto.TaskError { - var te *dto.TaskError - result, te = relay.RelayTaskSubmit(c, relayInfo) - return te - }) + retryParam := &service.RetryParam{ + Ctx: c, + TokenGroup: relayInfo.TokenGroup, + ModelName: relayInfo.OriginModelName, + Retry: common.GetPointer(0), + } - // 4. 成功:结算 + 日志 + 插入任务 + for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() { + channel, channelErr := getChannel(c, relayInfo, retryParam) + if channelErr != nil { + logger.LogError(c, channelErr.Error()) + taskErr = service.TaskErrorWrapperLocal(channelErr.Err, "get_channel_failed", http.StatusInternalServerError) + break + } + + addUsedChannel(c, channel.Id) + requestBody, bodyErr := common.GetRequestBody(c) + if bodyErr != nil { + if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) { + taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusRequestEntityTooLarge) + } else { + taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusBadRequest) + } + break + } + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + + result, taskErr = relay.RelayTaskSubmit(c, relayInfo) + if taskErr == nil { + break + } + + if !taskErr.LocalError { + processChannelError(c, + *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, + common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), + types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode)) + } + + if !shouldRetryTaskRelay(c, channel.Id, taskErr, common.RetryTimes-retryParam.GetRetry()) { + break + } + } + + useChannel := c.GetStringSlice("use_channel") + if len(useChannel) > 1 { + retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) + logger.LogInfo(c, retryLogStr) + } + + // ── 成功:结算 + 日志 + 插入任务 ── if taskErr == nil { if settleErr := service.SettleBilling(c, relayInfo, result.Quota); settleErr != nil { common.SysError("settle task billing error: " + settleErr.Error()) @@ -520,7 +558,6 @@ func RelayTask(c *gin.Context) { task.Data = result.TaskData task.Action = relayInfo.Action if insertErr := task.Insert(); insertErr != nil { - //taskErr = service.TaskErrorWrapper(insertErr, "insert_task_failed", http.StatusInternalServerError) common.SysError("insert task error: " + insertErr.Error()) } } @@ -538,69 +575,6 @@ func respondTaskError(c *gin.Context, taskErr *dto.TaskError) { c.JSON(taskErr.StatusCode, taskErr) } -// taskSubmitWithRetry 执行首次尝试并在失败时切换渠道重试,返回最终的 taskErr。 -// attempt 闭包负责实际的上游请求,不涉及计费。 -func taskSubmitWithRetry(c *gin.Context, relayInfo *relaycommon.RelayInfo, - channelId int, retryTimes int, attempt func() *dto.TaskError) *dto.TaskError { - - taskErr := attempt() - if taskErr == nil { - return nil - } - if !taskErr.LocalError { - processChannelError(c, - *types.NewChannelError(channelId, c.GetInt("channel_type"), c.GetString("channel_name"), common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey), - common.GetContextKeyString(c, constant.ContextKeyChannelKey), common.GetContextKeyBool(c, constant.ContextKeyChannelAutoBan)), - types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode)) - } - - retryParam := &service.RetryParam{ - Ctx: c, - TokenGroup: relayInfo.TokenGroup, - ModelName: relayInfo.OriginModelName, - Retry: common.GetPointer(0), - } - for ; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && retryParam.GetRetry() < retryTimes; retryParam.IncreaseRetry() { - channel, newAPIError := getChannel(c, relayInfo, retryParam) - if newAPIError != nil { - logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error())) - taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError) - break - } - channelId = channel.Id - useChannel := c.GetStringSlice("use_channel") - useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) - c.Set("use_channel", useChannel) - logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry())) - middleware.SetupContextForSelectedChannel(c, channel, c.GetString("original_model")) - - bodyStorage, err := common.GetBodyStorage(c) - if err != nil { - if common.IsRequestBodyTooLargeError(err) || errors.Is(err, common.ErrRequestBodyTooLarge) { - taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusRequestEntityTooLarge) - } else { - taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusBadRequest) - } - break - } - c.Request.Body = io.NopCloser(bodyStorage) - taskErr = attempt() - if taskErr != nil && !taskErr.LocalError { - processChannelError(c, - *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, - common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), - types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode)) - } - } - - useChannel := c.GetStringSlice("use_channel") - if len(useChannel) > 1 { - retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - logger.LogInfo(c, retryLogStr) - } - return taskErr -} - func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool { if taskErr == nil { return false diff --git a/model/log.go b/model/log.go index d7cd97a4..1f521b1e 100644 --- a/model/log.go +++ b/model/log.go @@ -199,6 +199,49 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) } } +type RecordTaskBillingLogParams struct { + UserId int + LogType int + Content string + ChannelId int + ModelName string + Quota int + TokenId int + Group string + Other map[string]interface{} +} + +func RecordTaskBillingLog(params RecordTaskBillingLogParams) { + if params.LogType == LogTypeConsume && !common.LogConsumeEnabled { + return + } + username, _ := GetUsernameById(params.UserId, false) + tokenName := "" + if params.TokenId > 0 { + if token, err := GetTokenById(params.TokenId); err == nil { + tokenName = token.Name + } + } + log := &Log{ + UserId: params.UserId, + Username: username, + CreatedAt: common.GetTimestamp(), + Type: params.LogType, + Content: params.Content, + TokenName: tokenName, + ModelName: params.ModelName, + Quota: params.Quota, + ChannelId: params.ChannelId, + TokenId: params.TokenId, + Group: params.Group, + Other: common.MapToJsonStr(params.Other), + } + err := LOG_DB.Create(log).Error + if err != nil { + common.SysLog("failed to record task billing log: " + err.Error()) + } +} + func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int, group string, requestId string) (logs []*Log, total int64, err error) { var tx *gorm.DB if logType == LogTypeUnknown { diff --git a/service/task_billing.go b/service/task_billing.go index fc44c587..78ad0fc0 100644 --- a/service/task_billing.go +++ b/service/task_billing.go @@ -108,6 +108,29 @@ func taskAdjustTokenQuota(ctx context.Context, task *model.Task, delta int) { } } +// taskBillingOther 从 task 的 BillingContext 构建日志 Other 字段。 +func taskBillingOther(task *model.Task) map[string]interface{} { + other := make(map[string]interface{}) + if bc := task.PrivateData.BillingContext; bc != nil { + other["model_price"] = bc.ModelPrice + other["group_ratio"] = bc.GroupRatio + if len(bc.OtherRatios) > 0 { + for k, v := range bc.OtherRatios { + other[k] = v + } + } + } + return other +} + +// taskModelName 从 BillingContext 或 Properties 中获取模型名称。 +func taskModelName(task *model.Task) string { + if bc := task.PrivateData.BillingContext; bc != nil && bc.ModelName != "" { + return bc.ModelName + } + return task.Properties.OriginModelName +} + // RefundTaskQuota 统一的任务失败退款逻辑。 // 当异步任务失败时,将预扣的 quota 退还给用户(支持钱包和订阅),并退还令牌额度。 func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) { @@ -126,8 +149,20 @@ func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) { taskAdjustTokenQuota(ctx, task, -quota) // 3. 记录日志 - logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s,原因:%s", task.TaskID, logger.LogQuota(quota), reason) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + other := taskBillingOther(task) + other["task_id"] = task.TaskID + other["reason"] = reason + model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{ + UserId: task.UserId, + LogType: model.LogTypeRefund, + Content: "", + ChannelId: task.ChannelId, + ModelName: taskModelName(task), + Quota: quota, + TokenId: task.PrivateData.TokenId, + Group: task.Group, + Other: other, + }) } // RecalculateTaskQuota 通用的异步差额结算。 @@ -163,23 +198,35 @@ func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int // 调整令牌额度 taskAdjustTokenQuota(ctx, task, quotaDelta) - // 更新统计(仅补扣时更新,退还不影响已用统计) - if quotaDelta > 0 { - model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta) - model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta) - } task.Quota = actualQuota - var action string + var logType int + var logQuota int if quotaDelta > 0 { - action = "补扣费" + logType = model.LogTypeConsume + logQuota = quotaDelta + model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta) + model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta) } else { - action = "退还" + logType = model.LogTypeRefund + logQuota = -quotaDelta } - logContent := fmt.Sprintf("异步任务成功%s,预扣费 %s,实际扣费 %s,原因:%s", - action, - logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), reason) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + other := taskBillingOther(task) + other["task_id"] = task.TaskID + other["reason"] = reason + other["pre_consumed_quota"] = preConsumedQuota + other["actual_quota"] = actualQuota + model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{ + UserId: task.UserId, + LogType: logType, + Content: "", + ChannelId: task.ChannelId, + ModelName: taskModelName(task), + Quota: logQuota, + TokenId: task.PrivateData.TokenId, + Group: task.Group, + Other: other, + }) } // RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。 diff --git a/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx b/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx index f0dcd379..b1538877 100644 --- a/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx +++ b/web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx @@ -133,6 +133,12 @@ function renderType(type, t) { {t('错误')} ); + case 6: + return ( + + {t('退款')} + + ); default: return ( @@ -368,7 +374,7 @@ export const getLogsColumns = ({ } return isAdminUser && - (record.type === 0 || record.type === 2 || record.type === 5) ? ( + (record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6) ? ( @@ -459,7 +465,7 @@ export const getLogsColumns = ({ title: t('令牌'), dataIndex: 'token_name', render: (text, record, index) => { - return record.type === 0 || record.type === 2 || record.type === 5 ? ( + return record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6 ? (
{ - if (record.type === 0 || record.type === 2 || record.type === 5) { + if (record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6) { if (record.group) { return <>{renderGroup(record.group)}; } else { @@ -522,7 +528,7 @@ export const getLogsColumns = ({ title: t('模型'), dataIndex: 'model_name', render: (text, record, index) => { - return record.type === 0 || record.type === 2 || record.type === 5 ? ( + return record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6 ? ( <>{renderModelName(record, copyText, t)} ) : ( <> @@ -589,7 +595,7 @@ export const getLogsColumns = ({ cacheText = `${t('缓存写')} ${formatTokenCount(cacheSummary.cacheWriteTokens)}`; } - return record.type === 0 || record.type === 2 || record.type === 5 ? ( + return record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6 ? (
{ return parseInt(text) > 0 && - (record.type === 0 || record.type === 2 || record.type === 5) ? ( + (record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6) ? ( <>{ {text} } ) : ( <> @@ -635,7 +641,7 @@ export const getLogsColumns = ({ title: t('花费'), dataIndex: 'quota', render: (text, record, index) => { - if (!(record.type === 0 || record.type === 2 || record.type === 5)) { + if (!(record.type === 0 || record.type === 2 || record.type === 5 || record.type === 6)) { return <>; } const other = getLogOther(record.other); @@ -722,6 +728,16 @@ export const getLogsColumns = ({ fixed: 'right', render: (text, record, index) => { let other = getLogOther(record.other); + if (record.type === 6) { + return ( + + {t('异步任务退款')} + + ); + } if (other == null || record.type !== 2) { return ( {t('管理')} {t('系统')} {t('错误')} + {t('退款')}
diff --git a/web/src/hooks/usage-logs/useUsageLogsData.jsx b/web/src/hooks/usage-logs/useUsageLogsData.jsx index 14c021e4..b69a7cf1 100644 --- a/web/src/hooks/usage-logs/useUsageLogsData.jsx +++ b/web/src/hooks/usage-logs/useUsageLogsData.jsx @@ -344,7 +344,7 @@ export const useLogsData = () => { let other = getLogOther(logs[i].other); let expandDataLocal = []; - if (isAdminUser && (logs[i].type === 0 || logs[i].type === 2)) { + if (isAdminUser && (logs[i].type === 0 || logs[i].type === 2 || logs[i].type === 6)) { expandDataLocal.push({ key: t('渠道信息'), value: `${logs[i].channel} - ${logs[i].channel_name || '[未知]'}`, @@ -535,6 +535,24 @@ export const useLogsData = () => { }); } } + if (logs[i].type === 6) { + if (other?.task_id) { + expandDataLocal.push({ + key: t('任务ID'), + value: other.task_id, + }); + } + if (other?.reason) { + expandDataLocal.push({ + key: t('失败原因'), + value: ( +
+ {other.reason} +
+ ), + }); + } + } if (other?.request_path) { expandDataLocal.push({ key: t('请求路径'), @@ -590,13 +608,13 @@ export const useLogsData = () => { ), }); } - if (isAdminUser) { + if (isAdminUser && logs[i].type !== 6) { expandDataLocal.push({ key: t('请求转换'), value: requestConversionDisplayValue(other?.request_conversion), }); } - if (isAdminUser) { + if (isAdminUser && logs[i].type !== 6) { let localCountMode = ''; if (other?.admin_info?.local_count_tokens) { localCountMode = t('本地计费'); diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index 8b2b0852..c2546833 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -2545,6 +2545,11 @@ "销毁容器": "Destroy Container", "销毁容器失败": "Failed to destroy container", "错误": "errors", + "退款": "Refund", + "错误详情": "Error Details", + "异步任务退款": "Async Task Refund", + "任务ID": "Task ID", + "失败原因": "Failure Reason", "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "The key is the group name, and the value is another JSON object. The key is the group name, and the value is the special group ratio for users in that group. For example: {\"vip\": {\"default\": 0.5, \"test\": 1}} means that users in the vip group have a ratio of 0.5 when using tokens from the default group, and a ratio of 1 when using tokens from the test group", "键为原状态码,值为要复写的状态码,仅影响本地判断": "The key is the original status code, and the value is the status code to override, only affects local judgment", "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "Keys are user group names and values are operation mappings. Inner keys prefixed with \"+:\" add the specified group (key is the group name, value is the description); keys prefixed with \"-:\" remove the specified group; keys without a prefix add that group directly. Example: {\"vip\": {\"+:premium\": \"Advanced group\", \"special\": \"Special group\", \"-:default\": \"Default group\"}} means vip users can access the premium and special groups while removing access to the default group.", diff --git a/web/src/i18n/locales/fr.json b/web/src/i18n/locales/fr.json index d4c76db6..54fd3617 100644 --- a/web/src/i18n/locales/fr.json +++ b/web/src/i18n/locales/fr.json @@ -2508,6 +2508,11 @@ "销毁容器": "Destroy Container", "销毁容器失败": "Failed to destroy container", "错误": "Erreur", + "退款": "Remboursement", + "错误详情": "Détails de l'erreur", + "异步任务退款": "Remboursement de tâche asynchrone", + "任务ID": "ID de tâche", + "失败原因": "Raison de l'échec", "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "La clé est le nom du groupe, la valeur est un autre objet JSON, la clé est le nom du groupe, la valeur est le ratio de groupe spécial des utilisateurs de ce groupe, par exemple : {\"vip\": {\"default\": 0.5, \"test\": 1}}, ce qui signifie que les utilisateurs du groupe vip ont un ratio de 0.5 lors de l'utilisation de jetons du groupe default et un ratio de 1 lors de l'utilisation du groupe test", "键为原状态码,值为要复写的状态码,仅影响本地判断": "La clé est le code d'état d'origine, la valeur est le code d'état à réécrire, n'affecte que le jugement local", "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "La clé correspond au nom du groupe d'utilisateurs et la valeur à un objet de mappage des opérations. Les clés internes commençant par \"+:\" ajoutent le groupe indiqué (clé = nom du groupe, valeur = description), celles commençant par \"-:\" retirent le groupe indiqué, et les clés sans préfixe ajoutent directement ce groupe. Exemple : {\"vip\": {\"+:premium\": \"Groupe avancé\", \"special\": \"Groupe spécial\", \"-:default\": \"Groupe par défaut\"}} signifie que les utilisateurs du groupe vip peuvent accéder aux groupes premium et special tout en perdant l'accès au groupe default.", diff --git a/web/src/i18n/locales/ja.json b/web/src/i18n/locales/ja.json index 9ab727ec..d9a49aa5 100644 --- a/web/src/i18n/locales/ja.json +++ b/web/src/i18n/locales/ja.json @@ -2491,6 +2491,11 @@ "销毁容器": "Destroy Container", "销毁容器失败": "Failed to destroy container", "错误": "エラー", + "退款": "返金", + "错误详情": "エラー詳細", + "异步任务退款": "非同期タスク返金", + "任务ID": "タスクID", + "失败原因": "失敗の原因", "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "キーはグループ名、値は別のJSONオブジェクトです。このオブジェクトのキーには、利用するトークンが属するグループ名を指定し、値にはそのユーザーグループに適用される特別な倍率を指定します。例:{\"vip\": {\"default\": 0.5, \"test\": 1}} は、vipグループのユーザーがdefaultグループのトークンを利用する際の倍率が0.5、testグループのトークンを利用する際の倍率が1になることを示します", "键为原状态码,值为要复写的状态码,仅影响本地判断": "キーは元のステータスコード、値は上書きするステータスコードで、ローカルでの判断にのみ影響します", "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "Keys are user group names and values are operation mappings. Inner keys prefixed with \"+:\" add the specified group (key is the group name, value is the description); keys prefixed with \"-:\" remove the specified group; keys without a prefix add that group directly. Example: {\"vip\": {\"+:premium\": \"Advanced group\", \"special\": \"Special group\", \"-:default\": \"Default group\"}} means vip users can access the premium and special groups while removing access to the default group.", diff --git a/web/src/i18n/locales/ru.json b/web/src/i18n/locales/ru.json index 97e243d3..fc117a51 100644 --- a/web/src/i18n/locales/ru.json +++ b/web/src/i18n/locales/ru.json @@ -2521,6 +2521,11 @@ "销毁容器": "Destroy Container", "销毁容器失败": "Failed to destroy container", "错误": "Ошибка", + "退款": "Возврат", + "错误详情": "Детали ошибки", + "异步任务退款": "Возврат асинхронной задачи", + "任务ID": "ID задачи", + "失败原因": "Причина ошибки", "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "Ключ - это имя группы, значение - другой JSON объект, ключ - имя группы, значение - специальный групповой коэффициент для пользователей этой группы, например: {\"vip\": {\"default\": 0.5, \"test\": 1}}, означает, что пользователи группы vip при использовании токенов группы default имеют коэффициент 0.5, при использовании группы test - коэффициент 1", "键为原状态码,值为要复写的状态码,仅影响本地判断": "Ключ - исходный код состояния, значение - код состояния для перезаписи, влияет только на локальную проверку", "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "Ключ — это название группы пользователей, значение — объект сопоставления операций. Внутренние ключи с префиксом \"+:\" добавляют указанные группы (ключ — название группы, значение — описание), с префиксом \"-:\" удаляют указанные группы, без префикса — сразу добавляют эту группу. Пример: {\"vip\": {\"+:premium\": \"Продвинутая группа\", \"special\": \"Особая группа\", \"-:default\": \"Группа по умолчанию\"}} означает, что пользователи группы vip могут использовать группы premium и special, одновременно теряя доступ к группе default.", diff --git a/web/src/i18n/locales/vi.json b/web/src/i18n/locales/vi.json index 8875b1b5..89d8715e 100644 --- a/web/src/i18n/locales/vi.json +++ b/web/src/i18n/locales/vi.json @@ -3060,10 +3060,14 @@ "销毁容器失败": "Failed to destroy container", "锁定": "Khóa", "错误": "Lỗi", + "退款": "Hoàn tiền", "错误信息": "Thông tin lỗi", "错误日志": "Nhật ký lỗi", "错误码": "Mã lỗi", "错误详情": "Chi tiết lỗi", + "异步任务退款": "Hoàn tiền tác vụ bất đồng bộ", + "任务ID": "ID tác vụ", + "失败原因": "Nguyên nhân thất bại", "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "Khóa là tên nhóm và giá trị là một đối tượng JSON khác. Khóa là tên nhóm và giá trị là tỷ lệ nhóm đặc biệt cho người dùng trong nhóm đó. Ví dụ: {\"vip\": {\"default\": 0.5, \"test\": 1}} có nghĩa là người dùng trong nhóm vip có tỷ lệ 0.5 khi sử dụng mã thông báo từ nhóm default và tỷ lệ 1 khi sử dụng mã thông báo từ nhóm test.", "键为原状态码,值为要复写的状态码,仅影响本地判断": "Khóa là mã trạng thái gốc và giá trị là mã trạng thái cần ghi đè, chỉ ảnh hưởng đến phán đoán cục bộ", "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "Keys are user group names and values are operation mappings. Inner keys prefixed with \"+:\" add the specified group (key is the group name, value is the description); keys prefixed with \"-:\" remove the specified group; keys without a prefix add that group directly. Example: {\"vip\": {\"+:premium\": \"Advanced group\", \"special\": \"Special group\", \"-:default\": \"Default group\"}} means vip users can access the premium and special groups while removing access to the default group.", diff --git a/web/src/i18n/locales/zh-CN.json b/web/src/i18n/locales/zh-CN.json index 43ce65b7..3cfcc032 100644 --- a/web/src/i18n/locales/zh-CN.json +++ b/web/src/i18n/locales/zh-CN.json @@ -2531,6 +2531,11 @@ "销毁容器": "销毁容器", "销毁容器失败": "销毁容器失败", "错误": "错误", + "退款": "退款", + "错误详情": "错误详情", + "异步任务退款": "异步任务退款", + "任务ID": "任务ID", + "失败原因": "失败原因", "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1": "键为分组名称,值为另一个 JSON 对象,键为分组名称,值为该分组的用户的特殊分组倍率,例如:{\"vip\": {\"default\": 0.5, \"test\": 1}},表示 vip 分组的用户在使用default分组的令牌时倍率为0.5,使用test分组时倍率为1", "键为原状态码,值为要复写的状态码,仅影响本地判断": "键为原状态码,值为要复写的状态码,仅影响本地判断", "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限": "键为用户分组名称,值为操作映射对象。内层键以\"+:\"开头表示添加指定分组(键值为分组名称,值为描述),以\"-:\"开头表示移除指定分组(键值为分组名称),不带前缀的键直接添加该分组。例如:{\"vip\": {\"+:premium\": \"高级分组\", \"special\": \"特殊分组\", \"-:default\": \"默认分组\"}},表示 vip 分组的用户可以使用 premium 和 special 分组,同时移除 default 分组的访问权限", From 7d5fc3ff5143bd86c5a1f18b08a02d02e3f12e93 Mon Sep 17 00:00:00 2001 From: CaIon Date: Sat, 21 Feb 2026 23:05:58 +0800 Subject: [PATCH 04/10] refactor(relay): rename RelayTask to RelayTaskFetch and update routing - Renamed RelayTask function to RelayTaskFetch for clarity. - Updated routing in relay-router.go and video-router.go to use RelayTaskFetch for fetch operations. - Enhanced error handling in RelayTaskFetch function. - Adjusted task data conversion in TaskAdaptor to include task ID. --- controller/relay.go | 26 +++++++++++++++----------- relay/channel/task/sora/adaptor.go | 8 +++++++- router/relay-router.go | 4 ++-- router/video-router.go | 8 ++++---- 4 files changed, 28 insertions(+), 18 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index e90d6dd0..1477df8f 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -450,6 +450,21 @@ func RelayNotFound(c *gin.Context) { }) } +func RelayTaskFetch(c *gin.Context) { + relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil) + if err != nil { + c.JSON(http.StatusInternalServerError, &dto.TaskError{ + Code: "gen_relay_info_failed", + Message: err.Error(), + StatusCode: http.StatusInternalServerError, + }) + return + } + if taskErr := relay.RelayTaskFetch(c, relayInfo.RelayMode); taskErr != nil { + respondTaskError(c, taskErr) + } +} + func RelayTask(c *gin.Context) { relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil) if err != nil { @@ -461,17 +476,6 @@ func RelayTask(c *gin.Context) { return } - // Fetch 路径:纯 DB 查询,不依赖上下文 channel,无需重试 - switch relayInfo.RelayMode { - case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID: - if taskErr := relay.RelayTaskFetch(c, relayInfo.RelayMode); taskErr != nil { - respondTaskError(c, taskErr) - } - return - } - - // ── Submit 路径 ───────────────────────────────────────────────── - if taskErr := relay.ResolveOriginTask(c, relayInfo); taskErr != nil { respondTaskError(c, taskErr) return diff --git a/relay/channel/task/sora/adaptor.go b/relay/channel/task/sora/adaptor.go index 8faaf984..bf2f7005 100644 --- a/relay/channel/task/sora/adaptor.go +++ b/relay/channel/task/sora/adaptor.go @@ -18,6 +18,7 @@ import ( "github.com/gin-gonic/gin" "github.com/pkg/errors" + "github.com/tidwall/sjson" ) // ============================ @@ -250,5 +251,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e } func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { - return task.Data, nil + data := task.Data + var err error + if data, err = sjson.SetBytes(data, "id", task.TaskID); err != nil { + return nil, errors.Wrap(err, "set id failed") + } + return data, nil } diff --git a/router/relay-router.go b/router/relay-router.go index 04584945..dcec439c 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -174,8 +174,8 @@ func SetRelayRouter(router *gin.Engine) { relaySunoRouter.Use(middleware.TokenAuth(), middleware.Distribute()) { relaySunoRouter.POST("/submit/:action", controller.RelayTask) - relaySunoRouter.POST("/fetch", controller.RelayTask) - relaySunoRouter.GET("/fetch/:id", controller.RelayTask) + relaySunoRouter.POST("/fetch", controller.RelayTaskFetch) + relaySunoRouter.GET("/fetch/:id", controller.RelayTaskFetch) } relayGeminiRouter := router.Group("/v1beta") diff --git a/router/video-router.go b/router/video-router.go index d2bce42b..875b0af8 100644 --- a/router/video-router.go +++ b/router/video-router.go @@ -19,14 +19,14 @@ func SetVideoRouter(router *gin.Engine) { videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) { videoV1Router.POST("/video/generations", controller.RelayTask) - videoV1Router.GET("/video/generations/:task_id", controller.RelayTask) + videoV1Router.GET("/video/generations/:task_id", controller.RelayTaskFetch) videoV1Router.POST("/videos/:video_id/remix", controller.RelayTask) } // openai compatible API video routes // docs: https://platform.openai.com/docs/api-reference/videos/create { videoV1Router.POST("/videos", controller.RelayTask) - videoV1Router.GET("/videos/:task_id", controller.RelayTask) + videoV1Router.GET("/videos/:task_id", controller.RelayTaskFetch) } klingV1Router := router.Group("/kling/v1") @@ -34,8 +34,8 @@ func SetVideoRouter(router *gin.Engine) { { klingV1Router.POST("/videos/text2video", controller.RelayTask) klingV1Router.POST("/videos/image2video", controller.RelayTask) - klingV1Router.GET("/videos/text2video/:task_id", controller.RelayTask) - klingV1Router.GET("/videos/image2video/:task_id", controller.RelayTask) + klingV1Router.GET("/videos/text2video/:task_id", controller.RelayTaskFetch) + klingV1Router.GET("/videos/image2video/:task_id", controller.RelayTaskFetch) } // Jimeng official API routes - direct mapping to official API format From 143b4535b22d6d6394d4ad3f11a412861be1dbfc Mon Sep 17 00:00:00 2001 From: CaIon Date: Sat, 21 Feb 2026 23:20:31 +0800 Subject: [PATCH 05/10] refactor(relay): enhance remix logic for billing context extraction - Updated the remix handling in ResolveOriginTask to prioritize extracting OtherRatios from the BillingContext of the original task if available. - Retained the previous logic for extracting seconds and size from task data as a fallback. - Improved clarity and maintainability of the remix logic by separating the new and old approaches. --- relay/relay_task.go | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/relay/relay_task.go b/relay/relay_task.go index 7c6724d8..cc4d0e45 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -106,21 +106,29 @@ func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskErr // 提取 remix 参数(时长、分辨率 → OtherRatios) if info.Action == constant.TaskActionRemix { - var taskData map[string]interface{} - _ = common.Unmarshal(originTask.Data, &taskData) - secondsStr, _ := taskData["seconds"].(string) - seconds, _ := strconv.Atoi(secondsStr) - if seconds <= 0 { - seconds = 4 - } - sizeStr, _ := taskData["size"].(string) - if info.PriceData.OtherRatios == nil { - info.PriceData.OtherRatios = map[string]float64{} - } - info.PriceData.OtherRatios["seconds"] = float64(seconds) - info.PriceData.OtherRatios["size"] = 1 - if sizeStr == "1792x1024" || sizeStr == "1024x1792" { - info.PriceData.OtherRatios["size"] = 1.666667 + if originTask.PrivateData.BillingContext != nil { + // 新的 remix 逻辑:直接从原始任务的 BillingContext 中提取 OtherRatios(如果存在) + for s, f := range originTask.PrivateData.BillingContext.OtherRatios { + info.PriceData.AddOtherRatio(s, f) + } + } else { + // 旧的 remix 逻辑:直接从 task data 解析 seconds 和 size(如果存在) + var taskData map[string]interface{} + _ = common.Unmarshal(originTask.Data, &taskData) + secondsStr, _ := taskData["seconds"].(string) + seconds, _ := strconv.Atoi(secondsStr) + if seconds <= 0 { + seconds = 4 + } + sizeStr, _ := taskData["size"].(string) + if info.PriceData.OtherRatios == nil { + info.PriceData.OtherRatios = map[string]float64{} + } + info.PriceData.OtherRatios["seconds"] = float64(seconds) + info.PriceData.OtherRatios["size"] = 1 + if sizeStr == "1792x1024" || sizeStr == "1024x1792" { + info.PriceData.OtherRatios["size"] = 1.666667 + } } } From 6f39c0285706c0eda5836e2916a3c097c060be36 Mon Sep 17 00:00:00 2001 From: CaIon Date: Sat, 21 Feb 2026 23:47:55 +0800 Subject: [PATCH 06/10] refactor(relay): improve channel locking and retry logic in RelayTask - Enhanced the RelayTask function to utilize a locked channel when available, allowing for better reuse during retries. - Updated error handling to ensure proper context setup for the selected channel. - Clarified comments in ResolveOriginTask regarding channel locking and retry behavior. - Introduced a new field in TaskRelayInfo to store the locked channel object, improving type safety and reducing import cycles. --- controller/relay.go | 23 ++++++++++++++++++----- relay/common/relay_info.go | 5 +++++ relay/relay_task.go | 26 +++++++++++++------------- 3 files changed, 36 insertions(+), 18 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index 1477df8f..6951974c 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -497,11 +497,24 @@ func RelayTask(c *gin.Context) { } for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() { - channel, channelErr := getChannel(c, relayInfo, retryParam) - if channelErr != nil { - logger.LogError(c, channelErr.Error()) - taskErr = service.TaskErrorWrapperLocal(channelErr.Err, "get_channel_failed", http.StatusInternalServerError) - break + var channel *model.Channel + + if lockedCh, ok := relayInfo.LockedChannel.(*model.Channel); ok && lockedCh != nil { + channel = lockedCh + if retryParam.GetRetry() > 0 { + if setupErr := middleware.SetupContextForSelectedChannel(c, channel, relayInfo.OriginModelName); setupErr != nil { + taskErr = service.TaskErrorWrapperLocal(setupErr.Err, "setup_locked_channel_failed", http.StatusInternalServerError) + break + } + } + } else { + var channelErr *types.NewAPIError + channel, channelErr = getChannel(c, relayInfo, retryParam) + if channelErr != nil { + logger.LogError(c, channelErr.Error()) + taskErr = service.TaskErrorWrapperLocal(channelErr.Err, "get_channel_failed", http.StatusInternalServerError) + break + } } addUsedChannel(c, channel.Id) diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index b6882681..541f1b9f 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -619,6 +619,11 @@ type TaskRelayInfo struct { PublicTaskID string ConsumeQuota bool + + // LockedChannel holds the full channel object when the request is bound to + // a specific channel (e.g., remix on origin task's channel). Stored as any + // to avoid an import cycle with model; callers type-assert to *model.Channel. + LockedChannel any } type TaskSubmitReq struct { diff --git a/relay/relay_task.go b/relay/relay_task.go index cc4d0e45..8d0e61d7 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -32,8 +32,9 @@ type TaskSubmitResult struct { } // ResolveOriginTask 处理基于已有任务的提交(remix / continuation): -// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道(并通过 -// specific_channel_id 禁止重试),以及提取 OtherRatios(时长、分辨率)。 +// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道 +// (通过 info.LockedChannel,重试时复用同一渠道并轮换 key), +// 以及提取 OtherRatios(时长、分辨率)。 // 该函数在控制器的重试循环之前调用一次,其结果通过 info 字段和上下文持久化。 func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError { // 检测 remix action @@ -77,15 +78,17 @@ func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskErr } } - // 锁定到原始任务的渠道(如果与当前选中的不同) + // 锁定到原始任务的渠道(重试时复用同一渠道,轮换 key) + ch, err := model.GetChannelById(originTask.ChannelId, true) + if err != nil { + return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) + } + if ch.Status != common.ChannelStatusEnabled { + return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest) + } + info.LockedChannel = ch + if originTask.ChannelId != info.ChannelId { - ch, err := model.GetChannelById(originTask.ChannelId, true) - if err != nil { - return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) - } - if ch.Status != common.ChannelStatusEnabled { - return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest) - } key, _, newAPIError := ch.GetNextEnabledKey() if newAPIError != nil { return service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode) @@ -101,9 +104,6 @@ func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskErr info.ApiKey = key } - // 渠道已锁定到原始任务 → 禁止重试切换到其他渠道 - c.Set("specific_channel_id", fmt.Sprintf("%d", originTask.ChannelId)) - // 提取 remix 参数(时长、分辨率 → OtherRatios) if info.Action == constant.TaskActionRemix { if originTask.PrivateData.BillingContext != nil { From b386490d5e532db598f97894c889b6536949a7d2 Mon Sep 17 00:00:00 2001 From: CaIon Date: Sun, 22 Feb 2026 00:52:35 +0800 Subject: [PATCH 07/10] refactor(task): add CAS-guarded updates to prevent concurrent billing conflicts Replace all bare task.Update() (DB.Save) calls with UpdateWithStatus(), which adds a WHERE status = ? guard to prevent concurrent processes from overwriting each other's state transitions. Key changes: model/task.go: - Add taskSnapshot struct with Equal() method for change detection - Add Snapshot() method to capture pre-update state - Add UpdateWithStatus(fromStatus) using DB.Where().Save() for CAS semantics with full-struct save (no explicit field listing needed) model/midjourney.go: - Add UpdateWithStatus(fromStatus string) with same CAS pattern service/task_polling.go (updateVideoSingleTask): - Snapshot before processing upstream response; skip DB write if unchanged - Terminal transitions (SUCCESS/FAILURE) use UpdateWithStatus CAS: billing/refund only executes if this process wins the transition - Non-terminal updates also use UpdateWithStatus to prevent overwriting a concurrent terminal transition back to IN_PROGRESS - Defer settleTaskBillingOnComplete to after CAS check (shouldSettle flag) relay/relay_task.go (tryRealtimeFetch): - Add snapshot + change detection; use UpdateWithStatus for CAS safety controller/midjourney.go (UpdateMidjourneyTaskBulk): - Capture preStatus before mutations; use UpdateWithStatus CAS - Gate refund (IncreaseUserQuota) on CAS success (won && shouldReturnQuota) This prevents the multi-instance race condition where: 1. Instance A reads task (IN_PROGRESS), fetches upstream (still IN_PROGRESS) 2. Instance B reads same task, fetches upstream (now SUCCESS), writes SUCCESS 3. Instance A's bare Save() overwrites SUCCESS back to IN_PROGRESS --- controller/midjourney.go | 17 ++++---- model/midjourney.go | 11 +++++ model/task.go | 91 ++++++++++++++++++---------------------- relay/relay_task.go | 7 +++- service/task_polling.go | 43 ++++++++++++------- 5 files changed, 95 insertions(+), 74 deletions(-) diff --git a/controller/midjourney.go b/controller/midjourney.go index c480c12b..4045a550 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -130,6 +130,7 @@ func UpdateMidjourneyTaskBulk() { if !checkMjTaskNeedUpdate(task, responseItem) { continue } + preStatus := task.Status task.Code = 1 task.Progress = responseItem.Progress task.PromptEn = responseItem.PromptEn @@ -172,18 +173,16 @@ func UpdateMidjourneyTaskBulk() { shouldReturnQuota = true } } - err = task.Update() + won, err := task.UpdateWithStatus(preStatus) if err != nil { logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) - } else { - if shouldReturnQuota { - err = model.IncreaseUserQuota(task.UserId, task.Quota, false) - if err != nil { - logger.LogError(ctx, "fail to increase user quota: "+err.Error()) - } - logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + } else if won && shouldReturnQuota { + err = model.IncreaseUserQuota(task.UserId, task.Quota, false) + if err != nil { + logger.LogError(ctx, "fail to increase user quota: "+err.Error()) } + logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota)) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) } } } diff --git a/model/midjourney.go b/model/midjourney.go index c6ef5de5..9867e8a9 100644 --- a/model/midjourney.go +++ b/model/midjourney.go @@ -157,6 +157,17 @@ func (midjourney *Midjourney) Update() error { return err } +// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS). +// Returns (true, nil) if this caller won the update, (false, nil) if +// another process already moved the task out of fromStatus. +func (midjourney *Midjourney) UpdateWithStatus(fromStatus string) (bool, error) { + result := DB.Where("status = ?", fromStatus).Save(midjourney) + if result.Error != nil { + return false, result.Error + } + return result.RowsAffected > 0, nil +} + func MjBulkUpdate(mjIds []string, params map[string]any) error { return DB.Model(&Midjourney{}). Where("mj_id in (?)", mjIds). diff --git a/model/task.go b/model/task.go index 592643eb..4d1482f8 100644 --- a/model/task.go +++ b/model/task.go @@ -1,6 +1,7 @@ package model import ( + "bytes" "database/sql/driver" "encoding/json" "time" @@ -340,38 +341,59 @@ func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) { return task, nil } -func TaskUpdateProgress(id int64, progress string) error { - return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error -} - func (Task *Task) Insert() error { var err error err = DB.Create(Task).Error return err } +type taskSnapshot struct { + Status TaskStatus + Progress string + StartTime int64 + FinishTime int64 + FailReason string + ResultURL string + Data json.RawMessage +} + +func (s taskSnapshot) Equal(other taskSnapshot) bool { + return s.Status == other.Status && + s.Progress == other.Progress && + s.StartTime == other.StartTime && + s.FinishTime == other.FinishTime && + s.FailReason == other.FailReason && + s.ResultURL == other.ResultURL && + bytes.Equal(s.Data, other.Data) +} + +func (t *Task) Snapshot() taskSnapshot { + return taskSnapshot{ + Status: t.Status, + Progress: t.Progress, + StartTime: t.StartTime, + FinishTime: t.FinishTime, + FailReason: t.FailReason, + ResultURL: t.PrivateData.ResultURL, + Data: t.Data, + } +} + func (Task *Task) Update() error { var err error err = DB.Save(Task).Error return err } -func TaskBulkUpdate(TaskIds []string, params map[string]any) error { - if len(TaskIds) == 0 { - return nil +// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS). +// Returns (true, nil) if this caller won the update, (false, nil) if +// another process already moved the task out of fromStatus. +func (t *Task) UpdateWithStatus(fromStatus TaskStatus) (bool, error) { + result := DB.Where("status = ?", fromStatus).Save(t) + if result.Error != nil { + return false, result.Error } - return DB.Model(&Task{}). - Where("task_id in (?)", TaskIds). - Updates(params).Error -} - -func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error { - if len(taskIDs) == 0 { - return nil - } - return DB.Model(&Task{}). - Where("id in (?)", taskIDs). - Updates(params).Error + return result.RowsAffected > 0, nil } func TaskBulkUpdateByID(ids []int64, params map[string]any) error { @@ -388,37 +410,6 @@ type TaskQuotaUsage struct { Count float64 `json:"count"` } -func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) { - query := DB.Model(Task{}) - // 添加过滤条件 - if queryParams.ChannelID != "" { - query = query.Where("channel_id = ?", queryParams.ChannelID) - } - if queryParams.UserID != "" { - query = query.Where("user_id = ?", queryParams.UserID) - } - if len(queryParams.UserIDs) != 0 { - query = query.Where("user_id in (?)", queryParams.UserIDs) - } - if queryParams.TaskID != "" { - query = query.Where("task_id = ?", queryParams.TaskID) - } - if queryParams.Action != "" { - query = query.Where("action = ?", queryParams.Action) - } - if queryParams.Status != "" { - query = query.Where("status = ?", queryParams.Status) - } - if queryParams.StartTimestamp != 0 { - query = query.Where("submit_time >= ?", queryParams.StartTimestamp) - } - if queryParams.EndTimestamp != 0 { - query = query.Where("submit_time <= ?", queryParams.EndTimestamp) - } - err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error - return stat, err -} - // TaskCountAllTasks returns total tasks that match the given query params (admin usage) func TaskCountAllTasks(queryParams SyncTaskQueryParams) int64 { var total int64 diff --git a/relay/relay_task.go b/relay/relay_task.go index 8d0e61d7..cd43e6eb 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -444,6 +444,8 @@ func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte { return nil } + snap := task.Snapshot() + // 将上游最新状态更新到 task if ti.Status != "" { task.Status = model.TaskStatus(ti.Status) @@ -459,7 +461,10 @@ func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte { // No URL from adaptor — construct proxy URL using public task ID task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID) } - _ = task.Update() + + if !snap.Equal(task.Snapshot()) { + _, _ = task.UpdateWithStatus(snap.Status) + } // OpenAI Video API 由调用者的 ConvertToOpenAIVideo 分支处理 if isOpenAIVideoAPI { diff --git a/service/task_polling.go b/service/task_polling.go index efbad8af..7e92d14b 100644 --- a/service/task_polling.go +++ b/service/task_polling.go @@ -319,6 +319,8 @@ func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch * logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask response: %s", string(responseBody))) + snap := task.Snapshot() + taskResult := &relaycommon.TaskInfo{} // try parse as New API response format var responseItems dto.TaskResponse[model.Task] @@ -344,10 +346,9 @@ func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch * taskResult = relaycommon.FailTaskInfo("upstream returned empty status") } - // 记录原本的状态,防止重复退款 shouldRefund := false + shouldSettle := false quota := task.Quota - preStatus := task.Status task.Status = model.TaskStatus(taskResult.Status) switch taskResult.Status { @@ -374,9 +375,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch * // No URL from adaptor — construct proxy URL using public task ID task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID) } - - // 完成时计费调整:优先由 adaptor 计算,回退到 token 重算 - settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) + shouldSettle = true case model.TaskStatusFailure: logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task) task.Status = model.TaskStatusFailure @@ -388,23 +387,39 @@ func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch * logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) taskResult.Progress = taskcommon.ProgressComplete if quota != 0 { - if preStatus != model.TaskStatusFailure { - shouldRefund = true - } else { - logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID)) - } + shouldRefund = true } default: - return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId) + return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, task.TaskID) } if taskResult.Progress != "" { task.Progress = taskResult.Progress } - if err := task.Update(); err != nil { - common.SysLog("UpdateVideoTask task error: " + err.Error()) - shouldRefund = false + + isDone := task.Status == model.TaskStatusSuccess || task.Status == model.TaskStatusFailure + if isDone && snap.Status != task.Status { + won, err := task.UpdateWithStatus(snap.Status) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("UpdateWithStatus failed for task %s: %s", task.TaskID, err.Error())) + shouldRefund = false + shouldSettle = false + } else if !won { + logger.LogWarn(ctx, fmt.Sprintf("Task %s already transitioned by another process, skip billing", task.TaskID)) + shouldRefund = false + shouldSettle = false + } + } else if !snap.Equal(task.Snapshot()) { + if _, err := task.UpdateWithStatus(snap.Status); err != nil { + logger.LogError(ctx, fmt.Sprintf("Failed to update task %s: %s", task.TaskID, err.Error())) + } + } else { + // No changes, skip update + logger.LogDebug(ctx, fmt.Sprintf("No update needed for task %s", task.TaskID)) } + if shouldSettle { + settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) + } if shouldRefund { RefundTaskQuota(ctx, task, task.FailReason) } From 374aabf3014e8748f1afd388703e64ec9635ae97 Mon Sep 17 00:00:00 2001 From: CaIon Date: Sun, 22 Feb 2026 01:25:04 +0800 Subject: [PATCH 08/10] refactor(task): enhance UpdateWithStatus for CAS updates and add integration tests - Updated UpdateWithStatus method to use Model().Select("*").Updates() for conditional updates, preventing GORM's INSERT fallback. - Introduced comprehensive integration tests for UpdateWithStatus, covering scenarios for winning and losing CAS updates, as well as concurrent updates. - Added task_cas_test.go to validate the new behavior and ensure data integrity during concurrent state transitions. --- model/midjourney.go | 4 +- model/task.go | 6 +- model/task_cas_test.go | 217 +++++++++++++ service/task_billing_test.go | 606 +++++++++++++++++++++++++++++++++++ 4 files changed, 831 insertions(+), 2 deletions(-) create mode 100644 model/task_cas_test.go create mode 100644 service/task_billing_test.go diff --git a/model/midjourney.go b/model/midjourney.go index 9867e8a9..e1a8d772 100644 --- a/model/midjourney.go +++ b/model/midjourney.go @@ -160,8 +160,10 @@ func (midjourney *Midjourney) Update() error { // UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS). // Returns (true, nil) if this caller won the update, (false, nil) if // another process already moved the task out of fromStatus. +// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS). +// Uses Model().Select("*").Updates() to avoid GORM Save()'s INSERT fallback. func (midjourney *Midjourney) UpdateWithStatus(fromStatus string) (bool, error) { - result := DB.Where("status = ?", fromStatus).Save(midjourney) + result := DB.Model(midjourney).Where("status = ?", fromStatus).Select("*").Updates(midjourney) if result.Error != nil { return false, result.Error } diff --git a/model/task.go b/model/task.go index 4d1482f8..0cf6bd47 100644 --- a/model/task.go +++ b/model/task.go @@ -388,8 +388,12 @@ func (Task *Task) Update() error { // UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS). // Returns (true, nil) if this caller won the update, (false, nil) if // another process already moved the task out of fromStatus. +// +// Uses Model().Select("*").Updates() instead of Save() because GORM's Save +// falls back to INSERT ON CONFLICT when the WHERE-guarded UPDATE matches +// zero rows, which silently bypasses the CAS guard. func (t *Task) UpdateWithStatus(fromStatus TaskStatus) (bool, error) { - result := DB.Where("status = ?", fromStatus).Save(t) + result := DB.Model(t).Where("status = ?", fromStatus).Select("*").Updates(t) if result.Error != nil { return false, result.Error } diff --git a/model/task_cas_test.go b/model/task_cas_test.go new file mode 100644 index 00000000..3449c6d2 --- /dev/null +++ b/model/task_cas_test.go @@ -0,0 +1,217 @@ +package model + +import ( + "encoding/json" + "os" + "sync" + "testing" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/glebarez/sqlite" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +func TestMain(m *testing.M) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + panic("failed to open test db: " + err.Error()) + } + DB = db + LOG_DB = db + + common.UsingSQLite = true + common.RedisEnabled = false + common.BatchUpdateEnabled = false + common.LogConsumeEnabled = true + + sqlDB, err := db.DB() + if err != nil { + panic("failed to get sql.DB: " + err.Error()) + } + sqlDB.SetMaxOpenConns(1) + + if err := db.AutoMigrate(&Task{}, &User{}, &Token{}, &Log{}, &Channel{}); err != nil { + panic("failed to migrate: " + err.Error()) + } + + os.Exit(m.Run()) +} + +func truncateTables(t *testing.T) { + t.Helper() + t.Cleanup(func() { + DB.Exec("DELETE FROM tasks") + DB.Exec("DELETE FROM users") + DB.Exec("DELETE FROM tokens") + DB.Exec("DELETE FROM logs") + DB.Exec("DELETE FROM channels") + }) +} + +func insertTask(t *testing.T, task *Task) { + t.Helper() + task.CreatedAt = time.Now().Unix() + task.UpdatedAt = time.Now().Unix() + require.NoError(t, DB.Create(task).Error) +} + +// --------------------------------------------------------------------------- +// Snapshot / Equal — pure logic tests (no DB) +// --------------------------------------------------------------------------- + +func TestSnapshotEqual_Same(t *testing.T) { + s := taskSnapshot{ + Status: TaskStatusInProgress, + Progress: "50%", + StartTime: 1000, + FinishTime: 0, + FailReason: "", + ResultURL: "", + Data: json.RawMessage(`{"key":"value"}`), + } + assert.True(t, s.Equal(s)) +} + +func TestSnapshotEqual_DifferentStatus(t *testing.T) { + a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{}`)} + b := taskSnapshot{Status: TaskStatusSuccess, Data: json.RawMessage(`{}`)} + assert.False(t, a.Equal(b)) +} + +func TestSnapshotEqual_DifferentProgress(t *testing.T) { + a := taskSnapshot{Status: TaskStatusInProgress, Progress: "30%", Data: json.RawMessage(`{}`)} + b := taskSnapshot{Status: TaskStatusInProgress, Progress: "60%", Data: json.RawMessage(`{}`)} + assert.False(t, a.Equal(b)) +} + +func TestSnapshotEqual_DifferentData(t *testing.T) { + a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":1}`)} + b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":2}`)} + assert.False(t, a.Equal(b)) +} + +func TestSnapshotEqual_NilVsEmpty(t *testing.T) { + a := taskSnapshot{Status: TaskStatusInProgress, Data: nil} + b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage{}} + // bytes.Equal(nil, []byte{}) == true + assert.True(t, a.Equal(b)) +} + +func TestSnapshot_Roundtrip(t *testing.T) { + task := &Task{ + Status: TaskStatusInProgress, + Progress: "42%", + StartTime: 1234, + FinishTime: 5678, + FailReason: "timeout", + PrivateData: TaskPrivateData{ + ResultURL: "https://example.com/result.mp4", + }, + Data: json.RawMessage(`{"model":"test-model"}`), + } + snap := task.Snapshot() + assert.Equal(t, task.Status, snap.Status) + assert.Equal(t, task.Progress, snap.Progress) + assert.Equal(t, task.StartTime, snap.StartTime) + assert.Equal(t, task.FinishTime, snap.FinishTime) + assert.Equal(t, task.FailReason, snap.FailReason) + assert.Equal(t, task.PrivateData.ResultURL, snap.ResultURL) + assert.JSONEq(t, string(task.Data), string(snap.Data)) +} + +// --------------------------------------------------------------------------- +// UpdateWithStatus CAS — DB integration tests +// --------------------------------------------------------------------------- + +func TestUpdateWithStatus_Win(t *testing.T) { + truncateTables(t) + + task := &Task{ + TaskID: "task_cas_win", + Status: TaskStatusInProgress, + Progress: "50%", + Data: json.RawMessage(`{}`), + } + insertTask(t, task) + + task.Status = TaskStatusSuccess + task.Progress = "100%" + won, err := task.UpdateWithStatus(TaskStatusInProgress) + require.NoError(t, err) + assert.True(t, won) + + var reloaded Task + require.NoError(t, DB.First(&reloaded, task.ID).Error) + assert.EqualValues(t, TaskStatusSuccess, reloaded.Status) + assert.Equal(t, "100%", reloaded.Progress) +} + +func TestUpdateWithStatus_Lose(t *testing.T) { + truncateTables(t) + + task := &Task{ + TaskID: "task_cas_lose", + Status: TaskStatusFailure, + Data: json.RawMessage(`{}`), + } + insertTask(t, task) + + task.Status = TaskStatusSuccess + won, err := task.UpdateWithStatus(TaskStatusInProgress) // wrong fromStatus + require.NoError(t, err) + assert.False(t, won) + + var reloaded Task + require.NoError(t, DB.First(&reloaded, task.ID).Error) + assert.EqualValues(t, TaskStatusFailure, reloaded.Status) // unchanged +} + +func TestUpdateWithStatus_ConcurrentWinner(t *testing.T) { + truncateTables(t) + + task := &Task{ + TaskID: "task_cas_race", + Status: TaskStatusInProgress, + Quota: 1000, + Data: json.RawMessage(`{}`), + } + insertTask(t, task) + + const goroutines = 5 + wins := make([]bool, goroutines) + var wg sync.WaitGroup + wg.Add(goroutines) + + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + t := &Task{} + *t = Task{ + ID: task.ID, + TaskID: task.TaskID, + Status: TaskStatusSuccess, + Progress: "100%", + Quota: task.Quota, + Data: json.RawMessage(`{}`), + } + t.CreatedAt = task.CreatedAt + t.UpdatedAt = time.Now().Unix() + won, err := t.UpdateWithStatus(TaskStatusInProgress) + if err == nil { + wins[idx] = won + } + }(i) + } + wg.Wait() + + winCount := 0 + for _, w := range wins { + if w { + winCount++ + } + } + assert.Equal(t, 1, winCount, "exactly one goroutine should win the CAS") +} diff --git a/service/task_billing_test.go b/service/task_billing_test.go new file mode 100644 index 00000000..6c2d231d --- /dev/null +++ b/service/task_billing_test.go @@ -0,0 +1,606 @@ +package service + +import ( + "context" + "encoding/json" + "os" + "testing" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/model" + "github.com/glebarez/sqlite" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +func TestMain(m *testing.M) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + panic("failed to open test db: " + err.Error()) + } + sqlDB, err := db.DB() + if err != nil { + panic("failed to get sql.DB: " + err.Error()) + } + sqlDB.SetMaxOpenConns(1) + + model.DB = db + model.LOG_DB = db + + common.UsingSQLite = true + common.RedisEnabled = false + common.BatchUpdateEnabled = false + common.LogConsumeEnabled = true + + if err := db.AutoMigrate( + &model.Task{}, + &model.User{}, + &model.Token{}, + &model.Log{}, + &model.Channel{}, + &model.UserSubscription{}, + ); err != nil { + panic("failed to migrate: " + err.Error()) + } + + os.Exit(m.Run()) +} + +// --------------------------------------------------------------------------- +// Seed helpers +// --------------------------------------------------------------------------- + +func truncate(t *testing.T) { + t.Helper() + t.Cleanup(func() { + model.DB.Exec("DELETE FROM tasks") + model.DB.Exec("DELETE FROM users") + model.DB.Exec("DELETE FROM tokens") + model.DB.Exec("DELETE FROM logs") + model.DB.Exec("DELETE FROM channels") + model.DB.Exec("DELETE FROM user_subscriptions") + }) +} + +func seedUser(t *testing.T, id int, quota int) { + t.Helper() + user := &model.User{Id: id, Username: "test_user", Quota: quota, Status: common.UserStatusEnabled} + require.NoError(t, model.DB.Create(user).Error) +} + +func seedToken(t *testing.T, id int, userId int, key string, remainQuota int) { + t.Helper() + token := &model.Token{ + Id: id, + UserId: userId, + Key: key, + Name: "test_token", + Status: common.TokenStatusEnabled, + RemainQuota: remainQuota, + UsedQuota: 0, + } + require.NoError(t, model.DB.Create(token).Error) +} + +func seedSubscription(t *testing.T, id int, userId int, amountTotal int64, amountUsed int64) { + t.Helper() + sub := &model.UserSubscription{ + Id: id, + UserId: userId, + AmountTotal: amountTotal, + AmountUsed: amountUsed, + Status: "active", + StartTime: time.Now().Unix(), + EndTime: time.Now().Add(30 * 24 * time.Hour).Unix(), + } + require.NoError(t, model.DB.Create(sub).Error) +} + +func seedChannel(t *testing.T, id int) { + t.Helper() + ch := &model.Channel{Id: id, Name: "test_channel", Key: "sk-test", Status: common.ChannelStatusEnabled} + require.NoError(t, model.DB.Create(ch).Error) +} + +func makeTask(userId, channelId, quota, tokenId int, billingSource string, subscriptionId int) *model.Task { + return &model.Task{ + TaskID: "task_" + time.Now().Format("150405.000"), + UserId: userId, + ChannelId: channelId, + Quota: quota, + Status: model.TaskStatus(model.TaskStatusInProgress), + Group: "default", + Data: json.RawMessage(`{}`), + CreatedAt: time.Now().Unix(), + UpdatedAt: time.Now().Unix(), + Properties: model.Properties{ + OriginModelName: "test-model", + }, + PrivateData: model.TaskPrivateData{ + BillingSource: billingSource, + SubscriptionId: subscriptionId, + TokenId: tokenId, + BillingContext: &model.TaskBillingContext{ + ModelPrice: 0.02, + GroupRatio: 1.0, + ModelName: "test-model", + }, + }, + } +} + +// --------------------------------------------------------------------------- +// Read-back helpers +// --------------------------------------------------------------------------- + +func getUserQuota(t *testing.T, id int) int { + t.Helper() + var user model.User + require.NoError(t, model.DB.Select("quota").Where("id = ?", id).First(&user).Error) + return user.Quota +} + +func getTokenRemainQuota(t *testing.T, id int) int { + t.Helper() + var token model.Token + require.NoError(t, model.DB.Select("remain_quota").Where("id = ?", id).First(&token).Error) + return token.RemainQuota +} + +func getTokenUsedQuota(t *testing.T, id int) int { + t.Helper() + var token model.Token + require.NoError(t, model.DB.Select("used_quota").Where("id = ?", id).First(&token).Error) + return token.UsedQuota +} + +func getSubscriptionUsed(t *testing.T, id int) int64 { + t.Helper() + var sub model.UserSubscription + require.NoError(t, model.DB.Select("amount_used").Where("id = ?", id).First(&sub).Error) + return sub.AmountUsed +} + +func getLastLog(t *testing.T) *model.Log { + t.Helper() + var log model.Log + err := model.LOG_DB.Order("id desc").First(&log).Error + if err != nil { + return nil + } + return &log +} + +func countLogs(t *testing.T) int64 { + t.Helper() + var count int64 + model.LOG_DB.Model(&model.Log{}).Count(&count) + return count +} + +// =========================================================================== +// RefundTaskQuota tests +// =========================================================================== + +func TestRefundTaskQuota_Wallet(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 1, 1, 1 + const initQuota, preConsumed = 10000, 3000 + const tokenRemain = 5000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-test-key", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + + RefundTaskQuota(ctx, task, "task failed: upstream error") + + // User quota should increase by preConsumed + assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID)) + + // Token remain_quota should increase, used_quota should decrease + assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID)) + assert.Equal(t, -preConsumed, getTokenUsedQuota(t, tokenID)) + + // A refund log should be created + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) + assert.Equal(t, preConsumed, log.Quota) + assert.Equal(t, "test-model", log.ModelName) +} + +func TestRefundTaskQuota_Subscription(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID, subID = 2, 2, 2, 1 + const preConsumed = 2000 + const subTotal, subUsed int64 = 100000, 50000 + const tokenRemain = 8000 + + seedUser(t, userID, 0) + seedToken(t, tokenID, userID, "sk-sub-key", tokenRemain) + seedChannel(t, channelID) + seedSubscription(t, subID, userID, subTotal, subUsed) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID) + + RefundTaskQuota(ctx, task, "subscription task failed") + + // Subscription used should decrease by preConsumed + assert.Equal(t, subUsed-int64(preConsumed), getSubscriptionUsed(t, subID)) + + // Token should also be refunded + assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID)) + + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) +} + +func TestRefundTaskQuota_ZeroQuota(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID = 3 + seedUser(t, userID, 5000) + + task := makeTask(userID, 0, 0, 0, BillingSourceWallet, 0) + + RefundTaskQuota(ctx, task, "zero quota task") + + // No change to user quota + assert.Equal(t, 5000, getUserQuota(t, userID)) + + // No log created + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestRefundTaskQuota_NoToken(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, channelID = 4, 4 + const initQuota, preConsumed = 10000, 1500 + + seedUser(t, userID, initQuota) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0) // TokenId=0 + + RefundTaskQuota(ctx, task, "no token task failed") + + // User quota refunded + assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID)) + + // Log created + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) +} + +// =========================================================================== +// RecalculateTaskQuota tests +// =========================================================================== + +func TestRecalculate_PositiveDelta(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 10, 10, 10 + const initQuota, preConsumed = 10000, 2000 + const actualQuota = 3000 // under-charged by 1000 + const tokenRemain = 5000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-recalc-pos", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + + RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment") + + // User quota should decrease by the delta (1000 additional charge) + assert.Equal(t, initQuota-(actualQuota-preConsumed), getUserQuota(t, userID)) + + // Token should also be charged the delta + assert.Equal(t, tokenRemain-(actualQuota-preConsumed), getTokenRemainQuota(t, tokenID)) + + // task.Quota should be updated to actualQuota + assert.Equal(t, actualQuota, task.Quota) + + // Log type should be Consume (additional charge) + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeConsume, log.Type) + assert.Equal(t, actualQuota-preConsumed, log.Quota) +} + +func TestRecalculate_NegativeDelta(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 11, 11, 11 + const initQuota, preConsumed = 10000, 5000 + const actualQuota = 3000 // over-charged by 2000 + const tokenRemain = 5000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-recalc-neg", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + + RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment") + + // User quota should increase by abs(delta) = 2000 (refund overpayment) + assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID)) + + // Token should be refunded the difference + assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID)) + + // task.Quota updated + assert.Equal(t, actualQuota, task.Quota) + + // Log type should be Refund + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) + assert.Equal(t, preConsumed-actualQuota, log.Quota) +} + +func TestRecalculate_ZeroDelta(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID = 12 + const initQuota, preConsumed = 10000, 3000 + + seedUser(t, userID, initQuota) + + task := makeTask(userID, 0, preConsumed, 0, BillingSourceWallet, 0) + + RecalculateTaskQuota(ctx, task, preConsumed, "exact match") + + // No change to user quota + assert.Equal(t, initQuota, getUserQuota(t, userID)) + + // No log created (delta is zero) + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestRecalculate_ActualQuotaZero(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID = 13 + const initQuota = 10000 + + seedUser(t, userID, initQuota) + + task := makeTask(userID, 0, 5000, 0, BillingSourceWallet, 0) + + RecalculateTaskQuota(ctx, task, 0, "zero actual") + + // No change (early return) + assert.Equal(t, initQuota, getUserQuota(t, userID)) + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestRecalculate_Subscription_NegativeDelta(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID, subID = 14, 14, 14, 2 + const preConsumed = 5000 + const actualQuota = 2000 // over-charged by 3000 + const subTotal, subUsed int64 = 100000, 50000 + const tokenRemain = 8000 + + seedUser(t, userID, 0) + seedToken(t, tokenID, userID, "sk-sub-recalc", tokenRemain) + seedChannel(t, channelID) + seedSubscription(t, subID, userID, subTotal, subUsed) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID) + + RecalculateTaskQuota(ctx, task, actualQuota, "subscription over-charge") + + // Subscription used should decrease by delta (refund 3000) + assert.Equal(t, subUsed-int64(preConsumed-actualQuota), getSubscriptionUsed(t, subID)) + + // Token refunded + assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID)) + + assert.Equal(t, actualQuota, task.Quota) + + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) +} + +// =========================================================================== +// CAS + Billing integration tests +// Simulates the flow in updateVideoSingleTask (service/task_polling.go) +// =========================================================================== + +// simulatePollBilling reproduces the CAS + billing logic from updateVideoSingleTask. +// It takes a persisted task (already in DB), applies the new status, and performs +// the conditional update + billing exactly as the polling loop does. +func simulatePollBilling(ctx context.Context, task *model.Task, newStatus model.TaskStatus, actualQuota int) { + snap := task.Snapshot() + + shouldRefund := false + shouldSettle := false + quota := task.Quota + + task.Status = newStatus + switch string(newStatus) { + case model.TaskStatusSuccess: + task.Progress = "100%" + task.FinishTime = 9999 + shouldSettle = true + case model.TaskStatusFailure: + task.Progress = "100%" + task.FinishTime = 9999 + task.FailReason = "upstream error" + if quota != 0 { + shouldRefund = true + } + default: + task.Progress = "50%" + } + + isDone := task.Status == model.TaskStatus(model.TaskStatusSuccess) || task.Status == model.TaskStatus(model.TaskStatusFailure) + if isDone && snap.Status != task.Status { + won, err := task.UpdateWithStatus(snap.Status) + if err != nil { + shouldRefund = false + shouldSettle = false + } else if !won { + shouldRefund = false + shouldSettle = false + } + } else if !snap.Equal(task.Snapshot()) { + _, _ = task.UpdateWithStatus(snap.Status) + } + + if shouldSettle && actualQuota > 0 { + RecalculateTaskQuota(ctx, task, actualQuota, "test settle") + } + if shouldRefund { + RefundTaskQuota(ctx, task, task.FailReason) + } +} + +func TestCASGuardedRefund_Win(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 20, 20, 20 + const initQuota, preConsumed = 10000, 4000 + const tokenRemain = 6000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-cas-refund-win", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + task.Status = model.TaskStatus(model.TaskStatusInProgress) + require.NoError(t, model.DB.Create(task).Error) + + simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0) + + // CAS wins: task in DB should now be FAILURE + var reloaded model.Task + require.NoError(t, model.DB.First(&reloaded, task.ID).Error) + assert.EqualValues(t, model.TaskStatusFailure, reloaded.Status) + + // Refund should have happened + assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID)) + assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID)) + + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) +} + +func TestCASGuardedRefund_Lose(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 21, 21, 21 + const initQuota, preConsumed = 10000, 4000 + const tokenRemain = 6000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-cas-refund-lose", tokenRemain) + seedChannel(t, channelID) + + // Create task with IN_PROGRESS in DB + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + task.Status = model.TaskStatus(model.TaskStatusInProgress) + require.NoError(t, model.DB.Create(task).Error) + + // Simulate another process already transitioning to FAILURE + model.DB.Model(&model.Task{}).Where("id = ?", task.ID).Update("status", model.TaskStatusFailure) + + // Our process still has the old in-memory state (IN_PROGRESS) and tries to transition + // task.Status is still IN_PROGRESS in the snapshot + simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0) + + // CAS lost: user quota should NOT change (no double refund) + assert.Equal(t, initQuota, getUserQuota(t, userID)) + assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID)) + + // No billing log should be created + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestCASGuardedSettle_Win(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 22, 22, 22 + const initQuota, preConsumed = 10000, 5000 + const actualQuota = 3000 // over-charged, should get partial refund + const tokenRemain = 8000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-cas-settle-win", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + task.Status = model.TaskStatus(model.TaskStatusInProgress) + require.NoError(t, model.DB.Create(task).Error) + + simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusSuccess), actualQuota) + + // CAS wins: task should be SUCCESS + var reloaded model.Task + require.NoError(t, model.DB.First(&reloaded, task.ID).Error) + assert.EqualValues(t, model.TaskStatusSuccess, reloaded.Status) + + // Settlement should refund the over-charge (5000 - 3000 = 2000 back to user) + assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID)) + assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID)) + + // task.Quota should be updated to actualQuota + assert.Equal(t, actualQuota, task.Quota) +} + +func TestNonTerminalUpdate_NoBilling(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, channelID = 23, 23 + const initQuota, preConsumed = 10000, 3000 + + seedUser(t, userID, initQuota) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0) + task.Status = model.TaskStatus(model.TaskStatusInProgress) + task.Progress = "20%" + require.NoError(t, model.DB.Create(task).Error) + + // Simulate a non-terminal poll update (still IN_PROGRESS, progress changed) + simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusInProgress), 0) + + // User quota should NOT change + assert.Equal(t, initQuota, getUserQuota(t, userID)) + + // No billing log + assert.Equal(t, int64(0), countLogs(t)) + + // Task progress should be updated in DB + var reloaded model.Task + require.NoError(t, model.DB.First(&reloaded, task.ID).Error) + assert.Equal(t, "50%", reloaded.Progress) +} From 06fe03e34ce1f590be1a32d6b79b7ddccfeb59af Mon Sep 17 00:00:00 2001 From: CaIon Date: Sun, 22 Feb 2026 15:32:33 +0800 Subject: [PATCH 09/10] feat(task): add model redirection, per-call billing, and multipart retry fix for async tasks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Async task model redirection (aligned with sync tasks): - Integrate ModelMappedHelper in RelayTaskSubmit after model name determination, populating OriginModelName / UpstreamModelName on RelayInfo. - All task adaptors now send UpstreamModelName to upstream providers: - Gemini & Vertex: BuildRequestURL uses UpstreamModelName. - Doubao & Ali: BuildRequestBody conditionally overwrites body.Model. - Vidu, Kling, Hailuo, Jimeng: convertToRequestPayload accepts RelayInfo and unconditionally uses info.UpstreamModelName. - Sora: BuildRequestBody parses JSON and multipart bodies to replace the "model" field with UpstreamModelName. - Frontend log visibility: LogTaskConsumption and taskBillingOther now emit is_model_mapped / upstream_model_name in the "other" JSON field. - Billing safety: RecalculateTaskQuotaByTokens reads model name from BillingContext.OriginModelName (via taskModelName) instead of task.Data["model"], preventing billing leaks from upstream model names. 2. Per-call billing (TaskPricePatches lifecycle): - Rename TaskBillingContext.ModelName → OriginModelName; add PerCallBilling bool field, populated from TaskPricePatches at submission time. - settleTaskBillingOnComplete short-circuits when PerCallBilling is true, skipping both adaptor adjustments and token-based recalculation. - Remove ModelName from TaskSubmitResult; use relayInfo.OriginModelName consistently in controller/relay.go for billing context and logging. 3. Multipart retry boundary mismatch fix: - Root cause: after Sora (or OpenAI audio) rebuilds a multipart body with a new boundary and overwrites c.Request.Header["Content-Type"], subsequent calls to ParseMultipartFormReusable on retry would parse the cached original body with the wrong boundary, causing "NextPart: EOF". - Fix: ParseMultipartFormReusable now caches the original Content-Type in gin context key "_original_multipart_ct" on first call and reuses it for all subsequent parses, making multipart parsing retry-safe globally. - Sora adaptor reverted to the standard pattern (direct header set/get), which is now safe thanks to the root fix. 4. Tests: - task_billing_test.go: update makeTask to use OriginModelName; add PerCallBilling settlement tests (skip adaptor adjust, skip token recalc); add non-per-call adaptor adjustment test with refund verification. --- common/gin.go | 10 +- controller/relay.go | 17 +-- controller/task.go | 26 ++++- model/task.go | 11 +- relay/channel/task/ali/adaptor.go | 8 +- relay/channel/task/doubao/adaptor.go | 6 +- relay/channel/task/gemini/adaptor.go | 2 +- relay/channel/task/hailuo/adaptor.go | 8 +- relay/channel/task/jimeng/adaptor.go | 6 +- relay/channel/task/kling/adaptor.go | 9 +- relay/channel/task/sora/adaptor.go | 55 +++++++++ relay/channel/task/vertex/adaptor.go | 2 +- relay/channel/task/vidu/adaptor.go | 6 +- relay/relay_task.go | 9 +- service/task_billing.go | 29 ++--- service/task_billing_test.go | 108 +++++++++++++++++- service/task_polling.go | 5 + .../table/task-logs/TaskLogsColumnDefs.jsx | 36 +++--- web/src/components/table/task-logs/index.jsx | 2 - 19 files changed, 277 insertions(+), 78 deletions(-) diff --git a/common/gin.go b/common/gin.go index 48971c13..009e3908 100644 --- a/common/gin.go +++ b/common/gin.go @@ -243,7 +243,15 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) { return nil, err } - contentType := c.Request.Header.Get("Content-Type") + // Use the original Content-Type saved on first call to avoid boundary + // mismatch when callers overwrite c.Request.Header after multipart rebuild. + var contentType string + if saved, ok := c.Get("_original_multipart_ct"); ok { + contentType = saved.(string) + } else { + contentType = c.Request.Header.Get("Content-Type") + c.Set("_original_multipart_ct", contentType) + } boundary, err := parseBoundary(contentType) if err != nil { return nil, err diff --git a/controller/relay.go b/controller/relay.go index 6951974c..7e7922e7 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -518,7 +518,7 @@ func RelayTask(c *gin.Context) { } addUsedChannel(c, channel.Id) - requestBody, bodyErr := common.GetRequestBody(c) + bodyStorage, bodyErr := common.GetBodyStorage(c) if bodyErr != nil { if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) { taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusRequestEntityTooLarge) @@ -527,7 +527,7 @@ func RelayTask(c *gin.Context) { } break } - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + c.Request.Body = io.NopCloser(bodyStorage) result, taskErr = relay.RelayTaskSubmit(c, relayInfo) if taskErr == nil { @@ -557,7 +557,7 @@ func RelayTask(c *gin.Context) { if settleErr := service.SettleBilling(c, relayInfo, result.Quota); settleErr != nil { common.SysError("settle task billing error: " + settleErr.Error()) } - service.LogTaskConsumption(c, relayInfo, result.ModelName) + service.LogTaskConsumption(c, relayInfo) task := model.InitTask(result.Platform, relayInfo) task.PrivateData.UpstreamTaskID = result.UpstreamTaskID @@ -565,11 +565,12 @@ func RelayTask(c *gin.Context) { task.PrivateData.SubscriptionId = relayInfo.SubscriptionId task.PrivateData.TokenId = relayInfo.TokenId task.PrivateData.BillingContext = &model.TaskBillingContext{ - ModelPrice: relayInfo.PriceData.ModelPrice, - GroupRatio: relayInfo.PriceData.GroupRatioInfo.GroupRatio, - ModelRatio: relayInfo.PriceData.ModelRatio, - OtherRatios: relayInfo.PriceData.OtherRatios, - ModelName: result.ModelName, + ModelPrice: relayInfo.PriceData.ModelPrice, + GroupRatio: relayInfo.PriceData.GroupRatioInfo.GroupRatio, + ModelRatio: relayInfo.PriceData.ModelRatio, + OtherRatios: relayInfo.PriceData.OtherRatios, + OriginModelName: relayInfo.OriginModelName, + PerCallBilling: common.StringsContains(constant.TaskPricePatches, relayInfo.OriginModelName), } task.Quota = result.Quota task.Data = result.TaskData diff --git a/controller/task.go b/controller/task.go index ec713c5d..eac7db15 100644 --- a/controller/task.go +++ b/controller/task.go @@ -9,6 +9,7 @@ import ( "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay" "github.com/QuantumNous/new-api/service" + "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" ) @@ -37,7 +38,7 @@ func GetAllTask(c *gin.Context) { items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams) total := model.TaskCountAllTasks(queryParams) pageInfo.SetTotal(int(total)) - pageInfo.SetItems(tasksToDto(items)) + pageInfo.SetItems(tasksToDto(items, true)) common.ApiSuccess(c, pageInfo) } @@ -61,13 +62,32 @@ func GetUserTask(c *gin.Context) { items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams) total := model.TaskCountAllUserTask(userId, queryParams) pageInfo.SetTotal(int(total)) - pageInfo.SetItems(tasksToDto(items)) + pageInfo.SetItems(tasksToDto(items, false)) common.ApiSuccess(c, pageInfo) } -func tasksToDto(tasks []*model.Task) []*dto.TaskDto { +func tasksToDto(tasks []*model.Task, fillUser bool) []*dto.TaskDto { + var userIdMap map[int]*model.UserBase + if fillUser { + userIdMap = make(map[int]*model.UserBase) + userIds := types.NewSet[int]() + for _, task := range tasks { + userIds.Add(task.UserId) + } + for _, userId := range userIds.Items() { + cacheUser, err := model.GetUserCache(userId) + if err == nil { + userIdMap[userId] = cacheUser + } + } + } result := make([]*dto.TaskDto, len(tasks)) for i, task := range tasks { + if fillUser { + if user, ok := userIdMap[task.UserId]; ok { + task.Username = user.Username + } + } result[i] = relay.TaskModel2Dto(task) } return result diff --git a/model/task.go b/model/task.go index 0cf6bd47..da3be34e 100644 --- a/model/task.go +++ b/model/task.go @@ -109,11 +109,12 @@ type TaskPrivateData struct { // TaskBillingContext 记录任务提交时的计费参数,以便轮询阶段可以重新计算额度。 type TaskBillingContext struct { - ModelPrice float64 `json:"model_price,omitempty"` // 模型单价 - GroupRatio float64 `json:"group_ratio,omitempty"` // 分组倍率 - ModelRatio float64 `json:"model_ratio,omitempty"` // 模型倍率 - OtherRatios map[string]float64 `json:"other_ratios,omitempty"` // 附加倍率(时长、分辨率等) - ModelName string `json:"model_name,omitempty"` // 模型名称 + ModelPrice float64 `json:"model_price,omitempty"` // 模型单价 + GroupRatio float64 `json:"group_ratio,omitempty"` // 分组倍率 + ModelRatio float64 `json:"model_ratio,omitempty"` // 模型倍率 + OtherRatios map[string]float64 `json:"other_ratios,omitempty"` // 附加倍率(时长、分辨率等) + OriginModelName string `json:"origin_model_name,omitempty"` // 模型名称,必须为OriginModelName + PerCallBilling bool `json:"per_call_billing,omitempty"` // 按次计费:跳过轮询阶段的差额结算 } // GetUpstreamTaskID 获取上游真实 task ID(用于与 provider 通信) diff --git a/relay/channel/task/ali/adaptor.go b/relay/channel/task/ali/adaptor.go index f55178b3..f698fc9f 100644 --- a/relay/channel/task/ali/adaptor.go +++ b/relay/channel/task/ali/adaptor.go @@ -253,8 +253,12 @@ func ProcessAliOtherRatios(aliReq *AliVideoRequest) (map[string]float64, error) } func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relaycommon.TaskSubmitReq) (*AliVideoRequest, error) { + upstreamModel := req.Model + if info.IsModelMapped { + upstreamModel = info.UpstreamModelName + } aliReq := &AliVideoRequest{ - Model: req.Model, + Model: upstreamModel, Input: AliVideoInput{ Prompt: req.Prompt, ImgURL: req.InputReference, @@ -332,7 +336,7 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay } } - if aliReq.Model != req.Model { + if aliReq.Model != upstreamModel { return nil, errors.New("can't change model with metadata") } diff --git a/relay/channel/task/doubao/adaptor.go b/relay/channel/task/doubao/adaptor.go index eca421bd..8f1d748c 100644 --- a/relay/channel/task/doubao/adaptor.go +++ b/relay/channel/task/doubao/adaptor.go @@ -131,7 +131,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn if err != nil { return nil, errors.Wrap(err, "convert request payload failed") } - info.UpstreamModelName = body.Model + if info.IsModelMapped { + body.Model = info.UpstreamModelName + } else { + info.UpstreamModelName = body.Model + } data, err := common.Marshal(body) if err != nil { return nil, err diff --git a/relay/channel/task/gemini/adaptor.go b/relay/channel/task/gemini/adaptor.go index 06c00a46..5644cd5d 100644 --- a/relay/channel/task/gemini/adaptor.go +++ b/relay/channel/task/gemini/adaptor.go @@ -105,7 +105,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom // BuildRequestURL constructs the upstream URL. func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { - modelName := info.OriginModelName + modelName := info.UpstreamModelName version := model_setting.GetGeminiVersionSetting(modelName) return fmt.Sprintf( diff --git a/relay/channel/task/hailuo/adaptor.go b/relay/channel/task/hailuo/adaptor.go index ab83d659..28b3a97f 100644 --- a/relay/channel/task/hailuo/adaptor.go +++ b/relay/channel/task/hailuo/adaptor.go @@ -61,7 +61,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn return nil, fmt.Errorf("invalid request type in context") } - body, err := a.convertToRequestPayload(&req) + body, err := a.convertToRequestPayload(&req, info) if err != nil { return nil, errors.Wrap(err, "convert request payload failed") } @@ -142,8 +142,8 @@ func (a *TaskAdaptor) GetChannelName() string { return ChannelName } -func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*VideoRequest, error) { - modelConfig := GetModelConfig(req.Model) +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*VideoRequest, error) { + modelConfig := GetModelConfig(info.UpstreamModelName) duration := DefaultDuration if req.Duration > 0 { duration = req.Duration @@ -154,7 +154,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* } videoRequest := &VideoRequest{ - Model: req.Model, + Model: info.UpstreamModelName, Prompt: req.Prompt, Duration: &duration, Resolution: resolution, diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index b61cca41..e6211b1e 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -165,7 +165,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn } } - body, err := a.convertToRequestPayload(&req) + body, err := a.convertToRequestPayload(&req, info) if err != nil { return nil, errors.Wrap(err, "convert request payload failed") } @@ -378,9 +378,9 @@ func hmacSHA256(key []byte, data []byte) []byte { return h.Sum(nil) } -func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) { r := requestPayload{ - ReqKey: req.Model, + ReqKey: info.UpstreamModelName, Prompt: req.Prompt, } diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index 46e210f1..cdbb5687 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -150,7 +150,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn } req := v.(relaycommon.TaskSubmitReq) - body, err := a.convertToRequestPayload(&req) + body, err := a.convertToRequestPayload(&req, info) if err != nil { return nil, err } @@ -248,15 +248,15 @@ func (a *TaskAdaptor) GetChannelName() string { // helpers // ============================ -func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) { r := requestPayload{ Prompt: req.Prompt, Image: req.Image, Mode: taskcommon.DefaultString(req.Mode, "std"), Duration: fmt.Sprintf("%d", taskcommon.DefaultInt(req.Duration, 5)), AspectRatio: a.getAspectRatio(req.Size), - ModelName: req.Model, - Model: req.Model, // Keep consistent with model_name, double writing improves compatibility + ModelName: info.UpstreamModelName, + Model: info.UpstreamModelName, CfgScale: 0.5, StaticMask: "", DynamicMasks: []DynamicMask{}, @@ -266,6 +266,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (* } if r.ModelName == "" { r.ModelName = "kling-v1" + r.Model = "kling-v1" } if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") diff --git a/relay/channel/task/sora/adaptor.go b/relay/channel/task/sora/adaptor.go index bf2f7005..33db8fe5 100644 --- a/relay/channel/task/sora/adaptor.go +++ b/relay/channel/task/sora/adaptor.go @@ -1,8 +1,10 @@ package sora import ( + "bytes" "fmt" "io" + "mime/multipart" "net/http" "strconv" "strings" @@ -145,6 +147,59 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn if err != nil { return nil, errors.Wrap(err, "get_request_body_failed") } + cachedBody, err := storage.Bytes() + if err != nil { + return nil, errors.Wrap(err, "read_body_bytes_failed") + } + contentType := c.GetHeader("Content-Type") + + if strings.HasPrefix(contentType, "application/json") { + var bodyMap map[string]interface{} + if err := common.Unmarshal(cachedBody, &bodyMap); err == nil { + bodyMap["model"] = info.UpstreamModelName + if newBody, err := common.Marshal(bodyMap); err == nil { + return bytes.NewReader(newBody), nil + } + } + return bytes.NewReader(cachedBody), nil + } + + if strings.Contains(contentType, "multipart/form-data") { + formData, err := common.ParseMultipartFormReusable(c) + if err != nil { + return bytes.NewReader(cachedBody), nil + } + var buf bytes.Buffer + writer := multipart.NewWriter(&buf) + writer.WriteField("model", info.UpstreamModelName) + for key, values := range formData.Value { + if key == "model" { + continue + } + for _, v := range values { + writer.WriteField(key, v) + } + } + for fieldName, fileHeaders := range formData.File { + for _, fh := range fileHeaders { + f, err := fh.Open() + if err != nil { + continue + } + part, err := writer.CreateFormFile(fieldName, fh.Filename) + if err != nil { + f.Close() + continue + } + io.Copy(part, f) + f.Close() + } + } + writer.Close() + c.Request.Header.Set("Content-Type", writer.FormDataContentType()) + return &buf, nil + } + return common.ReaderOnly(storage), nil } diff --git a/relay/channel/task/vertex/adaptor.go b/relay/channel/task/vertex/adaptor.go index 4931002d..700e6097 100644 --- a/relay/channel/task/vertex/adaptor.go +++ b/relay/channel/task/vertex/adaptor.go @@ -86,7 +86,7 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil { return "", fmt.Errorf("failed to decode credentials: %w", err) } - modelName := info.OriginModelName + modelName := info.UpstreamModelName if modelName == "" { modelName = "veo-3.0-generate-001" } diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index e689bf88..6ae1c181 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -116,7 +116,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn } req := v.(relaycommon.TaskSubmitReq) - body, err := a.convertToRequestPayload(&req) + body, err := a.convertToRequestPayload(&req, info) if err != nil { return nil, err } @@ -224,9 +224,9 @@ func (a *TaskAdaptor) GetChannelName() string { // helpers // ============================ -func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) { r := requestPayload{ - Model: taskcommon.DefaultString(req.Model, "viduq1"), + Model: taskcommon.DefaultString(info.UpstreamModelName, "viduq1"), Images: req.Images, Prompt: req.Prompt, Duration: taskcommon.DefaultInt(req.Duration, 5), diff --git a/relay/relay_task.go b/relay/relay_task.go index cd43e6eb..c740facd 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -26,7 +26,6 @@ type TaskSubmitResult struct { UpstreamTaskID string TaskData []byte Platform constant.TaskPlatform - ModelName string Quota int //PerCallPrice types.PriceData } @@ -163,6 +162,13 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe modelName = service.CoverTaskActionToModelName(platform, info.Action) } + // 2.5 应用渠道的模型映射(与同步任务对齐) + info.OriginModelName = modelName + info.UpstreamModelName = modelName + if err := helper.ModelMappedHelper(c, info, nil); err != nil { + return nil, service.TaskErrorWrapperLocal(err, "model_mapping_failed", http.StatusBadRequest) + } + // 3. 预生成公开 task ID(仅首次) if info.PublicTaskID == "" { info.PublicTaskID = model.GenerateTaskID() @@ -241,7 +247,6 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe UpstreamTaskID: upstreamTaskID, TaskData: taskData, Platform: platform, - ModelName: modelName, Quota: finalQuota, }, nil } diff --git a/service/task_billing.go b/service/task_billing.go index 78ad0fc0..0da4cf43 100644 --- a/service/task_billing.go +++ b/service/task_billing.go @@ -16,11 +16,11 @@ import ( // LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。 // 实际扣费已由 BillingSession(PreConsumeBilling + SettleBilling)完成。 -func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo, modelName string) { +func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo) { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("操作 %s", info.Action) // 支持任务仅按次计费 - if common.StringsContains(constant.TaskPricePatches, modelName) { + if common.StringsContains(constant.TaskPricePatches, info.OriginModelName) { logContent = fmt.Sprintf("%s,按次计费", logContent) } else { if len(info.PriceData.OtherRatios) > 0 { @@ -42,9 +42,13 @@ func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo, modelName s if info.PriceData.GroupRatioInfo.HasSpecialRatio { other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio } + if info.IsModelMapped { + other["is_model_mapped"] = true + other["upstream_model_name"] = info.UpstreamModelName + } model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{ ChannelId: info.ChannelId, - ModelName: modelName, + ModelName: info.OriginModelName, TokenName: tokenName, Quota: info.PriceData.Quota, Content: logContent, @@ -120,13 +124,18 @@ func taskBillingOther(task *model.Task) map[string]interface{} { } } } + props := task.Properties + if props.UpstreamModelName != "" && props.UpstreamModelName != props.OriginModelName { + other["is_model_mapped"] = true + other["upstream_model_name"] = props.UpstreamModelName + } return other } // taskModelName 从 BillingContext 或 Properties 中获取模型名称。 func taskModelName(task *model.Task) string { - if bc := task.PrivateData.BillingContext; bc != nil && bc.ModelName != "" { - return bc.ModelName + if bc := task.PrivateData.BillingContext; bc != nil && bc.OriginModelName != "" { + return bc.OriginModelName } return task.Properties.OriginModelName } @@ -237,15 +246,7 @@ func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTo return } - // 获取模型名称 - var taskData map[string]interface{} - if err := common.Unmarshal(task.Data, &taskData); err != nil { - return - } - modelName, ok := taskData["model"].(string) - if !ok || modelName == "" { - return - } + modelName := taskModelName(task) // 获取模型价格和倍率 modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName) diff --git a/service/task_billing_test.go b/service/task_billing_test.go index 6c2d231d..1145bba5 100644 --- a/service/task_billing_test.go +++ b/service/task_billing_test.go @@ -3,12 +3,14 @@ package service import ( "context" "encoding/json" + "net/http" "os" "testing" "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/model" + relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/glebarez/sqlite" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -125,7 +127,7 @@ func makeTask(userId, channelId, quota, tokenId int, billingSource string, subsc BillingContext: &model.TaskBillingContext{ ModelPrice: 0.02, GroupRatio: 1.0, - ModelName: "test-model", + OriginModelName: "test-model", }, }, } @@ -604,3 +606,107 @@ func TestNonTerminalUpdate_NoBilling(t *testing.T) { require.NoError(t, model.DB.First(&reloaded, task.ID).Error) assert.Equal(t, "50%", reloaded.Progress) } + +// =========================================================================== +// Mock adaptor for settleTaskBillingOnComplete tests +// =========================================================================== + +type mockAdaptor struct { + adjustReturn int +} + +func (m *mockAdaptor) Init(_ *relaycommon.RelayInfo) {} +func (m *mockAdaptor) FetchTask(string, string, map[string]any, string) (*http.Response, error) { return nil, nil } +func (m *mockAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { return nil, nil } +func (m *mockAdaptor) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int { + return m.adjustReturn +} + +// =========================================================================== +// PerCallBilling tests — settleTaskBillingOnComplete +// =========================================================================== + +func TestSettle_PerCallBilling_SkipsAdaptorAdjust(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 30, 30, 30 + const initQuota, preConsumed = 10000, 5000 + const tokenRemain = 8000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-percall-adaptor", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + task.PrivateData.BillingContext.PerCallBilling = true + + adaptor := &mockAdaptor{adjustReturn: 2000} + taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess} + + settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) + + // Per-call: no adjustment despite adaptor returning 2000 + assert.Equal(t, initQuota, getUserQuota(t, userID)) + assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID)) + assert.Equal(t, preConsumed, task.Quota) + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestSettle_PerCallBilling_SkipsTotalTokens(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 31, 31, 31 + const initQuota, preConsumed = 10000, 4000 + const tokenRemain = 7000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-percall-tokens", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + task.PrivateData.BillingContext.PerCallBilling = true + + adaptor := &mockAdaptor{adjustReturn: 0} + taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess, TotalTokens: 9999} + + settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) + + // Per-call: no recalculation by tokens + assert.Equal(t, initQuota, getUserQuota(t, userID)) + assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID)) + assert.Equal(t, preConsumed, task.Quota) + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestSettle_NonPerCall_AdaptorAdjustWorks(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 32, 32, 32 + const initQuota, preConsumed = 10000, 5000 + const adaptorQuota = 3000 + const tokenRemain = 8000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-nonpercall-adj", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + // PerCallBilling defaults to false + + adaptor := &mockAdaptor{adjustReturn: adaptorQuota} + taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess} + + settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) + + // Non-per-call: adaptor adjustment applies (refund 2000) + assert.Equal(t, initQuota+(preConsumed-adaptorQuota), getUserQuota(t, userID)) + assert.Equal(t, tokenRemain+(preConsumed-adaptorQuota), getTokenRemainQuota(t, tokenID)) + assert.Equal(t, adaptorQuota, task.Quota) + + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) +} diff --git a/service/task_polling.go b/service/task_polling.go index 7e92d14b..a03fc9b8 100644 --- a/service/task_polling.go +++ b/service/task_polling.go @@ -467,6 +467,11 @@ func truncateBase64(s string) string { // 2. taskResult.TotalTokens > 0 → 按 token 重算 // 3. 都不满足 → 保持预扣额度不变 func settleTaskBillingOnComplete(ctx context.Context, adaptor TaskPollingAdaptor, task *model.Task, taskResult *relaycommon.TaskInfo) { + // 0. 按次计费的任务不做差额结算 + if bc := task.PrivateData.BillingContext; bc != nil && bc.PerCallBilling { + logger.LogInfo(ctx, fmt.Sprintf("任务 %s 按次计费,跳过差额结算", task.TaskID)) + return + } // 1. 优先让 adaptor 决定最终额度 if actualQuota := adaptor.AdjustBillingOnComplete(task, taskResult); actualQuota > 0 { RecalculateTaskQuota(ctx, task, actualQuota, "adaptor计费调整") diff --git a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx index 4bce4525..7fddb0a5 100644 --- a/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx +++ b/web/src/components/table/task-logs/TaskLogsColumnDefs.jsx @@ -84,8 +84,8 @@ function renderDuration(submit_time, finishTime) { // 返回带有样式的颜色标签 return ( - }> - {durationSec} 秒 + + {durationSec} s ); } @@ -149,7 +149,7 @@ const renderPlatform = (platform, t) => { ); if (option) { return ( - }> + {option.label} ); @@ -157,13 +157,13 @@ const renderPlatform = (platform, t) => { switch (platform) { case 'suno': return ( - }> + Suno ); default: return ( - }> + {t('未知')} ); @@ -240,7 +240,6 @@ export const getTaskLogsColumns = ({ openContentModal, isAdminUser, openVideoModal, - showUserInfoFunc, }) => { return [ { @@ -278,7 +277,6 @@ export const getTaskLogsColumns = ({ color={colors[parseInt(text) % colors.length]} size='large' shape='circle' - prefixIcon={} onClick={() => { copyText(text); }} @@ -294,7 +292,7 @@ export const getTaskLogsColumns = ({ { key: COLUMN_KEYS.USERNAME, title: t('用户'), - dataIndex: 'user_id', + dataIndex: 'username', render: (userId, record, index) => { if (!isAdminUser) { return <>; @@ -302,22 +300,14 @@ export const getTaskLogsColumns = ({ const displayText = String(record.username || userId || '?'); return ( - - showUserInfoFunc && showUserInfoFunc(userId)} - > - {displayText.slice(0, 1)} - - - showUserInfoFunc && showUserInfoFunc(userId)} + - {userId} + {displayText.slice(0, 1)} + + + {displayText} ); diff --git a/web/src/components/table/task-logs/index.jsx b/web/src/components/table/task-logs/index.jsx index 140725a8..bc5b9178 100644 --- a/web/src/components/table/task-logs/index.jsx +++ b/web/src/components/table/task-logs/index.jsx @@ -25,7 +25,6 @@ import TaskLogsActions from './TaskLogsActions'; import TaskLogsFilters from './TaskLogsFilters'; import ColumnSelectorModal from './modals/ColumnSelectorModal'; import ContentModal from './modals/ContentModal'; -import UserInfoModal from '../usage-logs/modals/UserInfoModal'; import { useTaskLogsData } from '../../../hooks/task-logs/useTaskLogsData'; import { useIsMobile } from '../../../hooks/common/useIsMobile'; import { createCardProPagination } from '../../../helpers/utils'; @@ -46,7 +45,6 @@ const TaskLogsPage = () => { modalContent={taskLogsData.videoUrl} isVideo={true} /> - Date: Sun, 22 Feb 2026 16:45:35 +0800 Subject: [PATCH 10/10] fix(i18n): remove duplicate task ID translations and clean up unused keys across multiple languages --- web/src/i18n/locales/en.json | 43 --------------------------------- web/src/i18n/locales/fr.json | 2 -- web/src/i18n/locales/ja.json | 2 -- web/src/i18n/locales/ru.json | 2 -- web/src/i18n/locales/vi.json | 2 -- web/src/i18n/locales/zh-CN.json | 6 ----- 6 files changed, 57 deletions(-) diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index c2546833..93b5f18c 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -302,7 +302,6 @@ "价格重新计算中...": "Recalculating price...", "价格预估": "Price Estimate", "任务 ID": "Task ID", - "任务ID": "Task ID", "任务日志": "Task Logs", "任务状态": "Status", "任务记录": "Task Records", @@ -544,7 +543,6 @@ "创建": "Create", "创建令牌默认选择auto分组,初始令牌也将设为auto(否则留空,为用户默认分组)": "Create token with auto group by default, initial token will also be set to auto (otherwise leave blank for user default group)", "创建失败": "Creation failed", - "创建成功": "Creation successful", "创建或选择密钥时,将 Project 设置为 io.cloud": "When creating or selecting a key, set Project to io.cloud", "创建新用户账户": "Create new user account", "创建新的令牌": "Create New Token", @@ -787,7 +785,6 @@ "天": "day", "天前": "days ago", "失败": "Failed", - "失败原因": "Failure reason", "失败时自动禁用通道": "Automatically disable channel on failure", "失败重试次数": "Failed retry times", "奖励说明": "Reward description", @@ -1336,7 +1333,6 @@ "更新失败,请检查输入信息": "Update failed, please check the input information", "更新容器配置": "Update Container Configuration", "更新容器配置可能会导致容器重启,请确保在合适的时间进行此操作。": "Updating container configuration may cause the container to restart, please ensure you perform this operation at an appropriate time.", - "更新成功": "Update successful", "更新所有已启用通道余额": "Update balance for all enabled channels", "更新支付设置": "Update payment settings", "更新时间": "Update time", @@ -1767,7 +1763,6 @@ "确认清理不活跃的磁盘缓存?": "Confirm cleanup of inactive disk cache?", "确认禁用": "Confirm disable", "确认补单": "Confirm Order Completion", - "确认解绑": "Confirm Unbind", "确认解绑 Passkey": "Confirm Unbind Passkey", "确认设置并完成初始化": "Confirm settings and complete initialization", "确认重置 Passkey": "Confirm Passkey Reset", @@ -1945,7 +1940,6 @@ "自动分组auto,从第一个开始选择": "Auto grouping auto, select from the first one", "自动刷新": "Auto Refresh", "自动刷新中": "Auto refreshing", - "自动检测": "Auto Detect", "自动模式": "Auto Mode", "自动测试所有通道间隔时间": "Auto test interval for all channels", "自动禁用": "Auto disabled", @@ -2343,46 +2337,9 @@ "输入验证码完成设置": "Enter verification code to complete setup", "输出": "Output", "输出 {{completion}} tokens / 1M tokens * {{symbol}}{{compPrice}}) * {{ratioType}} {{ratio}}": "Output {{completion}} tokens / 1M tokens * {{symbol}}{{compPrice}} * {{ratioType}} {{ratio}}", - "磁盘缓存设置(磁盘换内存)": "Disk Cache Settings (Disk Swap Memory)", - "启用磁盘缓存后,大请求体将临时存储到磁盘而非内存,可显著降低内存占用,适用于处理包含大量图片/文件的请求。建议在 SSD 环境下使用。": "When enabled, large request bodies are temporarily stored on disk instead of memory, significantly reducing memory usage. Suitable for requests with large images/files. SSD recommended.", - "启用磁盘缓存": "Enable Disk Cache", - "将大请求体临时存储到磁盘": "Store large request bodies temporarily on disk", - "磁盘缓存阈值 (MB)": "Disk Cache Threshold (MB)", - "请求体超过此大小时使用磁盘缓存": "Use disk cache when request body exceeds this size", - "磁盘缓存最大总量 (MB)": "Max Disk Cache Size (MB)", - "可用空间: {{free}} / 总空间: {{total}}": "Free: {{free}} / Total: {{total}}", - "磁盘缓存占用的最大空间": "Maximum space occupied by disk cache", - "留空使用系统临时目录": "Leave empty to use system temp directory", - "例如 /var/cache/new-api": "e.g. /var/cache/new-api", - "性能监控": "Performance Monitor", - "刷新统计": "Refresh Stats", - "重置统计": "Reset Stats", - "执行 GC": "Run GC", - "请求体磁盘缓存": "Request Body Disk Cache", - "活跃文件": "Active Files", - "磁盘命中": "Disk Hits", - "请求体内存缓存": "Request Body Memory Cache", - "当前缓存大小": "Current Cache Size", - "活跃缓存数": "Active Cache Count", - "内存命中": "Memory Hits", - "缓存目录磁盘空间": "Cache Directory Disk Space", - "磁盘可用空间小于缓存最大总量设置": "Disk free space is less than max cache size setting", - "已分配内存": "Allocated Memory", - "总分配内存": "Total Allocated Memory", - "系统内存": "System Memory", - "GC 次数": "GC Count", - "Goroutine 数": "Goroutine Count", - "目录文件数": "Directory File Count", - "目录总大小": "Directory Total Size", - "磁盘缓存已清理": "Disk cache cleared", - "清理失败": "Cleanup failed", - "统计已重置": "Statistics reset", - "重置失败": "Reset failed", - "GC 已执行": "GC executed", "GC 执行失败": "GC execution failed", "缓存目录": "Cache Directory", "可用": "Available", - "输出价格": "Output Price", "输出价格:{{symbol}}{{price}} * {{completionRatio}} = {{symbol}}{{total}} / 1M tokens (补全倍率: {{completionRatio}})": "Output price: {{symbol}}{{price}} * {{completionRatio}} = {{symbol}}{{total}} / 1M tokens (Completion ratio: {{completionRatio}})", "输出倍率 {{completionRatio}}": "Output ratio {{completionRatio}}", "边栏设置": "Sidebar Settings", diff --git a/web/src/i18n/locales/fr.json b/web/src/i18n/locales/fr.json index 54fd3617..702a61de 100644 --- a/web/src/i18n/locales/fr.json +++ b/web/src/i18n/locales/fr.json @@ -304,7 +304,6 @@ "价格重新计算中...": "Recalculating price...", "价格预估": "Price Estimate", "任务 ID": "ID de la tâche", - "任务ID": "ID de la tâche", "任务日志": "Tâches", "任务状态": "Statut de la tâche", "任务记录": "Tâches", @@ -792,7 +791,6 @@ "天": "Jour", "天前": "il y a des jours", "失败": "Échec", - "失败原因": "Raison de l'échec", "失败时自动禁用通道": "Désactiver automatiquement le canal en cas d'échec", "失败重试次数": "Nombre de tentatives en cas d'échec", "奖励说明": "Description de la récompense", diff --git a/web/src/i18n/locales/ja.json b/web/src/i18n/locales/ja.json index d9a49aa5..d1e770e9 100644 --- a/web/src/i18n/locales/ja.json +++ b/web/src/i18n/locales/ja.json @@ -300,7 +300,6 @@ "价格重新计算中...": "Recalculating price...", "价格预估": "Price Estimate", "任务 ID": "タスクID", - "任务ID": "タスクID", "任务日志": "タスク履歴", "任务状态": "タスクステータス", "任务记录": "タスク履歴", @@ -783,7 +782,6 @@ "天": "日", "天前": "日前", "失败": "失敗", - "失败原因": "失敗理由", "失败时自动禁用通道": "失敗時にチャネルを自動的に無効にする", "失败重试次数": "再試行回数", "奖励说明": "特典説明", diff --git a/web/src/i18n/locales/ru.json b/web/src/i18n/locales/ru.json index fc117a51..e2a52904 100644 --- a/web/src/i18n/locales/ru.json +++ b/web/src/i18n/locales/ru.json @@ -307,7 +307,6 @@ "价格重新计算中...": "Recalculating price...", "价格预估": "Price Estimate", "任务 ID": "ID задачи", - "任务ID": "ID задачи", "任务日志": "Журнал задач", "任务状态": "Статус задачи", "任务记录": "Записи задач", @@ -798,7 +797,6 @@ "天": "день", "天前": "дней назад", "失败": "Неудача", - "失败原因": "Причина неудачи", "失败时自动禁用通道": "Автоматически отключать канал при неудаче", "失败重试次数": "Количество повторных попыток при неудаче", "奖励说明": "Описание награды", diff --git a/web/src/i18n/locales/vi.json b/web/src/i18n/locales/vi.json index 89d8715e..a311ca9e 100644 --- a/web/src/i18n/locales/vi.json +++ b/web/src/i18n/locales/vi.json @@ -301,7 +301,6 @@ "价格重新计算中...": "Recalculating price...", "价格预估": "Price Estimate", "任务 ID": "ID tác vụ", - "任务ID": "ID tác vụ", "任务日志": "Nhật ký tác vụ", "任务状态": "Trạng thái", "任务记录": "Hồ sơ tác vụ", @@ -784,7 +783,6 @@ "天": "ngày", "天前": "ngày trước", "失败": "Thất bại", - "失败原因": "Lý do thất bại", "失败时自动禁用通道": "Tự động vô hiệu hóa kênh khi thất bại", "失败重试次数": "Số lần thử lại thất bại", "奖励说明": "Mô tả phần thưởng", diff --git a/web/src/i18n/locales/zh-CN.json b/web/src/i18n/locales/zh-CN.json index 3cfcc032..a5bace57 100644 --- a/web/src/i18n/locales/zh-CN.json +++ b/web/src/i18n/locales/zh-CN.json @@ -298,7 +298,6 @@ "价格重新计算中...": "价格重新计算中...", "价格预估": "价格预估", "任务 ID": "任务 ID", - "任务ID": "任务ID", "任务日志": "任务日志", "任务状态": "任务状态", "任务记录": "任务记录", @@ -539,7 +538,6 @@ "创建": "创建", "创建令牌默认选择auto分组,初始令牌也将设为auto(否则留空,为用户默认分组)": "创建令牌默认选择auto分组,初始令牌也将设为auto(否则留空,为用户默认分组)", "创建失败": "创建失败", - "创建成功": "创建成功", "创建或选择密钥时,将 Project 设置为 io.cloud": "创建或选择密钥时,将 Project 设置为 io.cloud", "创建新用户账户": "创建新用户账户", "创建新的令牌": "创建新的令牌", @@ -782,7 +780,6 @@ "天": "天", "天前": "天前", "失败": "失败", - "失败原因": "失败原因", "失败时自动禁用通道": "失败时自动禁用通道", "失败重试次数": "失败重试次数", "奖励说明": "奖励说明", @@ -1326,7 +1323,6 @@ "更新失败,请检查输入信息": "更新失败,请检查输入信息", "更新容器配置": "更新容器配置", "更新容器配置可能会导致容器重启,请确保在合适的时间进行此操作。": "更新容器配置可能会导致容器重启,请确保在合适的时间进行此操作。", - "更新成功": "更新成功", "更新所有已启用通道余额": "更新所有已启用通道余额", "更新支付设置": "更新支付设置", "更新时间": "更新时间", @@ -1754,7 +1750,6 @@ "确认清除历史日志": "确认清除历史日志", "确认禁用": "确认禁用", "确认补单": "确认补单", - "确认解绑": "确认解绑", "确认解绑 Passkey": "确认解绑 Passkey", "确认设置并完成初始化": "确认设置并完成初始化", "确认重置 Passkey": "确认重置 Passkey", @@ -1932,7 +1927,6 @@ "自动分组auto,从第一个开始选择": "自动分组auto,从第一个开始选择", "自动刷新": "自动刷新", "自动刷新中": "自动刷新中", - "自动检测": "自动检测", "自动模式": "自动模式", "自动测试所有通道间隔时间": "自动测试所有通道间隔时间", "自动禁用": "自动禁用",