diff --git a/common/str.go b/common/str.go index 71391f72..9f3b9d46 100644 --- a/common/str.go +++ b/common/str.go @@ -3,6 +3,7 @@ package common import ( "encoding/base64" "encoding/json" + "fmt" "net/url" "regexp" "strconv" @@ -20,6 +21,16 @@ var ( maskApiKeyPattern = regexp.MustCompile(`(['"]?)api_key:([^\s'"]+)(['"]?)`) ) +const LocalLogContentLimit = 2048 + +// LocalLogPreview limits log-only content unless debug logging is enabled. +func LocalLogPreview(content string) string { + if DebugEnabled || len(content) <= LocalLogContentLimit { + return content + } + return fmt.Sprintf("%s... [truncated, original_length=%d, limit=%d]", content[:LocalLogContentLimit], len(content), LocalLogContentLimit) +} + func GetStringIfEmpty(str string, defaultValue string) string { if str == "" { return defaultValue diff --git a/controller/relay.go b/controller/relay.go index 5e2db44c..1d14dcc6 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -88,7 +88,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { defer func() { if newAPIError != nil { - logger.LogError(c, fmt.Sprintf("relay error: %s", newAPIError.Error())) + logger.LogError(c, fmt.Sprintf("relay error: %s", common.LocalLogPreview(newAPIError.Error()))) newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId)) switch relayFormat { case types.RelayFormatOpenAIRealtime: @@ -354,7 +354,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b } func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) { - logger.LogError(c, fmt.Sprintf("channel error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error())) + logger.LogError(c, fmt.Sprintf("channel error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, common.LocalLogPreview(err.Error()))) // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况 // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously if service.ShouldDisableChannel(err) && channelError.AutoBan { diff --git a/model/log.go b/model/log.go index 7edf24b4..c1c01da4 100644 --- a/model/log.go +++ b/model/log.go @@ -17,24 +17,24 @@ import ( ) type Log struct { - Id int `json:"id" gorm:"index:idx_created_at_id,priority:1;index:idx_user_id_id,priority:2"` - UserId int `json:"user_id" gorm:"index;index:idx_user_id_id,priority:1"` - CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"` - Type int `json:"type" gorm:"index:idx_created_at_type"` - Content string `json:"content"` - Username string `json:"username" gorm:"index;index:index_username_model_name,priority:2;default:''"` - TokenName string `json:"token_name" gorm:"index;default:''"` - ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"` - Quota int `json:"quota" gorm:"default:0"` - PromptTokens int `json:"prompt_tokens" gorm:"default:0"` - CompletionTokens int `json:"completion_tokens" gorm:"default:0"` - UseTime int `json:"use_time" gorm:"default:0"` - IsStream bool `json:"is_stream"` - ChannelId int `json:"channel" gorm:"index"` - ChannelName string `json:"channel_name" gorm:"->"` - TokenId int `json:"token_id" gorm:"default:0;index"` - Group string `json:"group" gorm:"index"` - Ip string `json:"ip" gorm:"index;default:''"` + Id int `json:"id" gorm:"index:idx_created_at_id,priority:1;index:idx_user_id_id,priority:2"` + UserId int `json:"user_id" gorm:"index;index:idx_user_id_id,priority:1"` + CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"` + Type int `json:"type" gorm:"index:idx_created_at_type"` + Content string `json:"content"` + Username string `json:"username" gorm:"index;index:index_username_model_name,priority:2;default:''"` + TokenName string `json:"token_name" gorm:"index;default:''"` + ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"` + Quota int `json:"quota" gorm:"default:0"` + PromptTokens int `json:"prompt_tokens" gorm:"default:0"` + CompletionTokens int `json:"completion_tokens" gorm:"default:0"` + UseTime int `json:"use_time" gorm:"default:0"` + IsStream bool `json:"is_stream"` + ChannelId int `json:"channel" gorm:"index"` + ChannelName string `json:"channel_name" gorm:"->"` + TokenId int `json:"token_id" gorm:"default:0;index"` + Group string `json:"group" gorm:"index"` + Ip string `json:"ip" gorm:"index;default:''"` RequestId string `json:"request_id,omitempty" gorm:"type:varchar(64);index:idx_logs_request_id;default:''"` UpstreamRequestId string `json:"upstream_request_id,omitempty" gorm:"type:varchar(128);index:idx_logs_upstream_request_id;default:''"` Other string `json:"other"` @@ -145,7 +145,7 @@ func RecordTopupLog(userId int, content string, callerIp string, paymentMethod s func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int, isStream bool, group string, other map[string]interface{}) { - logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content)) + logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, common.LocalLogPreview(content))) username := c.GetString("username") requestId := c.GetString(common.RequestIdKey) upstreamRequestId := c.GetString(common.UpstreamRequestIdKey) diff --git a/service/channel.go b/service/channel.go index 3fde6e20..856e2cde 100644 --- a/service/channel.go +++ b/service/channel.go @@ -17,7 +17,7 @@ func formatNotifyType(channelId int, status int) string { // disable & notify func DisableChannel(channelError types.ChannelError, reason string) { - common.SysLog(fmt.Sprintf("通道「%s」(#%d)发生错误,准备禁用,原因:%s", channelError.ChannelName, channelError.ChannelId, reason)) + common.SysLog(fmt.Sprintf("通道「%s」(#%d)发生错误,准备禁用,原因:%s", channelError.ChannelName, channelError.ChannelId, common.LocalLogPreview(reason))) // 检查是否启用自动禁用功能 if !channelError.AutoBan { diff --git a/service/error.go b/service/error.go index a2ff0aad..cf7325b6 100644 --- a/service/error.go +++ b/service/error.go @@ -92,11 +92,13 @@ func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFai } CloseResponseBodyGracefully(resp) var errResponse dto.GeneralErrorResponse + responseBodyText := string(responseBody) + responseBodyPreview := common.LocalLogPreview(responseBodyText) buildErrWithBody := func(message string) error { if message == "" { - return fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)) + return fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, responseBodyText) } - return fmt.Errorf("bad response status code %d, message: %s, body: %s", resp.StatusCode, message, string(responseBody)) + return fmt.Errorf("bad response status code %d, message: %s, body: %s", resp.StatusCode, message, responseBodyText) } err = common.Unmarshal(responseBody, &errResponse) @@ -104,7 +106,7 @@ func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFai if showBodyWhenFail { newApiErr.Err = buildErrWithBody("") } else { - logger.LogError(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))) + logger.LogError(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, responseBodyPreview)) newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode) } return diff --git a/service/error_test.go b/service/error_test.go index 2303e8f4..9f19bfbb 100644 --- a/service/error_test.go +++ b/service/error_test.go @@ -1,9 +1,17 @@ package service import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "strings" "testing" + "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/types" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -55,3 +63,99 @@ func TestResetStatusCode(t *testing.T) { }) } } + +func TestRelayErrorHandlerTruncatesInvalidJSONBodyInLog(t *testing.T) { + withDebugEnabled(t, false) + + body := strings.Repeat("b", common.LocalLogContentLimit+256) + var logBuffer bytes.Buffer + + common.LogWriterMu.Lock() + oldWriter := gin.DefaultErrorWriter + gin.DefaultErrorWriter = &logBuffer + common.LogWriterMu.Unlock() + t.Cleanup(func() { + common.LogWriterMu.Lock() + gin.DefaultErrorWriter = oldWriter + common.LogWriterMu.Unlock() + }) + + resp := &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: io.NopCloser(strings.NewReader(body)), + } + + newAPIError := RelayErrorHandler(context.Background(), resp, false) + + require.NotNil(t, newAPIError) + require.Equal(t, "bad response status code 500", newAPIError.Error()) + require.Contains(t, logBuffer.String(), "[truncated") + require.Contains(t, logBuffer.String(), fmt.Sprintf("original_length=%d", len(body))) + require.NotContains(t, logBuffer.String(), strings.Repeat("b", common.LocalLogContentLimit+1)) +} + +func TestRelayErrorHandlerKeepsStructuredErrorMessage(t *testing.T) { + message := strings.Repeat("c", common.LocalLogContentLimit+256) + body := `{"message":"` + message + `"}` + resp := &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: io.NopCloser(strings.NewReader(body)), + } + + newAPIError := RelayErrorHandler(context.Background(), resp, false) + + require.NotNil(t, newAPIError) + require.Equal(t, message, newAPIError.Error()) +} + +func TestRelayErrorHandlerKeepsOpenAIErrorMessage(t *testing.T) { + message := strings.Repeat("d", common.LocalLogContentLimit+256) + body := `{"error":{"message":"` + message + `","type":"server_error","code":"server_error"}}` + resp := &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: io.NopCloser(strings.NewReader(body)), + } + + newAPIError := RelayErrorHandler(context.Background(), resp, false) + + require.NotNil(t, newAPIError) + require.Equal(t, message, newAPIError.Error()) +} + +func TestRelayErrorHandlerKeepsInvalidJSONBodyInDebugLog(t *testing.T) { + withDebugEnabled(t, true) + + body := strings.Repeat("e", common.LocalLogContentLimit+256) + var logBuffer bytes.Buffer + + common.LogWriterMu.Lock() + oldWriter := gin.DefaultErrorWriter + gin.DefaultErrorWriter = &logBuffer + common.LogWriterMu.Unlock() + t.Cleanup(func() { + common.LogWriterMu.Lock() + gin.DefaultErrorWriter = oldWriter + common.LogWriterMu.Unlock() + }) + + resp := &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: io.NopCloser(strings.NewReader(body)), + } + + newAPIError := RelayErrorHandler(context.Background(), resp, false) + + require.NotNil(t, newAPIError) + require.NotContains(t, logBuffer.String(), "[truncated") + require.Contains(t, logBuffer.String(), body) +} + +func withDebugEnabled(t *testing.T, enabled bool) { + t.Helper() + + oldDebug := common.DebugEnabled + common.DebugEnabled = enabled + t.Cleanup(func() { + common.DebugEnabled = oldDebug + }) +}