perf: optimize request metadata extraction and disabled field filtering (#5009)
* perf: optimize request metadata extraction and disabled field filtering * perf: optimize stream usage estimation path
This commit is contained in:
parent
006e801652
commit
ae6a03364d
@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
@ -20,6 +21,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type ModelRequest struct {
|
||||
@ -170,6 +172,14 @@ func Distribute() func(c *gin.Context) {
|
||||
// - application/x-www-form-urlencoded
|
||||
// - multipart/form-data
|
||||
func getModelFromRequest(c *gin.Context) (*ModelRequest, error) {
|
||||
if strings.HasPrefix(c.Request.Header.Get("Content-Type"), "application/json") {
|
||||
modelRequest, err := getModelFromJSONBody(c)
|
||||
if err != nil {
|
||||
return nil, errors.New(i18n.T(c, i18n.MsgDistributorInvalidRequest, map[string]any{"Error": err.Error()}))
|
||||
}
|
||||
return modelRequest, nil
|
||||
}
|
||||
|
||||
var modelRequest ModelRequest
|
||||
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
if err != nil {
|
||||
@ -178,6 +188,50 @@ func getModelFromRequest(c *gin.Context) (*ModelRequest, error) {
|
||||
return &modelRequest, nil
|
||||
}
|
||||
|
||||
func getModelFromJSONBody(c *gin.Context) (*ModelRequest, error) {
|
||||
storage, err := common.GetBodyStorage(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestBody, err := storage.Bytes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !gjson.ValidBytes(requestBody) {
|
||||
return nil, errors.New("invalid JSON request body")
|
||||
}
|
||||
|
||||
values := gjson.GetManyBytes(requestBody, "model", "group")
|
||||
model, err := getJSONStringValue(values[0], "model")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
group, err := getJSONStringValue(values[1], "group")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
|
||||
return nil, seekErr
|
||||
}
|
||||
c.Request.Body = io.NopCloser(storage)
|
||||
|
||||
return &ModelRequest{
|
||||
Model: model,
|
||||
Group: group,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getJSONStringValue(result gjson.Result, field string) (string, error) {
|
||||
if !result.Exists() || result.Type == gjson.Null {
|
||||
return "", nil
|
||||
}
|
||||
if result.Type != gjson.String {
|
||||
return "", fmt.Errorf("field %s must be a string", field)
|
||||
}
|
||||
return result.String(), nil
|
||||
}
|
||||
|
||||
func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
var modelRequest ModelRequest
|
||||
shouldSelectChannel := true
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
@ -92,78 +91,28 @@ func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, res
|
||||
return nil
|
||||
}
|
||||
|
||||
func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||||
|
||||
func processTokenData(relayMode int, data string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := common.UnmarshalJsonStr(data, &streamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
return ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount)
|
||||
case relayconstant.RelayModeCompletions:
|
||||
return processCompletions(streamResp, streamItems, responseTextBuilder)
|
||||
var streamResponse dto.CompletionsStreamResponse
|
||||
if err := common.UnmarshalJsonStr(data, &streamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
processCompletionsStreamResponse(streamResponse, responseTextBuilder)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
var streamResponses []dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
|
||||
common.SysLog("error processing stream response: " + err.Error())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
func processCompletionsStreamResponse(streamResponse dto.CompletionsStreamResponse, responseTextBuilder *strings.Builder) {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
|
||||
// 批量处理所有响应
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
if len(choice.Delta.ToolCalls) > *toolCount {
|
||||
*toolCount = len(choice.Delta.ToolCalls)
|
||||
}
|
||||
for _, tool := range choice.Delta.ToolCalls {
|
||||
responseTextBuilder.WriteString(tool.Function.Name)
|
||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
|
||||
var streamResponses []dto.CompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.CompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
||||
continue
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 批量处理所有响应
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
|
||||
|
||||
@ -119,7 +119,6 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
var responseTextBuilder strings.Builder
|
||||
var toolCount int
|
||||
var usage = &dto.Usage{}
|
||||
var streamItems []string // store stream items
|
||||
var lastStreamData string
|
||||
var secondLastStreamData string // 存储倒数第二个stream data,用于音频模型
|
||||
|
||||
@ -140,7 +139,10 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
}
|
||||
|
||||
lastStreamData = data
|
||||
streamItems = append(streamItems, data)
|
||||
if err := processTokenData(info.RelayMode, data, &responseTextBuilder, &toolCount); err != nil {
|
||||
logger.LogError(c, "error processing stream token data: "+err.Error())
|
||||
sr.Error(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@ -175,11 +177,6 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
}
|
||||
}
|
||||
|
||||
// 处理token计算
|
||||
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
|
||||
logger.LogError(c, "error processing tokens: "+err.Error())
|
||||
}
|
||||
|
||||
if !containStreamUsage {
|
||||
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
|
||||
@ -2054,6 +2054,17 @@ func TestRemoveDisabledFieldsDefaultFiltering(t *testing.T) {
|
||||
assertJSONEqual(t, `{"cache_control":{"type":"ephemeral"},"store":true}`, string(out))
|
||||
}
|
||||
|
||||
func TestRemoveDisabledFieldsNoControlledFieldsKeepsBody(t *testing.T) {
|
||||
input := `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`
|
||||
settings := dto.ChannelOtherSettings{}
|
||||
|
||||
out, err := RemoveDisabledFields([]byte(input), settings, false)
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveDisabledFields returned error: %v", err)
|
||||
}
|
||||
require.Equal(t, input, string(out))
|
||||
}
|
||||
|
||||
func TestRemoveDisabledFieldsAllowInferenceGeo(t *testing.T) {
|
||||
input := `{
|
||||
"inference_geo":"eu",
|
||||
|
||||
@ -18,6 +18,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type ThinkingContentInfo struct {
|
||||
@ -785,6 +786,9 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || channelPassThroughEnabled {
|
||||
return jsonData, nil
|
||||
}
|
||||
if !hasRemovableDisabledField(jsonData, channelOtherSettings) {
|
||||
return jsonData, nil
|
||||
}
|
||||
|
||||
var data map[string]interface{}
|
||||
if err := common.Unmarshal(jsonData, &data); err != nil {
|
||||
@ -851,6 +855,25 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
|
||||
return jsonDataAfter, nil
|
||||
}
|
||||
|
||||
func hasRemovableDisabledField(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings) bool {
|
||||
values := gjson.GetManyBytes(
|
||||
jsonData,
|
||||
"service_tier",
|
||||
"inference_geo",
|
||||
"speed",
|
||||
"store",
|
||||
"safety_identifier",
|
||||
"stream_options.include_obfuscation",
|
||||
)
|
||||
|
||||
return (!channelOtherSettings.AllowServiceTier && values[0].Exists()) ||
|
||||
(!channelOtherSettings.AllowInferenceGeo && values[1].Exists()) ||
|
||||
(!channelOtherSettings.AllowSpeed && values[2].Exists()) ||
|
||||
(channelOtherSettings.DisableStore && values[3].Exists()) ||
|
||||
(!channelOtherSettings.AllowSafetyIdentifier && values[4].Exists()) ||
|
||||
(!channelOtherSettings.AllowIncludeObfuscation && values[5].Exists())
|
||||
}
|
||||
|
||||
// RemoveGeminiDisabledFields removes disabled fields from Gemini request JSON data
|
||||
// Currently supports removing functionResponse.id field which Vertex AI does not support
|
||||
func RemoveGeminiDisabledFields(jsonData []byte) ([]byte, error) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user