new-api/relay/gemini_handler.go
CaIon fddf54ccc5
perf: reduce heap residency for large base64 relay requests
Three layered optimizations targeting Gemini-style 5MB base64 payloads where
RSS could balloon to tens of GB under concurrent load:

1. Byte-based param override (relay/common/override.go)
   - Switch legacy/operations hot paths from common.Marshal round-trips and
     map[string]any conversions to gjson/sjson on []byte directly.
   - Avoids cloning 5MB strings during each Set/Delete operation.

2. strings.Builder for Gemini response markdown (relay/channel/gemini/relay-gemini.go)
   - Replace string concatenation + strings.Join when assembling
     "![image](data:...;base64,DATA)" content for inline image responses.
   - Pre-allocates capacity from inline_data byte sizes.

3. Outbound BodyStorage + streaming Decoder (this commit's core)
   - New relay/common/outbound_body.go helper wraps marshaled upstream bodies
     in common.BodyStorage, allowing disk-cache mode to offload jsonData to
     a temp file while waiting for upstream TTFB. The original []byte can
     then be GC'd, removing ~5MB/req of heap residency during the longest
     window of a request.
   - All 7 relay handlers (gemini/claude/responses/embedding/image/compatible/
     rerank) plus chat_completions_via_responses adopt the helper with
     defer closer.Close() and explicit jsonData = nil.
   - relay/common/relay_info.go: new UpstreamRequestBodySize so
     relay/channel/api_request.go can populate req.ContentLength (lost when
     body becomes a type-erased io.Reader).
   - common/gin.go UnmarshalBodyReusable: when storage is disk-backed and
     content-type is JSON, decode via DecodeJson(storage) instead of
     storage.Bytes()+Unmarshal, removing one transient 5MB copy per request.
     memory mode and form/multipart paths unchanged.
2026-05-22 19:08:38 +08:00

307 lines
9.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package relay
import (
"fmt"
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/relay/channel/gemini"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
func isNoThinkingRequest(req *dto.GeminiChatRequest) bool {
if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil {
configBudget := req.GenerationConfig.ThinkingConfig.ThinkingBudget
if configBudget != nil && *configBudget == 0 {
// 如果思考预算为 0则认为是非思考请求
return true
}
}
return false
}
func trimModelThinking(modelName string) string {
// 去除模型名称中的 -nothinking 后缀
if strings.HasSuffix(modelName, "-nothinking") {
return strings.TrimSuffix(modelName, "-nothinking")
}
// 去除模型名称中的 -thinking 后缀
if strings.HasSuffix(modelName, "-thinking") {
return strings.TrimSuffix(modelName, "-thinking")
}
// 去除模型名称中的 -thinking-number
if strings.Contains(modelName, "-thinking-") {
parts := strings.Split(modelName, "-thinking-")
if len(parts) > 1 {
return parts[0] + "-thinking"
}
}
return modelName
}
func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
info.InitChannelMeta(c)
geminiReq, ok := info.Request.(*dto.GeminiChatRequest)
if !ok {
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
request, err := common.DeepCopy(geminiReq)
if err != nil {
return types.NewError(fmt.Errorf("failed to copy request to GeminiChatRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
// model mapped 模型映射
err = helper.ModelMappedHelper(c, info, request)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
if isNoThinkingRequest(request) {
// check is thinking
if !strings.Contains(info.OriginModelName, "-nothinking") {
// try to get no thinking model price
noThinkingModelName := info.OriginModelName + "-nothinking"
containPrice := helper.HasModelBillingConfig(noThinkingModelName)
if containPrice {
info.OriginModelName = noThinkingModelName
info.UpstreamModelName = noThinkingModelName
}
}
}
if request.GenerationConfig.ThinkingConfig == nil {
gemini.ThinkingAdaptor(request, info)
}
}
adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(info)
if info.ChannelSetting.SystemPrompt != "" {
if request.SystemInstructions == nil {
request.SystemInstructions = &dto.GeminiChatContent{
Parts: []dto.GeminiPart{
{Text: info.ChannelSetting.SystemPrompt},
},
}
} else if len(request.SystemInstructions.Parts) == 0 {
request.SystemInstructions.Parts = []dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}}
} else if info.ChannelSetting.SystemPromptOverride {
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
merged := false
for i := range request.SystemInstructions.Parts {
if request.SystemInstructions.Parts[i].Text == "" {
continue
}
request.SystemInstructions.Parts[i].Text = info.ChannelSetting.SystemPrompt + "\n" + request.SystemInstructions.Parts[i].Text
merged = true
break
}
if !merged {
request.SystemInstructions.Parts = append([]dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}}, request.SystemInstructions.Parts...)
}
}
}
// Clean up empty system instruction
if request.SystemInstructions != nil {
hasContent := false
for _, part := range request.SystemInstructions.Parts {
if part.Text != "" {
hasContent = true
break
}
}
if !hasContent {
request.SystemInstructions = nil
}
}
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
storage, err := common.GetBodyStorage(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
requestBody = common.ReaderOnly(storage)
} else {
// 使用 ConvertGeminiRequest 转换请求格式
convertedRequest, err := adaptor.ConvertGeminiRequest(c, info, request)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
relaycommon.AppendRequestConversionFromRequest(info, convertedRequest)
jsonData, err := common.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
if err != nil {
return newAPIErrorFromParamOverride(err)
}
}
logger.LogDebug(c, "Gemini request body: %s", jsonData)
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
}
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
logger.LogError(c, "Do gemini request failed: "+err.Error())
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
}
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info)
if openaiErr != nil {
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
return nil
}
func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
info.InitChannelMeta(c)
isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
info.IsGeminiBatchEmbedding = isBatch
var req dto.Request
var err error
var inputTexts []string
if isBatch {
batchRequest := &dto.GeminiBatchEmbeddingRequest{}
err = common.UnmarshalBodyReusable(c, batchRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
req = batchRequest
for _, r := range batchRequest.Requests {
for _, part := range r.Content.Parts {
if part.Text != "" {
inputTexts = append(inputTexts, part.Text)
}
}
}
} else {
singleRequest := &dto.GeminiEmbeddingRequest{}
err = common.UnmarshalBodyReusable(c, singleRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
req = singleRequest
for _, part := range singleRequest.Content.Parts {
if part.Text != "" {
inputTexts = append(inputTexts, part.Text)
}
}
}
err = helper.ModelMappedHelper(c, info, req)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
req.SetModelName("models/" + info.UpstreamModelName)
adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(info)
var requestBody io.Reader
jsonData, err := common.Marshal(req)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
if err != nil {
return newAPIErrorFromParamOverride(err)
}
}
logger.LogDebug(c, "Gemini embedding request body: %s", jsonData)
body, size, closer, err := relaycommon.NewOutboundJSONBody(jsonData)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
defer closer.Close()
jsonData = nil
info.UpstreamRequestBodySize = size
requestBody = body
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
logger.LogError(c, "Do gemini request failed: "+err.Error())
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
}
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info)
if openaiErr != nil {
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
return nil
}