424 lines
13 KiB
Go
Raw Normal View History

2024-08-27 20:19:51 +08:00
package vertex
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
2025-05-02 13:59:46 +08:00
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
"github.com/QuantumNous/new-api/relay/channel/claude"
"github.com/QuantumNous/new-api/relay/channel/gemini"
"github.com/QuantumNous/new-api/relay/channel/openai"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/constant"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/QuantumNous/new-api/setting/reasoning"
"github.com/QuantumNous/new-api/types"
2025-05-02 13:59:46 +08:00
"github.com/gin-gonic/gin"
"github.com/samber/lo"
2024-08-27 20:19:51 +08:00
)
const (
RequestModeClaude = 1
RequestModeGemini = 2
RequestModeOpenSource = 3
2024-08-27 20:19:51 +08:00
)
var claudeModelMap = map[string]string{
"claude-3-sonnet-20240229": "claude-3-sonnet@20240229",
"claude-3-opus-20240229": "claude-3-opus@20240229",
"claude-3-haiku-20240307": "claude-3-haiku@20240307",
"claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
2025-03-05 16:43:40 +08:00
"claude-3-5-sonnet-20241022": "claude-3-5-sonnet-v2@20241022",
"claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219",
"claude-sonnet-4-20250514": "claude-sonnet-4@20250514",
"claude-opus-4-20250514": "claude-opus-4@20250514",
"claude-opus-4-1-20250805": "claude-opus-4-1@20250805",
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5@20250929",
"claude-haiku-4-5-20251001": "claude-haiku-4-5@20251001",
2025-11-25 10:46:56 +08:00
"claude-opus-4-5-20251101": "claude-opus-4-5@20251101",
2026-02-06 09:12:45 +08:00
"claude-opus-4-6": "claude-opus-4-6",
"claude-opus-4-7": "claude-opus-4-7",
2024-08-27 20:19:51 +08:00
}
const anthropicVersion = "vertex-2023-10-16"
type Adaptor struct {
RequestMode int
AccountCredentials Credentials
}
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
// Vertex AI does not support functionResponse.id; keep it stripped here for consistency.
if model_setting.GetGeminiSettings().RemoveFunctionResponseIdEnabled {
removeFunctionResponseID(request)
}
geminiAdaptor := gemini.Adaptor{}
return geminiAdaptor.ConvertGeminiRequest(c, info, request)
}
func removeFunctionResponseID(request *dto.GeminiChatRequest) {
if request == nil {
return
}
if len(request.Contents) > 0 {
for i := range request.Contents {
if len(request.Contents[i].Parts) == 0 {
continue
}
for j := range request.Contents[i].Parts {
part := &request.Contents[i].Parts[j]
if part.FunctionResponse == nil {
continue
}
if len(part.FunctionResponse.ID) > 0 {
part.FunctionResponse.ID = nil
}
}
}
}
if len(request.Requests) > 0 {
for i := range request.Requests {
removeFunctionResponseID(&request.Requests[i])
}
}
}
2025-03-12 21:31:46 +08:00
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
c.Set("request_model", v)
} else {
c.Set("request_model", request.Model)
}
vertexClaudeReq := copyRequest(request, anthropicVersion)
return vertexClaudeReq, nil
2025-03-12 21:31:46 +08:00
}
2024-08-27 20:19:51 +08:00
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
2025-08-18 21:49:28 +08:00
geminiAdaptor := gemini.Adaptor{}
return geminiAdaptor.ConvertImageRequest(c, info, request)
2024-08-27 20:19:51 +08:00
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude") {
a.RequestMode = RequestModeClaude
} else if strings.Contains(info.UpstreamModelName, "llama") ||
// open source models
strings.Contains(info.UpstreamModelName, "-maas") {
a.RequestMode = RequestModeOpenSource
} else {
a.RequestMode = RequestModeGemini
2024-08-27 20:19:51 +08:00
}
}
func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix string) (string, error) {
region := GetModelRegion(info.ApiVersion, info.OriginModelName)
if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey {
adc := &Credentials{}
if err := common.Unmarshal([]byte(info.ApiKey), adc); err != nil {
return "", fmt.Errorf("failed to decode credentials file: %w", err)
}
a.AccountCredentials = *adc
if a.RequestMode == RequestModeGemini {
if region == "global" {
return fmt.Sprintf(
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
adc.ProjectID,
modelName,
suffix,
), nil
} else {
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
region,
adc.ProjectID,
region,
modelName,
suffix,
), nil
}
} else if a.RequestMode == RequestModeClaude {
if region == "global" {
return fmt.Sprintf(
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
adc.ProjectID,
modelName,
suffix,
), nil
} else {
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
region,
adc.ProjectID,
region,
modelName,
suffix,
), nil
}
} else if a.RequestMode == RequestModeOpenSource {
return fmt.Sprintf(
"https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
adc.ProjectID,
region,
), nil
}
} else {
var keyPrefix string
if strings.HasSuffix(suffix, "?alt=sse") {
keyPrefix = "&"
} else {
keyPrefix = "?"
}
if region == "global" {
return fmt.Sprintf(
"https://aiplatform.googleapis.com/v1/publishers/google/models/%s:%s%skey=%s",
modelName,
suffix,
keyPrefix,
info.ApiKey,
), nil
} else {
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/publishers/google/models/%s:%s%skey=%s",
region,
modelName,
suffix,
keyPrefix,
info.ApiKey,
), nil
}
}
return "", errors.New("unsupported request mode")
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
2024-08-27 20:19:51 +08:00
suffix := ""
if a.RequestMode == RequestModeGemini {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled &&
!model_setting.ShouldPreserveThinkingSuffix(info.OriginModelName) {
// 新增逻辑:处理 -thinking-<budget> 格式
if strings.Contains(info.UpstreamModelName, "-thinking-") {
parts := strings.Split(info.UpstreamModelName, "-thinking-")
info.UpstreamModelName = parts[0]
} else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
} else if baseModel, level, ok := reasoning.TrimEffortSuffix(info.UpstreamModelName); ok && level != "" {
info.UpstreamModelName = baseModel
}
}
2024-08-27 20:19:51 +08:00
if info.IsStream {
suffix = "streamGenerateContent?alt=sse"
} else {
suffix = "generateContent"
}
2025-07-30 18:39:19 +08:00
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
suffix = "predict"
}
return a.getRequestUrl(info, info.UpstreamModelName, suffix)
2024-08-27 20:19:51 +08:00
} else if a.RequestMode == RequestModeClaude {
if info.IsStream {
suffix = "streamRawPredict?alt=sse"
} else {
suffix = "rawPredict"
}
2025-03-05 16:43:40 +08:00
model := info.UpstreamModelName
2024-08-27 20:19:51 +08:00
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
2025-03-05 16:43:40 +08:00
model = v
2024-08-27 20:19:51 +08:00
}
return a.getRequestUrl(info, model, suffix)
} else if a.RequestMode == RequestModeOpenSource {
return a.getRequestUrl(info, "", "")
2024-08-27 20:19:51 +08:00
}
return "", errors.New("unsupported request mode")
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
2024-08-27 20:19:51 +08:00
channel.SetupApiRequestHeader(info, c, req)
if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey {
accessToken, err := getAccessToken(a, info)
if err != nil {
return err
}
req.Set("Authorization", "Bearer "+accessToken)
2024-08-27 20:19:51 +08:00
}
if a.AccountCredentials.ProjectID != "" {
2025-09-13 13:14:34 +08:00
req.Set("x-goog-user-project", a.AccountCredentials.ProjectID)
2024-08-27 20:19:51 +08:00
}
if strings.Contains(info.UpstreamModelName, "claude") {
claude.CommonClaudeHeadersOperation(c, req, info)
}
2024-08-27 20:19:51 +08:00
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
2024-08-27 20:19:51 +08:00
if request == nil {
return nil, errors.New("request is nil")
}
2025-08-18 21:49:28 +08:00
if a.RequestMode == RequestModeGemini && strings.HasPrefix(info.UpstreamModelName, "imagen") {
prompt := ""
for _, m := range request.Messages {
if m.Role == "user" {
prompt = m.StringContent()
if prompt != "" {
break
}
}
}
if prompt == "" {
if p, ok := request.Prompt.(string); ok {
prompt = p
}
}
if prompt == "" {
return nil, errors.New("prompt is required for image generation")
}
imgReq := dto.ImageRequest{
Model: request.Model,
Prompt: prompt,
N: lo.ToPtr(uint(1)),
2025-08-18 21:49:28 +08:00
Size: "1024x1024",
}
if request.N != nil && *request.N > 0 {
imgReq.N = lo.ToPtr(uint(*request.N))
2025-08-18 21:49:28 +08:00
}
if request.Size != "" {
imgReq.Size = request.Size
}
if len(request.ExtraBody) > 0 {
var extra map[string]any
if err := json.Unmarshal(request.ExtraBody, &extra); err == nil {
if n, ok := extra["n"].(float64); ok && n > 0 {
imgReq.N = lo.ToPtr(uint(n))
2025-08-18 21:49:28 +08:00
}
if size, ok := extra["size"].(string); ok {
imgReq.Size = size
}
// accept aspectRatio in extra body (top-level or under parameters)
if ar, ok := extra["aspectRatio"].(string); ok && ar != "" {
imgReq.Size = ar
}
if params, ok := extra["parameters"].(map[string]any); ok {
if ar, ok := params["aspectRatio"].(string); ok && ar != "" {
imgReq.Size = ar
}
}
}
}
c.Set("request_model", request.Model)
return a.ConvertImageRequest(c, info, imgReq)
}
2024-08-27 20:19:51 +08:00
if a.RequestMode == RequestModeClaude {
claudeReq, err := claude.RequestOpenAI2ClaudeMessage(c, *request)
2024-08-27 20:19:51 +08:00
if err != nil {
return nil, err
}
2025-03-03 20:06:08 +08:00
vertexClaudeReq := copyRequest(claudeReq, anthropicVersion)
c.Set("request_model", claudeReq.Model)
2025-03-05 16:43:40 +08:00
info.UpstreamModelName = claudeReq.Model
2024-08-27 20:19:51 +08:00
return vertexClaudeReq, nil
} else if a.RequestMode == RequestModeGemini {
geminiRequest, err := gemini.CovertOpenAI2Gemini(c, *request, info)
if err != nil {
return nil, err
}
2024-08-27 20:19:51 +08:00
c.Set("request_model", request.Model)
return geminiRequest, nil
} else if a.RequestMode == RequestModeOpenSource {
2024-08-27 20:19:51 +08:00
return request, nil
}
return nil, errors.New("unsupported request mode")
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
2025-01-23 05:54:39 +08:00
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
2025-05-02 13:59:46 +08:00
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
2024-08-27 20:19:51 +08:00
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
claudeAdaptor := claude.Adaptor{}
2024-08-27 20:19:51 +08:00
if info.IsStream {
switch a.RequestMode {
case RequestModeClaude:
return claudeAdaptor.DoResponse(c, resp, info)
2024-08-27 20:19:51 +08:00
case RequestModeGemini:
2025-05-26 15:02:20 +08:00
if info.RelayMode == constant.RelayModeGemini {
return gemini.GeminiTextGenerationStreamHandler(c, info, resp)
2025-05-26 15:02:20 +08:00
} else {
return gemini.GeminiChatStreamHandler(c, info, resp)
2025-05-26 15:02:20 +08:00
}
case RequestModeOpenSource:
return openai.OaiStreamHandler(c, info, resp)
2024-08-27 20:19:51 +08:00
}
} else {
switch a.RequestMode {
case RequestModeClaude:
return claudeAdaptor.DoResponse(c, resp, info)
2024-08-27 20:19:51 +08:00
case RequestModeGemini:
2025-05-26 15:02:20 +08:00
if info.RelayMode == constant.RelayModeGemini {
return gemini.GeminiTextGenerationHandler(c, info, resp)
2025-05-26 15:02:20 +08:00
} else {
2025-07-30 18:39:19 +08:00
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return gemini.GeminiImageHandler(c, info, resp)
}
return gemini.GeminiChatHandler(c, info, resp)
2025-05-26 15:02:20 +08:00
}
case RequestModeOpenSource:
return openai.OpenaiHandler(c, info, resp)
2024-08-27 20:19:51 +08:00
}
}
return
}
func (a *Adaptor) GetModelList() []string {
var modelList []string
for i, s := range ModelList {
modelList = append(modelList, s)
ModelList[i] = s
}
for i, s := range claude.ModelList {
modelList = append(modelList, s)
claude.ModelList[i] = s
}
for i, s := range gemini.ModelList {
modelList = append(modelList, s)
gemini.ModelList[i] = s
}
return modelList
2024-08-27 20:19:51 +08:00
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}