new-api/relay/channel/aws/adaptor.go

185 lines
5.3 KiB
Go
Raw Normal View History

2024-04-23 11:44:40 +08:00
package aws
import (
2025-10-15 17:29:10 +08:00
"fmt"
2024-04-23 11:44:40 +08:00
"io"
"net/http"
2025-10-15 17:29:10 +08:00
"strings"
"github.com/QuantumNous/new-api/dto"
2025-10-15 17:29:10 +08:00
"github.com/QuantumNous/new-api/relay/channel"
"github.com/QuantumNous/new-api/relay/channel/claude"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
2025-10-15 16:44:33 +08:00
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/pkg/errors"
2025-05-02 13:59:46 +08:00
"github.com/gin-gonic/gin"
2024-04-23 11:44:40 +08:00
)
2025-10-15 17:29:10 +08:00
type ClientMode int
2024-04-23 11:44:40 +08:00
const (
2025-10-15 17:29:10 +08:00
ClientModeApiKey ClientMode = iota + 1
ClientModeAKSK
2024-04-23 11:44:40 +08:00
)
type Adaptor struct {
2025-10-15 17:29:10 +08:00
ClientMode ClientMode
2025-10-15 16:44:33 +08:00
AwsClient *bedrockruntime.Client
AwsModelId string
AwsReq any
IsNova bool
2024-04-23 11:44:40 +08:00
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
2025-03-12 21:31:46 +08:00
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
for i, message := range request.Messages {
updated := false
if !message.IsStringContent() {
content, err := message.ParseContent()
if err != nil {
return nil, errors.Wrap(err, "failed to parse message content")
}
for i2, mediaMessage := range content {
if mediaMessage.Source != nil {
if mediaMessage.Source.Type == "url" {
// 使用统一的文件服务获取图片数据
source := types.NewURLFileSource(mediaMessage.Source.Url)
base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Claude")
if err != nil {
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
}
mediaMessage.Source.MediaType = mimeType
mediaMessage.Source.Data = base64Data
mediaMessage.Source.Url = ""
mediaMessage.Source.Type = "base64"
content[i2] = mediaMessage
updated = true
}
}
}
if updated {
message.SetContent(content)
}
}
if updated {
request.Messages[i] = message
}
}
2025-03-12 21:31:46 +08:00
return request, nil
}
2024-07-16 22:07:10 +08:00
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
2024-07-06 17:09:22 +08:00
//TODO implement me
2024-07-16 22:07:10 +08:00
return nil, errors.New("not implemented")
}
2024-07-06 17:09:22 +08:00
2024-07-16 22:07:10 +08:00
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
2024-07-06 17:09:22 +08:00
}
2024-07-16 22:07:10 +08:00
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
2024-04-23 11:44:40 +08:00
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
2025-10-15 17:29:10 +08:00
if info.ChannelOtherSettings.AwsKeyType == dto.AwsKeyTypeApiKey {
awsModelId := getAwsModelID(info.UpstreamModelName)
2025-10-15 17:29:10 +08:00
a.ClientMode = ClientModeApiKey
awsSecret := strings.Split(info.ApiKey, "|")
if len(awsSecret) != 2 {
return "", errors.New("invalid aws api key, should be in format of <api-key>|<region>")
}
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/converse", awsModelId, awsSecret[1]), nil
} else {
a.ClientMode = ClientModeAKSK
return "", nil
}
2024-04-23 11:44:40 +08:00
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
claude.CommonClaudeHeadersOperation(c, req, info)
2025-10-15 17:29:10 +08:00
if a.ClientMode == ClientModeApiKey {
req.Set("Authorization", "Bearer "+info.ApiKey)
}
2024-04-23 11:44:40 +08:00
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
2024-04-23 11:44:40 +08:00
if request == nil {
return nil, errors.New("request is nil")
}
2025-09-10 20:30:00 +08:00
// 检查是否为Nova模型
if isNovaModel(request.Model) {
novaReq := convertToNovaRequest(request)
2025-10-15 16:44:33 +08:00
a.IsNova = true
2025-09-10 20:30:00 +08:00
return novaReq, nil
}
2024-04-23 11:44:40 +08:00
2025-09-10 20:30:00 +08:00
// 原有的Claude模型处理逻辑
claudeReq, err := claude.RequestOpenAI2ClaudeMessage(c, *request)
if err != nil {
return nil, errors.Wrap(err, "failed to convert openai request to claude request")
}
info.UpstreamModelName = claudeReq.Model
2024-04-23 11:44:40 +08:00
return claudeReq, err
}
2024-07-06 17:09:22 +08:00
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
2025-01-23 05:54:39 +08:00
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
2025-05-02 13:59:46 +08:00
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
2025-10-15 17:29:10 +08:00
if a.ClientMode == ClientModeApiKey {
return channel.DoApiRequest(a, c, info, requestBody)
2025-10-15 16:44:33 +08:00
} else {
2025-10-15 17:29:10 +08:00
return doAwsClientRequest(c, info, a, requestBody)
2025-10-15 16:44:33 +08:00
}
2024-04-23 11:44:40 +08:00
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
2025-10-15 17:29:10 +08:00
if a.ClientMode == ClientModeApiKey {
claudeAdaptor := claude.Adaptor{}
usage, err = claudeAdaptor.DoResponse(c, resp, info)
2024-04-23 11:44:40 +08:00
} else {
2025-10-15 17:29:10 +08:00
if a.IsNova {
err, usage = handleNovaRequest(c, info, a)
2025-10-15 16:44:33 +08:00
} else {
2025-10-15 17:29:10 +08:00
if info.IsStream {
err, usage = awsStreamHandler(c, info, a)
} else {
err, usage = awsHandler(c, info, a)
}
2025-10-15 16:44:33 +08:00
}
2024-04-23 11:44:40 +08:00
}
return
}
func (a *Adaptor) GetModelList() (models []string) {
for n := range awsModelIDMap {
models = append(models, n)
}
return
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}