From ae6a03364d1e73d46d8e8f03a6c1805e558584c6 Mon Sep 17 00:00:00 2001 From: Seefs <40468931+seefs001@users.noreply.github.com> Date: Fri, 22 May 2026 10:32:11 +0800 Subject: [PATCH] perf: optimize request metadata extraction and disabled field filtering (#5009) * perf: optimize request metadata extraction and disabled field filtering * perf: optimize stream usage estimation path --- middleware/distributor.go | 54 +++++++++++++++++++ relay/channel/openai/helper.go | 79 +++++----------------------- relay/channel/openai/relay-openai.go | 11 ++-- relay/common/override_test.go | 11 ++++ relay/common/relay_info.go | 23 ++++++++ 5 files changed, 106 insertions(+), 72 deletions(-) diff --git a/middleware/distributor.go b/middleware/distributor.go index 2263fae3..771719b9 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -3,6 +3,7 @@ package middleware import ( "errors" "fmt" + "io" "net/http" "slices" "strconv" @@ -20,6 +21,7 @@ import ( "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" ) type ModelRequest struct { @@ -170,6 +172,14 @@ func Distribute() func(c *gin.Context) { // - application/x-www-form-urlencoded // - multipart/form-data func getModelFromRequest(c *gin.Context) (*ModelRequest, error) { + if strings.HasPrefix(c.Request.Header.Get("Content-Type"), "application/json") { + modelRequest, err := getModelFromJSONBody(c) + if err != nil { + return nil, errors.New(i18n.T(c, i18n.MsgDistributorInvalidRequest, map[string]any{"Error": err.Error()})) + } + return modelRequest, nil + } + var modelRequest ModelRequest err := common.UnmarshalBodyReusable(c, &modelRequest) if err != nil { @@ -178,6 +188,50 @@ func getModelFromRequest(c *gin.Context) (*ModelRequest, error) { return &modelRequest, nil } +func getModelFromJSONBody(c *gin.Context) (*ModelRequest, error) { + storage, err := common.GetBodyStorage(c) + if err != nil { + return nil, err + } + requestBody, err := storage.Bytes() + if err != nil { + return nil, err + } + if !gjson.ValidBytes(requestBody) { + return nil, errors.New("invalid JSON request body") + } + + values := gjson.GetManyBytes(requestBody, "model", "group") + model, err := getJSONStringValue(values[0], "model") + if err != nil { + return nil, err + } + group, err := getJSONStringValue(values[1], "group") + if err != nil { + return nil, err + } + + if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil { + return nil, seekErr + } + c.Request.Body = io.NopCloser(storage) + + return &ModelRequest{ + Model: model, + Group: group, + }, nil +} + +func getJSONStringValue(result gjson.Result, field string) (string, error) { + if !result.Exists() || result.Type == gjson.Null { + return "", nil + } + if result.Type != gjson.String { + return "", fmt.Errorf("field %s must be a string", field) + } + return result.String(), nil +} + func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { var modelRequest ModelRequest shouldSelectChannel := true diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go index 08811a77..1a01d06d 100644 --- a/relay/channel/openai/helper.go +++ b/relay/channel/openai/helper.go @@ -1,7 +1,6 @@ package openai import ( - "encoding/json" "strings" "github.com/QuantumNous/new-api/common" @@ -92,78 +91,28 @@ func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, res return nil } -func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error { - streamResp := "[" + strings.Join(streamItems, ",") + "]" - +func processTokenData(relayMode int, data string, responseTextBuilder *strings.Builder, toolCount *int) error { switch relayMode { case relayconstant.RelayModeChatCompletions: - return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount) + var streamResponse dto.ChatCompletionsStreamResponse + if err := common.UnmarshalJsonStr(data, &streamResponse); err != nil { + return err + } + return ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount) case relayconstant.RelayModeCompletions: - return processCompletions(streamResp, streamItems, responseTextBuilder) + var streamResponse dto.CompletionsStreamResponse + if err := common.UnmarshalJsonStr(data, &streamResponse); err != nil { + return err + } + processCompletionsStreamResponse(streamResponse, responseTextBuilder) } return nil } -func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error { - var streamResponses []dto.ChatCompletionsStreamResponse - if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil { - // 一次性解析失败,逐个解析 - common.SysLog("error unmarshalling stream response: " + err.Error()) - for _, item := range streamItems { - var streamResponse dto.ChatCompletionsStreamResponse - if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { - return err - } - if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil { - common.SysLog("error processing stream response: " + err.Error()) - } - } - return nil +func processCompletionsStreamResponse(streamResponse dto.CompletionsStreamResponse, responseTextBuilder *strings.Builder) { + for _, choice := range streamResponse.Choices { + responseTextBuilder.WriteString(choice.Text) } - - // 批量处理所有响应 - for _, streamResponse := range streamResponses { - for _, choice := range streamResponse.Choices { - responseTextBuilder.WriteString(choice.Delta.GetContentString()) - responseTextBuilder.WriteString(choice.Delta.GetReasoningContent()) - if choice.Delta.ToolCalls != nil { - if len(choice.Delta.ToolCalls) > *toolCount { - *toolCount = len(choice.Delta.ToolCalls) - } - for _, tool := range choice.Delta.ToolCalls { - responseTextBuilder.WriteString(tool.Function.Name) - responseTextBuilder.WriteString(tool.Function.Arguments) - } - } - } - } - return nil -} - -func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error { - var streamResponses []dto.CompletionsStreamResponse - if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil { - // 一次性解析失败,逐个解析 - common.SysLog("error unmarshalling stream response: " + err.Error()) - for _, item := range streamItems { - var streamResponse dto.CompletionsStreamResponse - if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { - continue - } - for _, choice := range streamResponse.Choices { - responseTextBuilder.WriteString(choice.Text) - } - } - return nil - } - - // 批量处理所有响应 - for _, streamResponse := range streamResponses { - for _, choice := range streamResponse.Choices { - responseTextBuilder.WriteString(choice.Text) - } - } - return nil } func handleLastResponse(lastStreamData string, responseId *string, createAt *int64, diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index c21d4399..d6a354f7 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -119,7 +119,6 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re var responseTextBuilder strings.Builder var toolCount int var usage = &dto.Usage{} - var streamItems []string // store stream items var lastStreamData string var secondLastStreamData string // 存储倒数第二个stream data,用于音频模型 @@ -140,7 +139,10 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re } lastStreamData = data - streamItems = append(streamItems, data) + if err := processTokenData(info.RelayMode, data, &responseTextBuilder, &toolCount); err != nil { + logger.LogError(c, "error processing stream token data: "+err.Error()) + sr.Error(err) + } } }) @@ -175,11 +177,6 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re } } - // 处理token计算 - if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil { - logger.LogError(c, "error processing tokens: "+err.Error()) - } - if !containStreamUsage { usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.GetEstimatePromptTokens()) usage.CompletionTokens += toolCount * 7 diff --git a/relay/common/override_test.go b/relay/common/override_test.go index 8c7b7772..79688113 100644 --- a/relay/common/override_test.go +++ b/relay/common/override_test.go @@ -2054,6 +2054,17 @@ func TestRemoveDisabledFieldsDefaultFiltering(t *testing.T) { assertJSONEqual(t, `{"cache_control":{"type":"ephemeral"},"store":true}`, string(out)) } +func TestRemoveDisabledFieldsNoControlledFieldsKeepsBody(t *testing.T) { + input := `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}` + settings := dto.ChannelOtherSettings{} + + out, err := RemoveDisabledFields([]byte(input), settings, false) + if err != nil { + t.Fatalf("RemoveDisabledFields returned error: %v", err) + } + require.Equal(t, input, string(out)) +} + func TestRemoveDisabledFieldsAllowInferenceGeo(t *testing.T) { input := `{ "inference_geo":"eu", diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 64d4d4ee..8a6c471e 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -18,6 +18,7 @@ import ( "github.com/gin-gonic/gin" "github.com/gorilla/websocket" + "github.com/tidwall/gjson" ) type ThinkingContentInfo struct { @@ -785,6 +786,9 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther if model_setting.GetGlobalSettings().PassThroughRequestEnabled || channelPassThroughEnabled { return jsonData, nil } + if !hasRemovableDisabledField(jsonData, channelOtherSettings) { + return jsonData, nil + } var data map[string]interface{} if err := common.Unmarshal(jsonData, &data); err != nil { @@ -851,6 +855,25 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther return jsonDataAfter, nil } +func hasRemovableDisabledField(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings) bool { + values := gjson.GetManyBytes( + jsonData, + "service_tier", + "inference_geo", + "speed", + "store", + "safety_identifier", + "stream_options.include_obfuscation", + ) + + return (!channelOtherSettings.AllowServiceTier && values[0].Exists()) || + (!channelOtherSettings.AllowInferenceGeo && values[1].Exists()) || + (!channelOtherSettings.AllowSpeed && values[2].Exists()) || + (channelOtherSettings.DisableStore && values[3].Exists()) || + (!channelOtherSettings.AllowSafetyIdentifier && values[4].Exists()) || + (!channelOtherSettings.AllowIncludeObfuscation && values[5].Exists()) +} + // RemoveGeminiDisabledFields removes disabled fields from Gemini request JSON data // Currently supports removing functionResponse.id field which Vertex AI does not support func RemoveGeminiDisabledFields(jsonData []byte) ([]byte, error) {