162 lines
4.1 KiB
Go
162 lines
4.1 KiB
Go
package service
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
"github.com/QuantumNous/new-api/types"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestResetStatusCode(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
statusCode int
|
|
statusCodeConfig string
|
|
expectedCode int
|
|
}{
|
|
{
|
|
name: "map string value",
|
|
statusCode: 429,
|
|
statusCodeConfig: `{"429":"503"}`,
|
|
expectedCode: 503,
|
|
},
|
|
{
|
|
name: "map int value",
|
|
statusCode: 429,
|
|
statusCodeConfig: `{"429":503}`,
|
|
expectedCode: 503,
|
|
},
|
|
{
|
|
name: "skip invalid string value",
|
|
statusCode: 429,
|
|
statusCodeConfig: `{"429":"bad-code"}`,
|
|
expectedCode: 429,
|
|
},
|
|
{
|
|
name: "skip status code 200",
|
|
statusCode: 200,
|
|
statusCodeConfig: `{"200":503}`,
|
|
expectedCode: 200,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
newAPIError := &types.NewAPIError{
|
|
StatusCode: tc.statusCode,
|
|
}
|
|
ResetStatusCode(newAPIError, tc.statusCodeConfig)
|
|
require.Equal(t, tc.expectedCode, newAPIError.StatusCode)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRelayErrorHandlerTruncatesInvalidJSONBodyInLog(t *testing.T) {
|
|
withDebugEnabled(t, false)
|
|
|
|
body := strings.Repeat("b", common.LocalLogContentLimit+256)
|
|
var logBuffer bytes.Buffer
|
|
|
|
common.LogWriterMu.Lock()
|
|
oldWriter := gin.DefaultErrorWriter
|
|
gin.DefaultErrorWriter = &logBuffer
|
|
common.LogWriterMu.Unlock()
|
|
t.Cleanup(func() {
|
|
common.LogWriterMu.Lock()
|
|
gin.DefaultErrorWriter = oldWriter
|
|
common.LogWriterMu.Unlock()
|
|
})
|
|
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusInternalServerError,
|
|
Body: io.NopCloser(strings.NewReader(body)),
|
|
}
|
|
|
|
newAPIError := RelayErrorHandler(context.Background(), resp, false)
|
|
|
|
require.NotNil(t, newAPIError)
|
|
require.Equal(t, "bad response status code 500", newAPIError.Error())
|
|
require.Contains(t, logBuffer.String(), "[truncated")
|
|
require.Contains(t, logBuffer.String(), fmt.Sprintf("original_length=%d", len(body)))
|
|
require.NotContains(t, logBuffer.String(), strings.Repeat("b", common.LocalLogContentLimit+1))
|
|
}
|
|
|
|
func TestRelayErrorHandlerKeepsStructuredErrorMessage(t *testing.T) {
|
|
message := strings.Repeat("c", common.LocalLogContentLimit+256)
|
|
body := `{"message":"` + message + `"}`
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusInternalServerError,
|
|
Body: io.NopCloser(strings.NewReader(body)),
|
|
}
|
|
|
|
newAPIError := RelayErrorHandler(context.Background(), resp, false)
|
|
|
|
require.NotNil(t, newAPIError)
|
|
require.Equal(t, message, newAPIError.Error())
|
|
}
|
|
|
|
func TestRelayErrorHandlerKeepsOpenAIErrorMessage(t *testing.T) {
|
|
message := strings.Repeat("d", common.LocalLogContentLimit+256)
|
|
body := `{"error":{"message":"` + message + `","type":"server_error","code":"server_error"}}`
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusInternalServerError,
|
|
Body: io.NopCloser(strings.NewReader(body)),
|
|
}
|
|
|
|
newAPIError := RelayErrorHandler(context.Background(), resp, false)
|
|
|
|
require.NotNil(t, newAPIError)
|
|
require.Equal(t, message, newAPIError.Error())
|
|
}
|
|
|
|
func TestRelayErrorHandlerKeepsInvalidJSONBodyInDebugLog(t *testing.T) {
|
|
withDebugEnabled(t, true)
|
|
|
|
body := strings.Repeat("e", common.LocalLogContentLimit+256)
|
|
var logBuffer bytes.Buffer
|
|
|
|
common.LogWriterMu.Lock()
|
|
oldWriter := gin.DefaultErrorWriter
|
|
gin.DefaultErrorWriter = &logBuffer
|
|
common.LogWriterMu.Unlock()
|
|
t.Cleanup(func() {
|
|
common.LogWriterMu.Lock()
|
|
gin.DefaultErrorWriter = oldWriter
|
|
common.LogWriterMu.Unlock()
|
|
})
|
|
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusInternalServerError,
|
|
Body: io.NopCloser(strings.NewReader(body)),
|
|
}
|
|
|
|
newAPIError := RelayErrorHandler(context.Background(), resp, false)
|
|
|
|
require.NotNil(t, newAPIError)
|
|
require.NotContains(t, logBuffer.String(), "[truncated")
|
|
require.Contains(t, logBuffer.String(), body)
|
|
}
|
|
|
|
func withDebugEnabled(t *testing.T, enabled bool) {
|
|
t.Helper()
|
|
|
|
oldDebug := common.DebugEnabled
|
|
common.DebugEnabled = enabled
|
|
t.Cleanup(func() {
|
|
common.DebugEnabled = oldDebug
|
|
})
|
|
}
|