2024-02-29 01:08:18 +08:00
package openai
2023-07-22 17:48:45 +08:00
import (
2024-07-16 22:07:10 +08:00
"fmt"
2023-07-22 17:48:45 +08:00
"io"
"net/http"
"strings"
2025-03-06 19:16:26 +08:00
2025-10-11 15:30:09 +08:00
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/relay/channel/openrouter"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
2025-07-10 15:02:40 +08:00
2025-03-06 19:16:26 +08:00
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
2023-07-22 17:48:45 +08:00
)
2025-02-23 17:05:57 +08:00
func sendStreamData ( c * gin . Context , info * relaycommon . RelayInfo , data string , forceFormat bool , thinkToContent bool ) error {
2024-12-15 15:52:41 +08:00
if data == "" {
return nil
}
2025-02-23 17:05:57 +08:00
if ! forceFormat && ! thinkToContent {
2025-03-05 19:47:41 +08:00
return helper . StringData ( c , data )
2025-02-23 17:05:57 +08:00
}
var lastStreamResponse dto . ChatCompletionsStreamResponse
2025-06-28 00:02:07 +08:00
if err := common . UnmarshalJsonStr ( data , & lastStreamResponse ) ; err != nil {
2025-02-23 17:05:57 +08:00
return err
}
if ! thinkToContent {
2025-03-05 19:47:41 +08:00
return helper . ObjectData ( c , lastStreamResponse )
}
hasThinkingContent := false
2025-03-06 19:20:29 +08:00
hasContent := false
var thinkingContent strings . Builder
2025-03-05 19:47:41 +08:00
for _ , choice := range lastStreamResponse . Choices {
if len ( choice . Delta . GetReasoningContent ( ) ) > 0 {
hasThinkingContent = true
2025-03-06 19:20:29 +08:00
thinkingContent . WriteString ( choice . Delta . GetReasoningContent ( ) )
}
if len ( choice . Delta . GetContentString ( ) ) > 0 {
hasContent = true
2025-03-05 19:47:41 +08:00
}
2025-02-23 17:05:57 +08:00
}
// Handle think to content conversion
2025-03-05 19:47:41 +08:00
if info . ThinkingContentInfo . IsFirstThinkingContent {
if hasThinkingContent {
response := lastStreamResponse . Copy ( )
for i := range response . Choices {
2025-03-06 19:20:29 +08:00
// send `think` tag with thinking content
response . Choices [ i ] . Delta . SetContentString ( "<think>\n" + thinkingContent . String ( ) )
2025-03-06 19:16:26 +08:00
response . Choices [ i ] . Delta . ReasoningContent = nil
response . Choices [ i ] . Delta . Reasoning = nil
2025-03-05 19:47:41 +08:00
}
info . ThinkingContentInfo . IsFirstThinkingContent = false
2025-03-14 03:13:52 +08:00
info . ThinkingContentInfo . HasSentThinkingContent = true
2025-03-05 19:47:41 +08:00
return helper . ObjectData ( c , response )
2024-12-15 15:52:41 +08:00
}
2025-02-23 17:05:57 +08:00
}
if lastStreamResponse . Choices == nil || len ( lastStreamResponse . Choices ) == 0 {
2025-03-05 19:47:41 +08:00
return helper . ObjectData ( c , lastStreamResponse )
2024-12-15 15:52:41 +08:00
}
2025-02-23 17:05:57 +08:00
// Process each choice
for i , choice := range lastStreamResponse . Choices {
// Handle transition from thinking to content
2025-03-14 03:13:52 +08:00
// only send `</think>` tag when previous thinking content has been sent
if hasContent && ! info . ThinkingContentInfo . SendLastThinkingContent && info . ThinkingContentInfo . HasSentThinkingContent {
2025-02-23 17:05:57 +08:00
response := lastStreamResponse . Copy ( )
for j := range response . Choices {
2025-03-06 19:20:29 +08:00
response . Choices [ j ] . Delta . SetContentString ( "\n</think>\n" )
2025-03-06 19:16:26 +08:00
response . Choices [ j ] . Delta . ReasoningContent = nil
response . Choices [ j ] . Delta . Reasoning = nil
2025-02-23 17:05:57 +08:00
}
2025-03-05 19:47:41 +08:00
info . ThinkingContentInfo . SendLastThinkingContent = true
helper . ObjectData ( c , response )
2025-02-23 17:05:57 +08:00
}
2025-03-14 03:13:52 +08:00
// Convert reasoning content to regular content if any
2025-02-23 17:05:57 +08:00
if len ( choice . Delta . GetReasoningContent ( ) ) > 0 {
lastStreamResponse . Choices [ i ] . Delta . SetContentString ( choice . Delta . GetReasoningContent ( ) )
2025-03-06 19:16:26 +08:00
lastStreamResponse . Choices [ i ] . Delta . ReasoningContent = nil
lastStreamResponse . Choices [ i ] . Delta . Reasoning = nil
2025-03-06 19:20:29 +08:00
} else if ! hasThinkingContent && ! hasContent {
// flush thinking content
lastStreamResponse . Choices [ i ] . Delta . ReasoningContent = nil
lastStreamResponse . Choices [ i ] . Delta . Reasoning = nil
2025-02-23 17:05:57 +08:00
}
}
2025-03-05 19:47:41 +08:00
return helper . ObjectData ( c , lastStreamResponse )
2024-12-15 15:52:41 +08:00
}
2025-07-10 15:02:40 +08:00
func OaiStreamHandler ( c * gin . Context , info * relaycommon . RelayInfo , resp * http . Response ) ( * dto . Usage , * types . NewAPIError ) {
2024-12-15 15:52:41 +08:00
if resp == nil || resp . Body == nil {
2025-08-14 20:05:06 +08:00
logger . LogError ( c , "invalid response or response body" )
2025-07-30 18:39:19 +08:00
return nil , types . NewOpenAIError ( fmt . Errorf ( "invalid response" ) , types . ErrorCodeBadResponse , http . StatusInternalServerError )
2024-12-15 15:52:41 +08:00
}
2025-08-14 20:05:06 +08:00
defer service . CloseResponseBodyGracefully ( resp )
2025-06-27 22:44:20 +08:00
model := info . UpstreamModelName
2024-07-28 01:12:26 +08:00
var responseId string
2024-07-15 18:04:05 +08:00
var createAt int64 = 0
var systemFingerprint string
2025-06-27 22:44:20 +08:00
var containStreamUsage bool
2023-11-28 22:02:09 +08:00
var responseTextBuilder strings . Builder
2025-04-11 23:31:32 +08:00
var toolCount int
2024-07-15 18:04:05 +08:00
var usage = & dto . Usage { }
2024-07-15 19:06:13 +08:00
var streamItems [ ] string // store stream items
2025-07-26 13:31:33 +08:00
var lastStreamData string
2025-11-04 01:43:04 +08:00
var secondLastStreamData string // 存储倒数第二个stream data, 用于音频模型
// 检查是否为音频模型
isAudioModel := strings . Contains ( strings . ToLower ( model ) , "audio" )
2025-03-04 18:42:34 +08:00
2026-03-31 16:50:24 +08:00
helper . StreamScannerHandler ( c , resp , info , func ( data string , sr * helper . StreamResult ) {
2025-03-05 19:47:41 +08:00
if lastStreamData != "" {
2026-03-31 16:50:24 +08:00
if err := HandleStreamFormat ( c , info , lastStreamData , info . ChannelSetting . ForceFormat , info . ChannelSetting . ThinkingToContent ) ; err != nil {
2025-08-14 21:10:04 +08:00
common . SysLog ( "error handling stream format: " + err . Error ( ) )
2026-03-31 16:50:24 +08:00
sr . Error ( err )
2025-03-04 17:35:41 +08:00
}
2025-03-04 17:10:56 +08:00
}
2025-07-23 20:59:56 +08:00
if len ( data ) > 0 {
2025-11-04 01:43:04 +08:00
// 对音频模型, 保存倒数第二个stream data
if isAudioModel && lastStreamData != "" {
secondLastStreamData = lastStreamData
}
2025-07-23 20:59:56 +08:00
lastStreamData = data
streamItems = append ( streamItems , data )
}
2024-07-19 01:07:37 +08:00
} )
2024-07-15 18:04:05 +08:00
2025-11-04 01:43:04 +08:00
// 对音频模型, 从倒数第二个stream data中提取usage信息
if isAudioModel && secondLastStreamData != "" {
var streamResp struct {
Usage * dto . Usage ` json:"usage" `
}
2025-12-13 17:24:23 +08:00
err := common . Unmarshal ( [ ] byte ( secondLastStreamData ) , & streamResp )
2025-11-04 01:43:04 +08:00
if err == nil && streamResp . Usage != nil && service . ValidUsage ( streamResp . Usage ) {
usage = streamResp . Usage
containStreamUsage = true
if common . DebugEnabled {
logger . LogDebug ( c , fmt . Sprintf ( "Audio model usage extracted from second last SSE: PromptTokens=%d, CompletionTokens=%d, TotalTokens=%d, InputTokens=%d, OutputTokens=%d" ,
usage . PromptTokens , usage . CompletionTokens , usage . TotalTokens ,
usage . InputTokens , usage . OutputTokens ) )
}
}
}
2025-06-27 22:44:20 +08:00
// 处理最后的响应
2024-07-19 14:46:25 +08:00
shouldSendLastResp := true
2025-06-27 22:44:20 +08:00
if err := handleLastResponse ( lastStreamData , & responseId , & createAt , & systemFingerprint , & model , & usage ,
& containStreamUsage , info , & shouldSendLastResp ) ; err != nil {
2025-08-14 20:05:06 +08:00
logger . LogError ( c , fmt . Sprintf ( "error handling last response: %s, lastStreamData: [%s]" , err . Error ( ) , lastStreamData ) )
2024-07-19 14:46:25 +08:00
}
2025-04-11 18:28:50 +08:00
2025-08-14 21:10:04 +08:00
if info . RelayFormat == types . RelayFormatOpenAI {
2025-07-23 20:59:56 +08:00
if shouldSendLastResp {
2025-07-26 13:31:33 +08:00
_ = sendStreamData ( c , info , lastStreamData , info . ChannelSetting . ForceFormat , info . ChannelSetting . ThinkingToContent )
2025-07-23 20:59:56 +08:00
}
2024-07-19 14:46:25 +08:00
}
2025-03-13 19:32:08 +08:00
// 处理token计算
if err := processTokens ( info . RelayMode , streamItems , & responseTextBuilder , & toolCount ) ; err != nil {
2025-08-14 20:05:06 +08:00
logger . LogError ( c , "error processing tokens: " + err . Error ( ) )
2024-07-15 18:04:05 +08:00
}
2024-07-19 01:07:37 +08:00
if ! containStreamUsage {
2025-12-02 21:34:39 +08:00
usage = service . ResponseText2Usage ( c , responseTextBuilder . String ( ) , info . UpstreamModelName , info . GetEstimatePromptTokens ( ) )
2024-07-15 18:04:05 +08:00
usage . CompletionTokens += toolCount * 7
}
2025-10-08 16:52:49 +08:00
2025-12-30 17:38:32 +08:00
applyUsagePostProcessing ( info , usage , common . StringToByteSlice ( lastStreamData ) )
2025-10-08 16:52:49 +08:00
2025-07-26 13:31:33 +08:00
HandleFinalResponse ( c , info , lastStreamData , responseId , createAt , model , systemFingerprint , usage , containStreamUsage )
2024-07-15 18:04:05 +08:00
2025-07-10 15:02:40 +08:00
return usage , nil
2023-07-22 17:48:45 +08:00
}
2025-07-10 15:02:40 +08:00
func OpenaiHandler ( c * gin . Context , info * relaycommon . RelayInfo , resp * http . Response ) ( * dto . Usage , * types . NewAPIError ) {
2025-08-14 20:05:06 +08:00
defer service . CloseResponseBodyGracefully ( resp )
2025-06-27 23:35:56 +08:00
2025-03-16 18:34:39 +08:00
var simpleResponse dto . OpenAITextResponse
2023-11-23 02:56:18 +08:00
responseBody , err := io . ReadAll ( resp . Body )
if err != nil {
2025-07-30 18:39:19 +08:00
return nil , types . NewOpenAIError ( err , types . ErrorCodeReadResponseBodyFailed , http . StatusInternalServerError )
2023-11-23 02:56:18 +08:00
}
2025-08-08 13:47:39 +08:00
if common . DebugEnabled {
println ( "upstream response body:" , string ( responseBody ) )
}
2025-09-28 15:23:27 +08:00
// Unmarshal to simpleResponse
if info . ChannelType == constant . ChannelTypeOpenRouter && info . ChannelOtherSettings . IsOpenRouterEnterprise ( ) {
// 尝试解析为 openrouter enterprise
var enterpriseResponse openrouter . OpenRouterEnterpriseResponse
err = common . Unmarshal ( responseBody , & enterpriseResponse )
if err != nil {
return nil , types . NewOpenAIError ( err , types . ErrorCodeBadResponseBody , http . StatusInternalServerError )
}
if enterpriseResponse . Success {
responseBody = enterpriseResponse . Data
} else {
logger . LogError ( c , fmt . Sprintf ( "openrouter enterprise response success=false, data: %s" , enterpriseResponse . Data ) )
return nil , types . NewOpenAIError ( fmt . Errorf ( "openrouter response success=false" ) , types . ErrorCodeBadResponseBody , http . StatusInternalServerError )
}
2023-11-23 02:56:18 +08:00
}
2025-09-28 15:29:01 +08:00
err = common . Unmarshal ( responseBody , & simpleResponse )
if err != nil {
return nil , types . NewOpenAIError ( err , types . ErrorCodeBadResponseBody , http . StatusInternalServerError )
}
2025-07-31 21:16:01 +08:00
if oaiError := simpleResponse . GetOpenAIError ( ) ; oaiError != nil && oaiError . Type != "" {
return nil , types . WithOpenAIError ( * oaiError , resp . StatusCode )
2023-07-22 17:48:45 +08:00
}
2025-06-21 00:54:40 +08:00
2026-01-25 14:52:18 +08:00
for _ , choice := range simpleResponse . Choices {
if choice . FinishReason == constant . FinishReasonContentFilter {
common . SetContextKey ( c , constant . ContextKeyAdminRejectReason , "openai_finish_reason=content_filter" )
break
}
}
2025-05-09 18:57:06 +08:00
forceFormat := false
2025-07-07 14:26:37 +08:00
if info . ChannelSetting . ForceFormat {
forceFormat = true
2025-05-09 18:57:06 +08:00
}
2025-08-16 22:54:00 +08:00
usageModified := false
if simpleResponse . Usage . PromptTokens == 0 {
completionTokens := simpleResponse . Usage . CompletionTokens
if completionTokens == 0 {
for _ , choice := range simpleResponse . Choices {
2026-04-29 02:30:39 +08:00
ctkm := service . CountTextToken ( choice . Message . StringContent ( ) + choice . Message . GetReasoningContent ( ) , info . UpstreamModelName )
2025-08-16 22:54:00 +08:00
completionTokens += ctkm
}
2025-05-09 18:57:06 +08:00
}
simpleResponse . Usage = dto . Usage {
2025-12-02 21:34:39 +08:00
PromptTokens : info . GetEstimatePromptTokens ( ) ,
2025-05-09 18:57:06 +08:00
CompletionTokens : completionTokens ,
2025-12-02 21:34:39 +08:00
TotalTokens : info . GetEstimatePromptTokens ( ) + completionTokens ,
2025-05-09 18:57:06 +08:00
}
2025-08-16 22:54:00 +08:00
usageModified = true
2025-05-09 18:57:06 +08:00
}
2025-03-16 18:34:39 +08:00
2025-10-08 16:52:49 +08:00
applyUsagePostProcessing ( info , & simpleResponse . Usage , responseBody )
2025-03-16 18:34:39 +08:00
switch info . RelayFormat {
2025-08-14 21:10:04 +08:00
case types . RelayFormatOpenAI :
2025-08-22 17:33:20 +08:00
if usageModified {
var bodyMap map [ string ] interface { }
err = common . Unmarshal ( responseBody , & bodyMap )
if err != nil {
return nil , types . NewOpenAIError ( err , types . ErrorCodeBadResponseBody , http . StatusInternalServerError )
}
bodyMap [ "usage" ] = simpleResponse . Usage
responseBody , _ = common . Marshal ( bodyMap )
}
if forceFormat {
2025-07-10 15:02:40 +08:00
responseBody , err = common . Marshal ( simpleResponse )
2025-05-09 18:57:06 +08:00
if err != nil {
2025-07-10 15:02:40 +08:00
return nil , types . NewError ( err , types . ErrorCodeBadResponseBody )
2025-05-09 18:57:06 +08:00
}
} else {
break
}
2025-08-14 21:10:04 +08:00
case types . RelayFormatClaude :
2025-03-16 18:34:39 +08:00
claudeResp := service . ResponseOpenAI2Claude ( & simpleResponse , info )
2025-07-10 15:02:40 +08:00
claudeRespStr , err := common . Marshal ( claudeResp )
2025-03-16 18:34:39 +08:00
if err != nil {
2025-07-10 15:02:40 +08:00
return nil , types . NewError ( err , types . ErrorCodeBadResponseBody )
2025-03-16 18:34:39 +08:00
}
responseBody = claudeRespStr
2025-08-14 21:10:04 +08:00
case types . RelayFormatGemini :
2025-08-01 22:23:35 +08:00
geminiResp := service . ResponseOpenAI2Gemini ( & simpleResponse , info )
geminiRespStr , err := common . Marshal ( geminiResp )
if err != nil {
return nil , types . NewError ( err , types . ErrorCodeBadResponseBody )
}
responseBody = geminiRespStr
2025-03-16 18:34:39 +08:00
}
2025-08-14 20:05:06 +08:00
service . IOCopyBytesGracefully ( c , resp , responseBody )
2025-06-27 22:36:12 +08:00
2025-07-10 15:02:40 +08:00
return & simpleResponse . Usage , nil
2023-07-22 17:48:45 +08:00
}
2024-07-16 22:07:10 +08:00
2025-10-22 00:38:51 +08:00
func streamTTSResponse ( c * gin . Context , resp * http . Response ) {
c . Writer . WriteHeaderNow ( )
flusher , ok := c . Writer . ( http . Flusher )
if ! ok {
logger . LogWarn ( c , "streaming not supported" )
_ , err := io . Copy ( c . Writer , resp . Body )
if err != nil {
logger . LogWarn ( c , err . Error ( ) )
}
return
}
buffer := make ( [ ] byte , 4096 )
for {
n , err := resp . Body . Read ( buffer )
//logger.LogInfo(c, fmt.Sprintf("streamTTSResponse read %d bytes", n))
if n > 0 {
if _ , writeErr := c . Writer . Write ( buffer [ : n ] ) ; writeErr != nil {
logger . LogError ( c , writeErr . Error ( ) )
break
}
flusher . Flush ( )
}
if err != nil {
if err != io . EOF {
logger . LogError ( c , err . Error ( ) )
}
break
}
}
}
2025-07-10 15:02:40 +08:00
func OpenaiRealtimeHandler ( c * gin . Context , info * relaycommon . RelayInfo ) ( * types . NewAPIError , * dto . RealtimeUsage ) {
2024-12-15 15:52:41 +08:00
if info == nil || info . ClientWs == nil || info . TargetWs == nil {
2025-07-10 15:02:40 +08:00
return types . NewError ( fmt . Errorf ( "invalid websocket connection" ) , types . ErrorCodeBadResponse ) , nil
2024-12-15 15:52:41 +08:00
}
2024-10-04 16:08:18 +08:00
info . IsStream = true
clientConn := info . ClientWs
targetConn := info . TargetWs
clientClosed := make ( chan struct { } )
targetClosed := make ( chan struct { } )
sendChan := make ( chan [ ] byte , 100 )
receiveChan := make ( chan [ ] byte , 100 )
errChan := make ( chan error , 2 )
usage := & dto . RealtimeUsage { }
2024-10-06 14:13:41 +08:00
localUsage := & dto . RealtimeUsage { }
2024-10-07 19:08:20 +08:00
sumUsage := & dto . RealtimeUsage { }
2024-10-04 16:08:18 +08:00
2024-10-07 20:35:33 +08:00
gopool . Go ( func ( ) {
2024-12-15 15:52:41 +08:00
defer func ( ) {
if r := recover ( ) ; r != nil {
errChan <- fmt . Errorf ( "panic in client reader: %v" , r )
}
} ( )
2024-10-04 16:08:18 +08:00
for {
select {
case <- c . Done ( ) :
return
default :
_ , message , err := clientConn . ReadMessage ( )
if err != nil {
if ! websocket . IsCloseError ( err , websocket . CloseNormalClosure , websocket . CloseGoingAway ) {
errChan <- fmt . Errorf ( "error reading from client: %v" , err )
}
close ( clientClosed )
return
}
2024-10-06 14:13:41 +08:00
realtimeEvent := & dto . RealtimeEvent { }
2025-07-10 15:02:40 +08:00
err = common . Unmarshal ( message , realtimeEvent )
2024-10-06 14:13:41 +08:00
if err != nil {
errChan <- fmt . Errorf ( "error unmarshalling message: %v" , err )
return
}
if realtimeEvent . Type == dto . RealtimeEventTypeSessionUpdate {
if realtimeEvent . Session != nil {
if realtimeEvent . Session . Tools != nil {
info . RealtimeTools = realtimeEvent . Session . Tools
}
}
}
textToken , audioToken , err := service . CountTokenRealtime ( info , * realtimeEvent , info . UpstreamModelName )
if err != nil {
errChan <- fmt . Errorf ( "error counting text token: %v" , err )
return
}
2025-08-14 20:05:06 +08:00
logger . LogInfo ( c , fmt . Sprintf ( "type: %s, textToken: %d, audioToken: %d" , realtimeEvent . Type , textToken , audioToken ) )
2024-10-06 14:13:41 +08:00
localUsage . TotalTokens += textToken + audioToken
2024-10-14 15:40:34 +08:00
localUsage . InputTokens += textToken + audioToken
2024-10-06 14:13:41 +08:00
localUsage . InputTokenDetails . TextTokens += textToken
localUsage . InputTokenDetails . AudioTokens += audioToken
2025-03-05 19:47:41 +08:00
err = helper . WssString ( c , targetConn , string ( message ) )
2024-10-04 16:08:18 +08:00
if err != nil {
errChan <- fmt . Errorf ( "error writing to target: %v" , err )
return
}
select {
case sendChan <- message :
default :
}
}
}
2024-10-07 20:35:33 +08:00
} )
2024-10-04 16:08:18 +08:00
2024-10-07 20:35:33 +08:00
gopool . Go ( func ( ) {
2024-12-15 15:52:41 +08:00
defer func ( ) {
if r := recover ( ) ; r != nil {
errChan <- fmt . Errorf ( "panic in target reader: %v" , r )
}
} ( )
2024-10-04 16:08:18 +08:00
for {
select {
case <- c . Done ( ) :
return
default :
_ , message , err := targetConn . ReadMessage ( )
if err != nil {
if ! websocket . IsCloseError ( err , websocket . CloseNormalClosure , websocket . CloseGoingAway ) {
errChan <- fmt . Errorf ( "error reading from target: %v" , err )
}
close ( targetClosed )
return
}
info . SetFirstResponseTime ( )
realtimeEvent := & dto . RealtimeEvent { }
2025-07-10 15:02:40 +08:00
err = common . Unmarshal ( message , realtimeEvent )
2024-10-04 16:08:18 +08:00
if err != nil {
errChan <- fmt . Errorf ( "error unmarshalling message: %v" , err )
return
}
if realtimeEvent . Type == dto . RealtimeEventTypeResponseDone {
realtimeUsage := realtimeEvent . Response . Usage
if realtimeUsage != nil {
usage . TotalTokens += realtimeUsage . TotalTokens
usage . InputTokens += realtimeUsage . InputTokens
usage . OutputTokens += realtimeUsage . OutputTokens
usage . InputTokenDetails . AudioTokens += realtimeUsage . InputTokenDetails . AudioTokens
usage . InputTokenDetails . CachedTokens += realtimeUsage . InputTokenDetails . CachedTokens
usage . InputTokenDetails . TextTokens += realtimeUsage . InputTokenDetails . TextTokens
usage . OutputTokenDetails . AudioTokens += realtimeUsage . OutputTokenDetails . AudioTokens
usage . OutputTokenDetails . TextTokens += realtimeUsage . OutputTokenDetails . TextTokens
2024-10-07 19:08:20 +08:00
err := preConsumeUsage ( c , info , usage , sumUsage )
if err != nil {
errChan <- fmt . Errorf ( "error consume usage: %v" , err )
return
}
2024-10-10 00:15:27 +08:00
// 本次计费完成,清除
2024-10-07 19:08:20 +08:00
usage = & dto . RealtimeUsage { }
2024-10-10 00:15:27 +08:00
localUsage = & dto . RealtimeUsage { }
2024-10-07 17:18:11 +08:00
} else {
textToken , audioToken , err := service . CountTokenRealtime ( info , * realtimeEvent , info . UpstreamModelName )
if err != nil {
errChan <- fmt . Errorf ( "error counting text token: %v" , err )
return
}
2025-08-14 20:05:06 +08:00
logger . LogInfo ( c , fmt . Sprintf ( "type: %s, textToken: %d, audioToken: %d" , realtimeEvent . Type , textToken , audioToken ) )
2024-10-07 17:18:11 +08:00
localUsage . TotalTokens += textToken + audioToken
info . IsFirstRequest = false
localUsage . InputTokens += textToken + audioToken
localUsage . InputTokenDetails . TextTokens += textToken
localUsage . InputTokenDetails . AudioTokens += audioToken
2024-10-07 19:08:20 +08:00
err = preConsumeUsage ( c , info , localUsage , sumUsage )
if err != nil {
errChan <- fmt . Errorf ( "error consume usage: %v" , err )
return
}
2024-10-10 00:15:27 +08:00
// 本次计费完成,清除
2024-10-07 19:08:20 +08:00
localUsage = & dto . RealtimeUsage { }
// print now usage
2024-10-04 16:08:18 +08:00
}
2025-08-14 20:05:06 +08:00
logger . LogInfo ( c , fmt . Sprintf ( "realtime streaming sumUsage: %v" , sumUsage ) )
logger . LogInfo ( c , fmt . Sprintf ( "realtime streaming localUsage: %v" , localUsage ) )
logger . LogInfo ( c , fmt . Sprintf ( "realtime streaming localUsage: %v" , localUsage ) )
2024-10-07 19:08:20 +08:00
2024-10-06 14:13:41 +08:00
} else if realtimeEvent . Type == dto . RealtimeEventTypeSessionUpdated || realtimeEvent . Type == dto . RealtimeEventTypeSessionCreated {
realtimeSession := realtimeEvent . Session
if realtimeSession != nil {
// update audio format
info . InputAudioFormat = common . GetStringIfEmpty ( realtimeSession . InputAudioFormat , info . InputAudioFormat )
info . OutputAudioFormat = common . GetStringIfEmpty ( realtimeSession . OutputAudioFormat , info . OutputAudioFormat )
}
} else {
textToken , audioToken , err := service . CountTokenRealtime ( info , * realtimeEvent , info . UpstreamModelName )
if err != nil {
errChan <- fmt . Errorf ( "error counting text token: %v" , err )
return
}
2025-08-14 20:05:06 +08:00
logger . LogInfo ( c , fmt . Sprintf ( "type: %s, textToken: %d, audioToken: %d" , realtimeEvent . Type , textToken , audioToken ) )
2024-10-06 14:13:41 +08:00
localUsage . TotalTokens += textToken + audioToken
2024-10-07 17:18:11 +08:00
localUsage . OutputTokens += textToken + audioToken
localUsage . OutputTokenDetails . TextTokens += textToken
localUsage . OutputTokenDetails . AudioTokens += audioToken
2024-10-04 16:08:18 +08:00
}
2025-03-05 19:47:41 +08:00
err = helper . WssString ( c , clientConn , string ( message ) )
2024-10-04 16:08:18 +08:00
if err != nil {
errChan <- fmt . Errorf ( "error writing to client: %v" , err )
return
}
select {
case receiveChan <- message :
default :
}
}
}
2024-10-07 20:35:33 +08:00
} )
2024-10-04 16:08:18 +08:00
select {
case <- clientClosed :
case <- targetClosed :
2024-10-07 19:08:20 +08:00
case err := <- errChan :
2024-10-04 16:08:18 +08:00
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
2025-08-14 20:05:06 +08:00
logger . LogError ( c , "realtime error: " + err . Error ( ) )
2024-10-04 16:08:18 +08:00
case <- c . Done ( ) :
}
2024-10-07 19:08:20 +08:00
if usage . TotalTokens != 0 {
_ = preConsumeUsage ( c , info , usage , sumUsage )
}
if localUsage . TotalTokens != 0 {
_ = preConsumeUsage ( c , info , localUsage , sumUsage )
}
2024-10-06 14:13:41 +08:00
// check usage total tokens, if 0, use local usage
2024-10-07 19:08:20 +08:00
return nil , sumUsage
}
func preConsumeUsage ( ctx * gin . Context , info * relaycommon . RelayInfo , usage * dto . RealtimeUsage , totalUsage * dto . RealtimeUsage ) error {
2024-12-15 15:52:41 +08:00
if usage == nil || totalUsage == nil {
return fmt . Errorf ( "invalid usage pointer" )
}
2024-10-07 19:08:20 +08:00
totalUsage . TotalTokens += usage . TotalTokens
totalUsage . InputTokens += usage . InputTokens
totalUsage . OutputTokens += usage . OutputTokens
totalUsage . InputTokenDetails . CachedTokens += usage . InputTokenDetails . CachedTokens
totalUsage . InputTokenDetails . TextTokens += usage . InputTokenDetails . TextTokens
totalUsage . InputTokenDetails . AudioTokens += usage . InputTokenDetails . AudioTokens
totalUsage . OutputTokenDetails . TextTokens += usage . OutputTokenDetails . TextTokens
totalUsage . OutputTokenDetails . AudioTokens += usage . OutputTokenDetails . AudioTokens
// clear usage
err := service . PreWssConsumeQuota ( ctx , info , usage )
return err
2024-10-04 16:08:18 +08:00
}
2025-04-24 19:25:08 +08:00
2025-07-10 15:02:40 +08:00
func OpenaiHandlerWithUsage ( c * gin . Context , info * relaycommon . RelayInfo , resp * http . Response ) ( * dto . Usage , * types . NewAPIError ) {
2025-08-14 20:05:06 +08:00
defer service . CloseResponseBodyGracefully ( resp )
2025-06-27 23:35:56 +08:00
2025-04-24 19:25:08 +08:00
responseBody , err := io . ReadAll ( resp . Body )
if err != nil {
2025-07-30 18:39:19 +08:00
return nil , types . NewOpenAIError ( err , types . ErrorCodeReadResponseBodyFailed , http . StatusInternalServerError )
2025-04-24 19:25:08 +08:00
}
2025-06-27 22:36:12 +08:00
var usageResp dto . SimpleResponse
2025-07-10 15:02:40 +08:00
err = common . Unmarshal ( responseBody , & usageResp )
2025-04-24 19:25:08 +08:00
if err != nil {
2025-07-30 18:39:19 +08:00
return nil , types . NewOpenAIError ( err , types . ErrorCodeBadResponseBody , http . StatusInternalServerError )
2025-04-24 19:25:08 +08:00
}
2025-06-27 22:36:12 +08:00
// 写入新的 response body
2025-08-14 20:05:06 +08:00
service . IOCopyBytesGracefully ( c , resp , responseBody )
2025-06-27 22:36:12 +08:00
2025-06-27 21:13:21 +08:00
// Once we've written to the client, we should not return errors anymore
// because the upstream has already consumed resources and returned content
// We should still perform billing even if parsing fails
2025-04-24 19:25:08 +08:00
// format
2025-08-07 16:22:40 +08:00
if usageResp . InputTokens > 0 {
usageResp . PromptTokens += usageResp . InputTokens
2025-04-24 19:25:08 +08:00
}
2025-08-07 16:22:40 +08:00
if usageResp . OutputTokens > 0 {
usageResp . CompletionTokens += usageResp . OutputTokens
2025-04-24 19:25:08 +08:00
}
if usageResp . InputTokensDetails != nil {
usageResp . PromptTokensDetails . ImageTokens += usageResp . InputTokensDetails . ImageTokens
usageResp . PromptTokensDetails . TextTokens += usageResp . InputTokensDetails . TextTokens
}
2025-10-08 16:52:49 +08:00
applyUsagePostProcessing ( info , & usageResp . Usage , responseBody )
2025-07-10 15:02:40 +08:00
return & usageResp . Usage , nil
2025-04-24 19:25:08 +08:00
}
2025-10-08 16:52:49 +08:00
func applyUsagePostProcessing ( info * relaycommon . RelayInfo , usage * dto . Usage , responseBody [ ] byte ) {
if info == nil || usage == nil {
return
}
switch info . ChannelType {
case constant . ChannelTypeDeepSeek :
if usage . PromptTokensDetails . CachedTokens == 0 && usage . PromptCacheHitTokens != 0 {
usage . PromptTokensDetails . CachedTokens = usage . PromptCacheHitTokens
}
2025-12-30 17:38:32 +08:00
case constant . ChannelTypeZhipu_v4 :
// 智普的cached_tokens在标准位置: usage.prompt_tokens_details.cached_tokens
2025-10-08 16:52:49 +08:00
if usage . PromptTokensDetails . CachedTokens == 0 {
if usage . InputTokensDetails != nil && usage . InputTokensDetails . CachedTokens > 0 {
usage . PromptTokensDetails . CachedTokens = usage . InputTokensDetails . CachedTokens
} else if cachedTokens , ok := extractCachedTokensFromBody ( responseBody ) ; ok {
usage . PromptTokensDetails . CachedTokens = cachedTokens
} else if usage . PromptCacheHitTokens > 0 {
usage . PromptTokensDetails . CachedTokens = usage . PromptCacheHitTokens
}
}
2025-12-30 17:38:32 +08:00
case constant . ChannelTypeMoonshot :
// Moonshot的cached_tokens在非标准位置: choices[].usage.cached_tokens
if usage . PromptTokensDetails . CachedTokens == 0 {
if usage . InputTokensDetails != nil && usage . InputTokensDetails . CachedTokens > 0 {
usage . PromptTokensDetails . CachedTokens = usage . InputTokensDetails . CachedTokens
} else if cachedTokens , ok := extractMoonshotCachedTokensFromBody ( responseBody ) ; ok {
usage . PromptTokensDetails . CachedTokens = cachedTokens
} else if cachedTokens , ok := extractCachedTokensFromBody ( responseBody ) ; ok {
usage . PromptTokensDetails . CachedTokens = cachedTokens
} else if usage . PromptCacheHitTokens > 0 {
usage . PromptTokensDetails . CachedTokens = usage . PromptCacheHitTokens
}
}
2026-03-20 16:10:18 +08:00
case constant . ChannelTypeOpenAI :
if usage . PromptTokensDetails . CachedTokens == 0 {
if cachedTokens , ok := extractLlamaCachedTokensFromBody ( responseBody ) ; ok {
usage . PromptTokensDetails . CachedTokens = cachedTokens
}
}
2025-10-08 16:52:49 +08:00
}
}
func extractCachedTokensFromBody ( body [ ] byte ) ( int , bool ) {
if len ( body ) == 0 {
return 0 , false
}
var payload struct {
Usage struct {
PromptTokensDetails struct {
CachedTokens * int ` json:"cached_tokens" `
} ` json:"prompt_tokens_details" `
CachedTokens * int ` json:"cached_tokens" `
PromptCacheHitTokens * int ` json:"prompt_cache_hit_tokens" `
} ` json:"usage" `
}
2025-12-13 17:24:23 +08:00
if err := common . Unmarshal ( body , & payload ) ; err != nil {
2025-10-08 16:52:49 +08:00
return 0 , false
}
if payload . Usage . PromptTokensDetails . CachedTokens != nil {
return * payload . Usage . PromptTokensDetails . CachedTokens , true
}
if payload . Usage . CachedTokens != nil {
return * payload . Usage . CachedTokens , true
}
if payload . Usage . PromptCacheHitTokens != nil {
return * payload . Usage . PromptCacheHitTokens , true
}
return 0 , false
}
2025-12-30 17:38:32 +08:00
// extractMoonshotCachedTokensFromBody 从Moonshot的非标准位置提取cached_tokens
// Moonshot的流式响应格式: {"choices":[{"usage":{"cached_tokens":111}}]}
func extractMoonshotCachedTokensFromBody ( body [ ] byte ) ( int , bool ) {
if len ( body ) == 0 {
return 0 , false
}
var payload struct {
Choices [ ] struct {
Usage struct {
CachedTokens * int ` json:"cached_tokens" `
} ` json:"usage" `
} ` json:"choices" `
}
if err := common . Unmarshal ( body , & payload ) ; err != nil {
return 0 , false
}
// 遍历choices查找cached_tokens
for _ , choice := range payload . Choices {
if choice . Usage . CachedTokens != nil && * choice . Usage . CachedTokens > 0 {
return * choice . Usage . CachedTokens , true
}
}
return 0 , false
}
2026-03-20 16:10:18 +08:00
// extractLlamaCachedTokensFromBody 从llama.cpp的非标准位置提取cache_n
func extractLlamaCachedTokensFromBody ( body [ ] byte ) ( int , bool ) {
if len ( body ) == 0 {
return 0 , false
}
var payload struct {
2026-03-20 16:38:48 +08:00
Timings struct {
2026-03-20 16:48:04 +08:00
CachedTokens * int ` json:"cache_n" `
2026-03-20 16:10:18 +08:00
} ` json:"timings" `
}
if err := common . Unmarshal ( body , & payload ) ; err != nil {
return 0 , false
}
2026-03-20 16:38:48 +08:00
2026-03-20 16:48:04 +08:00
if payload . Timings . CachedTokens == nil {
2026-03-20 16:38:48 +08:00
return 0 , false
}
2026-03-20 16:48:04 +08:00
return * payload . Timings . CachedTokens , true
2026-03-20 16:10:18 +08:00
}