new-api/service/error_test.go

162 lines
4.1 KiB
Go
Raw Normal View History

package service
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"strings"
"testing"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestResetStatusCode(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
statusCode int
statusCodeConfig string
expectedCode int
}{
{
name: "map string value",
statusCode: 429,
statusCodeConfig: `{"429":"503"}`,
expectedCode: 503,
},
{
name: "map int value",
statusCode: 429,
statusCodeConfig: `{"429":503}`,
expectedCode: 503,
},
{
name: "skip invalid string value",
statusCode: 429,
statusCodeConfig: `{"429":"bad-code"}`,
expectedCode: 429,
},
{
name: "skip status code 200",
statusCode: 200,
statusCodeConfig: `{"200":503}`,
expectedCode: 200,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
newAPIError := &types.NewAPIError{
StatusCode: tc.statusCode,
}
ResetStatusCode(newAPIError, tc.statusCodeConfig)
require.Equal(t, tc.expectedCode, newAPIError.StatusCode)
})
}
}
func TestRelayErrorHandlerTruncatesInvalidJSONBodyInLog(t *testing.T) {
withDebugEnabled(t, false)
body := strings.Repeat("b", common.LocalLogContentLimit+256)
var logBuffer bytes.Buffer
common.LogWriterMu.Lock()
oldWriter := gin.DefaultErrorWriter
gin.DefaultErrorWriter = &logBuffer
common.LogWriterMu.Unlock()
t.Cleanup(func() {
common.LogWriterMu.Lock()
gin.DefaultErrorWriter = oldWriter
common.LogWriterMu.Unlock()
})
resp := &http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(strings.NewReader(body)),
}
newAPIError := RelayErrorHandler(context.Background(), resp, false)
require.NotNil(t, newAPIError)
require.Equal(t, "bad response status code 500", newAPIError.Error())
require.Contains(t, logBuffer.String(), "[truncated")
require.Contains(t, logBuffer.String(), fmt.Sprintf("original_length=%d", len(body)))
require.NotContains(t, logBuffer.String(), strings.Repeat("b", common.LocalLogContentLimit+1))
}
func TestRelayErrorHandlerKeepsStructuredErrorMessage(t *testing.T) {
message := strings.Repeat("c", common.LocalLogContentLimit+256)
body := `{"message":"` + message + `"}`
resp := &http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(strings.NewReader(body)),
}
newAPIError := RelayErrorHandler(context.Background(), resp, false)
require.NotNil(t, newAPIError)
require.Equal(t, message, newAPIError.Error())
}
func TestRelayErrorHandlerKeepsOpenAIErrorMessage(t *testing.T) {
message := strings.Repeat("d", common.LocalLogContentLimit+256)
body := `{"error":{"message":"` + message + `","type":"server_error","code":"server_error"}}`
resp := &http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(strings.NewReader(body)),
}
newAPIError := RelayErrorHandler(context.Background(), resp, false)
require.NotNil(t, newAPIError)
require.Equal(t, message, newAPIError.Error())
}
func TestRelayErrorHandlerKeepsInvalidJSONBodyInDebugLog(t *testing.T) {
withDebugEnabled(t, true)
body := strings.Repeat("e", common.LocalLogContentLimit+256)
var logBuffer bytes.Buffer
common.LogWriterMu.Lock()
oldWriter := gin.DefaultErrorWriter
gin.DefaultErrorWriter = &logBuffer
common.LogWriterMu.Unlock()
t.Cleanup(func() {
common.LogWriterMu.Lock()
gin.DefaultErrorWriter = oldWriter
common.LogWriterMu.Unlock()
})
resp := &http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(strings.NewReader(body)),
}
newAPIError := RelayErrorHandler(context.Background(), resp, false)
require.NotNil(t, newAPIError)
require.NotContains(t, logBuffer.String(), "[truncated")
require.Contains(t, logBuffer.String(), body)
}
func withDebugEnabled(t *testing.T, enabled bool) {
t.Helper()
oldDebug := common.DebugEnabled
common.DebugEnabled = enabled
t.Cleanup(func() {
common.DebugEnabled = oldDebug
})
}