2023-04-23 18:24:11 +08:00
|
|
|
|
package controller
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
2024-04-04 16:35:44 +08:00
|
|
|
|
"bytes"
|
2025-12-16 17:00:19 +08:00
|
|
|
|
"errors"
|
2023-04-23 18:24:11 +08:00
|
|
|
|
"fmt"
|
2024-04-04 16:35:44 +08:00
|
|
|
|
"io"
|
2023-09-09 03:19:55 +08:00
|
|
|
|
"log"
|
2023-04-23 18:24:11 +08:00
|
|
|
|
"net/http"
|
2024-04-20 17:18:14 +08:00
|
|
|
|
"strings"
|
2025-05-02 13:59:46 +08:00
|
|
|
|
|
2025-10-11 15:30:09 +08:00
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
|
|
|
|
"github.com/QuantumNous/new-api/constant"
|
|
|
|
|
|
"github.com/QuantumNous/new-api/dto"
|
|
|
|
|
|
"github.com/QuantumNous/new-api/logger"
|
|
|
|
|
|
"github.com/QuantumNous/new-api/middleware"
|
|
|
|
|
|
"github.com/QuantumNous/new-api/model"
|
|
|
|
|
|
"github.com/QuantumNous/new-api/relay"
|
|
|
|
|
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
|
|
|
|
|
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
|
|
|
|
|
"github.com/QuantumNous/new-api/relay/helper"
|
|
|
|
|
|
"github.com/QuantumNous/new-api/service"
|
|
|
|
|
|
"github.com/QuantumNous/new-api/setting"
|
2026-01-14 14:34:12 +08:00
|
|
|
|
"github.com/QuantumNous/new-api/setting/operation_setting"
|
2025-10-11 15:30:09 +08:00
|
|
|
|
"github.com/QuantumNous/new-api/types"
|
|
|
|
|
|
|
2025-08-25 18:01:10 +08:00
|
|
|
|
"github.com/bytedance/gopkg/util/gopool"
|
|
|
|
|
|
|
2025-05-02 13:59:46 +08:00
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
|
|
"github.com/gorilla/websocket"
|
2023-12-27 16:32:54 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2025-08-14 20:05:06 +08:00
|
|
|
|
func relayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
|
2025-07-10 15:02:40 +08:00
|
|
|
|
var err *types.NewAPIError
|
2025-08-14 20:05:06 +08:00
|
|
|
|
switch info.RelayMode {
|
2025-04-24 19:25:08 +08:00
|
|
|
|
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
2025-08-14 20:05:06 +08:00
|
|
|
|
err = relay.ImageHelper(c, info)
|
2024-02-29 01:08:18 +08:00
|
|
|
|
case relayconstant.RelayModeAudioSpeech:
|
2024-01-09 15:46:45 +08:00
|
|
|
|
fallthrough
|
2024-02-29 01:08:18 +08:00
|
|
|
|
case relayconstant.RelayModeAudioTranslation:
|
2024-01-09 15:46:45 +08:00
|
|
|
|
fallthrough
|
2024-02-29 01:08:18 +08:00
|
|
|
|
case relayconstant.RelayModeAudioTranscription:
|
2025-08-14 20:05:06 +08:00
|
|
|
|
err = relay.AudioHelper(c, info)
|
2024-07-06 17:09:22 +08:00
|
|
|
|
case relayconstant.RelayModeRerank:
|
2025-08-14 20:05:06 +08:00
|
|
|
|
err = relay.RerankHelper(c, info)
|
2025-01-23 05:54:39 +08:00
|
|
|
|
case relayconstant.RelayModeEmbeddings:
|
2025-08-14 20:05:06 +08:00
|
|
|
|
err = relay.EmbeddingHelper(c, info)
|
2026-01-26 20:20:16 +08:00
|
|
|
|
case relayconstant.RelayModeResponses, relayconstant.RelayModeResponsesCompact:
|
2025-08-14 20:05:06 +08:00
|
|
|
|
err = relay.ResponsesHelper(c, info)
|
2023-06-19 10:28:55 +08:00
|
|
|
|
default:
|
2025-08-14 20:05:06 +08:00
|
|
|
|
err = relay.TextHelper(c, info)
|
2023-06-08 14:54:02 +08:00
|
|
|
|
}
|
2025-08-14 20:05:06 +08:00
|
|
|
|
return err
|
|
|
|
|
|
}
|
2025-04-12 00:43:34 +08:00
|
|
|
|
|
2025-08-14 20:05:06 +08:00
|
|
|
|
func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
|
|
|
|
|
|
var err *types.NewAPIError
|
|
|
|
|
|
if strings.Contains(c.Request.URL.Path, "embed") {
|
|
|
|
|
|
err = relay.GeminiEmbeddingHandler(c, info)
|
|
|
|
|
|
} else {
|
|
|
|
|
|
err = relay.GeminiHelper(c, info)
|
2025-04-12 00:43:34 +08:00
|
|
|
|
}
|
2024-04-04 16:35:44 +08:00
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-08-14 20:05:06 +08:00
|
|
|
|
func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
|
|
|
|
|
|
2024-04-04 16:35:44 +08:00
|
|
|
|
requestId := c.GetString(common.RequestIdKey)
|
2025-12-12 22:04:38 +08:00
|
|
|
|
//group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
|
|
|
|
|
|
//originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
|
2024-04-04 16:35:44 +08:00
|
|
|
|
|
2025-08-14 20:05:06 +08:00
|
|
|
|
var (
|
|
|
|
|
|
newAPIError *types.NewAPIError
|
|
|
|
|
|
ws *websocket.Conn
|
|
|
|
|
|
)
|
2024-08-03 16:55:29 +08:00
|
|
|
|
|
2025-08-14 20:05:06 +08:00
|
|
|
|
if relayFormat == types.RelayFormatOpenAIRealtime {
|
|
|
|
|
|
var err error
|
|
|
|
|
|
ws, err = upgrader.Upgrade(c.Writer, c.Request, nil)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
|
|
|
|
|
|
return
|
2024-08-03 16:55:29 +08:00
|
|
|
|
}
|
2025-08-14 20:05:06 +08:00
|
|
|
|
defer ws.Close()
|
|
|
|
|
|
}
|
2024-08-03 16:55:29 +08:00
|
|
|
|
|
2025-08-14 20:05:06 +08:00
|
|
|
|
defer func() {
|
|
|
|
|
|
if newAPIError != nil {
|
2025-10-29 23:33:55 +08:00
|
|
|
|
logger.LogError(c, fmt.Sprintf("relay error: %s", newAPIError.Error()))
|
2025-08-14 20:05:06 +08:00
|
|
|
|
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
|
|
|
|
|
switch relayFormat {
|
|
|
|
|
|
case types.RelayFormatOpenAIRealtime:
|
|
|
|
|
|
helper.WssError(c, ws, newAPIError.ToOpenAIError())
|
|
|
|
|
|
case types.RelayFormatClaude:
|
|
|
|
|
|
c.JSON(newAPIError.StatusCode, gin.H{
|
|
|
|
|
|
"type": "error",
|
|
|
|
|
|
"error": newAPIError.ToClaudeError(),
|
|
|
|
|
|
})
|
|
|
|
|
|
default:
|
|
|
|
|
|
c.JSON(newAPIError.StatusCode, gin.H{
|
|
|
|
|
|
"error": newAPIError.ToOpenAIError(),
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
2023-05-18 15:27:15 +08:00
|
|
|
|
}
|
2025-08-14 20:05:06 +08:00
|
|
|
|
}()
|
2024-04-04 16:35:44 +08:00
|
|
|
|
|
2025-08-14 20:05:06 +08:00
|
|
|
|
request, err := helper.GetAndValidateRequest(c, relayFormat)
|
|
|
|
|
|
if err != nil {
|
2025-12-16 17:00:19 +08:00
|
|
|
|
// Map "request body too large" to 413 so clients can handle it correctly
|
|
|
|
|
|
if common.IsRequestBodyTooLargeError(err) || errors.Is(err, common.ErrRequestBodyTooLarge) {
|
|
|
|
|
|
newAPIError = types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusRequestEntityTooLarge, types.ErrOptionWithSkipRetry())
|
|
|
|
|
|
} else {
|
|
|
|
|
|
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
|
|
|
|
|
|
}
|
2025-08-14 20:05:06 +08:00
|
|
|
|
return
|
2024-04-04 16:35:44 +08:00
|
|
|
|
}
|
2024-10-04 16:08:18 +08:00
|
|
|
|
|
2025-08-14 20:05:06 +08:00
|
|
|
|
relayInfo, err := relaycommon.GenRelayInfo(c, relayFormat, request, ws)
|
2024-10-03 20:46:00 +08:00
|
|
|
|
if err != nil {
|
2025-08-14 20:05:06 +08:00
|
|
|
|
newAPIError = types.NewError(err, types.ErrorCodeGenRelayInfoFailed)
|
2024-10-03 20:46:00 +08:00
|
|
|
|
return
|
|
|
|
|
|
}
|
2024-10-04 16:08:18 +08:00
|
|
|
|
|
2025-12-16 17:00:19 +08:00
|
|
|
|
needSensitiveCheck := setting.ShouldCheckPromptSensitive()
|
|
|
|
|
|
needCountToken := constant.CountToken
|
|
|
|
|
|
// Avoid building huge CombineText (strings.Join) when token counting and sensitive check are both disabled.
|
|
|
|
|
|
var meta *types.TokenCountMeta
|
|
|
|
|
|
if needSensitiveCheck || needCountToken {
|
|
|
|
|
|
meta = request.GetTokenCountMeta()
|
|
|
|
|
|
} else {
|
|
|
|
|
|
meta = fastTokenCountMetaForPricing(request)
|
|
|
|
|
|
}
|
2024-10-03 20:46:00 +08:00
|
|
|
|
|
2025-12-16 17:00:19 +08:00
|
|
|
|
if needSensitiveCheck && meta != nil {
|
2025-08-15 13:20:36 +08:00
|
|
|
|
contains, words := service.CheckSensitiveText(meta.CombineText)
|
|
|
|
|
|
if contains {
|
2025-08-14 20:05:06 +08:00
|
|
|
|
logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
|
|
|
|
|
|
newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
|
|
|
|
|
return
|
2024-10-03 20:46:00 +08:00
|
|
|
|
}
|
2025-08-14 20:05:06 +08:00
|
|
|
|
}
|
2024-10-03 20:46:00 +08:00
|
|
|
|
|
2025-12-02 21:34:39 +08:00
|
|
|
|
tokens, err := service.EstimateRequestToken(c, meta, relayInfo)
|
2025-08-14 20:05:06 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
newAPIError = types.NewError(err, types.ErrorCodeCountTokenFailed)
|
|
|
|
|
|
return
|
2024-10-03 20:46:00 +08:00
|
|
|
|
}
|
2025-08-14 20:05:06 +08:00
|
|
|
|
|
2025-12-02 21:34:39 +08:00
|
|
|
|
relayInfo.SetEstimatePromptTokens(tokens)
|
2025-08-16 22:54:00 +08:00
|
|
|
|
|
2025-08-14 20:05:06 +08:00
|
|
|
|
priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
newAPIError = types.NewError(err, types.ErrorCodeModelPriceError)
|
|
|
|
|
|
return
|
2024-10-03 20:46:00 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-08-15 18:43:08 +08:00
|
|
|
|
// common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
|
2025-08-15 18:40:54 +08:00
|
|
|
|
|
2025-10-12 13:31:03 +08:00
|
|
|
|
if priceData.FreeModel {
|
|
|
|
|
|
logger.LogInfo(c, fmt.Sprintf("模型 %s 免费,跳过预扣费", relayInfo.OriginModelName))
|
|
|
|
|
|
} else {
|
|
|
|
|
|
newAPIError = service.PreConsumeQuota(c, priceData.QuotaToPreConsume, relayInfo)
|
|
|
|
|
|
if newAPIError != nil {
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
2024-10-03 20:46:00 +08:00
|
|
|
|
}
|
2024-04-04 16:35:44 +08:00
|
|
|
|
|
2025-08-14 20:05:06 +08:00
|
|
|
|
defer func() {
|
2025-08-15 13:20:36 +08:00
|
|
|
|
// Only return quota if downstream failed and quota was actually pre-consumed
|
2026-01-26 20:20:30 +08:00
|
|
|
|
if newAPIError != nil {
|
|
|
|
|
|
newAPIError = service.NormalizeViolationFeeError(newAPIError)
|
|
|
|
|
|
if relayInfo.FinalPreConsumedQuota != 0 {
|
|
|
|
|
|
service.ReturnPreConsumedQuota(c, relayInfo)
|
|
|
|
|
|
}
|
|
|
|
|
|
service.ChargeViolationFeeIfNeeded(c, relayInfo, newAPIError)
|
2025-08-14 20:05:06 +08:00
|
|
|
|
}
|
|
|
|
|
|
}()
|
2025-03-12 21:31:46 +08:00
|
|
|
|
|
2025-12-13 16:43:38 +08:00
|
|
|
|
retryParam := &service.RetryParam{
|
|
|
|
|
|
Ctx: c,
|
|
|
|
|
|
TokenGroup: relayInfo.TokenGroup,
|
|
|
|
|
|
ModelName: relayInfo.OriginModelName,
|
|
|
|
|
|
Retry: common.GetPointer(0),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() {
|
2025-12-16 18:10:00 +08:00
|
|
|
|
channel, channelErr := getChannel(c, relayInfo, retryParam)
|
|
|
|
|
|
if channelErr != nil {
|
|
|
|
|
|
logger.LogError(c, channelErr.Error())
|
|
|
|
|
|
newAPIError = channelErr
|
2025-03-12 21:31:46 +08:00
|
|
|
|
break
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-08-14 20:05:06 +08:00
|
|
|
|
addUsedChannel(c, channel.Id)
|
2025-12-16 18:10:00 +08:00
|
|
|
|
requestBody, bodyErr := common.GetRequestBody(c)
|
|
|
|
|
|
if bodyErr != nil {
|
|
|
|
|
|
// Ensure consistent 413 for oversized bodies even when error occurs later (e.g., retry path)
|
|
|
|
|
|
if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) {
|
|
|
|
|
|
newAPIError = types.NewErrorWithStatusCode(bodyErr, types.ErrorCodeReadRequestBodyFailed, http.StatusRequestEntityTooLarge, types.ErrOptionWithSkipRetry())
|
|
|
|
|
|
} else {
|
|
|
|
|
|
newAPIError = types.NewErrorWithStatusCode(bodyErr, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
|
|
|
|
|
}
|
|
|
|
|
|
break
|
|
|
|
|
|
}
|
2025-08-14 20:05:06 +08:00
|
|
|
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
|
|
|
|
|
|
|
|
|
|
switch relayFormat {
|
|
|
|
|
|
case types.RelayFormatOpenAIRealtime:
|
2025-08-14 21:10:04 +08:00
|
|
|
|
newAPIError = relay.WssHelper(c, relayInfo)
|
2025-08-14 20:05:06 +08:00
|
|
|
|
case types.RelayFormatClaude:
|
|
|
|
|
|
newAPIError = relay.ClaudeHelper(c, relayInfo)
|
|
|
|
|
|
case types.RelayFormatGemini:
|
|
|
|
|
|
newAPIError = geminiRelayHandler(c, relayInfo)
|
|
|
|
|
|
default:
|
|
|
|
|
|
newAPIError = relayHandler(c, relayInfo)
|
|
|
|
|
|
}
|
2025-03-12 21:31:46 +08:00
|
|
|
|
|
2025-07-10 15:02:40 +08:00
|
|
|
|
if newAPIError == nil {
|
2025-08-14 20:05:06 +08:00
|
|
|
|
return
|
2025-03-12 21:31:46 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-01-26 20:20:30 +08:00
|
|
|
|
newAPIError = service.NormalizeViolationFeeError(newAPIError)
|
|
|
|
|
|
|
2025-08-16 15:15:19 +08:00
|
|
|
|
processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
2025-03-12 21:31:46 +08:00
|
|
|
|
|
2025-12-13 16:43:38 +08:00
|
|
|
|
if !shouldRetry(c, newAPIError, common.RetryTimes-retryParam.GetRetry()) {
|
2025-03-12 21:31:46 +08:00
|
|
|
|
break
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2025-08-14 20:05:06 +08:00
|
|
|
|
|
2025-03-12 21:31:46 +08:00
|
|
|
|
useChannel := c.GetStringSlice("use_channel")
|
|
|
|
|
|
if len(useChannel) > 1 {
|
|
|
|
|
|
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
2025-08-14 20:05:06 +08:00
|
|
|
|
logger.LogInfo(c, retryLogStr)
|
2025-03-12 21:31:46 +08:00
|
|
|
|
}
|
2024-10-04 16:08:18 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-08-14 20:05:06 +08:00
|
|
|
|
var upgrader = websocket.Upgrader{
|
|
|
|
|
|
Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol,则必须在此声明对应的 Protocol TODO add other protocol
|
|
|
|
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
|
|
|
|
return true // 允许跨域
|
|
|
|
|
|
},
|
2025-03-12 21:31:46 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2024-08-03 16:55:29 +08:00
|
|
|
|
func addUsedChannel(c *gin.Context, channelId int) {
|
|
|
|
|
|
useChannel := c.GetStringSlice("use_channel")
|
|
|
|
|
|
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
|
|
|
|
|
c.Set("use_channel", useChannel)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-16 17:00:19 +08:00
|
|
|
|
func fastTokenCountMetaForPricing(request dto.Request) *types.TokenCountMeta {
|
|
|
|
|
|
if request == nil {
|
|
|
|
|
|
return &types.TokenCountMeta{}
|
|
|
|
|
|
}
|
|
|
|
|
|
meta := &types.TokenCountMeta{
|
|
|
|
|
|
TokenType: types.TokenTypeTokenizer,
|
|
|
|
|
|
}
|
|
|
|
|
|
switch r := request.(type) {
|
|
|
|
|
|
case *dto.GeneralOpenAIRequest:
|
|
|
|
|
|
if r.MaxCompletionTokens > r.MaxTokens {
|
|
|
|
|
|
meta.MaxTokens = int(r.MaxCompletionTokens)
|
|
|
|
|
|
} else {
|
|
|
|
|
|
meta.MaxTokens = int(r.MaxTokens)
|
|
|
|
|
|
}
|
|
|
|
|
|
case *dto.OpenAIResponsesRequest:
|
|
|
|
|
|
meta.MaxTokens = int(r.MaxOutputTokens)
|
|
|
|
|
|
case *dto.ClaudeRequest:
|
|
|
|
|
|
meta.MaxTokens = int(r.MaxTokens)
|
|
|
|
|
|
case *dto.ImageRequest:
|
|
|
|
|
|
// Pricing for image requests depends on ImagePriceRatio; safe to compute even when CountToken is disabled.
|
|
|
|
|
|
return r.GetTokenCountMeta()
|
|
|
|
|
|
default:
|
|
|
|
|
|
// Best-effort: leave CombineText empty to avoid large allocations.
|
|
|
|
|
|
}
|
|
|
|
|
|
return meta
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-12-13 16:43:38 +08:00
|
|
|
|
func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryParam *service.RetryParam) (*model.Channel, *types.NewAPIError) {
|
|
|
|
|
|
if info.ChannelMeta == nil {
|
2024-08-03 17:32:28 +08:00
|
|
|
|
autoBan := c.GetBool("auto_ban")
|
|
|
|
|
|
autoBanInt := 1
|
|
|
|
|
|
if !autoBan {
|
|
|
|
|
|
autoBanInt = 0
|
|
|
|
|
|
}
|
2024-08-03 16:55:29 +08:00
|
|
|
|
return &model.Channel{
|
2024-08-03 17:32:28 +08:00
|
|
|
|
Id: c.GetInt("channel_id"),
|
|
|
|
|
|
Type: c.GetInt("channel_type"),
|
|
|
|
|
|
Name: c.GetString("channel_name"),
|
|
|
|
|
|
AutoBan: &autoBanInt,
|
2024-08-03 16:55:29 +08:00
|
|
|
|
}, nil
|
|
|
|
|
|
}
|
2025-12-13 16:43:38 +08:00
|
|
|
|
channel, selectGroup, err := service.CacheGetRandomSatisfiedChannel(retryParam)
|
2025-12-12 22:04:38 +08:00
|
|
|
|
|
|
|
|
|
|
info.PriceData.GroupRatioInfo = helper.HandleGroupRatio(c, info)
|
|
|
|
|
|
|
2024-08-03 16:55:29 +08:00
|
|
|
|
if err != nil {
|
2025-12-12 22:04:38 +08:00
|
|
|
|
return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, info.OriginModelName, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
2025-07-23 16:32:52 +08:00
|
|
|
|
}
|
|
|
|
|
|
if channel == nil {
|
2025-12-12 22:04:38 +08:00
|
|
|
|
return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(retry)", selectGroup, info.OriginModelName), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
2025-07-10 17:49:53 +08:00
|
|
|
|
}
|
2025-12-12 22:04:38 +08:00
|
|
|
|
|
|
|
|
|
|
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, info.OriginModelName)
|
2025-07-10 17:49:53 +08:00
|
|
|
|
if newAPIError != nil {
|
|
|
|
|
|
return nil, newAPIError
|
2024-08-03 16:55:29 +08:00
|
|
|
|
}
|
|
|
|
|
|
return channel, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-07-10 15:02:40 +08:00
|
|
|
|
func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) bool {
|
2024-04-04 16:35:44 +08:00
|
|
|
|
if openaiErr == nil {
|
|
|
|
|
|
return false
|
|
|
|
|
|
}
|
2025-07-10 15:02:40 +08:00
|
|
|
|
if types.IsChannelError(openaiErr) {
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|
2025-07-30 22:35:31 +08:00
|
|
|
|
if types.IsSkipRetryError(openaiErr) {
|
2024-08-24 17:27:14 +08:00
|
|
|
|
return false
|
|
|
|
|
|
}
|
2024-04-04 16:35:44 +08:00
|
|
|
|
if retryTimes <= 0 {
|
|
|
|
|
|
return false
|
|
|
|
|
|
}
|
|
|
|
|
|
if _, ok := c.Get("specific_channel_id"); ok {
|
|
|
|
|
|
return false
|
|
|
|
|
|
}
|
2026-01-14 14:34:12 +08:00
|
|
|
|
code := openaiErr.StatusCode
|
|
|
|
|
|
if code >= 200 && code < 300 {
|
2024-04-23 22:17:36 +08:00
|
|
|
|
return false
|
|
|
|
|
|
}
|
2026-01-14 14:34:12 +08:00
|
|
|
|
if code < 100 || code > 599 {
|
|
|
|
|
|
return true
|
2024-04-04 16:35:44 +08:00
|
|
|
|
}
|
2026-01-14 14:34:12 +08:00
|
|
|
|
return operation_setting.ShouldRetryByStatusCode(code)
|
2024-04-04 16:35:44 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-07-10 17:49:53 +08:00
|
|
|
|
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
2025-10-29 23:33:55 +08:00
|
|
|
|
logger.LogError(c, fmt.Sprintf("channel error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
2025-09-10 19:53:32 +08:00
|
|
|
|
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
|
|
|
|
|
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
2025-11-08 20:33:14 +08:00
|
|
|
|
if service.ShouldDisableChannel(channelError.ChannelType, err) && channelError.AutoBan {
|
2025-09-10 19:53:32 +08:00
|
|
|
|
gopool.Go(func() {
|
2026-01-12 18:47:45 +08:00
|
|
|
|
service.DisableChannel(channelError, err.ErrorWithStatusCode())
|
2025-09-10 19:53:32 +08:00
|
|
|
|
})
|
|
|
|
|
|
}
|
2025-08-16 15:15:19 +08:00
|
|
|
|
|
|
|
|
|
|
if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) {
|
|
|
|
|
|
// 保存错误日志到mysql中
|
|
|
|
|
|
userId := c.GetInt("id")
|
|
|
|
|
|
tokenName := c.GetString("token_name")
|
|
|
|
|
|
modelName := c.GetString("original_model")
|
|
|
|
|
|
tokenId := c.GetInt("token_id")
|
|
|
|
|
|
userGroup := c.GetString("group")
|
|
|
|
|
|
channelId := c.GetInt("channel_id")
|
|
|
|
|
|
other := make(map[string]interface{})
|
2025-10-13 22:44:40 +08:00
|
|
|
|
if c.Request != nil && c.Request.URL != nil {
|
|
|
|
|
|
other["request_path"] = c.Request.URL.Path
|
2025-10-13 22:25:39 +08:00
|
|
|
|
}
|
2025-08-16 15:15:19 +08:00
|
|
|
|
other["error_type"] = err.GetErrorType()
|
|
|
|
|
|
other["error_code"] = err.GetErrorCode()
|
|
|
|
|
|
other["status_code"] = err.StatusCode
|
|
|
|
|
|
other["channel_id"] = channelId
|
|
|
|
|
|
other["channel_name"] = c.GetString("channel_name")
|
|
|
|
|
|
other["channel_type"] = c.GetInt("channel_type")
|
|
|
|
|
|
adminInfo := make(map[string]interface{})
|
|
|
|
|
|
adminInfo["use_channel"] = c.GetStringSlice("use_channel")
|
|
|
|
|
|
isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
|
|
|
|
|
|
if isMultiKey {
|
|
|
|
|
|
adminInfo["is_multi_key"] = true
|
|
|
|
|
|
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
|
|
|
|
|
|
}
|
2026-01-26 19:57:41 +08:00
|
|
|
|
service.AppendChannelAffinityAdminInfo(c, adminInfo)
|
2025-08-16 15:15:19 +08:00
|
|
|
|
other["admin_info"] = adminInfo
|
2026-01-12 18:47:45 +08:00
|
|
|
|
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveErrorWithStatusCode(), tokenId, 0, false, userGroup, other)
|
2023-04-28 17:11:57 +08:00
|
|
|
|
}
|
2025-08-16 15:15:19 +08:00
|
|
|
|
|
2023-04-28 17:11:57 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2023-08-14 22:16:32 +08:00
|
|
|
|
func RelayMidjourney(c *gin.Context) {
|
2025-08-14 21:10:04 +08:00
|
|
|
|
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatMjProxy, nil, nil)
|
|
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{
|
|
|
|
|
|
"description": fmt.Sprintf("failed to generate relay info: %s", err.Error()),
|
|
|
|
|
|
"type": "upstream_error",
|
|
|
|
|
|
"code": 4,
|
|
|
|
|
|
})
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
var mjErr *dto.MidjourneyResponse
|
|
|
|
|
|
switch relayInfo.RelayMode {
|
2024-02-29 01:08:18 +08:00
|
|
|
|
case relayconstant.RelayModeMidjourneyNotify:
|
2025-08-14 21:10:04 +08:00
|
|
|
|
mjErr = relay.RelayMidjourneyNotify(c)
|
2024-02-29 01:08:18 +08:00
|
|
|
|
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
|
2025-08-14 21:10:04 +08:00
|
|
|
|
mjErr = relay.RelayMidjourneyTask(c, relayInfo.RelayMode)
|
2024-03-14 16:42:37 +08:00
|
|
|
|
case relayconstant.RelayModeMidjourneyTaskImageSeed:
|
2025-08-14 21:10:04 +08:00
|
|
|
|
mjErr = relay.RelayMidjourneyTaskImageSeed(c)
|
2024-03-14 18:08:12 +08:00
|
|
|
|
case relayconstant.RelayModeSwapFace:
|
2025-08-14 21:10:04 +08:00
|
|
|
|
mjErr = relay.RelaySwapFace(c, relayInfo)
|
2023-08-14 22:16:32 +08:00
|
|
|
|
default:
|
2025-08-14 21:10:04 +08:00
|
|
|
|
mjErr = relay.RelayMidjourneySubmit(c, relayInfo)
|
2023-08-14 22:16:32 +08:00
|
|
|
|
}
|
|
|
|
|
|
//err = relayMidjourneySubmit(c, relayMode)
|
2025-08-14 21:10:04 +08:00
|
|
|
|
log.Println(mjErr)
|
|
|
|
|
|
if mjErr != nil {
|
2024-03-13 21:19:48 +08:00
|
|
|
|
statusCode := http.StatusBadRequest
|
2025-08-14 21:10:04 +08:00
|
|
|
|
if mjErr.Code == 30 {
|
|
|
|
|
|
mjErr.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
2024-03-13 21:19:48 +08:00
|
|
|
|
statusCode = http.StatusTooManyRequests
|
2023-08-14 22:16:32 +08:00
|
|
|
|
}
|
2024-03-13 21:19:48 +08:00
|
|
|
|
c.JSON(statusCode, gin.H{
|
2025-08-14 21:10:04 +08:00
|
|
|
|
"description": fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result),
|
2024-03-13 21:19:48 +08:00
|
|
|
|
"type": "upstream_error",
|
2025-08-14 21:10:04 +08:00
|
|
|
|
"code": mjErr.Code,
|
2024-03-13 15:37:01 +08:00
|
|
|
|
})
|
2023-08-14 22:16:32 +08:00
|
|
|
|
channelId := c.GetInt("channel_id")
|
2025-08-14 21:10:04 +08:00
|
|
|
|
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result)))
|
2023-08-14 22:16:32 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2023-04-28 16:58:55 +08:00
|
|
|
|
func RelayNotImplemented(c *gin.Context) {
|
2025-12-13 16:43:38 +08:00
|
|
|
|
err := types.OpenAIError{
|
2023-05-18 11:11:15 +08:00
|
|
|
|
Message: "API not implemented",
|
2023-12-01 01:29:13 +08:00
|
|
|
|
Type: "new_api_error",
|
2023-05-18 11:11:15 +08:00
|
|
|
|
Param: "",
|
|
|
|
|
|
Code: "api_not_implemented",
|
|
|
|
|
|
}
|
2023-06-23 22:59:44 +08:00
|
|
|
|
c.JSON(http.StatusNotImplemented, gin.H{
|
2023-05-18 11:11:15 +08:00
|
|
|
|
"error": err,
|
2023-04-28 16:58:55 +08:00
|
|
|
|
})
|
|
|
|
|
|
}
|
2023-06-17 09:46:07 +08:00
|
|
|
|
|
|
|
|
|
|
func RelayNotFound(c *gin.Context) {
|
2025-12-13 16:43:38 +08:00
|
|
|
|
err := types.OpenAIError{
|
2023-08-11 19:53:01 +08:00
|
|
|
|
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
|
|
|
|
|
|
Type: "invalid_request_error",
|
2023-06-17 09:46:07 +08:00
|
|
|
|
Param: "",
|
2023-08-11 19:53:01 +08:00
|
|
|
|
Code: "",
|
2023-06-17 09:46:07 +08:00
|
|
|
|
}
|
2023-06-23 22:59:44 +08:00
|
|
|
|
c.JSON(http.StatusNotFound, gin.H{
|
2023-06-17 09:46:07 +08:00
|
|
|
|
"error": err,
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
2024-06-12 20:37:42 +08:00
|
|
|
|
|
|
|
|
|
|
func RelayTask(c *gin.Context) {
|
|
|
|
|
|
retryTimes := common.RetryTimes
|
|
|
|
|
|
channelId := c.GetInt("channel_id")
|
|
|
|
|
|
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
|
2025-08-25 18:01:10 +08:00
|
|
|
|
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
taskErr := taskRelayHandler(c, relayInfo)
|
2024-06-12 20:37:42 +08:00
|
|
|
|
if taskErr == nil {
|
|
|
|
|
|
retryTimes = 0
|
|
|
|
|
|
}
|
2025-12-13 16:43:38 +08:00
|
|
|
|
retryParam := &service.RetryParam{
|
|
|
|
|
|
Ctx: c,
|
|
|
|
|
|
TokenGroup: relayInfo.TokenGroup,
|
|
|
|
|
|
ModelName: relayInfo.OriginModelName,
|
|
|
|
|
|
Retry: common.GetPointer(0),
|
|
|
|
|
|
}
|
|
|
|
|
|
for ; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && retryParam.GetRetry() < retryTimes; retryParam.IncreaseRetry() {
|
|
|
|
|
|
channel, newAPIError := getChannel(c, relayInfo, retryParam)
|
2025-07-10 17:49:53 +08:00
|
|
|
|
if newAPIError != nil {
|
2025-08-14 20:05:06 +08:00
|
|
|
|
logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
|
2025-07-10 17:49:53 +08:00
|
|
|
|
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
|
2024-06-12 20:37:42 +08:00
|
|
|
|
break
|
|
|
|
|
|
}
|
|
|
|
|
|
channelId = channel.Id
|
|
|
|
|
|
useChannel := c.GetStringSlice("use_channel")
|
|
|
|
|
|
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
|
|
|
|
|
c.Set("use_channel", useChannel)
|
2025-12-13 16:43:38 +08:00
|
|
|
|
logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry()))
|
2025-07-06 12:37:56 +08:00
|
|
|
|
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
2024-06-12 20:37:42 +08:00
|
|
|
|
|
2025-12-16 18:10:00 +08:00
|
|
|
|
requestBody, err := common.GetRequestBody(c)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
if common.IsRequestBodyTooLargeError(err) || errors.Is(err, common.ErrRequestBodyTooLarge) {
|
|
|
|
|
|
taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusRequestEntityTooLarge)
|
|
|
|
|
|
} else {
|
|
|
|
|
|
taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusBadRequest)
|
|
|
|
|
|
}
|
|
|
|
|
|
break
|
|
|
|
|
|
}
|
2024-06-12 20:37:42 +08:00
|
|
|
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
2025-08-25 18:01:10 +08:00
|
|
|
|
taskErr = taskRelayHandler(c, relayInfo)
|
2024-06-12 20:37:42 +08:00
|
|
|
|
}
|
|
|
|
|
|
useChannel := c.GetStringSlice("use_channel")
|
|
|
|
|
|
if len(useChannel) > 1 {
|
|
|
|
|
|
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
2025-08-14 20:05:06 +08:00
|
|
|
|
logger.LogInfo(c, retryLogStr)
|
2024-06-12 20:37:42 +08:00
|
|
|
|
}
|
|
|
|
|
|
if taskErr != nil {
|
|
|
|
|
|
if taskErr.StatusCode == http.StatusTooManyRequests {
|
|
|
|
|
|
taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
|
|
|
|
|
|
}
|
|
|
|
|
|
c.JSON(taskErr.StatusCode, taskErr)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-08-25 18:01:10 +08:00
|
|
|
|
func taskRelayHandler(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.TaskError {
|
2024-06-12 20:37:42 +08:00
|
|
|
|
var err *dto.TaskError
|
2025-08-25 18:01:10 +08:00
|
|
|
|
switch relayInfo.RelayMode {
|
2025-07-22 17:36:38 +08:00
|
|
|
|
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID:
|
2025-08-25 18:01:10 +08:00
|
|
|
|
err = relay.RelayTaskFetch(c, relayInfo.RelayMode)
|
2024-06-12 20:37:42 +08:00
|
|
|
|
default:
|
2025-08-25 18:01:10 +08:00
|
|
|
|
err = relay.RelayTaskSubmit(c, relayInfo)
|
2024-06-12 20:37:42 +08:00
|
|
|
|
}
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
|
|
|
|
|
|
if taskErr == nil {
|
|
|
|
|
|
return false
|
|
|
|
|
|
}
|
|
|
|
|
|
if retryTimes <= 0 {
|
|
|
|
|
|
return false
|
|
|
|
|
|
}
|
|
|
|
|
|
if _, ok := c.Get("specific_channel_id"); ok {
|
|
|
|
|
|
return false
|
|
|
|
|
|
}
|
|
|
|
|
|
if taskErr.StatusCode == http.StatusTooManyRequests {
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|
|
|
|
|
|
if taskErr.StatusCode == 307 {
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|
|
|
|
|
|
if taskErr.StatusCode/100 == 5 {
|
|
|
|
|
|
// 超时不重试
|
|
|
|
|
|
if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 {
|
|
|
|
|
|
return false
|
|
|
|
|
|
}
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|
|
|
|
|
|
if taskErr.StatusCode == http.StatusBadRequest {
|
|
|
|
|
|
return false
|
|
|
|
|
|
}
|
|
|
|
|
|
if taskErr.StatusCode == 408 {
|
|
|
|
|
|
// azure处理超时不重试
|
|
|
|
|
|
return false
|
|
|
|
|
|
}
|
|
|
|
|
|
if taskErr.LocalError {
|
|
|
|
|
|
return false
|
|
|
|
|
|
}
|
|
|
|
|
|
if taskErr.StatusCode/100 == 2 {
|
|
|
|
|
|
return false
|
|
|
|
|
|
}
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|