Merge branch 'main' of github.com:danding5/new-api

# Conflicts:
#	relay/relay_adaptor.go
This commit is contained in:
DD 2025-09-15 14:31:06 +08:00
commit 71d41d6eaa
75 changed files with 1721 additions and 579 deletions

View File

@ -10,7 +10,7 @@ import (
"one-api/constant" "one-api/constant"
"one-api/model" "one-api/model"
"one-api/service" "one-api/service"
"one-api/setting" "one-api/setting/operation_setting"
"one-api/types" "one-api/types"
"strconv" "strconv"
"time" "time"
@ -342,7 +342,7 @@ func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode) return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
} }
availableBalanceCny := response.Data.AvailableBalance availableBalanceCny := response.Data.AvailableBalance
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64() availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(operation_setting.Price)).InexactFloat64()
channel.UpdateBalance(availableBalanceUsd) channel.UpdateBalance(availableBalanceUsd)
return availableBalanceUsd, nil return availableBalanceUsd, nil
} }

View File

@ -235,7 +235,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
if resp != nil { if resp != nil {
httpResp = resp.(*http.Response) httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK { if httpResp.StatusCode != http.StatusOK {
err := service.RelayErrorHandler(httpResp, true) err := service.RelayErrorHandler(c.Request.Context(), httpResp, true)
return testResult{ return testResult{
context: c, context: c,
localErr: err, localErr: err,

View File

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto"
"one-api/model" "one-api/model"
"strconv" "strconv"
"strings" "strings"
@ -560,7 +561,7 @@ func AddChannel(c *gin.Context) {
case "multi_to_single": case "multi_to_single":
addChannelRequest.Channel.ChannelInfo.IsMultiKey = true addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi { if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
array, err := getVertexArrayKeys(addChannelRequest.Channel.Key) array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@ -585,7 +586,7 @@ func AddChannel(c *gin.Context) {
} }
keys = []string{addChannelRequest.Channel.Key} keys = []string{addChannelRequest.Channel.Key}
case "batch": case "batch":
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi { if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
// multi json // multi json
keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key) keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
if err != nil { if err != nil {
@ -840,7 +841,7 @@ func UpdateChannel(c *gin.Context) {
} }
// 处理 Vertex AI 的特殊情况 // 处理 Vertex AI 的特殊情况
if channel.Type == constant.ChannelTypeVertexAi { if channel.Type == constant.ChannelTypeVertexAi && channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
// 尝试解析新密钥为JSON数组 // 尝试解析新密钥为JSON数组
if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") { if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
array, err := getVertexArrayKeys(channel.Key) array, err := getVertexArrayKeys(channel.Key)

View File

@ -13,6 +13,7 @@ import (
"one-api/model" "one-api/model"
"one-api/service" "one-api/service"
"one-api/setting" "one-api/setting"
"one-api/setting/system_setting"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -259,7 +260,7 @@ func GetAllMidjourney(c *gin.Context) {
if setting.MjForwardUrlEnabled { if setting.MjForwardUrlEnabled {
for i, midjourney := range items { for i, midjourney := range items {
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId midjourney.ImageUrl = system_setting.ServerAddress + "/mj/image/" + midjourney.MjId
items[i] = midjourney items[i] = midjourney
} }
} }
@ -284,7 +285,7 @@ func GetUserMidjourney(c *gin.Context) {
if setting.MjForwardUrlEnabled { if setting.MjForwardUrlEnabled {
for i, midjourney := range items { for i, midjourney := range items {
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId midjourney.ImageUrl = system_setting.ServerAddress + "/mj/image/" + midjourney.MjId
items[i] = midjourney items[i] = midjourney
} }
} }

View File

@ -58,11 +58,7 @@ func GetStatus(c *gin.Context) {
"footer_html": common.Footer, "footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL, "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled, "wechat_login": common.WeChatAuthEnabled,
"server_address": setting.ServerAddress, "server_address": system_setting.ServerAddress,
"price": setting.Price,
"stripe_unit_price": setting.StripeUnitPrice,
"min_topup": setting.MinTopUp,
"stripe_min_topup": setting.StripeMinTopUp,
"turnstile_check": common.TurnstileCheckEnabled, "turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey, "turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink, "top_up_link": common.TopUpLink,
@ -75,15 +71,15 @@ func GetStatus(c *gin.Context) {
"enable_data_export": common.DataExportEnabled, "enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime, "data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar, "default_collapse_sidebar": common.DefaultCollapseSidebar,
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
"mj_notify_enabled": setting.MjNotifyEnabled, "mj_notify_enabled": setting.MjNotifyEnabled,
"chats": setting.Chats, "chats": setting.Chats,
"demo_site_enabled": operation_setting.DemoSiteEnabled, "demo_site_enabled": operation_setting.DemoSiteEnabled,
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled, "self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
"default_use_auto_group": setting.DefaultUseAutoGroup, "default_use_auto_group": setting.DefaultUseAutoGroup,
"pay_methods": setting.PayMethods,
"usd_exchange_rate": setting.USDExchangeRate, "usd_exchange_rate": operation_setting.USDExchangeRate,
"price": operation_setting.Price,
"stripe_unit_price": setting.StripeUnitPrice,
// 面板启用开关 // 面板启用开关
"api_info_enabled": cs.ApiInfoEnabled, "api_info_enabled": cs.ApiInfoEnabled,
@ -253,7 +249,7 @@ func SendPasswordResetEmail(c *gin.Context) {
} }
code := common.GenerateVerificationCode(0) code := common.GenerateVerificationCode(0)
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", setting.ServerAddress, email, code) link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", system_setting.ServerAddress, email, code)
subject := fmt.Sprintf("%s密码重置", common.SystemName) subject := fmt.Sprintf("%s密码重置", common.SystemName)
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+ content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+ "<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+

View File

@ -8,7 +8,6 @@ import (
"net/url" "net/url"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"one-api/setting"
"one-api/setting/system_setting" "one-api/setting/system_setting"
"strconv" "strconv"
"strings" "strings"
@ -45,7 +44,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret) values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
values.Set("code", code) values.Set("code", code)
values.Set("grant_type", "authorization_code") values.Set("grant_type", "authorization_code")
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", setting.ServerAddress)) values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress))
formData := values.Encode() formData := values.Encode()
req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData)) req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
if err != nil { if err != nil {

View File

@ -139,15 +139,15 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
// common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta) // common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) newAPIError = service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil { if newAPIError != nil {
return return
} }
defer func() { defer func() {
// Only return quota if downstream failed and quota was actually pre-consumed // Only return quota if downstream failed and quota was actually pre-consumed
if newAPIError != nil && preConsumedQuota != 0 { if newAPIError != nil && relayInfo.FinalPreConsumedQuota != 0 {
service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota) service.ReturnPreConsumedQuota(c, relayInfo)
} }
}() }()
@ -277,14 +277,13 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) { func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error())) logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
// 不要使用context获取渠道信息异步处理时可能会出现渠道信息不一致的情况
gopool.Go(func() { // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
// 不要使用context获取渠道信息异步处理时可能会出现渠道信息不一致的情况 if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously gopool.Go(func() {
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
service.DisableChannel(channelError, err.Error()) service.DisableChannel(channelError, err.Error())
} })
}) }
if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) { if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) {
// 保存错误日志到mysql中 // 保存错误日志到mysql中

View File

@ -94,7 +94,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil { } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err) return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
} else { } else {
task.Data = responseBody task.Data = redactVideoResponseBody(responseBody)
} }
now := time.Now().Unix() now := time.Now().Unix()
@ -117,7 +117,9 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
if task.FinishTime == 0 { if task.FinishTime == 0 {
task.FinishTime = now task.FinishTime = now
} }
task.FailReason = taskResult.Url if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
task.FailReason = taskResult.Url
}
case model.TaskStatusFailure: case model.TaskStatusFailure:
task.Status = model.TaskStatusFailure task.Status = model.TaskStatusFailure
task.Progress = "100%" task.Progress = "100%"
@ -146,3 +148,37 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
return nil return nil
} }
func redactVideoResponseBody(body []byte) []byte {
var m map[string]any
if err := json.Unmarshal(body, &m); err != nil {
return body
}
resp, _ := m["response"].(map[string]any)
if resp != nil {
delete(resp, "bytesBase64Encoded")
if v, ok := resp["video"].(string); ok {
resp["video"] = truncateBase64(v)
}
if vs, ok := resp["videos"].([]any); ok {
for i := range vs {
if vm, ok := vs[i].(map[string]any); ok {
delete(vm, "bytesBase64Encoded")
}
}
}
}
b, err := json.Marshal(m)
if err != nil {
return body
}
return b
}
func truncateBase64(s string) string {
const maxKeep = 256
if len(s) <= maxKeep {
return s
}
return s[:maxKeep] + "..."
}

View File

@ -9,6 +9,8 @@ import (
"one-api/model" "one-api/model"
"one-api/service" "one-api/service"
"one-api/setting" "one-api/setting"
"one-api/setting/operation_setting"
"one-api/setting/system_setting"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@ -19,6 +21,44 @@ import (
"github.com/shopspring/decimal" "github.com/shopspring/decimal"
) )
func GetTopUpInfo(c *gin.Context) {
// 获取支付方式
payMethods := operation_setting.PayMethods
// 如果启用了 Stripe 支付,添加到支付方法列表
if setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "" {
// 检查是否已经包含 Stripe
hasStripe := false
for _, method := range payMethods {
if method["type"] == "stripe" {
hasStripe = true
break
}
}
if !hasStripe {
stripeMethod := map[string]string{
"name": "Stripe",
"type": "stripe",
"color": "rgba(var(--semi-purple-5), 1)",
"min_topup": strconv.Itoa(setting.StripeMinTopUp),
}
payMethods = append(payMethods, stripeMethod)
}
}
data := gin.H{
"enable_online_topup": operation_setting.PayAddress != "" && operation_setting.EpayId != "" && operation_setting.EpayKey != "",
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
"pay_methods": payMethods,
"min_topup": operation_setting.MinTopUp,
"stripe_min_topup": setting.StripeMinTopUp,
"amount_options": operation_setting.GetPaymentSetting().AmountOptions,
"discount": operation_setting.GetPaymentSetting().AmountDiscount,
}
common.ApiSuccess(c, data)
}
type EpayRequest struct { type EpayRequest struct {
Amount int64 `json:"amount"` Amount int64 `json:"amount"`
PaymentMethod string `json:"payment_method"` PaymentMethod string `json:"payment_method"`
@ -31,13 +71,13 @@ type AmountRequest struct {
} }
func GetEpayClient() *epay.Client { func GetEpayClient() *epay.Client {
if setting.PayAddress == "" || setting.EpayId == "" || setting.EpayKey == "" { if operation_setting.PayAddress == "" || operation_setting.EpayId == "" || operation_setting.EpayKey == "" {
return nil return nil
} }
withUrl, err := epay.NewClient(&epay.Config{ withUrl, err := epay.NewClient(&epay.Config{
PartnerID: setting.EpayId, PartnerID: operation_setting.EpayId,
Key: setting.EpayKey, Key: operation_setting.EpayKey,
}, setting.PayAddress) }, operation_setting.PayAddress)
if err != nil { if err != nil {
return nil return nil
} }
@ -58,15 +98,23 @@ func getPayMoney(amount int64, group string) float64 {
} }
dTopupGroupRatio := decimal.NewFromFloat(topupGroupRatio) dTopupGroupRatio := decimal.NewFromFloat(topupGroupRatio)
dPrice := decimal.NewFromFloat(setting.Price) dPrice := decimal.NewFromFloat(operation_setting.Price)
// apply optional preset discount by the original request amount (if configured), default 1.0
discount := 1.0
if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(amount)]; ok {
if ds > 0 {
discount = ds
}
}
dDiscount := decimal.NewFromFloat(discount)
payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio) payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio).Mul(dDiscount)
return payMoney.InexactFloat64() return payMoney.InexactFloat64()
} }
func getMinTopup() int64 { func getMinTopup() int64 {
minTopup := setting.MinTopUp minTopup := operation_setting.MinTopUp
if !common.DisplayInCurrencyEnabled { if !common.DisplayInCurrencyEnabled {
dMinTopup := decimal.NewFromInt(int64(minTopup)) dMinTopup := decimal.NewFromInt(int64(minTopup))
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
@ -99,13 +147,13 @@ func RequestEpay(c *gin.Context) {
return return
} }
if !setting.ContainsPayMethod(req.PaymentMethod) { if !operation_setting.ContainsPayMethod(req.PaymentMethod) {
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"}) c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
return return
} }
callBackAddress := service.GetCallbackAddress() callBackAddress := service.GetCallbackAddress()
returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log") returnUrl, _ := url.Parse(system_setting.ServerAddress + "/console/log")
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify") notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix()) tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo) tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)

View File

@ -8,6 +8,8 @@ import (
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"one-api/setting" "one-api/setting"
"one-api/setting/operation_setting"
"one-api/setting/system_setting"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -215,8 +217,8 @@ func genStripeLink(referenceId string, customerId string, email string, amount i
params := &stripe.CheckoutSessionParams{ params := &stripe.CheckoutSessionParams{
ClientReferenceID: stripe.String(referenceId), ClientReferenceID: stripe.String(referenceId),
SuccessURL: stripe.String(setting.ServerAddress + "/log"), SuccessURL: stripe.String(system_setting.ServerAddress + "/log"),
CancelURL: stripe.String(setting.ServerAddress + "/topup"), CancelURL: stripe.String(system_setting.ServerAddress + "/topup"),
LineItems: []*stripe.CheckoutSessionLineItemParams{ LineItems: []*stripe.CheckoutSessionLineItemParams{
{ {
Price: stripe.String(setting.StripePriceId), Price: stripe.String(setting.StripePriceId),
@ -254,6 +256,7 @@ func GetChargedAmount(count float64, user model.User) float64 {
} }
func getStripePayMoney(amount float64, group string) float64 { func getStripePayMoney(amount float64, group string) float64 {
originalAmount := amount
if !common.DisplayInCurrencyEnabled { if !common.DisplayInCurrencyEnabled {
amount = amount / common.QuotaPerUnit amount = amount / common.QuotaPerUnit
} }
@ -262,7 +265,14 @@ func getStripePayMoney(amount float64, group string) float64 {
if topupGroupRatio == 0 { if topupGroupRatio == 0 {
topupGroupRatio = 1 topupGroupRatio = 1
} }
payMoney := amount * setting.StripeUnitPrice * topupGroupRatio // apply optional preset discount by the original request amount (if configured), default 1.0
discount := 1.0
if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(originalAmount)]; ok {
if ds > 0 {
discount = ds
}
}
payMoney := amount * setting.StripeUnitPrice * topupGroupRatio * discount
return payMoney return payMoney
} }

View File

@ -9,6 +9,14 @@ type ChannelSettings struct {
SystemPromptOverride bool `json:"system_prompt_override,omitempty"` SystemPromptOverride bool `json:"system_prompt_override,omitempty"`
} }
type VertexKeyType string
const (
VertexKeyTypeJSON VertexKeyType = "json"
VertexKeyTypeAPIKey VertexKeyType = "api_key"
)
type ChannelOtherSettings struct { type ChannelOtherSettings struct {
AzureResponsesVersion string `json:"azure_responses_version,omitempty"` AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
} }

View File

@ -2,12 +2,11 @@ package dto
import ( import (
"encoding/json" "encoding/json"
"github.com/gin-gonic/gin"
"one-api/common" "one-api/common"
"one-api/logger" "one-api/logger"
"one-api/types" "one-api/types"
"strings" "strings"
"github.com/gin-gonic/gin"
) )
type GeminiChatRequest struct { type GeminiChatRequest struct {
@ -269,15 +268,14 @@ type GeminiChatResponse struct {
} }
type GeminiUsageMetadata struct { type GeminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"` PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"` CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"` TotalTokenCount int `json:"totalTokenCount"`
ThoughtsTokenCount int `json:"thoughtsTokenCount"` ThoughtsTokenCount int `json:"thoughtsTokenCount"`
PromptTokensDetails []GeminiModalityTokenCount `json:"promptTokensDetails"` PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"`
CandidatesTokensDetails []GeminiModalityTokenCount `json:"candidatesTokensDetails"`
} }
type GeminiModalityTokenCount struct { type GeminiPromptTokensDetails struct {
Modality string `json:"modality"` Modality string `json:"modality"`
TokenCount int `json:"tokenCount"` TokenCount int `json:"tokenCount"`
} }

View File

@ -59,6 +59,31 @@ func (i *ImageRequest) UnmarshalJSON(data []byte) error {
return nil return nil
} }
// 序列化时需要重新把字段平铺
func (r ImageRequest) MarshalJSON() ([]byte, error) {
// 将已定义字段转为 map
type Alias ImageRequest
alias := Alias(r)
base, err := common.Marshal(alias)
if err != nil {
return nil, err
}
var baseMap map[string]json.RawMessage
if err := common.Unmarshal(base, &baseMap); err != nil {
return nil, err
}
// 合并 ExtraFields
for k, v := range r.Extra {
if _, exists := baseMap[k]; !exists {
baseMap[k] = v
}
}
return json.Marshal(baseMap)
}
func GetJSONFieldNames(t reflect.Type) map[string]struct{} { func GetJSONFieldNames(t reflect.Type) map[string]struct{} {
fields := make(map[string]struct{}) fields := make(map[string]struct{})
for i := 0; i < t.NumField(); i++ { for i := 0; i < t.NumField(); i++ {

View File

@ -166,9 +166,9 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
c.Set("platform", string(constant.TaskPlatformSuno)) c.Set("platform", string(constant.TaskPlatformSuno))
c.Set("relay_mode", relayMode) c.Set("relay_mode", relayMode)
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") { } else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
relayMode := relayconstant.RelayModeUnknown relayMode := relayconstant.RelayModeUnknown
if c.Request.Method == http.MethodPost { if c.Request.Method == http.MethodPost {
err = common.UnmarshalBodyReusable(c, &modelRequest)
relayMode = relayconstant.RelayModeVideoSubmit relayMode = relayconstant.RelayModeVideoSubmit
} else if c.Request.Method == http.MethodGet { } else if c.Request.Method == http.MethodGet {
relayMode = relayconstant.RelayModeVideoFetchByID relayMode = relayconstant.RelayModeVideoFetchByID

View File

@ -42,7 +42,6 @@ type Channel struct {
Priority *int64 `json:"priority" gorm:"bigint;default:0"` Priority *int64 `json:"priority" gorm:"bigint;default:0"`
AutoBan *int `json:"auto_ban" gorm:"default:1"` AutoBan *int `json:"auto_ban" gorm:"default:1"`
OtherInfo string `json:"other_info"` OtherInfo string `json:"other_info"`
OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置
Tag *string `json:"tag" gorm:"index"` Tag *string `json:"tag" gorm:"index"`
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置 Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
ParamOverride *string `json:"param_override" gorm:"type:text"` ParamOverride *string `json:"param_override" gorm:"type:text"`
@ -51,6 +50,8 @@ type Channel struct {
// add after v0.8.5 // add after v0.8.5
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"` ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置存储azure版本等不需要检索的信息详见dto.ChannelOtherSettings
// cache info // cache info
Keys []string `json:"-" gorm:"-"` Keys []string `json:"-" gorm:"-"`
} }

View File

@ -6,6 +6,7 @@ import (
"one-api/setting/config" "one-api/setting/config"
"one-api/setting/operation_setting" "one-api/setting/operation_setting"
"one-api/setting/ratio_setting" "one-api/setting/ratio_setting"
"one-api/setting/system_setting"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -66,16 +67,16 @@ func InitOptionMap() {
common.OptionMap["SystemName"] = common.SystemName common.OptionMap["SystemName"] = common.SystemName
common.OptionMap["Logo"] = common.Logo common.OptionMap["Logo"] = common.Logo
common.OptionMap["ServerAddress"] = "" common.OptionMap["ServerAddress"] = ""
common.OptionMap["WorkerUrl"] = setting.WorkerUrl common.OptionMap["WorkerUrl"] = system_setting.WorkerUrl
common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey common.OptionMap["WorkerValidKey"] = system_setting.WorkerValidKey
common.OptionMap["WorkerAllowHttpImageRequestEnabled"] = strconv.FormatBool(setting.WorkerAllowHttpImageRequestEnabled) common.OptionMap["WorkerAllowHttpImageRequestEnabled"] = strconv.FormatBool(system_setting.WorkerAllowHttpImageRequestEnabled)
common.OptionMap["PayAddress"] = "" common.OptionMap["PayAddress"] = ""
common.OptionMap["CustomCallbackAddress"] = "" common.OptionMap["CustomCallbackAddress"] = ""
common.OptionMap["EpayId"] = "" common.OptionMap["EpayId"] = ""
common.OptionMap["EpayKey"] = "" common.OptionMap["EpayKey"] = ""
common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64) common.OptionMap["Price"] = strconv.FormatFloat(operation_setting.Price, 'f', -1, 64)
common.OptionMap["USDExchangeRate"] = strconv.FormatFloat(setting.USDExchangeRate, 'f', -1, 64) common.OptionMap["USDExchangeRate"] = strconv.FormatFloat(operation_setting.USDExchangeRate, 'f', -1, 64)
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp) common.OptionMap["MinTopUp"] = strconv.Itoa(operation_setting.MinTopUp)
common.OptionMap["StripeMinTopUp"] = strconv.Itoa(setting.StripeMinTopUp) common.OptionMap["StripeMinTopUp"] = strconv.Itoa(setting.StripeMinTopUp)
common.OptionMap["StripeApiSecret"] = setting.StripeApiSecret common.OptionMap["StripeApiSecret"] = setting.StripeApiSecret
common.OptionMap["StripeWebhookSecret"] = setting.StripeWebhookSecret common.OptionMap["StripeWebhookSecret"] = setting.StripeWebhookSecret
@ -85,7 +86,7 @@ func InitOptionMap() {
common.OptionMap["Chats"] = setting.Chats2JsonString() common.OptionMap["Chats"] = setting.Chats2JsonString()
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString() common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup) common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
common.OptionMap["PayMethods"] = setting.PayMethods2JsonString() common.OptionMap["PayMethods"] = operation_setting.PayMethods2JsonString()
common.OptionMap["GitHubClientId"] = "" common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = "" common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["TelegramBotToken"] = "" common.OptionMap["TelegramBotToken"] = ""
@ -271,7 +272,7 @@ func updateOptionMap(key string, value string) (err error) {
case "SMTPSSLEnabled": case "SMTPSSLEnabled":
common.SMTPSSLEnabled = boolValue common.SMTPSSLEnabled = boolValue
case "WorkerAllowHttpImageRequestEnabled": case "WorkerAllowHttpImageRequestEnabled":
setting.WorkerAllowHttpImageRequestEnabled = boolValue system_setting.WorkerAllowHttpImageRequestEnabled = boolValue
case "DefaultUseAutoGroup": case "DefaultUseAutoGroup":
setting.DefaultUseAutoGroup = boolValue setting.DefaultUseAutoGroup = boolValue
case "ExposeRatioEnabled": case "ExposeRatioEnabled":
@ -293,29 +294,29 @@ func updateOptionMap(key string, value string) (err error) {
case "SMTPToken": case "SMTPToken":
common.SMTPToken = value common.SMTPToken = value
case "ServerAddress": case "ServerAddress":
setting.ServerAddress = value system_setting.ServerAddress = value
case "WorkerUrl": case "WorkerUrl":
setting.WorkerUrl = value system_setting.WorkerUrl = value
case "WorkerValidKey": case "WorkerValidKey":
setting.WorkerValidKey = value system_setting.WorkerValidKey = value
case "PayAddress": case "PayAddress":
setting.PayAddress = value operation_setting.PayAddress = value
case "Chats": case "Chats":
err = setting.UpdateChatsByJsonString(value) err = setting.UpdateChatsByJsonString(value)
case "AutoGroups": case "AutoGroups":
err = setting.UpdateAutoGroupsByJsonString(value) err = setting.UpdateAutoGroupsByJsonString(value)
case "CustomCallbackAddress": case "CustomCallbackAddress":
setting.CustomCallbackAddress = value operation_setting.CustomCallbackAddress = value
case "EpayId": case "EpayId":
setting.EpayId = value operation_setting.EpayId = value
case "EpayKey": case "EpayKey":
setting.EpayKey = value operation_setting.EpayKey = value
case "Price": case "Price":
setting.Price, _ = strconv.ParseFloat(value, 64) operation_setting.Price, _ = strconv.ParseFloat(value, 64)
case "USDExchangeRate": case "USDExchangeRate":
setting.USDExchangeRate, _ = strconv.ParseFloat(value, 64) operation_setting.USDExchangeRate, _ = strconv.ParseFloat(value, 64)
case "MinTopUp": case "MinTopUp":
setting.MinTopUp, _ = strconv.Atoi(value) operation_setting.MinTopUp, _ = strconv.Atoi(value)
case "StripeApiSecret": case "StripeApiSecret":
setting.StripeApiSecret = value setting.StripeApiSecret = value
case "StripeWebhookSecret": case "StripeWebhookSecret":
@ -413,7 +414,7 @@ func updateOptionMap(key string, value string) (err error) {
case "StreamCacheQueueLength": case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value) setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
case "PayMethods": case "PayMethods":
err = setting.UpdatePayMethodsByJsonString(value) err = operation_setting.UpdatePayMethodsByJsonString(value)
} }
return err return err
} }

View File

@ -53,7 +53,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
if resp != nil { if resp != nil {
httpResp = resp.(*http.Response) httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK { if httpResp.StatusCode != http.StatusOK {
newAPIError = service.RelayErrorHandler(httpResp, false) newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr) service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError return newAPIError

View File

@ -264,9 +264,8 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
} }
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return nil, err return nil, types.NewError(err, types.ErrorCodeDoRequestFailed, types.ErrOptionWithHideErrMsg("upstream error: do request failed"))
} }
if resp == nil { if resp == nil {
return nil, errors.New("resp is nil") return nil, errors.New("resp is nil")

View File

@ -60,7 +60,16 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
// 检查是否为Nova模型
if isNovaModel(request.Model) {
novaReq := convertToNovaRequest(request)
c.Set("request_model", request.Model)
c.Set("converted_request", novaReq)
c.Set("is_nova_model", true)
return novaReq, nil
}
// 原有的Claude模型处理逻辑
var claudeReq *dto.ClaudeRequest var claudeReq *dto.ClaudeRequest
var err error var err error
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request) claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
@ -69,6 +78,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
} }
c.Set("request_model", claudeReq.Model) c.Set("request_model", claudeReq.Model)
c.Set("converted_request", claudeReq) c.Set("converted_request", claudeReq)
c.Set("is_nova_model", false)
return claudeReq, err return claudeReq, err
} }

View File

@ -1,5 +1,7 @@
package aws package aws
import "strings"
var awsModelIDMap = map[string]string{ var awsModelIDMap = map[string]string{
"claude-instant-1.2": "anthropic.claude-instant-v1", "claude-instant-1.2": "anthropic.claude-instant-v1",
"claude-2.0": "anthropic.claude-v2", "claude-2.0": "anthropic.claude-v2",
@ -14,6 +16,11 @@ var awsModelIDMap = map[string]string{
"claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0", "claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0",
"claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0", "claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0",
"claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0", "claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0",
// Nova models
"nova-micro-v1:0": "amazon.nova-micro-v1:0",
"nova-lite-v1:0": "amazon.nova-lite-v1:0",
"nova-pro-v1:0": "amazon.nova-pro-v1:0",
"nova-premier-v1:0": "amazon.nova-premier-v1:0",
} }
var awsModelCanCrossRegionMap = map[string]map[string]bool{ var awsModelCanCrossRegionMap = map[string]map[string]bool{
@ -58,7 +65,27 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
"anthropic.claude-opus-4-1-20250805-v1:0": { "anthropic.claude-opus-4-1-20250805-v1:0": {
"us": true, "us": true,
}, },
} // Nova models - all support three major regions
"amazon.nova-micro-v1:0": {
"us": true,
"eu": true,
"apac": true,
},
"amazon.nova-lite-v1:0": {
"us": true,
"eu": true,
"apac": true,
},
"amazon.nova-pro-v1:0": {
"us": true,
"eu": true,
"apac": true,
},
"amazon.nova-premier-v1:0": {
"us": true,
"eu": true,
"apac": true,
}}
var awsRegionCrossModelPrefixMap = map[string]string{ var awsRegionCrossModelPrefixMap = map[string]string{
"us": "us", "us": "us",
@ -67,3 +94,8 @@ var awsRegionCrossModelPrefixMap = map[string]string{
} }
var ChannelName = "aws" var ChannelName = "aws"
// 判断是否为Nova模型
func isNovaModel(modelId string) bool {
return strings.HasPrefix(modelId, "nova-")
}

View File

@ -34,3 +34,92 @@ func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
Thinking: req.Thinking, Thinking: req.Thinking,
} }
} }
// NovaMessage Nova模型使用messages-v1格式
type NovaMessage struct {
Role string `json:"role"`
Content []NovaContent `json:"content"`
}
type NovaContent struct {
Text string `json:"text"`
}
type NovaRequest struct {
SchemaVersion string `json:"schemaVersion"` // 请求版本,例如 "1.0"
Messages []NovaMessage `json:"messages"` // 对话消息列表
InferenceConfig *NovaInferenceConfig `json:"inferenceConfig,omitempty"` // 推理配置,可选
}
type NovaInferenceConfig struct {
MaxTokens int `json:"maxTokens,omitempty"` // 最大生成的 token 数
Temperature float64 `json:"temperature,omitempty"` // 随机性 (默认 0.7, 范围 0-1)
TopP float64 `json:"topP,omitempty"` // nucleus sampling (默认 0.9, 范围 0-1)
TopK int `json:"topK,omitempty"` // 限制候选 token 数 (默认 50, 范围 0-128)
StopSequences []string `json:"stopSequences,omitempty"` // 停止生成的序列
}
// 转换OpenAI请求为Nova格式
func convertToNovaRequest(req *dto.GeneralOpenAIRequest) *NovaRequest {
novaMessages := make([]NovaMessage, len(req.Messages))
for i, msg := range req.Messages {
novaMessages[i] = NovaMessage{
Role: msg.Role,
Content: []NovaContent{{Text: msg.StringContent()}},
}
}
novaReq := &NovaRequest{
SchemaVersion: "messages-v1",
Messages: novaMessages,
}
// 设置推理配置
if req.MaxTokens != 0 || (req.Temperature != nil && *req.Temperature != 0) || req.TopP != 0 || req.TopK != 0 || req.Stop != nil {
novaReq.InferenceConfig = &NovaInferenceConfig{}
if req.MaxTokens != 0 {
novaReq.InferenceConfig.MaxTokens = int(req.MaxTokens)
}
if req.Temperature != nil && *req.Temperature != 0 {
novaReq.InferenceConfig.Temperature = *req.Temperature
}
if req.TopP != 0 {
novaReq.InferenceConfig.TopP = req.TopP
}
if req.TopK != 0 {
novaReq.InferenceConfig.TopK = req.TopK
}
if req.Stop != nil {
if stopSequences := parseStopSequences(req.Stop); len(stopSequences) > 0 {
novaReq.InferenceConfig.StopSequences = stopSequences
}
}
}
return novaReq
}
// parseStopSequences 解析停止序列,支持字符串或字符串数组
func parseStopSequences(stop any) []string {
if stop == nil {
return nil
}
switch v := stop.(type) {
case string:
if v != "" {
return []string{v}
}
case []string:
return v
case []interface{}:
var sequences []string
for _, item := range v {
if str, ok := item.(string); ok && str != "" {
sequences = append(sequences, str)
}
}
return sequences
}
return nil
}

View File

@ -1,6 +1,7 @@
package aws package aws
import ( import (
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
@ -93,7 +94,19 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
} }
awsModelId := awsModelID(c.GetString("request_model")) awsModelId := awsModelID(c.GetString("request_model"))
// 检查是否为Nova模型
isNova, _ := c.Get("is_nova_model")
if isNova == true {
// Nova模型也支持跨区域
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
if canCrossRegion {
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
}
return handleNovaRequest(c, awsCli, info, awsModelId)
}
// 原有的Claude处理逻辑
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region) awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix) canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
if canCrossRegion { if canCrossRegion {
@ -209,3 +222,74 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage) claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
return nil, claudeInfo.Usage return nil, claudeInfo.Usage
} }
// Nova模型处理函数
func handleNovaRequest(c *gin.Context, awsCli *bedrockruntime.Client, info *relaycommon.RelayInfo, awsModelId string) (*types.NewAPIError, *dto.Usage) {
novaReq_, ok := c.Get("converted_request")
if !ok {
return types.NewError(errors.New("nova request not found"), types.ErrorCodeInvalidRequest), nil
}
novaReq := novaReq_.(*NovaRequest)
// 使用InvokeModel API但使用Nova格式的请求体
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
reqBody, err := json.Marshal(novaReq)
if err != nil {
return types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody), nil
}
awsReq.Body = reqBody
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
if err != nil {
return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
}
// 解析Nova响应
var novaResp struct {
Output struct {
Message struct {
Content []struct {
Text string `json:"text"`
} `json:"content"`
} `json:"message"`
} `json:"output"`
Usage struct {
InputTokens int `json:"inputTokens"`
OutputTokens int `json:"outputTokens"`
TotalTokens int `json:"totalTokens"`
} `json:"usage"`
}
if err := json.Unmarshal(awsResp.Body, &novaResp); err != nil {
return types.NewError(errors.Wrap(err, "unmarshal nova response"), types.ErrorCodeBadResponseBody), nil
}
// 构造OpenAI格式响应
response := dto.OpenAITextResponse{
Id: helper.GetResponseID(c),
Object: "chat.completion",
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
Choices: []dto.OpenAITextResponseChoice{{
Index: 0,
Message: dto.Message{
Role: "assistant",
Content: novaResp.Output.Message.Content[0].Text,
},
FinishReason: "stop",
}},
Usage: dto.Usage{
PromptTokens: novaResp.Usage.InputTokens,
CompletionTokens: novaResp.Usage.OutputTokens,
TotalTokens: novaResp.Usage.TotalTokens,
},
}
c.JSON(http.StatusOK, response)
return nil, &response.Usage
}

View File

@ -46,32 +46,6 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") {
imageOutputCounts := 0
for _, candidate := range geminiResponse.Candidates {
for _, part := range candidate.Content.Parts {
if part.InlineData != nil && strings.HasPrefix(part.InlineData.MimeType, "image/") {
imageOutputCounts++
}
}
}
if imageOutputCounts != 0 {
usage.CompletionTokens = usage.CompletionTokens - imageOutputCounts*1290
usage.TotalTokens = usage.TotalTokens - imageOutputCounts*1290
c.Set("gemini_image_tokens", imageOutputCounts*1290)
}
}
// if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") {
// for _, detail := range geminiResponse.UsageMetadata.CandidatesTokensDetails {
// if detail.Modality == "IMAGE" {
// usage.CompletionTokens = usage.CompletionTokens - detail.TokenCount
// usage.TotalTokens = usage.TotalTokens - detail.TokenCount
// c.Set("gemini_image_tokens", detail.TokenCount)
// }
// }
// }
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails { for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
if detail.Modality == "AUDIO" { if detail.Modality == "AUDIO" {
usage.PromptTokensDetails.AudioTokens = detail.TokenCount usage.PromptTokensDetails.AudioTokens = detail.TokenCount
@ -162,16 +136,6 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
usage.PromptTokensDetails.TextTokens = detail.TokenCount usage.PromptTokensDetails.TextTokens = detail.TokenCount
} }
} }
if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") {
for _, detail := range geminiResponse.UsageMetadata.CandidatesTokensDetails {
if detail.Modality == "IMAGE" {
usage.CompletionTokens = usage.CompletionTokens - detail.TokenCount
usage.TotalTokens = usage.TotalTokens - detail.TokenCount
c.Set("gemini_image_tokens", detail.TokenCount)
}
}
}
} }
// 直接发送 GeminiChatResponse 响应 // 直接发送 GeminiChatResponse 响应

View File

@ -17,15 +17,15 @@ type Adaptor struct {
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
return nil, errors.New("not implemented") return nil, errors.New("submodel channel: endpoint not supported")
} }
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
return nil, errors.New("not implemented") return nil, errors.New("submodel channel: endpoint not supported")
} }
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
return nil, errors.New("not implemented") return nil, errors.New("submodel channel: endpoint not supported")
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@ -49,15 +49,15 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
} }
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, errors.New("not implemented") return nil, errors.New("submodel channel: endpoint not supported")
} }
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return nil, errors.New("not implemented") return nil, errors.New("submodel channel: endpoint not supported")
} }
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
return nil, errors.New("not implemented") return nil, errors.New("submodel channel: endpoint not supported")
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {

View File

@ -18,7 +18,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pkg/errors" "github.com/pkg/errors"
"one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
@ -89,22 +88,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
// ValidateRequestAndSetAction parses body, validates fields and sets default action. // ValidateRequestAndSetAction parses body, validates fields and sets default action.
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
// Accept only POST /v1/video/generations as "generate" action. // Accept only POST /v1/video/generations as "generate" action.
action := constant.TaskActionGenerate return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
info.Action = action
req := relaycommon.TaskSubmitReq{}
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
return
}
if strings.TrimSpace(req.Prompt) == "" {
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
return
}
// Store into context for later usage
c.Set("task_request", req)
return nil
} }
// BuildRequestURL constructs the upstream URL. // BuildRequestURL constructs the upstream URL.
@ -334,11 +318,11 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
} }
// Handle one-of image_urls or binary_data_base64 // Handle one-of image_urls or binary_data_base64
if req.Image != "" { if req.HasImage() {
if strings.HasPrefix(req.Image, "http") { if strings.HasPrefix(req.Images[0], "http") {
r.ImageUrls = []string{req.Image} r.ImageUrls = req.Images
} else { } else {
r.BinaryDataBase64 = []string{req.Image} r.BinaryDataBase64 = req.Images
} }
} }
metadata := req.Metadata metadata := req.Metadata

View File

@ -16,7 +16,6 @@ import (
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"github.com/pkg/errors" "github.com/pkg/errors"
"one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
@ -28,16 +27,6 @@ import (
// Request / Response structures // Request / Response structures
// ============================ // ============================
type SubmitReq struct {
Prompt string `json:"prompt"`
Model string `json:"model,omitempty"`
Mode string `json:"mode,omitempty"`
Image string `json:"image,omitempty"`
Size string `json:"size,omitempty"`
Duration int `json:"duration,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
type TrajectoryPoint struct { type TrajectoryPoint struct {
X int `json:"x"` X int `json:"x"`
Y int `json:"y"` Y int `json:"y"`
@ -121,23 +110,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
// ValidateRequestAndSetAction parses body, validates fields and sets default action. // ValidateRequestAndSetAction parses body, validates fields and sets default action.
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
// Accept only POST /v1/video/generations as "generate" action. // Use the standard validation method for TaskSubmitReq
action := constant.TaskActionGenerate return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
info.Action = action
var req SubmitReq
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
return
}
if strings.TrimSpace(req.Prompt) == "" {
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
return
}
// Store into context for later usage
c.Set("task_request", req)
return nil
} }
// BuildRequestURL constructs the upstream URL. // BuildRequestURL constructs the upstream URL.
@ -166,7 +140,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
if !exists { if !exists {
return nil, fmt.Errorf("request not found in context") return nil, fmt.Errorf("request not found in context")
} }
req := v.(SubmitReq) req := v.(relaycommon.TaskSubmitReq)
body, err := a.convertToRequestPayload(&req) body, err := a.convertToRequestPayload(&req)
if err != nil { if err != nil {
@ -255,7 +229,7 @@ func (a *TaskAdaptor) GetChannelName() string {
// helpers // helpers
// ============================ // ============================
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) { func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
r := requestPayload{ r := requestPayload{
Prompt: req.Prompt, Prompt: req.Prompt,
Image: req.Image, Image: req.Image,

View File

@ -0,0 +1,355 @@
package vertex
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/model"
"regexp"
"strings"
"github.com/gin-gonic/gin"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
vertexcore "one-api/relay/channel/vertex"
relaycommon "one-api/relay/common"
"one-api/service"
)
// ============================
// Request / Response structures
// ============================
type requestPayload struct {
Instances []map[string]any `json:"instances"`
Parameters map[string]any `json:"parameters,omitempty"`
}
type submitResponse struct {
Name string `json:"name"`
}
type operationVideo struct {
MimeType string `json:"mimeType"`
BytesBase64Encoded string `json:"bytesBase64Encoded"`
Encoding string `json:"encoding"`
}
type operationResponse struct {
Name string `json:"name"`
Done bool `json:"done"`
Response struct {
Type string `json:"@type"`
RaiMediaFilteredCount int `json:"raiMediaFilteredCount"`
Videos []operationVideo `json:"videos"`
BytesBase64Encoded string `json:"bytesBase64Encoded"`
Encoding string `json:"encoding"`
Video string `json:"video"`
} `json:"response"`
Error struct {
Message string `json:"message"`
} `json:"error"`
}
// ============================
// Adaptor implementation
// ============================
type TaskAdaptor struct {
ChannelType int
apiKey string
baseURL string
}
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
a.ChannelType = info.ChannelType
a.baseURL = info.ChannelBaseUrl
a.apiKey = info.ApiKey
}
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
// Use the standard validation method for TaskSubmitReq
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionTextGenerate)
}
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
adc := &vertexcore.Credentials{}
if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
return "", fmt.Errorf("failed to decode credentials: %w", err)
}
modelName := info.OriginModelName
if modelName == "" {
modelName = "veo-3.0-generate-001"
}
region := vertexcore.GetModelRegion(info.ApiVersion, modelName)
if strings.TrimSpace(region) == "" {
region = "global"
}
if region == "global" {
return fmt.Sprintf(
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predictLongRunning",
adc.ProjectID,
modelName,
), nil
}
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predictLongRunning",
region,
adc.ProjectID,
region,
modelName,
), nil
}
// BuildRequestHeader sets required headers.
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
adc := &vertexcore.Credentials{}
if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
return fmt.Errorf("failed to decode credentials: %w", err)
}
token, err := vertexcore.AcquireAccessToken(*adc, "")
if err != nil {
return fmt.Errorf("failed to acquire access token: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("x-goog-user-project", adc.ProjectID)
return nil
}
// BuildRequestBody converts request into Vertex specific format.
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
v, ok := c.Get("task_request")
if !ok {
return nil, fmt.Errorf("request not found in context")
}
req := v.(relaycommon.TaskSubmitReq)
body := requestPayload{
Instances: []map[string]any{{"prompt": req.Prompt}},
Parameters: map[string]any{},
}
if req.Metadata != nil {
if v, ok := req.Metadata["storageUri"]; ok {
body.Parameters["storageUri"] = v
}
if v, ok := req.Metadata["sampleCount"]; ok {
body.Parameters["sampleCount"] = v
}
}
if _, ok := body.Parameters["sampleCount"]; !ok {
body.Parameters["sampleCount"] = 1
}
data, err := json.Marshal(body)
if err != nil {
return nil, err
}
return bytes.NewReader(data), nil
}
// DoRequest delegates to common helper.
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoTaskApiRequest(a, c, info, requestBody)
}
// DoResponse handles upstream response, returns taskID etc.
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return "", nil, service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
_ = resp.Body.Close()
var s submitResponse
if err := json.Unmarshal(responseBody, &s); err != nil {
return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
}
if strings.TrimSpace(s.Name) == "" {
return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
}
localID := encodeLocalTaskID(s.Name)
c.JSON(http.StatusOK, gin.H{"task_id": localID})
return localID, responseBody, nil
}
func (a *TaskAdaptor) GetModelList() []string { return []string{"veo-3.0-generate-001"} }
func (a *TaskAdaptor) GetChannelName() string { return "vertex" }
// FetchTask fetch task status
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
taskID, ok := body["task_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid task_id")
}
upstreamName, err := decodeLocalTaskID(taskID)
if err != nil {
return nil, fmt.Errorf("decode task_id failed: %w", err)
}
region := extractRegionFromOperationName(upstreamName)
if region == "" {
region = "us-central1"
}
project := extractProjectFromOperationName(upstreamName)
modelName := extractModelFromOperationName(upstreamName)
if project == "" || modelName == "" {
return nil, fmt.Errorf("cannot extract project/model from operation name")
}
var url string
if region == "global" {
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:fetchPredictOperation", project, modelName)
} else {
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName)
}
payload := map[string]string{"operationName": upstreamName}
data, err := json.Marshal(payload)
if err != nil {
return nil, err
}
adc := &vertexcore.Credentials{}
if err := json.Unmarshal([]byte(key), adc); err != nil {
return nil, fmt.Errorf("failed to decode credentials: %w", err)
}
token, err := vertexcore.AcquireAccessToken(*adc, "")
if err != nil {
return nil, fmt.Errorf("failed to acquire access token: %w", err)
}
req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(data))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("x-goog-user-project", adc.ProjectID)
return service.GetHttpClient().Do(req)
}
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
var op operationResponse
if err := json.Unmarshal(respBody, &op); err != nil {
return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
}
ti := &relaycommon.TaskInfo{}
if op.Error.Message != "" {
ti.Status = model.TaskStatusFailure
ti.Reason = op.Error.Message
ti.Progress = "100%"
return ti, nil
}
if !op.Done {
ti.Status = model.TaskStatusInProgress
ti.Progress = "50%"
return ti, nil
}
ti.Status = model.TaskStatusSuccess
ti.Progress = "100%"
if len(op.Response.Videos) > 0 {
v0 := op.Response.Videos[0]
if v0.BytesBase64Encoded != "" {
mime := strings.TrimSpace(v0.MimeType)
if mime == "" {
enc := strings.TrimSpace(v0.Encoding)
if enc == "" {
enc = "mp4"
}
if strings.Contains(enc, "/") {
mime = enc
} else {
mime = "video/" + enc
}
}
ti.Url = "data:" + mime + ";base64," + v0.BytesBase64Encoded
return ti, nil
}
}
if op.Response.BytesBase64Encoded != "" {
enc := strings.TrimSpace(op.Response.Encoding)
if enc == "" {
enc = "mp4"
}
mime := enc
if !strings.Contains(enc, "/") {
mime = "video/" + enc
}
ti.Url = "data:" + mime + ";base64," + op.Response.BytesBase64Encoded
return ti, nil
}
if op.Response.Video != "" { // some variants use `video` as base64
enc := strings.TrimSpace(op.Response.Encoding)
if enc == "" {
enc = "mp4"
}
mime := enc
if !strings.Contains(enc, "/") {
mime = "video/" + enc
}
ti.Url = "data:" + mime + ";base64," + op.Response.Video
return ti, nil
}
return ti, nil
}
// ============================
// helpers
// ============================
func encodeLocalTaskID(name string) string {
return base64.RawURLEncoding.EncodeToString([]byte(name))
}
func decodeLocalTaskID(local string) (string, error) {
b, err := base64.RawURLEncoding.DecodeString(local)
if err != nil {
return "", err
}
return string(b), nil
}
var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`)
func extractRegionFromOperationName(name string) string {
m := regionRe.FindStringSubmatch(name)
if len(m) == 2 {
return m[1]
}
return ""
}
var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`)
func extractModelFromOperationName(name string) string {
m := modelRe.FindStringSubmatch(name)
if len(m) == 2 {
return m[1]
}
idx := strings.Index(name, "models/")
if idx >= 0 {
s := name[idx+len("models/"):]
if p := strings.Index(s, "/operations/"); p > 0 {
return s[:p]
}
}
return ""
}
var projectRe = regexp.MustCompile(`projects/([^/]+)/locations/`)
func extractProjectFromOperationName(name string) string {
m := projectRe.FindStringSubmatch(name)
if len(m) == 2 {
return m[1]
}
return ""
}

View File

@ -23,16 +23,6 @@ import (
// Request / Response structures // Request / Response structures
// ============================ // ============================
type SubmitReq struct {
Prompt string `json:"prompt"`
Model string `json:"model,omitempty"`
Mode string `json:"mode,omitempty"`
Image string `json:"image,omitempty"`
Size string `json:"size,omitempty"`
Duration int `json:"duration,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
type requestPayload struct { type requestPayload struct {
Model string `json:"model"` Model string `json:"model"`
Images []string `json:"images"` Images []string `json:"images"`
@ -90,23 +80,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError { func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
var req SubmitReq // Use the unified validation method for TaskSubmitReq with image-based action determination
if err := c.ShouldBindJSON(&req); err != nil { return relaycommon.ValidateTaskRequestWithImageBinding(c, info)
return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest)
}
if req.Prompt == "" {
return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest)
}
if req.Image != "" {
info.Action = constant.TaskActionGenerate
} else {
info.Action = constant.TaskActionTextGenerate
}
c.Set("task_request", req)
return nil
} }
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) { func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) {
@ -114,7 +89,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo)
if !exists { if !exists {
return nil, fmt.Errorf("request not found in context") return nil, fmt.Errorf("request not found in context")
} }
req := v.(SubmitReq) req := v.(relaycommon.TaskSubmitReq)
body, err := a.convertToRequestPayload(&req) body, err := a.convertToRequestPayload(&req)
if err != nil { if err != nil {
@ -211,7 +186,7 @@ func (a *TaskAdaptor) GetChannelName() string {
// helpers // helpers
// ============================ // ============================
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) { func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
var images []string var images []string
if req.Image != "" { if req.Image != "" {
images = []string{req.Image} images = []string{req.Image}

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"one-api/common"
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
"one-api/relay/channel/claude" "one-api/relay/channel/claude"
@ -80,16 +81,64 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix string) (string, error) {
adc := &Credentials{}
if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil {
return "", fmt.Errorf("failed to decode credentials file: %w", err)
}
region := GetModelRegion(info.ApiVersion, info.OriginModelName) region := GetModelRegion(info.ApiVersion, info.OriginModelName)
a.AccountCredentials = *adc if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey {
adc := &Credentials{}
if err := common.Unmarshal([]byte(info.ApiKey), adc); err != nil {
return "", fmt.Errorf("failed to decode credentials file: %w", err)
}
a.AccountCredentials = *adc
if a.RequestMode == RequestModeLlama {
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
region,
adc.ProjectID,
region,
), nil
}
if region == "global" {
return fmt.Sprintf(
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
adc.ProjectID,
modelName,
suffix,
), nil
} else {
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
region,
adc.ProjectID,
region,
modelName,
suffix,
), nil
}
} else {
if region == "global" {
return fmt.Sprintf(
"https://aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s",
modelName,
suffix,
info.ApiKey,
), nil
} else {
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s",
region,
modelName,
suffix,
info.ApiKey,
), nil
}
}
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
suffix := "" suffix := ""
if a.RequestMode == RequestModeGemini { if a.RequestMode == RequestModeGemini {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
// 新增逻辑:处理 -thinking-<budget> 格式 // 新增逻辑:处理 -thinking-<budget> 格式
if strings.Contains(info.UpstreamModelName, "-thinking-") { if strings.Contains(info.UpstreamModelName, "-thinking-") {
@ -111,24 +160,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if strings.HasPrefix(info.UpstreamModelName, "imagen") { if strings.HasPrefix(info.UpstreamModelName, "imagen") {
suffix = "predict" suffix = "predict"
} }
return a.getRequestUrl(info, info.UpstreamModelName, suffix)
if region == "global" {
return fmt.Sprintf(
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
adc.ProjectID,
info.UpstreamModelName,
suffix,
), nil
} else {
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
region,
adc.ProjectID,
region,
info.UpstreamModelName,
suffix,
), nil
}
} else if a.RequestMode == RequestModeClaude { } else if a.RequestMode == RequestModeClaude {
if info.IsStream { if info.IsStream {
suffix = "streamRawPredict?alt=sse" suffix = "streamRawPredict?alt=sse"
@ -139,41 +171,25 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if v, ok := claudeModelMap[info.UpstreamModelName]; ok { if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
model = v model = v
} }
if region == "global" { return a.getRequestUrl(info, model, suffix)
return fmt.Sprintf(
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
adc.ProjectID,
model,
suffix,
), nil
} else {
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
region,
adc.ProjectID,
region,
model,
suffix,
), nil
}
} else if a.RequestMode == RequestModeLlama { } else if a.RequestMode == RequestModeLlama {
return fmt.Sprintf( return a.getRequestUrl(info, "", "")
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
region,
adc.ProjectID,
region,
), nil
} }
return "", errors.New("unsupported request mode") return "", errors.New("unsupported request mode")
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
accessToken, err := getAccessToken(a, info) if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey {
if err != nil { accessToken, err := getAccessToken(a, info)
return err if err != nil {
return err
}
req.Set("Authorization", "Bearer "+accessToken)
}
if a.AccountCredentials.ProjectID != "" {
req.Set("x-goog-user-project", a.AccountCredentials.ProjectID)
} }
req.Set("Authorization", "Bearer "+accessToken)
return nil return nil
} }

View File

@ -12,7 +12,10 @@ func GetModelRegion(other string, localModelName string) string {
if m[localModelName] != nil { if m[localModelName] != nil {
return m[localModelName].(string) return m[localModelName].(string)
} else { } else {
return m["default"].(string) if v, ok := m["default"]; ok {
return v.(string)
}
return "global"
} }
} }
return other return other

View File

@ -6,14 +6,15 @@ import (
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"errors" "errors"
"github.com/bytedance/gopkg/cache/asynccache"
"github.com/golang-jwt/jwt"
"net/http" "net/http"
"net/url" "net/url"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
"strings" "strings"
"github.com/bytedance/gopkg/cache/asynccache"
"github.com/golang-jwt/jwt"
"fmt" "fmt"
"time" "time"
) )
@ -137,3 +138,45 @@ func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (s
return "", fmt.Errorf("failed to get access token: %v", result) return "", fmt.Errorf("failed to get access token: %v", result)
} }
func AcquireAccessToken(creds Credentials, proxy string) (string, error) {
signedJWT, err := createSignedJWT(creds.ClientEmail, creds.PrivateKey)
if err != nil {
return "", fmt.Errorf("failed to create signed JWT: %w", err)
}
return exchangeJwtForAccessTokenWithProxy(signedJWT, proxy)
}
func exchangeJwtForAccessTokenWithProxy(signedJWT string, proxy string) (string, error) {
authURL := "https://www.googleapis.com/oauth2/v4/token"
data := url.Values{}
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
data.Set("assertion", signedJWT)
var client *http.Client
var err error
if proxy != "" {
client, err = service.NewProxyHttpClient(proxy)
if err != nil {
return "", fmt.Errorf("new proxy http client failed: %w", err)
}
} else {
client = service.GetHttpClient()
}
resp, err := client.PostForm(authURL, data)
if err != nil {
return "", err
}
defer resp.Body.Close()
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", err
}
if accessToken, ok := result["access_token"].(string); ok {
return accessToken, nil
}
return "", fmt.Errorf("failed to get access token: %v", result)
}

View File

@ -111,7 +111,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
httpResp = resp.(*http.Response) httpResp = resp.(*http.Response)
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK { if httpResp.StatusCode != http.StatusOK {
newAPIError = service.RelayErrorHandler(httpResp, false) newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr) service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError return newAPIError

View File

@ -481,11 +481,20 @@ type TaskSubmitReq struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Mode string `json:"mode,omitempty"` Mode string `json:"mode,omitempty"`
Image string `json:"image,omitempty"` Image string `json:"image,omitempty"`
Images []string `json:"images,omitempty"`
Size string `json:"size,omitempty"` Size string `json:"size,omitempty"`
Duration int `json:"duration,omitempty"` Duration int `json:"duration,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"` Metadata map[string]interface{} `json:"metadata,omitempty"`
} }
func (t TaskSubmitReq) GetPrompt() string {
return t.Prompt
}
func (t TaskSubmitReq) HasImage() bool {
return len(t.Images) > 0
}
type TaskInfo struct { type TaskInfo struct {
Code int `json:"code"` Code int `json:"code"`
TaskID string `json:"task_id"` TaskID string `json:"task_id"`

View File

@ -2,12 +2,23 @@ package common
import ( import (
"fmt" "fmt"
"net/http"
"one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
type HasPrompt interface {
GetPrompt() string
}
type HasImage interface {
HasImage() bool
}
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
@ -30,3 +41,72 @@ func GetAPIVersion(c *gin.Context) string {
} }
return apiVersion return apiVersion
} }
func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError {
return &dto.TaskError{
Code: code,
Message: err.Error(),
StatusCode: statusCode,
LocalError: localError,
Error: err,
}
}
func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj interface{}) {
info.Action = action
c.Set("task_request", requestObj)
}
func validatePrompt(prompt string) *dto.TaskError {
if strings.TrimSpace(prompt) == "" {
return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true)
}
return nil
}
func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
var req TaskSubmitReq
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
}
if taskErr := validatePrompt(req.Prompt); taskErr != nil {
return taskErr
}
if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" {
// 兼容单图上传
req.Images = []string{req.Image}
}
storeTaskRequest(c, info, action, req)
return nil
}
func ValidateTaskRequestWithImage(c *gin.Context, info *RelayInfo, requestObj interface{}) *dto.TaskError {
hasPrompt, ok := requestObj.(HasPrompt)
if !ok {
return createTaskError(fmt.Errorf("request must have prompt"), "invalid_request", http.StatusBadRequest, true)
}
if taskErr := validatePrompt(hasPrompt.GetPrompt()); taskErr != nil {
return taskErr
}
action := constant.TaskActionTextGenerate
if hasImage, ok := requestObj.(HasImage); ok && hasImage.HasImage() {
action = constant.TaskActionGenerate
}
storeTaskRequest(c, info, action, requestObj)
return nil
}
func ValidateTaskRequestWithImageBinding(c *gin.Context, info *RelayInfo) *dto.TaskError {
var req TaskSubmitReq
if err := c.ShouldBindJSON(&req); err != nil {
return createTaskError(err, "invalid_request_body", http.StatusBadRequest, false)
}
return ValidateTaskRequestWithImage(c, info, req)
}

View File

@ -158,7 +158,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
httpResp = resp.(*http.Response) httpResp = resp.(*http.Response)
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK { if httpResp.StatusCode != http.StatusOK {
newApiErr := service.RelayErrorHandler(httpResp, false) newApiErr := service.RelayErrorHandler(c.Request.Context(), httpResp, false)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(newApiErr, statusCodeMappingStr) service.ResetStatusCode(newApiErr, statusCodeMappingStr)
return newApiErr return newApiErr
@ -195,6 +195,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
imageTokens := usage.PromptTokensDetails.ImageTokens imageTokens := usage.PromptTokensDetails.ImageTokens
audioTokens := usage.PromptTokensDetails.AudioTokens audioTokens := usage.PromptTokensDetails.AudioTokens
completionTokens := usage.CompletionTokens completionTokens := usage.CompletionTokens
cachedCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
modelName := relayInfo.OriginModelName modelName := relayInfo.OriginModelName
tokenName := ctx.GetString("token_name") tokenName := ctx.GetString("token_name")
@ -204,6 +206,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
modelRatio := relayInfo.PriceData.ModelRatio modelRatio := relayInfo.PriceData.ModelRatio
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
modelPrice := relayInfo.PriceData.ModelPrice modelPrice := relayInfo.PriceData.ModelPrice
cachedCreationRatio := relayInfo.PriceData.CacheCreationRatio
// Convert values to decimal for precise calculation // Convert values to decimal for precise calculation
dPromptTokens := decimal.NewFromInt(int64(promptTokens)) dPromptTokens := decimal.NewFromInt(int64(promptTokens))
@ -211,12 +214,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
dImageTokens := decimal.NewFromInt(int64(imageTokens)) dImageTokens := decimal.NewFromInt(int64(imageTokens))
dAudioTokens := decimal.NewFromInt(int64(audioTokens)) dAudioTokens := decimal.NewFromInt(int64(audioTokens))
dCompletionTokens := decimal.NewFromInt(int64(completionTokens)) dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
dCachedCreationTokens := decimal.NewFromInt(int64(cachedCreationTokens))
dCompletionRatio := decimal.NewFromFloat(completionRatio) dCompletionRatio := decimal.NewFromFloat(completionRatio)
dCacheRatio := decimal.NewFromFloat(cacheRatio) dCacheRatio := decimal.NewFromFloat(cacheRatio)
dImageRatio := decimal.NewFromFloat(imageRatio) dImageRatio := decimal.NewFromFloat(imageRatio)
dModelRatio := decimal.NewFromFloat(modelRatio) dModelRatio := decimal.NewFromFloat(modelRatio)
dGroupRatio := decimal.NewFromFloat(groupRatio) dGroupRatio := decimal.NewFromFloat(groupRatio)
dModelPrice := decimal.NewFromFloat(modelPrice) dModelPrice := decimal.NewFromFloat(modelPrice)
dCachedCreationRatio := decimal.NewFromFloat(cachedCreationRatio)
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
ratio := dModelRatio.Mul(dGroupRatio) ratio := dModelRatio.Mul(dGroupRatio)
@ -284,6 +289,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
baseTokens = baseTokens.Sub(dCacheTokens) baseTokens = baseTokens.Sub(dCacheTokens)
cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio) cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
} }
var dCachedCreationTokensWithRatio decimal.Decimal
if !dCachedCreationTokens.IsZero() {
baseTokens = baseTokens.Sub(dCachedCreationTokens)
dCachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCachedCreationRatio)
}
// 减去 image tokens // 减去 image tokens
var imageTokensWithRatio decimal.Decimal var imageTokensWithRatio decimal.Decimal
@ -302,7 +312,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String()) extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String())
} }
} }
promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio) promptQuota := baseTokens.Add(cachedTokensWithRatio).
Add(imageTokensWithRatio).
Add(dCachedCreationTokensWithRatio)
completionQuota := dCompletionTokens.Mul(dCompletionRatio) completionQuota := dCompletionTokens.Mul(dCompletionRatio)
@ -314,22 +326,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
} else { } else {
quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio) quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
} }
var dGeminiImageOutputQuota decimal.Decimal
var imageOutputPrice float64
if strings.HasPrefix(modelName, "gemini-2.5-flash-image-preview") {
imageOutputPrice = operation_setting.GetGeminiImageOutputPricePerMillionTokens(modelName)
if imageOutputPrice > 0 {
dImageOutputTokens := decimal.NewFromInt(int64(ctx.GetInt("gemini_image_tokens")))
dGeminiImageOutputQuota = decimal.NewFromFloat(imageOutputPrice).Div(decimal.NewFromInt(1000000)).Mul(dImageOutputTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
}
}
// 添加 responses tools call 调用的配额 // 添加 responses tools call 调用的配额
quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota) quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota) quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
// 添加 audio input 独立计费 // 添加 audio input 独立计费
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota) quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
// 添加 Gemini image output 计费
quotaCalculateDecimal = quotaCalculateDecimal.Add(dGeminiImageOutputQuota)
quota := int(quotaCalculateDecimal.Round(0).IntPart()) quota := int(quotaCalculateDecimal.Round(0).IntPart())
totalTokens := promptTokens + completionTokens totalTokens := promptTokens + completionTokens
@ -395,6 +396,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
other["image_ratio"] = imageRatio other["image_ratio"] = imageRatio
other["image_output"] = imageTokens other["image_output"] = imageTokens
} }
if cachedCreationTokens != 0 {
other["cache_creation_tokens"] = cachedCreationTokens
other["cache_creation_ratio"] = cachedCreationRatio
}
if !dWebSearchQuota.IsZero() { if !dWebSearchQuota.IsZero() {
if relayInfo.ResponsesUsageInfo != nil { if relayInfo.ResponsesUsageInfo != nil {
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists { if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists {
@ -424,10 +429,6 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
other["audio_input_token_count"] = audioTokens other["audio_input_token_count"] = audioTokens
other["audio_input_price"] = audioInputPrice other["audio_input_price"] = audioInputPrice
} }
if !dGeminiImageOutputQuota.IsZero() {
other["image_output_token_count"] = ctx.GetInt("gemini_image_tokens")
other["image_output_price"] = imageOutputPrice
}
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId, ChannelId: relayInfo.ChannelId,
PromptTokens: promptTokens, PromptTokens: promptTokens,

View File

@ -58,7 +58,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
if resp != nil { if resp != nil {
httpResp = resp.(*http.Response) httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK { if httpResp.StatusCode != http.StatusOK {
newAPIError = service.RelayErrorHandler(httpResp, false) newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr) service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError return newAPIError

View File

@ -152,7 +152,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
httpResp = resp.(*http.Response) httpResp = resp.(*http.Response)
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK { if httpResp.StatusCode != http.StatusOK {
newAPIError = service.RelayErrorHandler(httpResp, false) newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr) service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError return newAPIError
@ -249,7 +249,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
if resp != nil { if resp != nil {
httpResp = resp.(*http.Response) httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK { if httpResp.StatusCode != http.StatusOK {
newAPIError = service.RelayErrorHandler(httpResp, false) newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
service.ResetStatusCode(newAPIError, statusCodeMappingStr) service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError return newAPIError
} }

View File

@ -91,7 +91,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
httpResp = resp.(*http.Response) httpResp = resp.(*http.Response)
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK { if httpResp.StatusCode != http.StatusOK {
newAPIError = service.RelayErrorHandler(httpResp, false) newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr) service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError return newAPIError
@ -120,7 +120,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
var logContent string var logContent string
if len(request.Size) > 0 { if len(request.Size) > 0 {
logContent = fmt.Sprintf("大小 %s, 品质 %s", request.Size, quality) logContent = fmt.Sprintf("大小 %s, 品质 %s, 张数 %d", request.Size, quality, request.N)
} }
postConsumeQuota(c, info, usage.(*dto.Usage), logContent) postConsumeQuota(c, info, usage.(*dto.Usage), logContent)

View File

@ -16,6 +16,7 @@ import (
"one-api/relay/helper" "one-api/relay/helper"
"one-api/service" "one-api/service"
"one-api/setting" "one-api/setting"
"one-api/setting/system_setting"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -131,7 +132,7 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
midjourneyTask.FinishTime = originTask.FinishTime midjourneyTask.FinishTime = originTask.FinishTime
midjourneyTask.ImageUrl = "" midjourneyTask.ImageUrl = ""
if originTask.ImageUrl != "" && setting.MjForwardUrlEnabled { if originTask.ImageUrl != "" && setting.MjForwardUrlEnabled {
midjourneyTask.ImageUrl = setting.ServerAddress + "/mj/image/" + originTask.MjId midjourneyTask.ImageUrl = system_setting.ServerAddress + "/mj/image/" + originTask.MjId
if originTask.Status != "SUCCESS" { if originTask.Status != "SUCCESS" {
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10) midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
} }

View File

@ -1,7 +1,6 @@
package relay package relay
import ( import (
"github.com/gin-gonic/gin"
"one-api/constant" "one-api/constant"
"one-api/relay/channel" "one-api/relay/channel"
"one-api/relay/channel/ali" "one-api/relay/channel/ali"
@ -28,6 +27,7 @@ import (
taskjimeng "one-api/relay/channel/task/jimeng" taskjimeng "one-api/relay/channel/task/jimeng"
"one-api/relay/channel/task/kling" "one-api/relay/channel/task/kling"
"one-api/relay/channel/task/suno" "one-api/relay/channel/task/suno"
taskvertex "one-api/relay/channel/task/vertex"
taskVidu "one-api/relay/channel/task/vidu" taskVidu "one-api/relay/channel/task/vidu"
"one-api/relay/channel/tencent" "one-api/relay/channel/tencent"
"one-api/relay/channel/vertex" "one-api/relay/channel/vertex"
@ -37,7 +37,12 @@ import (
"one-api/relay/channel/zhipu" "one-api/relay/channel/zhipu"
"one-api/relay/channel/zhipu_4v" "one-api/relay/channel/zhipu_4v"
"strconv" "strconv"
<<<<<<< HEAD
"one-api/relay/channel/submodel" "one-api/relay/channel/submodel"
=======
"github.com/gin-gonic/gin"
>>>>>>> 4f760a8d407d321bf7f011331ecffb2744b555fd
) )
func GetAdaptor(apiType int) channel.Adaptor { func GetAdaptor(apiType int) channel.Adaptor {
@ -129,6 +134,8 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
return &kling.TaskAdaptor{} return &kling.TaskAdaptor{}
case constant.ChannelTypeJimeng: case constant.ChannelTypeJimeng:
return &taskjimeng.TaskAdaptor{} return &taskjimeng.TaskAdaptor{}
case constant.ChannelTypeVertexAi:
return &taskvertex.TaskAdaptor{}
case constant.ChannelTypeVidu: case constant.ChannelTypeVidu:
return &taskVidu.TaskAdaptor{} return &taskVidu.TaskAdaptor{}
} }

View File

@ -15,6 +15,8 @@ import (
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/service" "one-api/service"
"one-api/setting/ratio_setting" "one-api/setting/ratio_setting"
"strconv"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@ -33,6 +35,7 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
platform = GetTaskPlatform(c) platform = GetTaskPlatform(c)
} }
info.InitChannelMeta(c)
adaptor := GetTaskAdaptor(platform) adaptor := GetTaskAdaptor(platform)
if adaptor == nil { if adaptor == nil {
return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest) return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
@ -197,6 +200,9 @@ func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
if taskErr != nil { if taskErr != nil {
return taskErr return taskErr
} }
if len(respBody) == 0 {
respBody = []byte("{\"code\":\"success\",\"data\":null}")
}
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
_, err := io.Copy(c.Writer, bytes.NewBuffer(respBody)) _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
@ -276,10 +282,92 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
return return
} }
respBody, err = json.Marshal(dto.TaskResponse[any]{ func() {
Code: "success", channelModel, err2 := model.GetChannelById(originTask.ChannelId, true)
Data: TaskModel2Dto(originTask), if err2 != nil {
}) return
}
if channelModel.Type != constant.ChannelTypeVertexAi {
return
}
baseURL := constant.ChannelBaseURLs[channelModel.Type]
if channelModel.GetBaseURL() != "" {
baseURL = channelModel.GetBaseURL()
}
adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
if adaptor == nil {
return
}
resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
"task_id": originTask.TaskID,
"action": originTask.Action,
})
if err2 != nil || resp == nil {
return
}
defer resp.Body.Close()
body, err2 := io.ReadAll(resp.Body)
if err2 != nil {
return
}
ti, err2 := adaptor.ParseTaskResult(body)
if err2 == nil && ti != nil {
if ti.Status != "" {
originTask.Status = model.TaskStatus(ti.Status)
}
if ti.Progress != "" {
originTask.Progress = ti.Progress
}
if ti.Url != "" {
originTask.FailReason = ti.Url
}
_ = originTask.Update()
var raw map[string]any
_ = json.Unmarshal(body, &raw)
format := "mp4"
if respObj, ok := raw["response"].(map[string]any); ok {
if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 {
if v0, ok := vids[0].(map[string]any); ok {
if mt, ok := v0["mimeType"].(string); ok && mt != "" {
if strings.Contains(mt, "mp4") {
format = "mp4"
} else {
format = mt
}
}
}
}
}
status := "processing"
switch originTask.Status {
case model.TaskStatusSuccess:
status = "succeeded"
case model.TaskStatusFailure:
status = "failed"
case model.TaskStatusQueued, model.TaskStatusSubmitted:
status = "queued"
}
out := map[string]any{
"error": nil,
"format": format,
"metadata": nil,
"status": status,
"task_id": originTask.TaskID,
"url": originTask.FailReason,
}
respBody, _ = json.Marshal(dto.TaskResponse[any]{
Code: "success",
Data: out,
})
}
}()
if len(respBody) == 0 {
respBody, err = json.Marshal(dto.TaskResponse[any]{
Code: "success",
Data: TaskModel2Dto(originTask),
})
}
return return
} }

View File

@ -81,7 +81,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
if resp != nil { if resp != nil {
httpResp = resp.(*http.Response) httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK { if httpResp.StatusCode != http.StatusOK {
newAPIError = service.RelayErrorHandler(httpResp, false) newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr) service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError return newAPIError

View File

@ -41,7 +41,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
} }
adaptor.Init(info) adaptor.Init(info)
var requestBody io.Reader var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled { if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c) body, err := common.GetRequestBody(c)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry()) return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry())
@ -82,7 +82,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
httpResp = resp.(*http.Response) httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK { if httpResp.StatusCode != http.StatusOK {
newAPIError = service.RelayErrorHandler(httpResp, false) newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr) service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError return newAPIError

View File

@ -60,6 +60,7 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.DELETE("/self", controller.DeleteSelf) selfRoute.DELETE("/self", controller.DeleteSelf)
selfRoute.GET("/token", controller.GenerateAccessToken) selfRoute.GET("/token", controller.GenerateAccessToken)
selfRoute.GET("/aff", controller.GetAffCode) selfRoute.GET("/aff", controller.GetAffCode)
selfRoute.GET("/topup/info", controller.GetTopUpInfo)
selfRoute.POST("/topup", middleware.CriticalRateLimit(), controller.TopUp) selfRoute.POST("/topup", middleware.CriticalRateLimit(), controller.TopUp)
selfRoute.POST("/pay", middleware.CriticalRateLimit(), controller.RequestEpay) selfRoute.POST("/pay", middleware.CriticalRateLimit(), controller.RequestEpay)
selfRoute.POST("/amount", controller.RequestAmount) selfRoute.POST("/amount", controller.RequestAmount)

View File

@ -6,7 +6,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/setting" "one-api/setting/system_setting"
"strings" "strings"
) )
@ -21,14 +21,14 @@ type WorkerRequest struct {
// DoWorkerRequest 通过Worker发送请求 // DoWorkerRequest 通过Worker发送请求
func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) { func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
if !setting.EnableWorker() { if !system_setting.EnableWorker() {
return nil, fmt.Errorf("worker not enabled") return nil, fmt.Errorf("worker not enabled")
} }
if !setting.WorkerAllowHttpImageRequestEnabled && !strings.HasPrefix(req.URL, "https") { if !system_setting.WorkerAllowHttpImageRequestEnabled && !strings.HasPrefix(req.URL, "https") {
return nil, fmt.Errorf("only support https url") return nil, fmt.Errorf("only support https url")
} }
workerUrl := setting.WorkerUrl workerUrl := system_setting.WorkerUrl
if !strings.HasSuffix(workerUrl, "/") { if !strings.HasSuffix(workerUrl, "/") {
workerUrl += "/" workerUrl += "/"
} }
@ -43,11 +43,11 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
} }
func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, err error) { func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, err error) {
if setting.EnableWorker() { if system_setting.EnableWorker() {
common.SysLog(fmt.Sprintf("downloading file from worker: %s, reason: %s", originUrl, strings.Join(reason, ", "))) common.SysLog(fmt.Sprintf("downloading file from worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
req := &WorkerRequest{ req := &WorkerRequest{
URL: originUrl, URL: originUrl,
Key: setting.WorkerValidKey, Key: system_setting.WorkerValidKey,
} }
return DoWorkerRequest(req) return DoWorkerRequest(req)
} else { } else {

View File

@ -1,12 +1,13 @@
package service package service
import ( import (
"one-api/setting" "one-api/setting/operation_setting"
"one-api/setting/system_setting"
) )
func GetCallbackAddress() string { func GetCallbackAddress() string {
if setting.CustomCallbackAddress == "" { if operation_setting.CustomCallbackAddress == "" {
return setting.ServerAddress return system_setting.ServerAddress
} }
return setting.CustomCallbackAddress return operation_setting.CustomCallbackAddress
} }

View File

@ -1,12 +1,14 @@
package service package service
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
"one-api/logger"
"one-api/types" "one-api/types"
"strconv" "strconv"
"strings" "strings"
@ -78,7 +80,7 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude
return claudeErr return claudeErr
} }
func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) { func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode) newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
@ -94,7 +96,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t
newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)) newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
} else { } else {
if common.DebugEnabled { if common.DebugEnabled {
println(fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))) logger.LogInfo(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
} }
newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode) newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
} }

View File

@ -13,13 +13,13 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) { func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
if preConsumedQuota != 0 { if relayInfo.FinalPreConsumedQuota != 0 {
logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota))) logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(relayInfo.FinalPreConsumedQuota)))
gopool.Go(func() { gopool.Go(func() {
relayInfoCopy := *relayInfo relayInfoCopy := *relayInfo
err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false) err := PostConsumeQuota(&relayInfoCopy, -relayInfo.FinalPreConsumedQuota, 0, false)
if err != nil { if err != nil {
common.SysLog("error return pre-consumed quota: " + err.Error()) common.SysLog("error return pre-consumed quota: " + err.Error())
} }
@ -29,16 +29,16 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, pr
// PreConsumeQuota checks if the user has enough quota to pre-consume. // PreConsumeQuota checks if the user has enough quota to pre-consume.
// It returns the pre-consumed quota if successful, or an error if not. // It returns the pre-consumed quota if successful, or an error if not.
func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *types.NewAPIError) { func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError {
userQuota, err := model.GetUserQuota(relayInfo.UserId, false) userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil { if err != nil {
return 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
} }
if userQuota <= 0 { if userQuota <= 0 {
return 0, types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) return types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
} }
if userQuota-preConsumedQuota < 0 { if userQuota-preConsumedQuota < 0 {
return 0, types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) return types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
} }
trustQuota := common.GetTrustQuota() trustQuota := common.GetTrustQuota()
@ -65,14 +65,14 @@ func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if preConsumedQuota > 0 { if preConsumedQuota > 0 {
err := PreConsumeTokenQuota(relayInfo, preConsumedQuota) err := PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil { if err != nil {
return 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
} }
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota) err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil { if err != nil {
return 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
} }
logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota))) logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota)))
} }
relayInfo.FinalPreConsumedQuota = preConsumedQuota relayInfo.FinalPreConsumedQuota = preConsumedQuota
return preConsumedQuota, nil return nil
} }

View File

@ -11,8 +11,8 @@ import (
"one-api/logger" "one-api/logger"
"one-api/model" "one-api/model"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/setting"
"one-api/setting/ratio_setting" "one-api/setting/ratio_setting"
"one-api/setting/system_setting"
"one-api/types" "one-api/types"
"strings" "strings"
"time" "time"
@ -534,7 +534,7 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon
} }
if quotaTooLow { if quotaTooLow {
prompt := "您的额度即将用尽" prompt := "您的额度即将用尽"
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress) topUpLink := fmt.Sprintf("%s/topup", system_setting.ServerAddress)
// 根据通知方式生成不同的内容格式 // 根据通知方式生成不同的内容格式
var content string var content string

View File

@ -336,7 +336,7 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
for i, file := range meta.Files { for i, file := range meta.Files {
switch file.FileType { switch file.FileType {
case types.FileTypeImage: case types.FileTypeImage:
if info.RelayFormat == types.RelayFormatGemini && !strings.HasPrefix(model, "gemini-2.5-flash-image-preview") { if info.RelayFormat == types.RelayFormatGemini {
tkm += 256 tkm += 256
} else { } else {
token, err := getImageToken(file, model, info.IsStream) token, err := getImageToken(file, model, info.IsStream)

View File

@ -7,7 +7,7 @@ import (
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
"one-api/model" "one-api/model"
"one-api/setting" "one-api/setting/system_setting"
"strings" "strings"
) )
@ -91,11 +91,11 @@ func sendBarkNotify(barkURL string, data dto.Notify) error {
var resp *http.Response var resp *http.Response
var err error var err error
if setting.EnableWorker() { if system_setting.EnableWorker() {
// 使用worker发送请求 // 使用worker发送请求
workerReq := &WorkerRequest{ workerReq := &WorkerRequest{
URL: finalURL, URL: finalURL,
Key: setting.WorkerValidKey, Key: system_setting.WorkerValidKey,
Method: http.MethodGet, Method: http.MethodGet,
Headers: map[string]string{ Headers: map[string]string{
"User-Agent": "OneAPI-Bark-Notify/1.0", "User-Agent": "OneAPI-Bark-Notify/1.0",

View File

@ -9,7 +9,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"one-api/dto" "one-api/dto"
"one-api/setting" "one-api/setting/system_setting"
"time" "time"
) )
@ -56,11 +56,11 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error
var req *http.Request var req *http.Request
var resp *http.Response var resp *http.Response
if setting.EnableWorker() { if system_setting.EnableWorker() {
// 构建worker请求数据 // 构建worker请求数据
workerReq := &WorkerRequest{ workerReq := &WorkerRequest{
URL: webhookURL, URL: webhookURL,
Key: setting.WorkerValidKey, Key: system_setting.WorkerValidKey,
Method: http.MethodPost, Method: http.MethodPost,
Headers: map[string]string{ Headers: map[string]string{
"Content-Type": "application/json", "Content-Type": "application/json",

View File

@ -26,7 +26,6 @@ var defaultGeminiSettings = GeminiSettings{
SupportedImagineModels: []string{ SupportedImagineModels: []string{
"gemini-2.0-flash-exp-image-generation", "gemini-2.0-flash-exp-image-generation",
"gemini-2.0-flash-exp", "gemini-2.0-flash-exp",
"gemini-2.5-flash-image-preview",
}, },
ThinkingAdapterEnabled: false, ThinkingAdapterEnabled: false,
ThinkingAdapterBudgetTokensPercentage: 0.6, ThinkingAdapterBudgetTokensPercentage: 0.6,

View File

@ -0,0 +1,23 @@
package operation_setting
import "one-api/setting/config"
type PaymentSetting struct {
AmountOptions []int `json:"amount_options"`
AmountDiscount map[int]float64 `json:"amount_discount"` // 充值金额对应的折扣,例如 100 元 0.9 表示 100 元充值享受 9 折优惠
}
// 默认配置
var paymentSetting = PaymentSetting{
AmountOptions: []int{10, 20, 50, 100, 200, 500},
AmountDiscount: map[int]float64{},
}
func init() {
// 注册到全局配置管理器
config.GlobalConfig.Register("payment_setting", &paymentSetting)
}
func GetPaymentSetting() *PaymentSetting {
return &paymentSetting
}

View File

@ -1,6 +1,13 @@
package setting /**
此文件为旧版支付设置文件如需增加新的参数变量等请在 payment_setting.go 中添加
This file is the old version of the payment settings file. If you need to add new parameters, variables, etc., please add them in payment_setting.go
*/
import "encoding/json" package operation_setting
import (
"one-api/common"
)
var PayAddress = "" var PayAddress = ""
var CustomCallbackAddress = "" var CustomCallbackAddress = ""
@ -21,15 +28,21 @@ var PayMethods = []map[string]string{
"color": "rgba(var(--semi-green-5), 1)", "color": "rgba(var(--semi-green-5), 1)",
"type": "wxpay", "type": "wxpay",
}, },
{
"name": "自定义1",
"color": "black",
"type": "custom1",
"min_topup": "50",
},
} }
func UpdatePayMethodsByJsonString(jsonString string) error { func UpdatePayMethodsByJsonString(jsonString string) error {
PayMethods = make([]map[string]string, 0) PayMethods = make([]map[string]string, 0)
return json.Unmarshal([]byte(jsonString), &PayMethods) return common.Unmarshal([]byte(jsonString), &PayMethods)
} }
func PayMethods2JsonString() string { func PayMethods2JsonString() string {
jsonBytes, err := json.Marshal(PayMethods) jsonBytes, err := common.Marshal(PayMethods)
if err != nil { if err != nil {
return "[]" return "[]"
} }

View File

@ -24,10 +24,6 @@ const (
ClaudeWebSearchPrice = 10.00 ClaudeWebSearchPrice = 10.00
) )
const (
Gemini25FlashImagePreviewImageOutputPrice = 30.00
)
func GetClaudeWebSearchPricePerThousand() float64 { func GetClaudeWebSearchPricePerThousand() float64 {
return ClaudeWebSearchPrice return ClaudeWebSearchPrice
} }
@ -69,10 +65,3 @@ func GetGeminiInputAudioPricePerMillionTokens(modelName string) float64 {
} }
return 0 return 0
} }
func GetGeminiImageOutputPricePerMillionTokens(modelName string) float64 {
if strings.HasPrefix(modelName, "gemini-2.5-flash-image-preview") {
return Gemini25FlashImagePreviewImageOutputPrice
}
return 0
}

View File

@ -178,7 +178,6 @@ var defaultModelRatio = map[string]float64{
"gemini-2.5-flash-lite-preview-thinking-*": 0.05, "gemini-2.5-flash-lite-preview-thinking-*": 0.05,
"gemini-2.5-flash-lite-preview-06-17": 0.05, "gemini-2.5-flash-lite-preview-06-17": 0.05,
"gemini-2.5-flash": 0.15, "gemini-2.5-flash": 0.15,
"gemini-2.5-flash-image-preview": 0.15, // $0.30text/image) / 1M tokens
"text-embedding-004": 0.001, "text-embedding-004": 0.001,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
@ -305,11 +304,10 @@ var (
) )
var defaultCompletionRatio = map[string]float64{ var defaultCompletionRatio = map[string]float64{
"gpt-4-gizmo-*": 2, "gpt-4-gizmo-*": 2,
"gpt-4o-gizmo-*": 3, "gpt-4o-gizmo-*": 3,
"gpt-4-all": 2, "gpt-4-all": 2,
"gpt-image-1": 8, "gpt-image-1": 8,
"gemini-2.5-flash-image-preview": 8.3333333333,
} }
// InitRatioSettings initializes all model related settings maps // InitRatioSettings initializes all model related settings maps

View File

@ -1,4 +1,4 @@
package setting package system_setting
var ServerAddress = "http://localhost:3000" var ServerAddress = "http://localhost:3000"
var WorkerUrl = "" var WorkerUrl = ""

View File

@ -185,6 +185,14 @@ func (e *NewAPIError) ToClaudeError() ClaudeError {
type NewAPIErrorOptions func(*NewAPIError) type NewAPIErrorOptions func(*NewAPIError)
func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPIError { func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPIError {
var newErr *NewAPIError
// 保留深层传递的 new err
if errors.As(err, &newErr) {
for _, op := range ops {
op(newErr)
}
return newErr
}
e := &NewAPIError{ e := &NewAPIError{
Err: err, Err: err,
RelayError: nil, RelayError: nil,
@ -199,8 +207,21 @@ func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPI
} }
func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
if errorCode == ErrorCodeDoRequestFailed { var newErr *NewAPIError
err = errors.New("upstream error: do request failed") // 保留深层传递的 new err
if errors.As(err, &newErr) {
if newErr.RelayError == nil {
openaiError := OpenAIError{
Message: newErr.Error(),
Type: string(errorCode),
Code: errorCode,
}
newErr.RelayError = openaiError
}
for _, op := range ops {
op(newErr)
}
return newErr
} }
openaiError := OpenAIError{ openaiError := OpenAIError{
Message: err.Error(), Message: err.Error(),
@ -305,6 +326,15 @@ func ErrOptionWithNoRecordErrorLog() NewAPIErrorOptions {
} }
} }
func ErrOptionWithHideErrMsg(replaceStr string) NewAPIErrorOptions {
return func(e *NewAPIError) {
if common.DebugEnabled {
fmt.Printf("ErrOptionWithHideErrMsg: %s, origin error: %s", replaceStr, e.Err)
}
e.Err = errors.New(replaceStr)
}
}
func IsRecordErrorLog(e *NewAPIError) bool { func IsRecordErrorLog(e *NewAPIError) bool {
if e == nil { if e == nil {
return false return false

View File

@ -135,7 +135,7 @@ const TwoFactorAuthModal = ({
autoFocus autoFocus
/> />
<Typography.Text type='tertiary' size='small' className='mt-2 block'> <Typography.Text type='tertiary' size='small' className='mt-2 block'>
{t('支持6位TOTP验证码或8位备用码')} {t('支持6位TOTP验证码或8位备用码,可到`个人设置-安全设置-两步验证设置`配置或查看。')}
</Typography.Text> </Typography.Text>
</div> </div>
</div> </div>

View File

@ -37,6 +37,8 @@ const PaymentSetting = () => {
TopupGroupRatio: '', TopupGroupRatio: '',
CustomCallbackAddress: '', CustomCallbackAddress: '',
PayMethods: '', PayMethods: '',
AmountOptions: '',
AmountDiscount: '',
StripeApiSecret: '', StripeApiSecret: '',
StripeWebhookSecret: '', StripeWebhookSecret: '',
@ -66,6 +68,30 @@ const PaymentSetting = () => {
newInputs[item.key] = item.value; newInputs[item.key] = item.value;
} }
break; break;
case 'payment_setting.amount_options':
try {
newInputs['AmountOptions'] = JSON.stringify(
JSON.parse(item.value),
null,
2,
);
} catch (error) {
console.error('解析AmountOptions出错:', error);
newInputs['AmountOptions'] = item.value;
}
break;
case 'payment_setting.amount_discount':
try {
newInputs['AmountDiscount'] = JSON.stringify(
JSON.parse(item.value),
null,
2,
);
} catch (error) {
console.error('解析AmountDiscount出错:', error);
newInputs['AmountDiscount'] = item.value;
}
break;
case 'Price': case 'Price':
case 'MinTopUp': case 'MinTopUp':
case 'StripeUnitPrice': case 'StripeUnitPrice':

View File

@ -142,6 +142,8 @@ const EditChannelModal = (props) => {
system_prompt: '', system_prompt: '',
system_prompt_override: false, system_prompt_override: false,
settings: '', settings: '',
// Vertex: settings.vertex_key_type
vertex_key_type: 'json',
}; };
const [batch, setBatch] = useState(false); const [batch, setBatch] = useState(false);
const [multiToSingle, setMultiToSingle] = useState(false); const [multiToSingle, setMultiToSingle] = useState(false);
@ -409,11 +411,17 @@ const EditChannelModal = (props) => {
const parsedSettings = JSON.parse(data.settings); const parsedSettings = JSON.parse(data.settings);
data.azure_responses_version = data.azure_responses_version =
parsedSettings.azure_responses_version || ''; parsedSettings.azure_responses_version || '';
// Vertex
data.vertex_key_type = parsedSettings.vertex_key_type || 'json';
} catch (error) { } catch (error) {
console.error('解析其他设置失败:', error); console.error('解析其他设置失败:', error);
data.azure_responses_version = ''; data.azure_responses_version = '';
data.region = ''; data.region = '';
data.vertex_key_type = 'json';
} }
} else {
// settings json
data.vertex_key_type = 'json';
} }
setInputs(data); setInputs(data);
@ -745,59 +753,56 @@ const EditChannelModal = (props) => {
let localInputs = { ...formValues }; let localInputs = { ...formValues };
if (localInputs.type === 41) { if (localInputs.type === 41) {
if (useManualInput) { const keyType = localInputs.vertex_key_type || 'json';
// if (keyType === 'api_key') {
if (localInputs.key && localInputs.key.trim() !== '') { //
try { if (!isEdit && (!localInputs.key || localInputs.key.trim() === '')) {
// JSON
const parsedKey = JSON.parse(localInputs.key);
//
localInputs.key = JSON.stringify(parsedKey);
} catch (err) {
showError(t('密钥格式无效,请输入有效的 JSON 格式密钥'));
return;
}
} else if (!isEdit) {
showInfo(t('请输入密钥!')); showInfo(t('请输入密钥!'));
return; return;
} }
} else { } else {
// // JSON
let keys = vertexKeys; if (useManualInput) {
if (localInputs.key && localInputs.key.trim() !== '') {
// try {
if (keys.length === 0 && vertexFileList.length > 0) { const parsedKey = JSON.parse(localInputs.key);
try { localInputs.key = JSON.stringify(parsedKey);
const parsed = await Promise.all( } catch (err) {
vertexFileList.map(async (item) => { showError(t('密钥格式无效,请输入有效的 JSON 格式密钥'));
const fileObj = item.fileInstance; return;
if (!fileObj) return null; }
const txt = await fileObj.text(); } else if (!isEdit) {
return JSON.parse(txt); showInfo(t('请输入密钥!'));
}),
);
keys = parsed.filter(Boolean);
} catch (err) {
showError(t('解析密钥文件失败: {{msg}}', { msg: err.message }));
return; return;
} }
}
//
if (keys.length === 0) {
if (!isEdit) {
showInfo(t('请上传密钥文件!'));
return;
} else {
// key
delete localInputs.key;
}
} else { } else {
// //
if (batch) { let keys = vertexKeys;
localInputs.key = JSON.stringify(keys); if (keys.length === 0 && vertexFileList.length > 0) {
try {
const parsed = await Promise.all(
vertexFileList.map(async (item) => {
const fileObj = item.fileInstance;
if (!fileObj) return null;
const txt = await fileObj.text();
return JSON.parse(txt);
}),
);
keys = parsed.filter(Boolean);
} catch (err) {
showError(t('解析密钥文件失败: {{msg}}', { msg: err.message }));
return;
}
}
if (keys.length === 0) {
if (!isEdit) {
showInfo(t('请上传密钥文件!'));
return;
} else {
delete localInputs.key;
}
} else { } else {
localInputs.key = JSON.stringify(keys[0]); localInputs.key = batch ? JSON.stringify(keys) : JSON.stringify(keys[0]);
} }
} }
} }
@ -853,6 +858,8 @@ const EditChannelModal = (props) => {
delete localInputs.pass_through_body_enabled; delete localInputs.pass_through_body_enabled;
delete localInputs.system_prompt; delete localInputs.system_prompt;
delete localInputs.system_prompt_override; delete localInputs.system_prompt_override;
// vertex_key_type
delete localInputs.vertex_key_type;
let res; let res;
localInputs.auto_ban = localInputs.auto_ban ? 1 : 0; localInputs.auto_ban = localInputs.auto_ban ? 1 : 0;
@ -1178,8 +1185,40 @@ const EditChannelModal = (props) => {
autoComplete='new-password' autoComplete='new-password'
/> />
{inputs.type === 41 && (
<Form.Select
field='vertex_key_type'
label={t('密钥格式')}
placeholder={t('请选择密钥格式')}
optionList={[
{ label: 'JSON', value: 'json' },
{ label: 'API Key', value: 'api_key' },
]}
style={{ width: '100%' }}
value={inputs.vertex_key_type || 'json'}
onChange={(value) => {
// vertex_key_type
handleChannelOtherSettingsChange('vertex_key_type', value);
// api_key /
if (value === 'api_key') {
setBatch(false);
setUseManualInput(false);
setVertexKeys([]);
setVertexFileList([]);
if (formApiRef.current) {
formApiRef.current.setValue('vertex_files', []);
}
}
}}
extraText={
inputs.vertex_key_type === 'api_key'
? t('API Key 模式下不支持批量创建')
: t('JSON 模式支持手动输入或上传服务账号 JSON')
}
/>
)}
{batch ? ( {batch ? (
inputs.type === 41 ? ( inputs.type === 41 && (inputs.vertex_key_type || 'json') === 'json' ? (
<Form.Upload <Form.Upload
field='vertex_files' field='vertex_files'
label={t('密钥文件 (.json)')} label={t('密钥文件 (.json)')}
@ -1243,7 +1282,7 @@ const EditChannelModal = (props) => {
) )
) : ( ) : (
<> <>
{inputs.type === 41 ? ( {inputs.type === 41 && (inputs.vertex_key_type || 'json') === 'json' ? (
<> <>
{!batch && ( {!batch && (
<div className='flex items-center justify-between mb-3'> <div className='flex items-center justify-between mb-3'>

View File

@ -21,6 +21,7 @@ import React, { useRef } from 'react';
import { import {
Avatar, Avatar,
Typography, Typography,
Tag,
Card, Card,
Button, Button,
Banner, Banner,
@ -29,7 +30,7 @@ import {
Space, Space,
Row, Row,
Col, Col,
Spin, Spin, Tooltip
} from '@douyinfe/semi-ui'; } from '@douyinfe/semi-ui';
import { SiAlipay, SiWechat, SiStripe } from 'react-icons/si'; import { SiAlipay, SiWechat, SiStripe } from 'react-icons/si';
import { CreditCard, Coins, Wallet, BarChart2, TrendingUp } from 'lucide-react'; import { CreditCard, Coins, Wallet, BarChart2, TrendingUp } from 'lucide-react';
@ -68,6 +69,7 @@ const RechargeCard = ({
userState, userState,
renderQuota, renderQuota,
statusLoading, statusLoading,
topupInfo,
}) => { }) => {
const onlineFormApiRef = useRef(null); const onlineFormApiRef = useRef(null);
const redeemFormApiRef = useRef(null); const redeemFormApiRef = useRef(null);
@ -261,44 +263,58 @@ const RechargeCard = ({
</Col> </Col>
<Col xs={24} sm={24} md={24} lg={14} xl={14}> <Col xs={24} sm={24} md={24} lg={14} xl={14}>
<Form.Slot label={t('选择支付方式')}> <Form.Slot label={t('选择支付方式')}>
<Space wrap> {payMethods && payMethods.length > 0 ? (
{payMethods.map((payMethod) => ( <Space wrap>
<Button {payMethods.map((payMethod) => {
key={payMethod.type} const minTopupVal = Number(payMethod.min_topup) || 0;
theme='outline' const isStripe = payMethod.type === 'stripe';
type='tertiary' const disabled =
onClick={() => preTopUp(payMethod.type)} (!enableOnlineTopUp && !isStripe) ||
disabled={ (!enableStripeTopUp && isStripe) ||
(!enableOnlineTopUp && minTopupVal > Number(topUpCount || 0);
payMethod.type !== 'stripe') ||
(!enableStripeTopUp && const buttonEl = (
payMethod.type === 'stripe') <Button
} key={payMethod.type}
loading={ theme='outline'
paymentLoading && payWay === payMethod.type type='tertiary'
} onClick={() => preTopUp(payMethod.type)}
icon={ disabled={disabled}
payMethod.type === 'alipay' ? ( loading={paymentLoading && payWay === payMethod.type}
<SiAlipay size={18} color='#1677FF' /> icon={
) : payMethod.type === 'wxpay' ? ( payMethod.type === 'alipay' ? (
<SiWechat size={18} color='#07C160' /> <SiAlipay size={18} color='#1677FF' />
) : payMethod.type === 'stripe' ? ( ) : payMethod.type === 'wxpay' ? (
<SiStripe size={18} color='#635BFF' /> <SiWechat size={18} color='#07C160' />
) : ( ) : payMethod.type === 'stripe' ? (
<CreditCard <SiStripe size={18} color='#635BFF' />
size={18} ) : (
color={ <CreditCard
payMethod.color || size={18}
'var(--semi-color-text-2)' color={payMethod.color || 'var(--semi-color-text-2)'}
} />
/> )
) }
} className='!rounded-lg !px-4 !py-2'
> >
{payMethod.name} {payMethod.name}
</Button> </Button>
))} );
</Space>
return disabled && minTopupVal > Number(topUpCount || 0) ? (
<Tooltip content={t('此支付方式最低充值金额为') + ' ' + minTopupVal} key={payMethod.type}>
{buttonEl}
</Tooltip>
) : (
<React.Fragment key={payMethod.type}>{buttonEl}</React.Fragment>
);
})}
</Space>
) : (
<div className='text-gray-500 text-sm p-3 bg-gray-50 rounded-lg border border-dashed border-gray-300'>
{t('暂无可用的支付方式,请联系管理员配置')}
</div>
)}
</Form.Slot> </Form.Slot>
</Col> </Col>
</Row> </Row>
@ -306,41 +322,60 @@ const RechargeCard = ({
{(enableOnlineTopUp || enableStripeTopUp) && ( {(enableOnlineTopUp || enableStripeTopUp) && (
<Form.Slot label={t('选择充值额度')}> <Form.Slot label={t('选择充值额度')}>
<Space wrap> <div className='grid grid-cols-2 sm:grid-cols-3 md:grid-cols-4 gap-2'>
{presetAmounts.map((preset, index) => ( {presetAmounts.map((preset, index) => {
<Button const discount = preset.discount || topupInfo?.discount?.[preset.value] || 1.0;
key={index} const originalPrice = preset.value * priceRatio;
theme={ const discountedPrice = originalPrice * discount;
selectedPreset === preset.value const hasDiscount = discount < 1.0;
? 'solid' const actualPay = discountedPrice;
: 'outline' const save = originalPrice - discountedPrice;
}
type={ return (
selectedPreset === preset.value <Card
? 'primary' key={index}
: 'tertiary' style={{
} cursor: 'pointer',
onClick={() => { border: selectedPreset === preset.value
selectPresetAmount(preset); ? '2px solid var(--semi-color-primary)'
onlineFormApiRef.current?.setValue( : '1px solid var(--semi-color-border)',
'topUpCount', height: '100%',
preset.value, width: '100%'
); }}
}} bodyStyle={{ padding: '12px' }}
className='!rounded-lg !py-2 !px-3' onClick={() => {
> selectPresetAmount(preset);
<div className='flex items-center gap-2'> onlineFormApiRef.current?.setValue(
<Coins size={14} className='opacity-80' /> 'topUpCount',
<span className='font-medium'> preset.value,
{formatLargeNumber(preset.value)} );
</span> }}
<span className='text-xs text-gray-500'> >
{(preset.value * priceRatio).toFixed(2)} <div style={{ textAlign: 'center' }}>
</span> <Typography.Title heading={6} style={{ margin: '0 0 8px 0' }}>
</div> <Coins size={18} />
</Button> {formatLargeNumber(preset.value)}
))} {hasDiscount && (
</Space> <Tag style={{ marginLeft: 4 }} color="green">
{t('折').includes('off') ?
((1 - parseFloat(discount)) * 100).toFixed(1) :
(discount * 10).toFixed(1)}{t('折')}
</Tag>
)}
</Typography.Title>
<div style={{
color: 'var(--semi-color-text-2)',
fontSize: '12px',
margin: '4px 0'
}}>
{t('实付')} {actualPay.toFixed(2)}
{hasDiscount ? `${t('节省')} ${save.toFixed(2)}` : `${t('节省')} 0.00`}
</div>
</div>
</Card>
);
})}
</div>
</Form.Slot> </Form.Slot>
)} )}
</div> </div>

View File

@ -81,6 +81,12 @@ const TopUp = () => {
const [presetAmounts, setPresetAmounts] = useState([]); const [presetAmounts, setPresetAmounts] = useState([]);
const [selectedPreset, setSelectedPreset] = useState(null); const [selectedPreset, setSelectedPreset] = useState(null);
//
const [topupInfo, setTopupInfo] = useState({
amount_options: [],
discount: {}
});
const topUp = async () => { const topUp = async () => {
if (redemptionCode === '') { if (redemptionCode === '') {
showInfo(t('请输入兑换码!')); showInfo(t('请输入兑换码!'));
@ -248,6 +254,99 @@ const TopUp = () => {
} }
}; };
//
const getTopupInfo = async () => {
try {
const res = await API.get('/api/user/topup/info');
const { message, data, success } = res.data;
if (success) {
setTopupInfo({
amount_options: data.amount_options || [],
discount: data.discount || {}
});
//
let payMethods = data.pay_methods || [];
try {
if (typeof payMethods === 'string') {
payMethods = JSON.parse(payMethods);
}
if (payMethods && payMethods.length > 0) {
// nametype
payMethods = payMethods.filter((method) => {
return method.name && method.type;
});
// color
payMethods = payMethods.map((method) => {
//
const normalizedMinTopup = Number(method.min_topup);
method.min_topup = Number.isFinite(normalizedMinTopup) ? normalizedMinTopup : 0;
// Stripe
if (method.type === 'stripe' && (!method.min_topup || method.min_topup <= 0)) {
const stripeMin = Number(data.stripe_min_topup);
if (Number.isFinite(stripeMin)) {
method.min_topup = stripeMin;
}
}
if (!method.color) {
if (method.type === 'alipay') {
method.color = 'rgba(var(--semi-blue-5), 1)';
} else if (method.type === 'wxpay') {
method.color = 'rgba(var(--semi-green-5), 1)';
} else if (method.type === 'stripe') {
method.color = 'rgba(var(--semi-purple-5), 1)';
} else {
method.color = 'rgba(var(--semi-primary-5), 1)';
}
}
return method;
});
} else {
payMethods = [];
}
// Stripe
// Stripe pay_methods
setPayMethods(payMethods);
const enableStripeTopUp = data.enable_stripe_topup || false;
const enableOnlineTopUp = data.enable_online_topup || false;
const minTopUpValue = enableOnlineTopUp? data.min_topup : enableStripeTopUp? data.stripe_min_topup : 1;
setEnableOnlineTopUp(enableOnlineTopUp);
setEnableStripeTopUp(enableStripeTopUp);
setMinTopUp(minTopUpValue);
setTopUpCount(minTopUpValue);
//
if (topupInfo.amount_options.length === 0) {
setPresetAmounts(generatePresetAmounts(minTopUpValue));
}
//
getAmount(minTopUpValue);
} catch (e) {
console.log('解析支付方式失败:', e);
setPayMethods([]);
}
// 使
if (data.amount_options && data.amount_options.length > 0) {
const customPresets = data.amount_options.map(amount => ({
value: amount,
discount: data.discount[amount] || 1.0
}));
setPresetAmounts(customPresets);
}
} else {
console.error('获取充值配置失败:', data);
}
} catch (error) {
console.error('获取充值配置异常:', error);
}
};
// //
const getAffLink = async () => { const getAffLink = async () => {
const res = await API.get('/api/user/aff'); const res = await API.get('/api/user/aff');
@ -290,52 +389,7 @@ const TopUp = () => {
getUserQuota().then(); getUserQuota().then();
} }
setTransferAmount(getQuotaPerUnit()); setTransferAmount(getQuotaPerUnit());
}, []);
let payMethods = localStorage.getItem('pay_methods');
try {
payMethods = JSON.parse(payMethods);
if (payMethods && payMethods.length > 0) {
// nametype
payMethods = payMethods.filter((method) => {
return method.name && method.type;
});
// color
payMethods = payMethods.map((method) => {
if (!method.color) {
if (method.type === 'alipay') {
method.color = 'rgba(var(--semi-blue-5), 1)';
} else if (method.type === 'wxpay') {
method.color = 'rgba(var(--semi-green-5), 1)';
} else if (method.type === 'stripe') {
method.color = 'rgba(var(--semi-purple-5), 1)';
} else {
method.color = 'rgba(var(--semi-primary-5), 1)';
}
}
return method;
});
} else {
payMethods = [];
}
// Stripe
if (statusState?.status?.enable_stripe_topup) {
const hasStripe = payMethods.some((method) => method.type === 'stripe');
if (!hasStripe) {
payMethods.push({
name: 'Stripe',
type: 'stripe',
color: 'rgba(var(--semi-purple-5), 1)',
});
}
}
setPayMethods(payMethods);
} catch (e) {
console.log(e);
showError(t('支付方式配置错误, 请联系管理员'));
}
}, [statusState?.status?.enable_stripe_topup]);
useEffect(() => { useEffect(() => {
if (affFetchedRef.current) return; if (affFetchedRef.current) return;
@ -343,20 +397,18 @@ const TopUp = () => {
getAffLink().then(); getAffLink().then();
}, []); }, []);
// statusState
useEffect(() => {
getTopupInfo().then();
}, []);
useEffect(() => { useEffect(() => {
if (statusState?.status) { if (statusState?.status) {
const minTopUpValue = statusState.status.min_topup || 1; // const minTopUpValue = statusState.status.min_topup || 1;
setMinTopUp(minTopUpValue); // setMinTopUp(minTopUpValue);
setTopUpCount(minTopUpValue); // setTopUpCount(minTopUpValue);
setTopUpLink(statusState.status.top_up_link || ''); setTopUpLink(statusState.status.top_up_link || '');
setEnableOnlineTopUp(statusState.status.enable_online_topup || false);
setPriceRatio(statusState.status.price || 1); setPriceRatio(statusState.status.price || 1);
setEnableStripeTopUp(statusState.status.enable_stripe_topup || false);
//
setPresetAmounts(generatePresetAmounts(minTopUpValue));
//
getAmount(minTopUpValue);
setStatusLoading(false); setStatusLoading(false);
} }
@ -431,7 +483,11 @@ const TopUp = () => {
const selectPresetAmount = (preset) => { const selectPresetAmount = (preset) => {
setTopUpCount(preset.value); setTopUpCount(preset.value);
setSelectedPreset(preset.value); setSelectedPreset(preset.value);
setAmount(preset.value * priceRatio);
//
const discount = preset.discount || topupInfo.discount[preset.value] || 1.0;
const discountedAmount = preset.value * priceRatio * discount;
setAmount(discountedAmount);
}; };
// //
@ -475,6 +531,8 @@ const TopUp = () => {
renderAmount={renderAmount} renderAmount={renderAmount}
payWay={payWay} payWay={payWay}
payMethods={payMethods} payMethods={payMethods}
amountNumber={amount}
discountRate={topupInfo?.discount?.[topUpCount] || 1.0}
/> />
{/* 用户信息头部 */} {/* 用户信息头部 */}
@ -512,6 +570,7 @@ const TopUp = () => {
userState={userState} userState={userState}
renderQuota={renderQuota} renderQuota={renderQuota}
statusLoading={statusLoading} statusLoading={statusLoading}
topupInfo={topupInfo}
/> />
</div> </div>

View File

@ -36,7 +36,13 @@ const PaymentConfirmModal = ({
renderAmount, renderAmount,
payWay, payWay,
payMethods, payMethods,
//
amountNumber,
discountRate,
}) => { }) => {
const hasDiscount = discountRate && discountRate > 0 && discountRate < 1 && amountNumber > 0;
const originalAmount = hasDiscount ? (amountNumber / discountRate) : 0;
const discountAmount = hasDiscount ? (originalAmount - amountNumber) : 0;
return ( return (
<Modal <Modal
title={ title={
@ -71,11 +77,38 @@ const PaymentConfirmModal = ({
{amountLoading ? ( {amountLoading ? (
<Skeleton.Title style={{ width: '60px', height: '16px' }} /> <Skeleton.Title style={{ width: '60px', height: '16px' }} />
) : ( ) : (
<Text strong className='font-bold' style={{ color: 'red' }}> <div className='flex items-baseline space-x-2'>
{renderAmount()} <Text strong className='font-bold' style={{ color: 'red' }}>
</Text> {renderAmount()}
</Text>
{hasDiscount && (
<Text size='small' className='text-rose-500'>
{Math.round(discountRate * 100)}%
</Text>
)}
</div>
)} )}
</div> </div>
{hasDiscount && !amountLoading && (
<>
<div className='flex justify-between items-center'>
<Text className='text-slate-500 dark:text-slate-400'>
{t('原价')}
</Text>
<Text delete className='text-slate-500 dark:text-slate-400'>
{`${originalAmount.toFixed(2)} ${t('元')}`}
</Text>
</div>
<div className='flex justify-between items-center'>
<Text className='text-slate-500 dark:text-slate-400'>
{t('优惠')}
</Text>
<Text className='text-emerald-600 dark:text-emerald-400'>
{`- ${discountAmount.toFixed(2)} ${t('元')}`}
</Text>
</div>
</>
)}
<div className='flex justify-between items-center'> <div className='flex justify-between items-center'>
<Text strong className='text-slate-700 dark:text-slate-200'> <Text strong className='text-slate-700 dark:text-slate-200'>
{t('支付方式')} {t('支付方式')}

View File

@ -28,7 +28,6 @@ export function setStatusData(data) {
localStorage.setItem('enable_task', data.enable_task); localStorage.setItem('enable_task', data.enable_task);
localStorage.setItem('enable_data_export', data.enable_data_export); localStorage.setItem('enable_data_export', data.enable_data_export);
localStorage.setItem('chats', JSON.stringify(data.chats)); localStorage.setItem('chats', JSON.stringify(data.chats));
localStorage.setItem('pay_methods', JSON.stringify(data.pay_methods));
localStorage.setItem( localStorage.setItem(
'data_export_default_time', 'data_export_default_time',
data.data_export_default_time, data.data_export_default_time,

View File

@ -1017,7 +1017,7 @@ export function renderModelPrice(
cacheRatio = 1.0, cacheRatio = 1.0,
image = false, image = false,
imageRatio = 1.0, imageRatio = 1.0,
imageInputTokens = 0, imageOutputTokens = 0,
webSearch = false, webSearch = false,
webSearchCallCount = 0, webSearchCallCount = 0,
webSearchPrice = 0, webSearchPrice = 0,
@ -1027,8 +1027,6 @@ export function renderModelPrice(
audioInputSeperatePrice = false, audioInputSeperatePrice = false,
audioInputTokens = 0, audioInputTokens = 0,
audioInputPrice = 0, audioInputPrice = 0,
imageOutputTokens = 0,
imageOutputPrice = 0,
) { ) {
const { ratio: effectiveGroupRatio, label: ratioLabel } = getEffectiveRatio( const { ratio: effectiveGroupRatio, label: ratioLabel } = getEffectiveRatio(
groupRatio, groupRatio,
@ -1059,9 +1057,9 @@ export function renderModelPrice(
let effectiveInputTokens = let effectiveInputTokens =
inputTokens - cacheTokens + cacheTokens * cacheRatio; inputTokens - cacheTokens + cacheTokens * cacheRatio;
// Handle image tokens if present // Handle image tokens if present
if (image && imageInputTokens > 0) { if (image && imageOutputTokens > 0) {
effectiveInputTokens = effectiveInputTokens =
inputTokens - imageInputTokens + imageInputTokens * imageRatio; inputTokens - imageOutputTokens + imageOutputTokens * imageRatio;
} }
if (audioInputTokens > 0) { if (audioInputTokens > 0) {
effectiveInputTokens -= audioInputTokens; effectiveInputTokens -= audioInputTokens;
@ -1071,8 +1069,7 @@ export function renderModelPrice(
(audioInputTokens / 1000000) * audioInputPrice * groupRatio + (audioInputTokens / 1000000) * audioInputPrice * groupRatio +
(completionTokens / 1000000) * completionRatioPrice * groupRatio + (completionTokens / 1000000) * completionRatioPrice * groupRatio +
(webSearchCallCount / 1000) * webSearchPrice * groupRatio + (webSearchCallCount / 1000) * webSearchPrice * groupRatio +
(fileSearchCallCount / 1000) * fileSearchPrice * groupRatio + (fileSearchCallCount / 1000) * fileSearchPrice * groupRatio;
(imageOutputTokens / 1000000) * imageOutputPrice * groupRatio;
return ( return (
<> <>
@ -1107,7 +1104,7 @@ export function renderModelPrice(
)} )}
</p> </p>
)} )}
{image && imageInputTokens > 0 && ( {image && imageOutputTokens > 0 && (
<p> <p>
{i18next.t( {i18next.t(
'图片输入价格:${{price}} * {{ratio}} = ${{total}} / 1M tokens (图片倍率: {{imageRatio}})', '图片输入价格:${{price}} * {{ratio}} = ${{total}} / 1M tokens (图片倍率: {{imageRatio}})',
@ -1134,26 +1131,17 @@ export function renderModelPrice(
})} })}
</p> </p>
)} )}
{imageOutputPrice > 0 && imageOutputTokens > 0 && (
<p>
{i18next.t('图片输出价格:${{price}} * 分组倍率{{ratio}} = ${{total}} / 1M tokens', {
price: imageOutputPrice,
ratio: groupRatio,
total: imageOutputPrice * groupRatio,
})}
</p>
)}
<p></p> <p></p>
<p> <p>
{(() => { {(() => {
// //
let inputDesc = ''; let inputDesc = '';
if (image && imageInputTokens > 0) { if (image && imageOutputTokens > 0) {
inputDesc = i18next.t( inputDesc = i18next.t(
'(输入 {{nonImageInput}} tokens + 图片输入 {{imageInput}} tokens * {{imageRatio}} / 1M tokens * ${{price}}', '(输入 {{nonImageInput}} tokens + 图片输入 {{imageInput}} tokens * {{imageRatio}} / 1M tokens * ${{price}}',
{ {
nonImageInput: inputTokens - imageInputTokens, nonImageInput: inputTokens - imageOutputTokens,
imageInput: imageInputTokens, imageInput: imageOutputTokens,
imageRatio: imageRatio, imageRatio: imageRatio,
price: inputRatioPrice, price: inputRatioPrice,
}, },
@ -1223,16 +1211,6 @@ export function renderModelPrice(
}, },
) )
: '', : '',
imageOutputPrice > 0 && imageOutputTokens > 0
? i18next.t(
' + 图片输出 {{tokenCounts}} tokens * ${{price}} / 1M tokens * 分组倍率{{ratio}}',
{
tokenCounts: imageOutputTokens,
price: imageOutputPrice,
ratio: groupRatio,
},
)
: '',
].join(''); ].join('');
return i18next.t( return i18next.t(

View File

@ -447,8 +447,6 @@ export const useLogsData = () => {
other?.audio_input_seperate_price || false, other?.audio_input_seperate_price || false,
other?.audio_input_token_count || 0, other?.audio_input_token_count || 0,
other?.audio_input_price || 0, other?.audio_input_price || 0,
other?.image_output_token_count || 0,
other?.image_output_price || 0,
); );
} }
expandDataLocal.push({ expandDataLocal.push({

View File

@ -1993,7 +1993,7 @@
"安全验证": "Security verification", "安全验证": "Security verification",
"验证": "Verify", "验证": "Verify",
"为了保护账户安全,请验证您的两步验证码。": "To protect account security, please verify your two-factor authentication code.", "为了保护账户安全,请验证您的两步验证码。": "To protect account security, please verify your two-factor authentication code.",
"支持6位TOTP验证码或8位备用码": "Supports 6-digit TOTP verification code or 8-digit backup code", "支持6位TOTP验证码或8位备用码,可到`个人设置-安全设置-两步验证设置`配置或查看。": "Supports 6-digit TOTP verification code or 8-digit backup code, can be configured or viewed in `Personal Settings - Security Settings - Two-Factor Authentication Settings`.",
"获取密钥失败": "Failed to get key", "获取密钥失败": "Failed to get key",
"查看密钥": "View key", "查看密钥": "View key",
"查看渠道密钥": "View channel key", "查看渠道密钥": "View channel key",
@ -2080,5 +2080,9 @@
"官方": "Official", "官方": "Official",
"模型社区需要大家的共同维护,如发现数据有误或想贡献新的模型数据,请访问:": "The model community needs everyone's contribution. If you find incorrect data or want to contribute new models, please visit:", "模型社区需要大家的共同维护,如发现数据有误或想贡献新的模型数据,请访问:": "The model community needs everyone's contribution. If you find incorrect data or want to contribute new models, please visit:",
"是": "Yes", "是": "Yes",
"否": "No" "否": "No",
"原价": "Original price",
"优惠": "Discount",
"折": "% off",
"节省": "Save"
} }

View File

@ -130,17 +130,19 @@ export default function GeneralSettings(props) {
showClear showClear
/> />
</Col> </Col>
<Col xs={24} sm={12} md={8} lg={8} xl={8}> {inputs.QuotaPerUnit !== '500000' && inputs.QuotaPerUnit !== 500000 && (
<Form.Input <Col xs={24} sm={12} md={8} lg={8} xl={8}>
field={'QuotaPerUnit'} <Form.Input
label={t('单位美元额度')} field={'QuotaPerUnit'}
initValue={''} label={t('单位美元额度')}
placeholder={t('一单位货币能兑换的额度')} initValue={''}
onChange={handleFieldChange('QuotaPerUnit')} placeholder={t('一单位货币能兑换的额度')}
showClear onChange={handleFieldChange('QuotaPerUnit')}
onClick={() => setShowQuotaWarning(true)} showClear
/> onClick={() => setShowQuotaWarning(true)}
</Col> />
</Col>
)}
<Col xs={24} sm={12} md={8} lg={8} xl={8}> <Col xs={24} sm={12} md={8} lg={8} xl={8}>
<Form.Input <Form.Input
field={'USDExchangeRate'} field={'USDExchangeRate'}
@ -194,7 +196,7 @@ export default function GeneralSettings(props) {
/> />
</Col> </Col>
</Row> </Row>
<Row> <Row gutter={16}>
<Col xs={24} sm={12} md={8} lg={8} xl={8}> <Col xs={24} sm={12} md={8} lg={8} xl={8}>
<Form.Switch <Form.Switch
field={'DemoSiteEnabled'} field={'DemoSiteEnabled'}

View File

@ -304,7 +304,7 @@ export default function SettingsHeaderNavModules(props) {
headerNavModules.pricing?.requireAuth || false headerNavModules.pricing?.requireAuth || false
} }
onChange={handlePricingAuthChange} onChange={handlePricingAuthChange}
size='small' size='default'
/> />
</div> </div>
</div> </div>

View File

@ -41,6 +41,8 @@ export default function SettingsPaymentGateway(props) {
TopupGroupRatio: '', TopupGroupRatio: '',
CustomCallbackAddress: '', CustomCallbackAddress: '',
PayMethods: '', PayMethods: '',
AmountOptions: '',
AmountDiscount: '',
}); });
const [originInputs, setOriginInputs] = useState({}); const [originInputs, setOriginInputs] = useState({});
const formApiRef = useRef(null); const formApiRef = useRef(null);
@ -62,7 +64,30 @@ export default function SettingsPaymentGateway(props) {
TopupGroupRatio: props.options.TopupGroupRatio || '', TopupGroupRatio: props.options.TopupGroupRatio || '',
CustomCallbackAddress: props.options.CustomCallbackAddress || '', CustomCallbackAddress: props.options.CustomCallbackAddress || '',
PayMethods: props.options.PayMethods || '', PayMethods: props.options.PayMethods || '',
AmountOptions: props.options.AmountOptions || '',
AmountDiscount: props.options.AmountDiscount || '',
}; };
// JSON
try {
if (currentInputs.AmountOptions) {
currentInputs.AmountOptions = JSON.stringify(
JSON.parse(currentInputs.AmountOptions),
null,
2,
);
}
} catch {}
try {
if (currentInputs.AmountDiscount) {
currentInputs.AmountDiscount = JSON.stringify(
JSON.parse(currentInputs.AmountDiscount),
null,
2,
);
}
} catch {}
setInputs(currentInputs); setInputs(currentInputs);
setOriginInputs({ ...currentInputs }); setOriginInputs({ ...currentInputs });
formApiRef.current.setValues(currentInputs); formApiRef.current.setValues(currentInputs);
@ -93,6 +118,20 @@ export default function SettingsPaymentGateway(props) {
} }
} }
if (originInputs['AmountOptions'] !== inputs.AmountOptions && inputs.AmountOptions.trim() !== '') {
if (!verifyJSON(inputs.AmountOptions)) {
showError(t('自定义充值数量选项不是合法的 JSON 数组'));
return;
}
}
if (originInputs['AmountDiscount'] !== inputs.AmountDiscount && inputs.AmountDiscount.trim() !== '') {
if (!verifyJSON(inputs.AmountDiscount)) {
showError(t('充值金额折扣配置不是合法的 JSON 对象'));
return;
}
}
setLoading(true); setLoading(true);
try { try {
const options = [ const options = [
@ -123,6 +162,12 @@ export default function SettingsPaymentGateway(props) {
if (originInputs['PayMethods'] !== inputs.PayMethods) { if (originInputs['PayMethods'] !== inputs.PayMethods) {
options.push({ key: 'PayMethods', value: inputs.PayMethods }); options.push({ key: 'PayMethods', value: inputs.PayMethods });
} }
if (originInputs['AmountOptions'] !== inputs.AmountOptions) {
options.push({ key: 'payment_setting.amount_options', value: inputs.AmountOptions });
}
if (originInputs['AmountDiscount'] !== inputs.AmountDiscount) {
options.push({ key: 'payment_setting.amount_discount', value: inputs.AmountDiscount });
}
// //
const requestQueue = options.map((opt) => const requestQueue = options.map((opt) =>
@ -228,6 +273,37 @@ export default function SettingsPaymentGateway(props) {
placeholder={t('为一个 JSON 文本')} placeholder={t('为一个 JSON 文本')}
autosize autosize
/> />
<Row
gutter={{ xs: 8, sm: 16, md: 24, lg: 24, xl: 24, xxl: 24 }}
style={{ marginTop: 16 }}
>
<Col span={24}>
<Form.TextArea
field='AmountOptions'
label={t('自定义充值数量选项')}
placeholder={t('为一个 JSON 数组,例如:[10, 20, 50, 100, 200, 500]')}
autosize
extraText={t('设置用户可选择的充值数量选项,例如:[10, 20, 50, 100, 200, 500]')}
/>
</Col>
</Row>
<Row
gutter={{ xs: 8, sm: 16, md: 24, lg: 24, xl: 24, xxl: 24 }}
style={{ marginTop: 16 }}
>
<Col span={24}>
<Form.TextArea
field='AmountDiscount'
label={t('充值金额折扣配置')}
placeholder={t('为一个 JSON 对象,例如:{"100": 0.95, "200": 0.9, "500": 0.85}')}
autosize
extraText={t('设置不同充值金额对应的折扣,键为充值金额,值为折扣率,例如:{"100": 0.95, "200": 0.9, "500": 0.85}')}
/>
</Col>
</Row>
<Button onClick={submitPayAddress}>{t('更新支付设置')}</Button> <Button onClick={submitPayAddress}>{t('更新支付设置')}</Button>
</Form.Section> </Form.Section>
</Form> </Form>