perf: reduce heap residency for large base64 relay requests

Three layered optimizations targeting Gemini-style 5MB base64 payloads where
RSS could balloon to tens of GB under concurrent load:

1. Byte-based param override (relay/common/override.go)
   - Switch legacy/operations hot paths from common.Marshal round-trips and
     map[string]any conversions to gjson/sjson on []byte directly.
   - Avoids cloning 5MB strings during each Set/Delete operation.

2. strings.Builder for Gemini response markdown (relay/channel/gemini/relay-gemini.go)
   - Replace string concatenation + strings.Join when assembling
     "![image](data:...;base64,DATA)" content for inline image responses.
   - Pre-allocates capacity from inline_data byte sizes.

3. Outbound BodyStorage + streaming Decoder (this commit's core)
   - New relay/common/outbound_body.go helper wraps marshaled upstream bodies
     in common.BodyStorage, allowing disk-cache mode to offload jsonData to
     a temp file while waiting for upstream TTFB. The original []byte can
     then be GC'd, removing ~5MB/req of heap residency during the longest
     window of a request.
   - All 7 relay handlers (gemini/claude/responses/embedding/image/compatible/
     rerank) plus chat_completions_via_responses adopt the helper with
     defer closer.Close() and explicit jsonData = nil.
   - relay/common/relay_info.go: new UpstreamRequestBodySize so
     relay/channel/api_request.go can populate req.ContentLength (lost when
     body becomes a type-erased io.Reader).
   - common/gin.go UnmarshalBodyReusable: when storage is disk-backed and
     content-type is JSON, decode via DecodeJson(storage) instead of
     storage.Bytes()+Unmarshal, removing one transient 5MB copy per request.
     memory mode and form/multipart paths unchanged.
This commit is contained in:
CaIon 2026-05-22 19:08:38 +08:00
parent b9bc6f0e21
commit fddf54ccc5
No known key found for this signature in database
GPG Key ID: 0CFA613529A9921D
15 changed files with 407 additions and 169 deletions

View File

@ -37,7 +37,7 @@ func checkWriter(writer io.Writer) stringWriter {
// W3C Working Draft 29 October 2009 // W3C Working Draft 29 October 2009
// http://www.w3.org/TR/2009/WD-eventsource-20091029/ // http://www.w3.org/TR/2009/WD-eventsource-20091029/
var contentType = []string{"text/event-stream"} var writeContentType = []string{"text/event-stream"}
var noCache = []string{"no-cache"} var noCache = []string{"no-cache"}
var fieldReplacer = strings.NewReplacer( var fieldReplacer = strings.NewReplacer(
@ -79,7 +79,7 @@ func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
r.Mutex.Lock() r.Mutex.Lock()
defer r.Mutex.Unlock() defer r.Mutex.Unlock()
header := w.Header() header := w.Header()
header["Content-Type"] = contentType header["Content-Type"] = writeContentType
if _, exist := header["Cache-Control"]; !exist { if _, exist := header["Cache-Control"]; !exist {
header["Cache-Control"] = noCache header["Cache-Control"] = noCache

View File

@ -110,11 +110,29 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
if err != nil { if err != nil {
return err return err
} }
contentType := c.Request.Header.Get("Content-Type")
// disk-backed JSON: stream-decode directly from the file to avoid
// materializing the entire payload back into a transient []byte
// (diskStorage.Bytes() would ReadFull the whole file into the heap).
if storage.IsDisk() && strings.HasPrefix(contentType, "application/json") {
if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
return seekErr
}
if err := DecodeJson(storage, v); err != nil {
return err
}
if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
return seekErr
}
c.Request.Body = io.NopCloser(storage)
return nil
}
requestBody, err := storage.Bytes() requestBody, err := storage.Bytes()
if err != nil { if err != nil {
return err return err
} }
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") { if strings.HasPrefix(contentType, "application/json") {
err = Unmarshal(requestBody, v) err = Unmarshal(requestBody, v)
} else if strings.Contains(contentType, gin.MIMEPOSTForm) { } else if strings.Contains(contentType, gin.MIMEPOSTForm) {

View File

@ -25,6 +25,23 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
// applyUpstreamContentLength populates req.ContentLength when the upstream
// body is wrapped in a BodyStorage (see relay/common/outbound_body.go).
//
// net/http.NewRequest only auto-detects ContentLength for *bytes.Reader,
// *bytes.Buffer and *strings.Reader. When the body is a type-erased io.Reader
// (which is the case for ReaderOnly(BodyStorage)), the Content-Length header
// would otherwise be omitted, forcing chunked transfer encoding and breaking
// some upstreams that require an explicit Content-Length.
func applyUpstreamContentLength(req *http.Request, info *common.RelayInfo) {
if info == nil {
return
}
if info.UpstreamRequestBodySize > 0 && req.ContentLength <= 0 {
req.ContentLength = info.UpstreamRequestBodySize
}
}
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) { func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation { if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
// multipart/form-data // multipart/form-data
@ -297,6 +314,7 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
if err != nil { if err != nil {
return nil, fmt.Errorf("new request failed: %w", err) return nil, fmt.Errorf("new request failed: %w", err)
} }
applyUpstreamContentLength(req, info)
headers := req.Header headers := req.Header
err = a.SetupRequestHeader(c, &headers, info) err = a.SetupRequestHeader(c, &headers, info)
if err != nil { if err != nil {
@ -326,6 +344,7 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
if err != nil { if err != nil {
return nil, fmt.Errorf("new request failed: %w", err) return nil, fmt.Errorf("new request failed: %w", err)
} }
applyUpstreamContentLength(req, info)
// set form data // set form data
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
headers := req.Header headers := req.Header
@ -522,6 +541,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.RelayInfo, req
if err != nil { if err != nil {
return nil, fmt.Errorf("new request failed: %w", err) return nil, fmt.Errorf("new request failed: %w", err)
} }
applyUpstreamContentLength(req, info)
req.GetBody = func() (io.ReadCloser, error) { req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(requestBody), nil return io.NopCloser(requestBody), nil
} }

View File

@ -1079,17 +1079,47 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
FinishReason: constant.FinishReasonStop, FinishReason: constant.FinishReasonStop,
} }
if len(candidate.Content.Parts) > 0 { if len(candidate.Content.Parts) > 0 {
var texts []string // 使用 strings.Builder 直接累积最终 content避免:
// 1) 每张 inline image 生成一次中间 "![image](...)" 字符串
// 2) 末尾 strings.Join 再分配一份等大缓冲
// Gemini 图片返回时 InlineData.Data 可能是数 MB 的 base64
// 上述两份临时分配在高并发下会显著放大堆驻留。
var content strings.Builder
var inlineGrow int
for _, part := range candidate.Content.Parts {
if part.InlineData != nil {
inlineGrow += len(part.InlineData.MimeType) + len(part.InlineData.Data) + 32
}
}
if inlineGrow > 0 {
content.Grow(inlineGrow)
}
appended := 0
writeSep := func() {
if appended > 0 {
content.WriteByte('\n')
}
appended++
}
var toolCalls []dto.ToolCallResponse var toolCalls []dto.ToolCallResponse
for _, part := range candidate.Content.Parts { for _, part := range candidate.Content.Parts {
if part.InlineData != nil { if part.InlineData != nil {
// 媒体内容 // 媒体内容
if strings.HasPrefix(part.InlineData.MimeType, "image") { if strings.HasPrefix(part.InlineData.MimeType, "image") {
imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")" writeSep()
texts = append(texts, imgText) content.WriteString("![image](data:")
content.WriteString(part.InlineData.MimeType)
content.WriteString(";base64,")
content.WriteString(part.InlineData.Data)
content.WriteByte(')')
} else { } else {
// 其他媒体类型,直接显示链接 // 其他媒体类型,直接显示链接
texts = append(texts, fmt.Sprintf("[media](data:%s;base64,%s)", part.InlineData.MimeType, part.InlineData.Data)) writeSep()
content.WriteString("[media](data:")
content.WriteString(part.InlineData.MimeType)
content.WriteString(";base64,")
content.WriteString(part.InlineData.Data)
content.WriteByte(')')
} }
} else if part.FunctionCall != nil { } else if part.FunctionCall != nil {
choice.FinishReason = constant.FinishReasonToolCalls choice.FinishReason = constant.FinishReasonToolCalls
@ -1100,13 +1130,22 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
choice.Message.ReasoningContent = &part.Text choice.Message.ReasoningContent = &part.Text
} else { } else {
if part.ExecutableCode != nil { if part.ExecutableCode != nil {
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```") writeSep()
content.WriteString("```")
content.WriteString(part.ExecutableCode.Language)
content.WriteByte('\n')
content.WriteString(part.ExecutableCode.Code)
content.WriteString("\n```")
} else if part.CodeExecutionResult != nil { } else if part.CodeExecutionResult != nil {
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```") writeSep()
content.WriteString("```output\n")
content.WriteString(part.CodeExecutionResult.Output)
content.WriteString("\n```")
} else { } else {
// 过滤掉空行 // 过滤掉空行
if part.Text != "\n" { if part.Text != "\n" {
texts = append(texts, part.Text) writeSep()
content.WriteString(part.Text)
} }
} }
} }
@ -1115,7 +1154,7 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
choice.Message.SetToolCalls(toolCalls) choice.Message.SetToolCalls(toolCalls)
isToolCall = true isToolCall = true
} }
choice.Message.SetStringContent(strings.Join(texts, "\n")) choice.Message.SetStringContent(content.String())
} }
if candidate.FinishReason != nil { if candidate.FinishReason != nil {
@ -1169,7 +1208,25 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
//Role: "assistant", //Role: "assistant",
}, },
} }
var texts []string // 使用 strings.Builder 直接累积 delta content避免每张 image / 每个
// 文本片段都先 `+` 拼出一份临时 string再 strings.Join 再拷贝一遍。
var content strings.Builder
var inlineGrow int
for _, part := range candidate.Content.Parts {
if part.InlineData != nil {
inlineGrow += len(part.InlineData.MimeType) + len(part.InlineData.Data) + 32
}
}
if inlineGrow > 0 {
content.Grow(inlineGrow)
}
appended := 0
writeSep := func() {
if appended > 0 {
content.WriteByte('\n')
}
appended++
}
isTools := false isTools := false
isThought := false isThought := false
if candidate.FinishReason != nil { if candidate.FinishReason != nil {
@ -1207,8 +1264,12 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
for _, part := range candidate.Content.Parts { for _, part := range candidate.Content.Parts {
if part.InlineData != nil { if part.InlineData != nil {
if strings.HasPrefix(part.InlineData.MimeType, "image") { if strings.HasPrefix(part.InlineData.MimeType, "image") {
imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")" writeSep()
texts = append(texts, imgText) content.WriteString("![image](data:")
content.WriteString(part.InlineData.MimeType)
content.WriteString(";base64,")
content.WriteString(part.InlineData.Data)
content.WriteByte(')')
} }
} else if part.FunctionCall != nil { } else if part.FunctionCall != nil {
isTools = true isTools = true
@ -1219,23 +1280,33 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
} else if part.Thought { } else if part.Thought {
isThought = true isThought = true
texts = append(texts, part.Text) writeSep()
content.WriteString(part.Text)
} else { } else {
if part.ExecutableCode != nil { if part.ExecutableCode != nil {
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n") writeSep()
content.WriteString("```")
content.WriteString(part.ExecutableCode.Language)
content.WriteByte('\n')
content.WriteString(part.ExecutableCode.Code)
content.WriteString("\n```\n")
} else if part.CodeExecutionResult != nil { } else if part.CodeExecutionResult != nil {
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n") writeSep()
content.WriteString("```output\n")
content.WriteString(part.CodeExecutionResult.Output)
content.WriteString("\n```\n")
} else { } else {
if part.Text != "\n" { if part.Text != "\n" {
texts = append(texts, part.Text) writeSep()
content.WriteString(part.Text)
} }
} }
} }
} }
if isThought { if isThought {
choice.Delta.SetReasoningContent(strings.Join(texts, "\n")) choice.Delta.SetReasoningContent(content.String())
} else { } else {
choice.Delta.SetContentString(strings.Join(texts, "\n")) choice.Delta.SetContentString(content.String())
} }
if isTools { if isTools {
choice.FinishReason = &constant.FinishReasonToolCalls choice.FinishReason = &constant.FinishReasonToolCalls

View File

@ -1,7 +1,6 @@
package relay package relay
import ( import (
"bytes"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@ -125,7 +124,14 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
} }
var requestBody io.Reader = bytes.NewBuffer(jsonData) body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
var requestBody io.Reader = body
var httpResp *http.Response var httpResp *http.Response
resp, err := adaptor.DoRequest(c, info, requestBody) resp, err := adaptor.DoRequest(c, info, requestBody)

View File

@ -1,7 +1,6 @@
package relay package relay
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -179,7 +178,14 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
} }
logger.LogDebug(c, "requestBody: %s", jsonData) logger.LogDebug(c, "requestBody: %s", jsonData)
requestBody = bytes.NewBuffer(jsonData) body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
} }
statusCodeMappingStr := c.GetString("status_code_mapping") statusCodeMappingStr := c.GetString("status_code_mapping")

View File

@ -0,0 +1,31 @@
package common
import (
"io"
"github.com/QuantumNous/new-api/common"
)
// NewOutboundJSONBody wraps the already-marshaled upstream request body into a
// BodyStorage. When disk cache is enabled and the payload exceeds the configured
// threshold, the data is written to a temp file and the original []byte can be
// GC'd, significantly reducing the heap residency while waiting for the
// upstream provider to respond (the dominant cost for large base64 payloads).
//
// In memory mode the underlying memoryStorage reuses the same backing array,
// so this is equivalent to bytes.NewReader(data) in terms of memory usage.
//
// The caller MUST invoke closer.Close() once the upstream call has finished
// (typically via defer) to release the disk file / memory accounting.
//
// The returned reader is wrapped with common.ReaderOnly to prevent the HTTP
// transport from prematurely closing the underlying BodyStorage. The returned
// size is meant to be propagated to http.Request.ContentLength because the
// type-erased io.Reader prevents net/http from auto-detecting it.
func NewOutboundJSONBody(data []byte) (body io.Reader, size int64, closer io.Closer, err error) {
storage, err := common.CreateBodyStorage(data)
if err != nil {
return nil, 0, nil, err
}
return common.ReaderOnly(storage), storage.Size(), storage, nil
}

View File

@ -153,9 +153,8 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c
} }
} }
// 使用新方法 // 使用新方法(基于 []byte避免整包 string 拷贝)
result, err := applyOperations(string(workingJSON), operations, conditionContext) return applyOperations(workingJSON, operations, conditionContext)
return []byte(result), err
} }
// 直接使用旧方法 // 直接使用旧方法
@ -510,13 +509,13 @@ func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation,
return operations, true return operations, true
} }
func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) { func checkConditions(data []byte, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) {
if len(conditions) == 0 { if len(conditions) == 0 {
return true, nil // 没有条件,直接通过 return true, nil // 没有条件,直接通过
} }
results := make([]bool, len(conditions)) results := make([]bool, len(conditions))
for i, condition := range conditions { for i, condition := range conditions {
result, err := checkSingleCondition(jsonStr, contextJSON, condition) result, err := checkSingleCondition(data, contextJSON, condition)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -529,10 +528,10 @@ func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperatio
return lo.SomeBy(results, func(item bool) bool { return item }), nil return lo.SomeBy(results, func(item bool) bool { return item }), nil
} }
func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperation) (bool, error) { func checkSingleCondition(data []byte, contextJSON string, condition ConditionOperation) (bool, error) {
// 处理负数索引 // 处理负数索引
path := processNegativeIndex(jsonStr, condition.Path) path := processNegativeIndex(data, condition.Path)
value := gjson.Get(jsonStr, path) value := gjson.GetBytes(data, path)
if !value.Exists() && contextJSON != "" { if !value.Exists() && contextJSON != "" {
value = gjson.Get(contextJSON, condition.Path) value = gjson.Get(contextJSON, condition.Path)
} }
@ -561,7 +560,7 @@ func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperat
return result, nil return result, nil
} }
func processNegativeIndex(jsonStr string, path string) string { func processNegativeIndex(data []byte, path string) string {
matches := negativeIndexRegexp.FindAllStringSubmatch(path, -1) matches := negativeIndexRegexp.FindAllStringSubmatch(path, -1)
if len(matches) == 0 { if len(matches) == 0 {
@ -578,7 +577,7 @@ func processNegativeIndex(jsonStr string, path string) string {
arrayPath = arrayPath[:len(arrayPath)-1] arrayPath = arrayPath[:len(arrayPath)-1]
} }
array := gjson.Get(jsonStr, arrayPath) array := gjson.GetBytes(data, arrayPath)
if array.IsArray() { if array.IsArray() {
length := len(array.Array()) length := len(array.Array())
actualIndex := length + index actualIndex := length + index
@ -667,36 +666,76 @@ func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool,
} }
} }
// applyOperationsLegacy 原参数覆盖方法 // applyOperationsLegacy 原参数覆盖方法。
//
// 旧实现把整个 jsonData unmarshal 成 map[string]interface{} 再 marshal 回来,
// 对包含大 base64 字段(如 Gemini inlineData.data的请求会放大数倍内存
// interface 装箱、map bucket、再次 marshal
// 这里改成在 []byte 上直接调用 sjson.SetBytes按顶层 key 逐个写入,
// 不再把 payload 解码到 map[string]interface{}。
//
// 语义保持:每个 paramOverride 顶层 key 视为字面 key不解析点号路径
// 与旧的 reqMap[key] = value 一致。包含 `.` `*` `?` `\` 的 key 会被转义,
// 防止被 sjson 当作嵌套路径或通配符。
func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}, auditRecorder *paramOverrideAuditRecorder) ([]byte, error) { func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}, auditRecorder *paramOverrideAuditRecorder) ([]byte, error) {
reqMap := make(map[string]interface{}) if len(paramOverride) == 0 {
err := common.Unmarshal(jsonData, &reqMap) return jsonData, nil
if err != nil {
return nil, err
} }
result := jsonData
for key, value := range paramOverride { for key, value := range paramOverride {
reqMap[key] = value escaped := escapeSjsonLiteralKey(key)
next, err := sjson.SetBytes(result, escaped, value)
if err != nil {
return nil, err
}
result = next
auditRecorder.recordOperation("set", key, "", "", value) auditRecorder.recordOperation("set", key, "", "", value)
} }
return common.Marshal(reqMap) return result, nil
} }
func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) { // escapeSjsonLiteralKey 把可能被 sjson 误判为路径或通配符的字符转义,
// 用于把字面 key 安全地传给 sjson.SetBytes / sjson.DeleteBytes。
func escapeSjsonLiteralKey(key string) string {
if !strings.ContainsAny(key, ".*?\\") {
return key
}
var sb strings.Builder
sb.Grow(len(key) + 4)
for i := 0; i < len(key); i++ {
c := key[i]
switch c {
case '.', '*', '?', '\\':
sb.WriteByte('\\')
}
sb.WriteByte(c)
}
return sb.String()
}
// applyOperations 在 []byte 上原地应用所有 param override 操作。
//
// 旧实现走 string-based gjson/sjson在 ApplyParamOverride 入口会做
// string(jsonData) 与最终 []byte(result) 各一次整包拷贝,对大 base64
// payload 来说每次重试都额外多花 2 倍 body 体积的临时内存。
// 这里改成全程在 []byte 上工作sjson.SetBytes / gjson.GetBytes 都是
// 直接读写 []byte每个操作只会产生一份新 buffer。
func applyOperations(jsonData []byte, operations []ParamOperation, conditionContext map[string]interface{}) ([]byte, error) {
context := ensureContextMap(conditionContext) context := ensureContextMap(conditionContext)
auditRecorder := getParamOverrideAuditRecorder(context) auditRecorder := getParamOverrideAuditRecorder(context)
contextJSON, err := marshalContextJSON(context) contextJSON, err := marshalContextJSON(context)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to marshal condition context: %v", err) return nil, fmt.Errorf("failed to marshal condition context: %v", err)
} }
result := jsonStr result := jsonData
for _, op := range operations { for _, op := range operations {
// 检查条件是否满足 // 检查条件是否满足
ok, err := checkConditions(result, contextJSON, op.Conditions, op.Logic) ok, err := checkConditions(result, contextJSON, op.Conditions, op.Logic)
if err != nil { if err != nil {
return "", err return nil, err
} }
if !ok { if !ok {
continue // 条件不满足,跳过当前操作 continue // 条件不满足,跳过当前操作
@ -707,7 +746,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
if isPathBasedOperation(op.Mode) { if isPathBasedOperation(op.Mode) {
opPaths, err = resolveOperationPaths(result, opPath) opPaths, err = resolveOperationPaths(result, opPath)
if err != nil { if err != nil {
return "", err return nil, err
} }
if len(opPaths) == 0 { if len(opPaths) == 0 {
continue continue
@ -725,10 +764,10 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
} }
case "set": case "set":
for _, path := range opPaths { for _, path := range opPaths {
if op.KeepOrigin && gjson.Get(result, path).Exists() { if op.KeepOrigin && gjson.GetBytes(result, path).Exists() {
continue continue
} }
result, err = sjson.Set(result, path, op.Value) result, err = sjson.SetBytes(result, path, op.Value)
if err != nil { if err != nil {
break break
} }
@ -743,7 +782,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
} }
case "copy": case "copy":
if op.From == "" || op.To == "" { if op.From == "" || op.To == "" {
return "", fmt.Errorf("copy from/to is required") return nil, fmt.Errorf("copy from/to is required")
} }
opFrom := processNegativeIndex(result, op.From) opFrom := processNegativeIndex(result, op.From)
opTo := processNegativeIndex(result, op.To) opTo := processNegativeIndex(result, op.To)
@ -843,9 +882,9 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
auditRecorder.recordOperation("return_error", op.Path, "", "", op.Value) auditRecorder.recordOperation("return_error", op.Path, "", "", op.Value)
returnErr, parseErr := parseParamOverrideReturnError(op.Value) returnErr, parseErr := parseParamOverrideReturnError(op.Value)
if parseErr != nil { if parseErr != nil {
return "", parseErr return nil, parseErr
} }
return "", returnErr return nil, returnErr
case "prune_objects": case "prune_objects":
for _, path := range opPaths { for _, path := range opPaths {
result, err = pruneObjects(result, path, contextJSON, op.Value) result, err = pruneObjects(result, path, contextJSON, op.Value)
@ -902,7 +941,7 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
case "pass_headers": case "pass_headers":
headerNames, parseErr := parseHeaderPassThroughNames(op.Value) headerNames, parseErr := parseHeaderPassThroughNames(op.Value)
if parseErr != nil { if parseErr != nil {
return "", parseErr return nil, parseErr
} }
for _, headerName := range headerNames { for _, headerName := range headerNames {
if err = copyHeaderInContext(context, headerName, headerName, op.KeepOrigin); err != nil { if err = copyHeaderInContext(context, headerName, headerName, op.KeepOrigin); err != nil {
@ -924,10 +963,10 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
contextJSON, err = marshalContextJSON(context) contextJSON, err = marshalContextJSON(context)
} }
default: default:
return "", fmt.Errorf("unknown operation: %s", op.Mode) return nil, fmt.Errorf("unknown operation: %s", op.Mode)
} }
if err != nil { if err != nil {
return "", fmt.Errorf("operation %s failed: %w", op.Mode, err) return nil, fmt.Errorf("operation %s failed: %w", op.Mode, err)
} }
} }
return result, nil return result, nil
@ -1361,11 +1400,11 @@ func parseSyncTarget(spec string) (syncTarget, error) {
} }
} }
func readSyncTargetValue(jsonStr string, context map[string]interface{}, target syncTarget) (interface{}, bool, error) { func readSyncTargetValue(data []byte, context map[string]interface{}, target syncTarget) (interface{}, bool, error) {
switch target.kind { switch target.kind {
case "json": case "json":
path := processNegativeIndex(jsonStr, target.key) path := processNegativeIndex(data, target.key)
value := gjson.Get(jsonStr, path) value := gjson.GetBytes(data, path)
if !value.Exists() || value.Type == gjson.Null { if !value.Exists() || value.Type == gjson.Null {
return nil, false, nil return nil, false, nil
} }
@ -1384,52 +1423,52 @@ func readSyncTargetValue(jsonStr string, context map[string]interface{}, target
} }
} }
func writeSyncTargetValue(jsonStr string, context map[string]interface{}, target syncTarget, value interface{}) (string, error) { func writeSyncTargetValue(data []byte, context map[string]interface{}, target syncTarget, value interface{}) ([]byte, error) {
switch target.kind { switch target.kind {
case "json": case "json":
path := processNegativeIndex(jsonStr, target.key) path := processNegativeIndex(data, target.key)
nextJSON, err := sjson.Set(jsonStr, path, value) nextJSON, err := sjson.SetBytes(data, path, value)
if err != nil { if err != nil {
return "", err return nil, err
} }
return nextJSON, nil return nextJSON, nil
case "header": case "header":
if err := setHeaderOverrideInContext(context, target.key, value, false); err != nil { if err := setHeaderOverrideInContext(context, target.key, value, false); err != nil {
return "", err return nil, err
} }
return jsonStr, nil return data, nil
default: default:
return "", fmt.Errorf("unsupported sync_fields target kind: %s", target.kind) return nil, fmt.Errorf("unsupported sync_fields target kind: %s", target.kind)
} }
} }
func syncFieldsBetweenTargets(jsonStr string, context map[string]interface{}, fromSpec string, toSpec string) (string, error) { func syncFieldsBetweenTargets(data []byte, context map[string]interface{}, fromSpec string, toSpec string) ([]byte, error) {
fromTarget, err := parseSyncTarget(fromSpec) fromTarget, err := parseSyncTarget(fromSpec)
if err != nil { if err != nil {
return "", err return nil, err
} }
toTarget, err := parseSyncTarget(toSpec) toTarget, err := parseSyncTarget(toSpec)
if err != nil { if err != nil {
return "", err return nil, err
} }
fromValue, fromExists, err := readSyncTargetValue(jsonStr, context, fromTarget) fromValue, fromExists, err := readSyncTargetValue(data, context, fromTarget)
if err != nil { if err != nil {
return "", err return nil, err
} }
toValue, toExists, err := readSyncTargetValue(jsonStr, context, toTarget) toValue, toExists, err := readSyncTargetValue(data, context, toTarget)
if err != nil { if err != nil {
return "", err return nil, err
} }
// If one side exists and the other side is missing, sync the missing side. // If one side exists and the other side is missing, sync the missing side.
if fromExists && !toExists { if fromExists && !toExists {
return writeSyncTargetValue(jsonStr, context, toTarget, fromValue) return writeSyncTargetValue(data, context, toTarget, fromValue)
} }
if toExists && !fromExists { if toExists && !fromExists {
return writeSyncTargetValue(jsonStr, context, fromTarget, toValue) return writeSyncTargetValue(data, context, fromTarget, toValue)
} }
return jsonStr, nil return data, nil
} }
func ensureMapKeyInContext(context map[string]interface{}, key string) map[string]interface{} { func ensureMapKeyInContext(context map[string]interface{}, key string) map[string]interface{} {
@ -1503,24 +1542,24 @@ func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]in
info.UseRuntimeHeadersOverride = true info.UseRuntimeHeadersOverride = true
} }
func moveValue(jsonStr, fromPath, toPath string) (string, error) { func moveValue(data []byte, fromPath, toPath string) ([]byte, error) {
sourceValue := gjson.Get(jsonStr, fromPath) sourceValue := gjson.GetBytes(data, fromPath)
if !sourceValue.Exists() { if !sourceValue.Exists() {
return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath) return data, fmt.Errorf("source path does not exist: %s", fromPath)
} }
result, err := sjson.Set(jsonStr, toPath, sourceValue.Value()) result, err := sjson.SetBytes(data, toPath, sourceValue.Value())
if err != nil { if err != nil {
return "", err return nil, err
} }
return sjson.Delete(result, fromPath) return sjson.DeleteBytes(result, fromPath)
} }
func copyValue(jsonStr, fromPath, toPath string) (string, error) { func copyValue(data []byte, fromPath, toPath string) ([]byte, error) {
sourceValue := gjson.Get(jsonStr, fromPath) sourceValue := gjson.GetBytes(data, fromPath)
if !sourceValue.Exists() { if !sourceValue.Exists() {
return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath) return data, fmt.Errorf("source path does not exist: %s", fromPath)
} }
return sjson.Set(jsonStr, toPath, sourceValue.Value()) return sjson.SetBytes(data, toPath, sourceValue.Value())
} }
func isPathBasedOperation(mode string) bool { func isPathBasedOperation(mode string) bool {
@ -1532,16 +1571,16 @@ func isPathBasedOperation(mode string) bool {
} }
} }
func resolveOperationPaths(jsonStr, path string) ([]string, error) { func resolveOperationPaths(data []byte, path string) ([]string, error) {
if !strings.Contains(path, "*") { if !strings.Contains(path, "*") {
return []string{path}, nil return []string{path}, nil
} }
return expandWildcardPaths(jsonStr, path) return expandWildcardPaths(data, path)
} }
func expandWildcardPaths(jsonStr, path string) ([]string, error) { func expandWildcardPaths(data []byte, path string) ([]string, error) {
var root interface{} var root interface{}
if err := common.Unmarshal([]byte(jsonStr), &root); err != nil { if err := common.Unmarshal(data, &root); err != nil {
return nil, err return nil, err
} }
@ -1602,28 +1641,28 @@ func collectWildcardPaths(node interface{}, segments []string, prefix []string)
} }
} }
func deleteValue(jsonStr, path string) (string, error) { func deleteValue(data []byte, path string) ([]byte, error) {
if strings.TrimSpace(path) == "" { if strings.TrimSpace(path) == "" {
return jsonStr, nil return data, nil
} }
return sjson.Delete(jsonStr, path) return sjson.DeleteBytes(data, path)
} }
func modifyValue(jsonStr, path string, value interface{}, keepOrigin, isPrepend bool) (string, error) { func modifyValue(data []byte, path string, value interface{}, keepOrigin, isPrepend bool) ([]byte, error) {
current := gjson.Get(jsonStr, path) current := gjson.GetBytes(data, path)
switch { switch {
case current.IsArray(): case current.IsArray():
return modifyArray(jsonStr, path, value, isPrepend) return modifyArray(data, path, value, isPrepend)
case current.Type == gjson.String: case current.Type == gjson.String:
return modifyString(jsonStr, path, value, isPrepend) return modifyString(data, path, value, isPrepend)
case current.Type == gjson.JSON: case current.Type == gjson.JSON:
return mergeObjects(jsonStr, path, value, keepOrigin) return mergeObjects(data, path, value, keepOrigin)
} }
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type) return data, fmt.Errorf("operation not supported for type: %v", current.Type)
} }
func modifyArray(jsonStr, path string, value interface{}, isPrepend bool) (string, error) { func modifyArray(data []byte, path string, value interface{}, isPrepend bool) ([]byte, error) {
current := gjson.Get(jsonStr, path) current := gjson.GetBytes(data, path)
var newArray []interface{} var newArray []interface{}
// 添加新值 // 添加新值
addValue := func() { addValue := func() {
@ -1647,11 +1686,11 @@ func modifyArray(jsonStr, path string, value interface{}, isPrepend bool) (strin
addOriginal() addOriginal()
addValue() addValue()
} }
return sjson.Set(jsonStr, path, newArray) return sjson.SetBytes(data, path, newArray)
} }
func modifyString(jsonStr, path string, value interface{}, isPrepend bool) (string, error) { func modifyString(data []byte, path string, value interface{}, isPrepend bool) ([]byte, error) {
current := gjson.Get(jsonStr, path) current := gjson.GetBytes(data, path)
valueStr := fmt.Sprintf("%v", value) valueStr := fmt.Sprintf("%v", value)
var newStr string var newStr string
if isPrepend { if isPrepend {
@ -1659,17 +1698,17 @@ func modifyString(jsonStr, path string, value interface{}, isPrepend bool) (stri
} else { } else {
newStr = current.String() + valueStr newStr = current.String() + valueStr
} }
return sjson.Set(jsonStr, path, newStr) return sjson.SetBytes(data, path, newStr)
} }
func trimStringValue(jsonStr, path string, value interface{}, isPrefix bool) (string, error) { func trimStringValue(data []byte, path string, value interface{}, isPrefix bool) ([]byte, error) {
current := gjson.Get(jsonStr, path) current := gjson.GetBytes(data, path)
if current.Type != gjson.String { if current.Type != gjson.String {
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type) return data, fmt.Errorf("operation not supported for type: %v", current.Type)
} }
if value == nil { if value == nil {
return jsonStr, fmt.Errorf("trim value is required") return data, fmt.Errorf("trim value is required")
} }
valueStr := fmt.Sprintf("%v", value) valueStr := fmt.Sprintf("%v", value)
@ -1679,69 +1718,69 @@ func trimStringValue(jsonStr, path string, value interface{}, isPrefix bool) (st
} else { } else {
newStr = strings.TrimSuffix(current.String(), valueStr) newStr = strings.TrimSuffix(current.String(), valueStr)
} }
return sjson.Set(jsonStr, path, newStr) return sjson.SetBytes(data, path, newStr)
} }
func ensureStringAffix(jsonStr, path string, value interface{}, isPrefix bool) (string, error) { func ensureStringAffix(data []byte, path string, value interface{}, isPrefix bool) ([]byte, error) {
current := gjson.Get(jsonStr, path) current := gjson.GetBytes(data, path)
if current.Type != gjson.String { if current.Type != gjson.String {
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type) return data, fmt.Errorf("operation not supported for type: %v", current.Type)
} }
if value == nil { if value == nil {
return jsonStr, fmt.Errorf("ensure value is required") return data, fmt.Errorf("ensure value is required")
} }
valueStr := fmt.Sprintf("%v", value) valueStr := fmt.Sprintf("%v", value)
if valueStr == "" { if valueStr == "" {
return jsonStr, fmt.Errorf("ensure value is required") return data, fmt.Errorf("ensure value is required")
} }
currentStr := current.String() currentStr := current.String()
if isPrefix { if isPrefix {
if strings.HasPrefix(currentStr, valueStr) { if strings.HasPrefix(currentStr, valueStr) {
return jsonStr, nil return data, nil
} }
return sjson.Set(jsonStr, path, valueStr+currentStr) return sjson.SetBytes(data, path, valueStr+currentStr)
} }
if strings.HasSuffix(currentStr, valueStr) { if strings.HasSuffix(currentStr, valueStr) {
return jsonStr, nil return data, nil
} }
return sjson.Set(jsonStr, path, currentStr+valueStr) return sjson.SetBytes(data, path, currentStr+valueStr)
} }
func transformStringValue(jsonStr, path string, transform func(string) string) (string, error) { func transformStringValue(data []byte, path string, transform func(string) string) ([]byte, error) {
current := gjson.Get(jsonStr, path) current := gjson.GetBytes(data, path)
if current.Type != gjson.String { if current.Type != gjson.String {
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type) return data, fmt.Errorf("operation not supported for type: %v", current.Type)
} }
return sjson.Set(jsonStr, path, transform(current.String())) return sjson.SetBytes(data, path, transform(current.String()))
} }
func replaceStringValue(jsonStr, path, from, to string) (string, error) { func replaceStringValue(data []byte, path, from, to string) ([]byte, error) {
current := gjson.Get(jsonStr, path) current := gjson.GetBytes(data, path)
if current.Type != gjson.String { if current.Type != gjson.String {
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type) return data, fmt.Errorf("operation not supported for type: %v", current.Type)
} }
if from == "" { if from == "" {
return jsonStr, fmt.Errorf("replace from is required") return data, fmt.Errorf("replace from is required")
} }
return sjson.Set(jsonStr, path, strings.ReplaceAll(current.String(), from, to)) return sjson.SetBytes(data, path, strings.ReplaceAll(current.String(), from, to))
} }
func regexReplaceStringValue(jsonStr, path, pattern, replacement string) (string, error) { func regexReplaceStringValue(data []byte, path, pattern, replacement string) ([]byte, error) {
current := gjson.Get(jsonStr, path) current := gjson.GetBytes(data, path)
if current.Type != gjson.String { if current.Type != gjson.String {
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type) return data, fmt.Errorf("operation not supported for type: %v", current.Type)
} }
if pattern == "" { if pattern == "" {
return jsonStr, fmt.Errorf("regex pattern is required") return data, fmt.Errorf("regex pattern is required")
} }
re, err := regexp.Compile(pattern) re, err := regexp.Compile(pattern)
if err != nil { if err != nil {
return jsonStr, err return data, err
} }
return sjson.Set(jsonStr, path, re.ReplaceAllString(current.String(), replacement)) return sjson.SetBytes(data, path, re.ReplaceAllString(current.String(), replacement))
} }
type pruneObjectsOptions struct { type pruneObjectsOptions struct {
@ -1750,37 +1789,33 @@ type pruneObjectsOptions struct {
recursive bool recursive bool
} }
func pruneObjects(jsonStr, path, contextJSON string, value interface{}) (string, error) { func pruneObjects(data []byte, path, contextJSON string, value interface{}) ([]byte, error) {
options, err := parsePruneObjectsOptions(value) options, err := parsePruneObjectsOptions(value)
if err != nil { if err != nil {
return "", err return nil, err
} }
if path == "" { if path == "" {
var root interface{} var root interface{}
if err := common.Unmarshal([]byte(jsonStr), &root); err != nil { if err := common.Unmarshal(data, &root); err != nil {
return "", err return nil, err
} }
cleaned, _, err := pruneObjectsNode(root, options, contextJSON, true) cleaned, _, err := pruneObjectsNode(root, options, contextJSON, true)
if err != nil { if err != nil {
return "", err return nil, err
} }
cleanedBytes, err := common.Marshal(cleaned) return common.Marshal(cleaned)
if err != nil {
return "", err
}
return string(cleanedBytes), nil
} }
target := gjson.Get(jsonStr, path) target := gjson.GetBytes(data, path)
if !target.Exists() { if !target.Exists() {
return jsonStr, nil return data, nil
} }
var targetNode interface{} var targetNode interface{}
if target.Type == gjson.JSON { if target.Type == gjson.JSON {
if err := common.Unmarshal([]byte(target.Raw), &targetNode); err != nil { if err := common.UnmarshalJsonStr(target.Raw, &targetNode); err != nil {
return "", err return nil, err
} }
} else { } else {
targetNode = target.Value() targetNode = target.Value()
@ -1788,13 +1823,13 @@ func pruneObjects(jsonStr, path, contextJSON string, value interface{}) (string,
cleaned, _, err := pruneObjectsNode(targetNode, options, contextJSON, true) cleaned, _, err := pruneObjectsNode(targetNode, options, contextJSON, true)
if err != nil { if err != nil {
return "", err return nil, err
} }
cleanedBytes, err := common.Marshal(cleaned) cleanedBytes, err := common.Marshal(cleaned)
if err != nil { if err != nil {
return "", err return nil, err
} }
return sjson.SetRaw(jsonStr, path, string(cleanedBytes)) return sjson.SetRawBytes(data, path, cleanedBytes)
} }
func parsePruneObjectsOptions(value interface{}) (pruneObjectsOptions, error) { func parsePruneObjectsOptions(value interface{}) (pruneObjectsOptions, error) {
@ -1970,16 +2005,16 @@ func shouldPruneObject(node map[string]interface{}, options pruneObjectsOptions,
if err != nil { if err != nil {
return false, err return false, err
} }
return checkConditions(string(nodeBytes), contextJSON, options.conditions, options.logic) return checkConditions(nodeBytes, contextJSON, options.conditions, options.logic)
} }
func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) { func mergeObjects(data []byte, path string, value interface{}, keepOrigin bool) ([]byte, error) {
current := gjson.Get(jsonStr, path) current := gjson.GetBytes(data, path)
var currentMap, newMap map[string]interface{} var currentMap, newMap map[string]interface{}
// 解析当前值 // 解析当前值current.Raw 是 data 的子串,避免再分配一份)
if err := common.Unmarshal([]byte(current.Raw), &currentMap); err != nil { if err := common.UnmarshalJsonStr(current.Raw, &currentMap); err != nil {
return "", err return nil, err
} }
// 解析新值 // 解析新值
switch v := value.(type) { switch v := value.(type) {
@ -1988,7 +2023,7 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
default: default:
jsonBytes, _ := common.Marshal(v) jsonBytes, _ := common.Marshal(v)
if err := common.Unmarshal(jsonBytes, &newMap); err != nil { if err := common.Unmarshal(jsonBytes, &newMap); err != nil {
return "", err return nil, err
} }
} }
// 合并 // 合并
@ -2001,7 +2036,7 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
result[k] = v result[k] = v
} }
} }
return sjson.Set(jsonStr, path, result) return sjson.SetBytes(data, path, result)
} }
// BuildParamOverrideContext 提供 ApplyParamOverride 可用的上下文信息。 // BuildParamOverrideContext 提供 ApplyParamOverride 可用的上下文信息。

View File

@ -154,6 +154,13 @@ type RelayInfo struct {
UseRuntimeHeadersOverride bool UseRuntimeHeadersOverride bool
ParamOverrideAudit []string ParamOverrideAudit []string
// UpstreamRequestBodySize is the byte size of the marshaled upstream request
// body. It is set when the body is wrapped in a BodyStorage (see
// relay/common/outbound_body.go), so that DoApiRequest can populate
// http.Request.ContentLength manually (net/http only auto-detects it for
// *bytes.Reader/Buffer/strings.Reader). 0 means "let net/http decide".
UpstreamRequestBodySize int64
PriceData types.PriceData PriceData types.PriceData
// TieredBillingSnapshot is a frozen snapshot of tiered billing rules // TieredBillingSnapshot is a frozen snapshot of tiered billing rules

View File

@ -1,7 +1,6 @@
package relay package relay
import ( import (
"bytes"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -176,7 +175,14 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
logger.LogDebug(c, "text request body: %s", jsonData) logger.LogDebug(c, "text request body: %s", jsonData)
requestBody = bytes.NewBuffer(jsonData) body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
} }
var httpResp *http.Response var httpResp *http.Response

View File

@ -1,7 +1,6 @@
package relay package relay
import ( import (
"bytes"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -59,7 +58,14 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
} }
logger.LogDebug(c, "converted embedding request body: %s", jsonData) logger.LogDebug(c, "converted embedding request body: %s", jsonData)
var requestBody io.Reader = bytes.NewBuffer(jsonData) body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
var requestBody io.Reader = body
statusCodeMappingStr := c.GetString("status_code_mapping") statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, info, requestBody) resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil { if err != nil {

View File

@ -1,7 +1,6 @@
package relay package relay
import ( import (
"bytes"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -165,7 +164,14 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
logger.LogDebug(c, "Gemini request body: %s", jsonData) logger.LogDebug(c, "Gemini request body: %s", jsonData)
requestBody = bytes.NewReader(jsonData) body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
} }
resp, err := adaptor.DoRequest(c, info, requestBody) resp, err := adaptor.DoRequest(c, info, requestBody)
@ -263,7 +269,14 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
} }
} }
logger.LogDebug(c, "Gemini embedding request body: %s", jsonData) logger.LogDebug(c, "Gemini embedding request body: %s", jsonData)
requestBody = bytes.NewReader(jsonData) body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
resp, err := adaptor.DoRequest(c, info, requestBody) resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil { if err != nil {

View File

@ -77,7 +77,14 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
} }
logger.LogDebug(c, "image request body: %s", jsonData) logger.LogDebug(c, "image request body: %s", jsonData)
requestBody = bytes.NewBuffer(jsonData) body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
} }
} }

View File

@ -1,7 +1,6 @@
package relay package relay
import ( import (
"bytes"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -69,7 +68,14 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
} }
logger.LogDebug(c, "Rerank request body: %s", jsonData) logger.LogDebug(c, "Rerank request body: %s", jsonData)
requestBody = bytes.NewBuffer(jsonData) body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
} }
resp, err := adaptor.DoRequest(c, info, requestBody) resp, err := adaptor.DoRequest(c, info, requestBody)

View File

@ -1,7 +1,6 @@
package relay package relay
import ( import (
"bytes"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -104,7 +103,14 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
} }
logger.LogDebug(c, "requestBody: %s", jsonData) logger.LogDebug(c, "requestBody: %s", jsonData)
requestBody = bytes.NewBuffer(jsonData) body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
} }
var httpResp *http.Response var httpResp *http.Response