From fddf54ccc5cf1c97c7a48c657c3cc204b3d9f68f Mon Sep 17 00:00:00 2001 From: CaIon Date: Fri, 22 May 2026 19:08:38 +0800 Subject: [PATCH] perf: reduce heap residency for large base64 relay requests Three layered optimizations targeting Gemini-style 5MB base64 payloads where RSS could balloon to tens of GB under concurrent load: 1. Byte-based param override (relay/common/override.go) - Switch legacy/operations hot paths from common.Marshal round-trips and map[string]any conversions to gjson/sjson on []byte directly. - Avoids cloning 5MB strings during each Set/Delete operation. 2. strings.Builder for Gemini response markdown (relay/channel/gemini/relay-gemini.go) - Replace string concatenation + strings.Join when assembling "![image](data:...;base64,DATA)" content for inline image responses. - Pre-allocates capacity from inline_data byte sizes. 3. Outbound BodyStorage + streaming Decoder (this commit's core) - New relay/common/outbound_body.go helper wraps marshaled upstream bodies in common.BodyStorage, allowing disk-cache mode to offload jsonData to a temp file while waiting for upstream TTFB. The original []byte can then be GC'd, removing ~5MB/req of heap residency during the longest window of a request. - All 7 relay handlers (gemini/claude/responses/embedding/image/compatible/ rerank) plus chat_completions_via_responses adopt the helper with defer closer.Close() and explicit jsonData = nil. - relay/common/relay_info.go: new UpstreamRequestBodySize so relay/channel/api_request.go can populate req.ContentLength (lost when body becomes a type-erased io.Reader). - common/gin.go UnmarshalBodyReusable: when storage is disk-backed and content-type is JSON, decode via DecodeJson(storage) instead of storage.Bytes()+Unmarshal, removing one transient 5MB copy per request. memory mode and form/multipart paths unchanged. --- common/custom-event.go | 4 +- common/gin.go | 20 +- relay/channel/api_request.go | 20 ++ relay/channel/gemini/relay-gemini.go | 105 +++++++-- relay/chat_completions_via_responses.go | 10 +- relay/claude_handler.go | 10 +- relay/common/outbound_body.go | 31 +++ relay/common/override.go | 301 +++++++++++++----------- relay/common/relay_info.go | 7 + relay/compatible_handler.go | 10 +- relay/embedding_handler.go | 10 +- relay/gemini_handler.go | 19 +- relay/image_handler.go | 9 +- relay/rerank_handler.go | 10 +- relay/responses_handler.go | 10 +- 15 files changed, 407 insertions(+), 169 deletions(-) create mode 100644 relay/common/outbound_body.go diff --git a/common/custom-event.go b/common/custom-event.go index 256db546..1bea2fd7 100644 --- a/common/custom-event.go +++ b/common/custom-event.go @@ -37,7 +37,7 @@ func checkWriter(writer io.Writer) stringWriter { // W3C Working Draft 29 October 2009 // http://www.w3.org/TR/2009/WD-eventsource-20091029/ -var contentType = []string{"text/event-stream"} +var writeContentType = []string{"text/event-stream"} var noCache = []string{"no-cache"} var fieldReplacer = strings.NewReplacer( @@ -79,7 +79,7 @@ func (r CustomEvent) WriteContentType(w http.ResponseWriter) { r.Mutex.Lock() defer r.Mutex.Unlock() header := w.Header() - header["Content-Type"] = contentType + header["Content-Type"] = writeContentType if _, exist := header["Cache-Control"]; !exist { header["Cache-Control"] = noCache diff --git a/common/gin.go b/common/gin.go index da7f8be4..315e8661 100644 --- a/common/gin.go +++ b/common/gin.go @@ -110,11 +110,29 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { if err != nil { return err } + contentType := c.Request.Header.Get("Content-Type") + + // disk-backed JSON: stream-decode directly from the file to avoid + // materializing the entire payload back into a transient []byte + // (diskStorage.Bytes() would ReadFull the whole file into the heap). + if storage.IsDisk() && strings.HasPrefix(contentType, "application/json") { + if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil { + return seekErr + } + if err := DecodeJson(storage, v); err != nil { + return err + } + if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil { + return seekErr + } + c.Request.Body = io.NopCloser(storage) + return nil + } + requestBody, err := storage.Bytes() if err != nil { return err } - contentType := c.Request.Header.Get("Content-Type") if strings.HasPrefix(contentType, "application/json") { err = Unmarshal(requestBody, v) } else if strings.Contains(contentType, gin.MIMEPOSTForm) { diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index d5d953b3..f945a838 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -25,6 +25,23 @@ import ( "github.com/gorilla/websocket" ) +// applyUpstreamContentLength populates req.ContentLength when the upstream +// body is wrapped in a BodyStorage (see relay/common/outbound_body.go). +// +// net/http.NewRequest only auto-detects ContentLength for *bytes.Reader, +// *bytes.Buffer and *strings.Reader. When the body is a type-erased io.Reader +// (which is the case for ReaderOnly(BodyStorage)), the Content-Length header +// would otherwise be omitted, forcing chunked transfer encoding and breaking +// some upstreams that require an explicit Content-Length. +func applyUpstreamContentLength(req *http.Request, info *common.RelayInfo) { + if info == nil { + return + } + if info.UpstreamRequestBodySize > 0 && req.ContentLength <= 0 { + req.ContentLength = info.UpstreamRequestBodySize + } +} + func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) { if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation { // multipart/form-data @@ -297,6 +314,7 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody if err != nil { return nil, fmt.Errorf("new request failed: %w", err) } + applyUpstreamContentLength(req, info) headers := req.Header err = a.SetupRequestHeader(c, &headers, info) if err != nil { @@ -326,6 +344,7 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod if err != nil { return nil, fmt.Errorf("new request failed: %w", err) } + applyUpstreamContentLength(req, info) // set form data req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) headers := req.Header @@ -522,6 +541,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.RelayInfo, req if err != nil { return nil, fmt.Errorf("new request failed: %w", err) } + applyUpstreamContentLength(req, info) req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(requestBody), nil } diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 0824a0e1..53020c3f 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -1079,17 +1079,47 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) FinishReason: constant.FinishReasonStop, } if len(candidate.Content.Parts) > 0 { - var texts []string + // 使用 strings.Builder 直接累积最终 content,避免: + // 1) 每张 inline image 生成一次中间 "![image](...)" 字符串 + // 2) 末尾 strings.Join 再分配一份等大缓冲 + // Gemini 图片返回时 InlineData.Data 可能是数 MB 的 base64, + // 上述两份临时分配在高并发下会显著放大堆驻留。 + var content strings.Builder + var inlineGrow int + for _, part := range candidate.Content.Parts { + if part.InlineData != nil { + inlineGrow += len(part.InlineData.MimeType) + len(part.InlineData.Data) + 32 + } + } + if inlineGrow > 0 { + content.Grow(inlineGrow) + } + appended := 0 + writeSep := func() { + if appended > 0 { + content.WriteByte('\n') + } + appended++ + } var toolCalls []dto.ToolCallResponse for _, part := range candidate.Content.Parts { if part.InlineData != nil { // 媒体内容 if strings.HasPrefix(part.InlineData.MimeType, "image") { - imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")" - texts = append(texts, imgText) + writeSep() + content.WriteString("![image](data:") + content.WriteString(part.InlineData.MimeType) + content.WriteString(";base64,") + content.WriteString(part.InlineData.Data) + content.WriteByte(')') } else { // 其他媒体类型,直接显示链接 - texts = append(texts, fmt.Sprintf("[media](data:%s;base64,%s)", part.InlineData.MimeType, part.InlineData.Data)) + writeSep() + content.WriteString("[media](data:") + content.WriteString(part.InlineData.MimeType) + content.WriteString(";base64,") + content.WriteString(part.InlineData.Data) + content.WriteByte(')') } } else if part.FunctionCall != nil { choice.FinishReason = constant.FinishReasonToolCalls @@ -1100,13 +1130,22 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) choice.Message.ReasoningContent = &part.Text } else { if part.ExecutableCode != nil { - texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```") + writeSep() + content.WriteString("```") + content.WriteString(part.ExecutableCode.Language) + content.WriteByte('\n') + content.WriteString(part.ExecutableCode.Code) + content.WriteString("\n```") } else if part.CodeExecutionResult != nil { - texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```") + writeSep() + content.WriteString("```output\n") + content.WriteString(part.CodeExecutionResult.Output) + content.WriteString("\n```") } else { // 过滤掉空行 if part.Text != "\n" { - texts = append(texts, part.Text) + writeSep() + content.WriteString(part.Text) } } } @@ -1115,7 +1154,7 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) choice.Message.SetToolCalls(toolCalls) isToolCall = true } - choice.Message.SetStringContent(strings.Join(texts, "\n")) + choice.Message.SetStringContent(content.String()) } if candidate.FinishReason != nil { @@ -1169,7 +1208,25 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d //Role: "assistant", }, } - var texts []string + // 使用 strings.Builder 直接累积 delta content,避免每张 image / 每个 + // 文本片段都先 `+` 拼出一份临时 string,再 strings.Join 再拷贝一遍。 + var content strings.Builder + var inlineGrow int + for _, part := range candidate.Content.Parts { + if part.InlineData != nil { + inlineGrow += len(part.InlineData.MimeType) + len(part.InlineData.Data) + 32 + } + } + if inlineGrow > 0 { + content.Grow(inlineGrow) + } + appended := 0 + writeSep := func() { + if appended > 0 { + content.WriteByte('\n') + } + appended++ + } isTools := false isThought := false if candidate.FinishReason != nil { @@ -1207,8 +1264,12 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d for _, part := range candidate.Content.Parts { if part.InlineData != nil { if strings.HasPrefix(part.InlineData.MimeType, "image") { - imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")" - texts = append(texts, imgText) + writeSep() + content.WriteString("![image](data:") + content.WriteString(part.InlineData.MimeType) + content.WriteString(";base64,") + content.WriteString(part.InlineData.Data) + content.WriteByte(')') } } else if part.FunctionCall != nil { isTools = true @@ -1219,23 +1280,33 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d } else if part.Thought { isThought = true - texts = append(texts, part.Text) + writeSep() + content.WriteString(part.Text) } else { if part.ExecutableCode != nil { - texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n") + writeSep() + content.WriteString("```") + content.WriteString(part.ExecutableCode.Language) + content.WriteByte('\n') + content.WriteString(part.ExecutableCode.Code) + content.WriteString("\n```\n") } else if part.CodeExecutionResult != nil { - texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n") + writeSep() + content.WriteString("```output\n") + content.WriteString(part.CodeExecutionResult.Output) + content.WriteString("\n```\n") } else { if part.Text != "\n" { - texts = append(texts, part.Text) + writeSep() + content.WriteString(part.Text) } } } } if isThought { - choice.Delta.SetReasoningContent(strings.Join(texts, "\n")) + choice.Delta.SetReasoningContent(content.String()) } else { - choice.Delta.SetContentString(strings.Join(texts, "\n")) + choice.Delta.SetContentString(content.String()) } if isTools { choice.FinishReason = &constant.FinishReasonToolCalls diff --git a/relay/chat_completions_via_responses.go b/relay/chat_completions_via_responses.go index 7a2eb9aa..c47da1fa 100644 --- a/relay/chat_completions_via_responses.go +++ b/relay/chat_completions_via_responses.go @@ -1,7 +1,6 @@ package relay import ( - "bytes" "io" "net/http" "strings" @@ -125,7 +124,14 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } - var requestBody io.Reader = bytes.NewBuffer(jsonData) + body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData) + if err != nil { + return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) + } + defer closer.Close() + jsonData = nil + info.UpstreamRequestBodySize = size + var requestBody io.Reader = body var httpResp *http.Response resp, err := adaptor.DoRequest(c, info, requestBody) diff --git a/relay/claude_handler.go b/relay/claude_handler.go index ec028c71..7ec934f9 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -1,7 +1,6 @@ package relay import ( - "bytes" "encoding/json" "fmt" "io" @@ -179,7 +178,14 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ } logger.LogDebug(c, "requestBody: %s", jsonData) - requestBody = bytes.NewBuffer(jsonData) + body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData) + if err != nil { + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) + } + defer closer.Close() + jsonData = nil + info.UpstreamRequestBodySize = size + requestBody = body } statusCodeMappingStr := c.GetString("status_code_mapping") diff --git a/relay/common/outbound_body.go b/relay/common/outbound_body.go new file mode 100644 index 00000000..94ef8dde --- /dev/null +++ b/relay/common/outbound_body.go @@ -0,0 +1,31 @@ +package common + +import ( + "io" + + "github.com/QuantumNous/new-api/common" +) + +// NewOutboundJSONBody wraps the already-marshaled upstream request body into a +// BodyStorage. When disk cache is enabled and the payload exceeds the configured +// threshold, the data is written to a temp file and the original []byte can be +// GC'd, significantly reducing the heap residency while waiting for the +// upstream provider to respond (the dominant cost for large base64 payloads). +// +// In memory mode the underlying memoryStorage reuses the same backing array, +// so this is equivalent to bytes.NewReader(data) in terms of memory usage. +// +// The caller MUST invoke closer.Close() once the upstream call has finished +// (typically via defer) to release the disk file / memory accounting. +// +// The returned reader is wrapped with common.ReaderOnly to prevent the HTTP +// transport from prematurely closing the underlying BodyStorage. The returned +// size is meant to be propagated to http.Request.ContentLength because the +// type-erased io.Reader prevents net/http from auto-detecting it. +func NewOutboundJSONBody(data []byte) (body io.Reader, size int64, closer io.Closer, err error) { + storage, err := common.CreateBodyStorage(data) + if err != nil { + return nil, 0, nil, err + } + return common.ReaderOnly(storage), storage.Size(), storage, nil +} diff --git a/relay/common/override.go b/relay/common/override.go index 5368061d..db59482a 100644 --- a/relay/common/override.go +++ b/relay/common/override.go @@ -153,9 +153,8 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c } } - // 使用新方法 - result, err := applyOperations(string(workingJSON), operations, conditionContext) - return []byte(result), err + // 使用新方法(基于 []byte,避免整包 string 拷贝) + return applyOperations(workingJSON, operations, conditionContext) } // 直接使用旧方法 @@ -510,13 +509,13 @@ func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, return operations, true } -func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) { +func checkConditions(data []byte, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) { if len(conditions) == 0 { return true, nil // 没有条件,直接通过 } results := make([]bool, len(conditions)) for i, condition := range conditions { - result, err := checkSingleCondition(jsonStr, contextJSON, condition) + result, err := checkSingleCondition(data, contextJSON, condition) if err != nil { return false, err } @@ -529,10 +528,10 @@ func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperatio return lo.SomeBy(results, func(item bool) bool { return item }), nil } -func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperation) (bool, error) { +func checkSingleCondition(data []byte, contextJSON string, condition ConditionOperation) (bool, error) { // 处理负数索引 - path := processNegativeIndex(jsonStr, condition.Path) - value := gjson.Get(jsonStr, path) + path := processNegativeIndex(data, condition.Path) + value := gjson.GetBytes(data, path) if !value.Exists() && contextJSON != "" { value = gjson.Get(contextJSON, condition.Path) } @@ -561,7 +560,7 @@ func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperat return result, nil } -func processNegativeIndex(jsonStr string, path string) string { +func processNegativeIndex(data []byte, path string) string { matches := negativeIndexRegexp.FindAllStringSubmatch(path, -1) if len(matches) == 0 { @@ -578,7 +577,7 @@ func processNegativeIndex(jsonStr string, path string) string { arrayPath = arrayPath[:len(arrayPath)-1] } - array := gjson.Get(jsonStr, arrayPath) + array := gjson.GetBytes(data, arrayPath) if array.IsArray() { length := len(array.Array()) actualIndex := length + index @@ -667,36 +666,76 @@ func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool, } } -// applyOperationsLegacy 原参数覆盖方法 +// applyOperationsLegacy 原参数覆盖方法。 +// +// 旧实现把整个 jsonData unmarshal 成 map[string]interface{} 再 marshal 回来, +// 对包含大 base64 字段(如 Gemini inlineData.data)的请求会放大数倍内存 +// (interface 装箱、map bucket、再次 marshal)。 +// 这里改成在 []byte 上直接调用 sjson.SetBytes,按顶层 key 逐个写入, +// 不再把 payload 解码到 map[string]interface{}。 +// +// 语义保持:每个 paramOverride 顶层 key 视为字面 key(不解析点号路径), +// 与旧的 reqMap[key] = value 一致。包含 `.` `*` `?` `\` 的 key 会被转义, +// 防止被 sjson 当作嵌套路径或通配符。 func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}, auditRecorder *paramOverrideAuditRecorder) ([]byte, error) { - reqMap := make(map[string]interface{}) - err := common.Unmarshal(jsonData, &reqMap) - if err != nil { - return nil, err + if len(paramOverride) == 0 { + return jsonData, nil } + result := jsonData for key, value := range paramOverride { - reqMap[key] = value + escaped := escapeSjsonLiteralKey(key) + next, err := sjson.SetBytes(result, escaped, value) + if err != nil { + return nil, err + } + result = next auditRecorder.recordOperation("set", key, "", "", value) } - return common.Marshal(reqMap) + return result, nil } -func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) { +// escapeSjsonLiteralKey 把可能被 sjson 误判为路径或通配符的字符转义, +// 用于把字面 key 安全地传给 sjson.SetBytes / sjson.DeleteBytes。 +func escapeSjsonLiteralKey(key string) string { + if !strings.ContainsAny(key, ".*?\\") { + return key + } + var sb strings.Builder + sb.Grow(len(key) + 4) + for i := 0; i < len(key); i++ { + c := key[i] + switch c { + case '.', '*', '?', '\\': + sb.WriteByte('\\') + } + sb.WriteByte(c) + } + return sb.String() +} + +// applyOperations 在 []byte 上原地应用所有 param override 操作。 +// +// 旧实现走 string-based gjson/sjson,在 ApplyParamOverride 入口会做 +// string(jsonData) 与最终 []byte(result) 各一次整包拷贝,对大 base64 +// payload 来说每次重试都额外多花 2 倍 body 体积的临时内存。 +// 这里改成全程在 []byte 上工作,sjson.SetBytes / gjson.GetBytes 都是 +// 直接读写 []byte,每个操作只会产生一份新 buffer。 +func applyOperations(jsonData []byte, operations []ParamOperation, conditionContext map[string]interface{}) ([]byte, error) { context := ensureContextMap(conditionContext) auditRecorder := getParamOverrideAuditRecorder(context) contextJSON, err := marshalContextJSON(context) if err != nil { - return "", fmt.Errorf("failed to marshal condition context: %v", err) + return nil, fmt.Errorf("failed to marshal condition context: %v", err) } - result := jsonStr + result := jsonData for _, op := range operations { // 检查条件是否满足 ok, err := checkConditions(result, contextJSON, op.Conditions, op.Logic) if err != nil { - return "", err + return nil, err } if !ok { continue // 条件不满足,跳过当前操作 @@ -707,7 +746,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte if isPathBasedOperation(op.Mode) { opPaths, err = resolveOperationPaths(result, opPath) if err != nil { - return "", err + return nil, err } if len(opPaths) == 0 { continue @@ -725,10 +764,10 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte } case "set": for _, path := range opPaths { - if op.KeepOrigin && gjson.Get(result, path).Exists() { + if op.KeepOrigin && gjson.GetBytes(result, path).Exists() { continue } - result, err = sjson.Set(result, path, op.Value) + result, err = sjson.SetBytes(result, path, op.Value) if err != nil { break } @@ -743,7 +782,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte } case "copy": if op.From == "" || op.To == "" { - return "", fmt.Errorf("copy from/to is required") + return nil, fmt.Errorf("copy from/to is required") } opFrom := processNegativeIndex(result, op.From) opTo := processNegativeIndex(result, op.To) @@ -843,9 +882,9 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte auditRecorder.recordOperation("return_error", op.Path, "", "", op.Value) returnErr, parseErr := parseParamOverrideReturnError(op.Value) if parseErr != nil { - return "", parseErr + return nil, parseErr } - return "", returnErr + return nil, returnErr case "prune_objects": for _, path := range opPaths { result, err = pruneObjects(result, path, contextJSON, op.Value) @@ -902,7 +941,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte case "pass_headers": headerNames, parseErr := parseHeaderPassThroughNames(op.Value) if parseErr != nil { - return "", parseErr + return nil, parseErr } for _, headerName := range headerNames { if err = copyHeaderInContext(context, headerName, headerName, op.KeepOrigin); err != nil { @@ -924,10 +963,10 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte contextJSON, err = marshalContextJSON(context) } default: - return "", fmt.Errorf("unknown operation: %s", op.Mode) + return nil, fmt.Errorf("unknown operation: %s", op.Mode) } if err != nil { - return "", fmt.Errorf("operation %s failed: %w", op.Mode, err) + return nil, fmt.Errorf("operation %s failed: %w", op.Mode, err) } } return result, nil @@ -1361,11 +1400,11 @@ func parseSyncTarget(spec string) (syncTarget, error) { } } -func readSyncTargetValue(jsonStr string, context map[string]interface{}, target syncTarget) (interface{}, bool, error) { +func readSyncTargetValue(data []byte, context map[string]interface{}, target syncTarget) (interface{}, bool, error) { switch target.kind { case "json": - path := processNegativeIndex(jsonStr, target.key) - value := gjson.Get(jsonStr, path) + path := processNegativeIndex(data, target.key) + value := gjson.GetBytes(data, path) if !value.Exists() || value.Type == gjson.Null { return nil, false, nil } @@ -1384,52 +1423,52 @@ func readSyncTargetValue(jsonStr string, context map[string]interface{}, target } } -func writeSyncTargetValue(jsonStr string, context map[string]interface{}, target syncTarget, value interface{}) (string, error) { +func writeSyncTargetValue(data []byte, context map[string]interface{}, target syncTarget, value interface{}) ([]byte, error) { switch target.kind { case "json": - path := processNegativeIndex(jsonStr, target.key) - nextJSON, err := sjson.Set(jsonStr, path, value) + path := processNegativeIndex(data, target.key) + nextJSON, err := sjson.SetBytes(data, path, value) if err != nil { - return "", err + return nil, err } return nextJSON, nil case "header": if err := setHeaderOverrideInContext(context, target.key, value, false); err != nil { - return "", err + return nil, err } - return jsonStr, nil + return data, nil default: - return "", fmt.Errorf("unsupported sync_fields target kind: %s", target.kind) + return nil, fmt.Errorf("unsupported sync_fields target kind: %s", target.kind) } } -func syncFieldsBetweenTargets(jsonStr string, context map[string]interface{}, fromSpec string, toSpec string) (string, error) { +func syncFieldsBetweenTargets(data []byte, context map[string]interface{}, fromSpec string, toSpec string) ([]byte, error) { fromTarget, err := parseSyncTarget(fromSpec) if err != nil { - return "", err + return nil, err } toTarget, err := parseSyncTarget(toSpec) if err != nil { - return "", err + return nil, err } - fromValue, fromExists, err := readSyncTargetValue(jsonStr, context, fromTarget) + fromValue, fromExists, err := readSyncTargetValue(data, context, fromTarget) if err != nil { - return "", err + return nil, err } - toValue, toExists, err := readSyncTargetValue(jsonStr, context, toTarget) + toValue, toExists, err := readSyncTargetValue(data, context, toTarget) if err != nil { - return "", err + return nil, err } // If one side exists and the other side is missing, sync the missing side. if fromExists && !toExists { - return writeSyncTargetValue(jsonStr, context, toTarget, fromValue) + return writeSyncTargetValue(data, context, toTarget, fromValue) } if toExists && !fromExists { - return writeSyncTargetValue(jsonStr, context, fromTarget, toValue) + return writeSyncTargetValue(data, context, fromTarget, toValue) } - return jsonStr, nil + return data, nil } func ensureMapKeyInContext(context map[string]interface{}, key string) map[string]interface{} { @@ -1503,24 +1542,24 @@ func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]in info.UseRuntimeHeadersOverride = true } -func moveValue(jsonStr, fromPath, toPath string) (string, error) { - sourceValue := gjson.Get(jsonStr, fromPath) +func moveValue(data []byte, fromPath, toPath string) ([]byte, error) { + sourceValue := gjson.GetBytes(data, fromPath) if !sourceValue.Exists() { - return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath) + return data, fmt.Errorf("source path does not exist: %s", fromPath) } - result, err := sjson.Set(jsonStr, toPath, sourceValue.Value()) + result, err := sjson.SetBytes(data, toPath, sourceValue.Value()) if err != nil { - return "", err + return nil, err } - return sjson.Delete(result, fromPath) + return sjson.DeleteBytes(result, fromPath) } -func copyValue(jsonStr, fromPath, toPath string) (string, error) { - sourceValue := gjson.Get(jsonStr, fromPath) +func copyValue(data []byte, fromPath, toPath string) ([]byte, error) { + sourceValue := gjson.GetBytes(data, fromPath) if !sourceValue.Exists() { - return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath) + return data, fmt.Errorf("source path does not exist: %s", fromPath) } - return sjson.Set(jsonStr, toPath, sourceValue.Value()) + return sjson.SetBytes(data, toPath, sourceValue.Value()) } func isPathBasedOperation(mode string) bool { @@ -1532,16 +1571,16 @@ func isPathBasedOperation(mode string) bool { } } -func resolveOperationPaths(jsonStr, path string) ([]string, error) { +func resolveOperationPaths(data []byte, path string) ([]string, error) { if !strings.Contains(path, "*") { return []string{path}, nil } - return expandWildcardPaths(jsonStr, path) + return expandWildcardPaths(data, path) } -func expandWildcardPaths(jsonStr, path string) ([]string, error) { +func expandWildcardPaths(data []byte, path string) ([]string, error) { var root interface{} - if err := common.Unmarshal([]byte(jsonStr), &root); err != nil { + if err := common.Unmarshal(data, &root); err != nil { return nil, err } @@ -1602,28 +1641,28 @@ func collectWildcardPaths(node interface{}, segments []string, prefix []string) } } -func deleteValue(jsonStr, path string) (string, error) { +func deleteValue(data []byte, path string) ([]byte, error) { if strings.TrimSpace(path) == "" { - return jsonStr, nil + return data, nil } - return sjson.Delete(jsonStr, path) + return sjson.DeleteBytes(data, path) } -func modifyValue(jsonStr, path string, value interface{}, keepOrigin, isPrepend bool) (string, error) { - current := gjson.Get(jsonStr, path) +func modifyValue(data []byte, path string, value interface{}, keepOrigin, isPrepend bool) ([]byte, error) { + current := gjson.GetBytes(data, path) switch { case current.IsArray(): - return modifyArray(jsonStr, path, value, isPrepend) + return modifyArray(data, path, value, isPrepend) case current.Type == gjson.String: - return modifyString(jsonStr, path, value, isPrepend) + return modifyString(data, path, value, isPrepend) case current.Type == gjson.JSON: - return mergeObjects(jsonStr, path, value, keepOrigin) + return mergeObjects(data, path, value, keepOrigin) } - return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type) + return data, fmt.Errorf("operation not supported for type: %v", current.Type) } -func modifyArray(jsonStr, path string, value interface{}, isPrepend bool) (string, error) { - current := gjson.Get(jsonStr, path) +func modifyArray(data []byte, path string, value interface{}, isPrepend bool) ([]byte, error) { + current := gjson.GetBytes(data, path) var newArray []interface{} // 添加新值 addValue := func() { @@ -1647,11 +1686,11 @@ func modifyArray(jsonStr, path string, value interface{}, isPrepend bool) (strin addOriginal() addValue() } - return sjson.Set(jsonStr, path, newArray) + return sjson.SetBytes(data, path, newArray) } -func modifyString(jsonStr, path string, value interface{}, isPrepend bool) (string, error) { - current := gjson.Get(jsonStr, path) +func modifyString(data []byte, path string, value interface{}, isPrepend bool) ([]byte, error) { + current := gjson.GetBytes(data, path) valueStr := fmt.Sprintf("%v", value) var newStr string if isPrepend { @@ -1659,17 +1698,17 @@ func modifyString(jsonStr, path string, value interface{}, isPrepend bool) (stri } else { newStr = current.String() + valueStr } - return sjson.Set(jsonStr, path, newStr) + return sjson.SetBytes(data, path, newStr) } -func trimStringValue(jsonStr, path string, value interface{}, isPrefix bool) (string, error) { - current := gjson.Get(jsonStr, path) +func trimStringValue(data []byte, path string, value interface{}, isPrefix bool) ([]byte, error) { + current := gjson.GetBytes(data, path) if current.Type != gjson.String { - return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type) + return data, fmt.Errorf("operation not supported for type: %v", current.Type) } if value == nil { - return jsonStr, fmt.Errorf("trim value is required") + return data, fmt.Errorf("trim value is required") } valueStr := fmt.Sprintf("%v", value) @@ -1679,69 +1718,69 @@ func trimStringValue(jsonStr, path string, value interface{}, isPrefix bool) (st } else { newStr = strings.TrimSuffix(current.String(), valueStr) } - return sjson.Set(jsonStr, path, newStr) + return sjson.SetBytes(data, path, newStr) } -func ensureStringAffix(jsonStr, path string, value interface{}, isPrefix bool) (string, error) { - current := gjson.Get(jsonStr, path) +func ensureStringAffix(data []byte, path string, value interface{}, isPrefix bool) ([]byte, error) { + current := gjson.GetBytes(data, path) if current.Type != gjson.String { - return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type) + return data, fmt.Errorf("operation not supported for type: %v", current.Type) } if value == nil { - return jsonStr, fmt.Errorf("ensure value is required") + return data, fmt.Errorf("ensure value is required") } valueStr := fmt.Sprintf("%v", value) if valueStr == "" { - return jsonStr, fmt.Errorf("ensure value is required") + return data, fmt.Errorf("ensure value is required") } currentStr := current.String() if isPrefix { if strings.HasPrefix(currentStr, valueStr) { - return jsonStr, nil + return data, nil } - return sjson.Set(jsonStr, path, valueStr+currentStr) + return sjson.SetBytes(data, path, valueStr+currentStr) } if strings.HasSuffix(currentStr, valueStr) { - return jsonStr, nil + return data, nil } - return sjson.Set(jsonStr, path, currentStr+valueStr) + return sjson.SetBytes(data, path, currentStr+valueStr) } -func transformStringValue(jsonStr, path string, transform func(string) string) (string, error) { - current := gjson.Get(jsonStr, path) +func transformStringValue(data []byte, path string, transform func(string) string) ([]byte, error) { + current := gjson.GetBytes(data, path) if current.Type != gjson.String { - return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type) + return data, fmt.Errorf("operation not supported for type: %v", current.Type) } - return sjson.Set(jsonStr, path, transform(current.String())) + return sjson.SetBytes(data, path, transform(current.String())) } -func replaceStringValue(jsonStr, path, from, to string) (string, error) { - current := gjson.Get(jsonStr, path) +func replaceStringValue(data []byte, path, from, to string) ([]byte, error) { + current := gjson.GetBytes(data, path) if current.Type != gjson.String { - return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type) + return data, fmt.Errorf("operation not supported for type: %v", current.Type) } if from == "" { - return jsonStr, fmt.Errorf("replace from is required") + return data, fmt.Errorf("replace from is required") } - return sjson.Set(jsonStr, path, strings.ReplaceAll(current.String(), from, to)) + return sjson.SetBytes(data, path, strings.ReplaceAll(current.String(), from, to)) } -func regexReplaceStringValue(jsonStr, path, pattern, replacement string) (string, error) { - current := gjson.Get(jsonStr, path) +func regexReplaceStringValue(data []byte, path, pattern, replacement string) ([]byte, error) { + current := gjson.GetBytes(data, path) if current.Type != gjson.String { - return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type) + return data, fmt.Errorf("operation not supported for type: %v", current.Type) } if pattern == "" { - return jsonStr, fmt.Errorf("regex pattern is required") + return data, fmt.Errorf("regex pattern is required") } re, err := regexp.Compile(pattern) if err != nil { - return jsonStr, err + return data, err } - return sjson.Set(jsonStr, path, re.ReplaceAllString(current.String(), replacement)) + return sjson.SetBytes(data, path, re.ReplaceAllString(current.String(), replacement)) } type pruneObjectsOptions struct { @@ -1750,37 +1789,33 @@ type pruneObjectsOptions struct { recursive bool } -func pruneObjects(jsonStr, path, contextJSON string, value interface{}) (string, error) { +func pruneObjects(data []byte, path, contextJSON string, value interface{}) ([]byte, error) { options, err := parsePruneObjectsOptions(value) if err != nil { - return "", err + return nil, err } if path == "" { var root interface{} - if err := common.Unmarshal([]byte(jsonStr), &root); err != nil { - return "", err + if err := common.Unmarshal(data, &root); err != nil { + return nil, err } cleaned, _, err := pruneObjectsNode(root, options, contextJSON, true) if err != nil { - return "", err + return nil, err } - cleanedBytes, err := common.Marshal(cleaned) - if err != nil { - return "", err - } - return string(cleanedBytes), nil + return common.Marshal(cleaned) } - target := gjson.Get(jsonStr, path) + target := gjson.GetBytes(data, path) if !target.Exists() { - return jsonStr, nil + return data, nil } var targetNode interface{} if target.Type == gjson.JSON { - if err := common.Unmarshal([]byte(target.Raw), &targetNode); err != nil { - return "", err + if err := common.UnmarshalJsonStr(target.Raw, &targetNode); err != nil { + return nil, err } } else { targetNode = target.Value() @@ -1788,13 +1823,13 @@ func pruneObjects(jsonStr, path, contextJSON string, value interface{}) (string, cleaned, _, err := pruneObjectsNode(targetNode, options, contextJSON, true) if err != nil { - return "", err + return nil, err } cleanedBytes, err := common.Marshal(cleaned) if err != nil { - return "", err + return nil, err } - return sjson.SetRaw(jsonStr, path, string(cleanedBytes)) + return sjson.SetRawBytes(data, path, cleanedBytes) } func parsePruneObjectsOptions(value interface{}) (pruneObjectsOptions, error) { @@ -1970,16 +2005,16 @@ func shouldPruneObject(node map[string]interface{}, options pruneObjectsOptions, if err != nil { return false, err } - return checkConditions(string(nodeBytes), contextJSON, options.conditions, options.logic) + return checkConditions(nodeBytes, contextJSON, options.conditions, options.logic) } -func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) { - current := gjson.Get(jsonStr, path) +func mergeObjects(data []byte, path string, value interface{}, keepOrigin bool) ([]byte, error) { + current := gjson.GetBytes(data, path) var currentMap, newMap map[string]interface{} - // 解析当前值 - if err := common.Unmarshal([]byte(current.Raw), ¤tMap); err != nil { - return "", err + // 解析当前值(current.Raw 是 data 的子串,避免再分配一份) + if err := common.UnmarshalJsonStr(current.Raw, ¤tMap); err != nil { + return nil, err } // 解析新值 switch v := value.(type) { @@ -1988,7 +2023,7 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str default: jsonBytes, _ := common.Marshal(v) if err := common.Unmarshal(jsonBytes, &newMap); err != nil { - return "", err + return nil, err } } // 合并 @@ -2001,7 +2036,7 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str result[k] = v } } - return sjson.Set(jsonStr, path, result) + return sjson.SetBytes(data, path, result) } // BuildParamOverrideContext 提供 ApplyParamOverride 可用的上下文信息。 diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 8a6c471e..2f7afd39 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -154,6 +154,13 @@ type RelayInfo struct { UseRuntimeHeadersOverride bool ParamOverrideAudit []string + // UpstreamRequestBodySize is the byte size of the marshaled upstream request + // body. It is set when the body is wrapped in a BodyStorage (see + // relay/common/outbound_body.go), so that DoApiRequest can populate + // http.Request.ContentLength manually (net/http only auto-detects it for + // *bytes.Reader/Buffer/strings.Reader). 0 means "let net/http decide". + UpstreamRequestBodySize int64 + PriceData types.PriceData // TieredBillingSnapshot is a frozen snapshot of tiered billing rules diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index fdd54f39..a68cfe73 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -1,7 +1,6 @@ package relay import ( - "bytes" "fmt" "io" "net/http" @@ -176,7 +175,14 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types logger.LogDebug(c, "text request body: %s", jsonData) - requestBody = bytes.NewBuffer(jsonData) + body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData) + if err != nil { + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) + } + defer closer.Close() + jsonData = nil + info.UpstreamRequestBodySize = size + requestBody = body } var httpResp *http.Response diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index e2fda93e..a12ef8d3 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -1,7 +1,6 @@ package relay import ( - "bytes" "fmt" "io" "net/http" @@ -59,7 +58,14 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * } logger.LogDebug(c, "converted embedding request body: %s", jsonData) - var requestBody io.Reader = bytes.NewBuffer(jsonData) + body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData) + if err != nil { + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) + } + defer closer.Close() + jsonData = nil + info.UpstreamRequestBodySize = size + var requestBody io.Reader = body statusCodeMappingStr := c.GetString("status_code_mapping") resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index df3bf47c..8f64552a 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -1,7 +1,6 @@ package relay import ( - "bytes" "fmt" "io" "net/http" @@ -165,7 +164,14 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ logger.LogDebug(c, "Gemini request body: %s", jsonData) - requestBody = bytes.NewReader(jsonData) + body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData) + if err != nil { + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) + } + defer closer.Close() + jsonData = nil + info.UpstreamRequestBodySize = size + requestBody = body } resp, err := adaptor.DoRequest(c, info, requestBody) @@ -263,7 +269,14 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI } } logger.LogDebug(c, "Gemini embedding request body: %s", jsonData) - requestBody = bytes.NewReader(jsonData) + body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData) + if err != nil { + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) + } + defer closer.Close() + jsonData = nil + info.UpstreamRequestBodySize = size + requestBody = body resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { diff --git a/relay/image_handler.go b/relay/image_handler.go index 7b3d961b..2c2990bb 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -77,7 +77,14 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type } logger.LogDebug(c, "image request body: %s", jsonData) - requestBody = bytes.NewBuffer(jsonData) + body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData) + if err != nil { + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) + } + defer closer.Close() + jsonData = nil + info.UpstreamRequestBodySize = size + requestBody = body } } diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index edc69f68..f1c19e27 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -1,7 +1,6 @@ package relay import ( - "bytes" "fmt" "io" "net/http" @@ -69,7 +68,14 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ } logger.LogDebug(c, "Rerank request body: %s", jsonData) - requestBody = bytes.NewBuffer(jsonData) + body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData) + if err != nil { + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) + } + defer closer.Close() + jsonData = nil + info.UpstreamRequestBodySize = size + requestBody = body } resp, err := adaptor.DoRequest(c, info, requestBody) diff --git a/relay/responses_handler.go b/relay/responses_handler.go index 54ca3cbc..010c38bb 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -1,7 +1,6 @@ package relay import ( - "bytes" "fmt" "io" "net/http" @@ -104,7 +103,14 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * } logger.LogDebug(c, "requestBody: %s", jsonData) - requestBody = bytes.NewBuffer(jsonData) + body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData) + if err != nil { + return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) + } + defer closer.Close() + jsonData = nil + info.UpstreamRequestBodySize = size + requestBody = body } var httpResp *http.Response