Merge branch 'alpha' into imageratio-and-audioratio-edit
This commit is contained in:
commit
d21886b9fb
@ -56,8 +56,6 @@
|
|||||||
# SESSION_SECRET=random_string
|
# SESSION_SECRET=random_string
|
||||||
|
|
||||||
# 其他配置
|
# 其他配置
|
||||||
# 渠道测试频率(单位:秒)
|
|
||||||
# CHANNEL_TEST_FREQUENCY=10
|
|
||||||
# 生成默认token
|
# 生成默认token
|
||||||
# GENERATE_DEFAULT_TOKEN=false
|
# GENERATE_DEFAULT_TOKEN=false
|
||||||
# Cohere 安全设置
|
# Cohere 安全设置
|
||||||
|
|||||||
21
.github/workflows/pr-target-branch-check.yml
vendored
21
.github/workflows/pr-target-branch-check.yml
vendored
@ -1,21 +0,0 @@
|
|||||||
name: Check PR Branching Strategy
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
types: [opened, synchronize, reopened, edited]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
check-branching-strategy:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Enforce branching strategy
|
|
||||||
run: |
|
|
||||||
if [[ "${{ github.base_ref }}" == "main" ]]; then
|
|
||||||
if [[ "${{ github.head_ref }}" != "alpha" ]]; then
|
|
||||||
echo "Error: Pull requests to 'main' are only allowed from the 'alpha' branch."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
elif [[ "${{ github.base_ref }}" != "alpha" ]]; then
|
|
||||||
echo "Error: Pull requests must be targeted to the 'alpha' or 'main' branch."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
echo "Branching strategy check passed."
|
|
||||||
@ -96,7 +96,11 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do
|
|||||||
- 添加后缀 `-thinking` 启用思考模式 (例如: `claude-3-7-sonnet-20250219-thinking`)
|
- 添加后缀 `-thinking` 启用思考模式 (例如: `claude-3-7-sonnet-20250219-thinking`)
|
||||||
16. 🔄 思考转内容功能
|
16. 🔄 思考转内容功能
|
||||||
17. 🔄 针对用户的模型限流功能
|
17. 🔄 针对用户的模型限流功能
|
||||||
18. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费:
|
18. 🔄 请求格式转换功能,支持以下三种格式转换:
|
||||||
|
1. OpenAI Chat Completions => Claude Messages
|
||||||
|
2. Clade Messages => OpenAI Chat Completions (可用于Claude Code调用第三方模型)
|
||||||
|
3. OpenAI Chat Completions => Gemini Chat
|
||||||
|
19. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费:
|
||||||
1. 在 `系统设置-运营设置` 中设置 `提示缓存倍率` 选项
|
1. 在 `系统设置-运营设置` 中设置 `提示缓存倍率` 选项
|
||||||
2. 在渠道中设置 `提示缓存倍率`,范围 0-1,例如设置为 0.5 表示缓存命中时按照 50% 计费
|
2. 在渠道中设置 `提示缓存倍率`,范围 0-1,例如设置为 0.5 表示缓存命中时按照 50% 计费
|
||||||
3. 支持的渠道:
|
3. 支持的渠道:
|
||||||
|
|||||||
@ -12,4 +12,4 @@ var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
|
|||||||
var UsingMySQL = false
|
var UsingMySQL = false
|
||||||
var UsingClickHouse = false
|
var UsingClickHouse = false
|
||||||
|
|
||||||
var SQLitePath = "one-api.db?_busy_timeout=30000"
|
var SQLitePath = "one-api.db?_busy_timeout=30000"
|
||||||
@ -123,8 +123,16 @@ func Interface2String(inter interface{}) string {
|
|||||||
return fmt.Sprintf("%d", inter.(int))
|
return fmt.Sprintf("%d", inter.(int))
|
||||||
case float64:
|
case float64:
|
||||||
return fmt.Sprintf("%f", inter.(float64))
|
return fmt.Sprintf("%f", inter.(float64))
|
||||||
|
case bool:
|
||||||
|
if inter.(bool) {
|
||||||
|
return "true"
|
||||||
|
} else {
|
||||||
|
return "false"
|
||||||
|
}
|
||||||
|
case nil:
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
return "Not Implemented"
|
return fmt.Sprintf("%v", inter)
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnescapeHTML(x string) interface{} {
|
func UnescapeHTML(x string) interface{} {
|
||||||
@ -257,32 +265,32 @@ func GetAudioDuration(ctx context.Context, filename string, ext string) (float64
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.Wrap(err, "failed to get audio duration")
|
return 0, errors.Wrap(err, "failed to get audio duration")
|
||||||
}
|
}
|
||||||
durationStr := string(bytes.TrimSpace(output))
|
durationStr := string(bytes.TrimSpace(output))
|
||||||
if durationStr == "N/A" {
|
if durationStr == "N/A" {
|
||||||
// Create a temporary output file name
|
// Create a temporary output file name
|
||||||
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.Wrap(err, "failed to create temporary file")
|
return 0, errors.Wrap(err, "failed to create temporary file")
|
||||||
}
|
}
|
||||||
tmpName := tmpFp.Name()
|
tmpName := tmpFp.Name()
|
||||||
// Close immediately so ffmpeg can open the file on Windows.
|
// Close immediately so ffmpeg can open the file on Windows.
|
||||||
_ = tmpFp.Close()
|
_ = tmpFp.Close()
|
||||||
defer os.Remove(tmpName)
|
defer os.Remove(tmpName)
|
||||||
|
|
||||||
// ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
|
// ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
|
||||||
ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
|
ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
|
||||||
if err := ffmpegCmd.Run(); err != nil {
|
if err := ffmpegCmd.Run(); err != nil {
|
||||||
return 0, errors.Wrap(err, "failed to run ffmpeg")
|
return 0, errors.Wrap(err, "failed to run ffmpeg")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recalculate the duration of the new file
|
// Recalculate the duration of the new file
|
||||||
c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
|
c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
|
||||||
output, err := c.Output()
|
output, err := c.Output()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
|
return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
|
||||||
}
|
}
|
||||||
durationStr = string(bytes.TrimSpace(output))
|
durationStr = string(bytes.TrimSpace(output))
|
||||||
}
|
}
|
||||||
return strconv.ParseFloat(durationStr, 64)
|
return strconv.ParseFloat(durationStr, 64)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,7 @@ import (
|
|||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
"one-api/setting/operation_setting"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
@ -342,7 +342,7 @@ func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
|
|||||||
return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
|
return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
|
||||||
}
|
}
|
||||||
availableBalanceCny := response.Data.AvailableBalance
|
availableBalanceCny := response.Data.AvailableBalance
|
||||||
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64()
|
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(operation_setting.Price)).InexactFloat64()
|
||||||
channel.UpdateBalance(availableBalanceUsd)
|
channel.UpdateBalance(availableBalanceUsd)
|
||||||
return availableBalanceUsd, nil
|
return availableBalanceUsd, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import (
|
|||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/setting/operation_setting"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -234,7 +235,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
|||||||
if resp != nil {
|
if resp != nil {
|
||||||
httpResp = resp.(*http.Response)
|
httpResp = resp.(*http.Response)
|
||||||
if httpResp.StatusCode != http.StatusOK {
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
err := service.RelayErrorHandler(httpResp, true)
|
err := service.RelayErrorHandler(c.Request.Context(), httpResp, true)
|
||||||
return testResult{
|
return testResult{
|
||||||
context: c,
|
context: c,
|
||||||
localErr: err,
|
localErr: err,
|
||||||
@ -477,15 +478,26 @@ func TestAllChannels(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func AutomaticallyTestChannels(frequency int) {
|
var autoTestChannelsOnce sync.Once
|
||||||
if frequency <= 0 {
|
|
||||||
common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
|
func AutomaticallyTestChannels() {
|
||||||
return
|
autoTestChannelsOnce.Do(func() {
|
||||||
}
|
for {
|
||||||
for {
|
if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled {
|
||||||
time.Sleep(time.Duration(frequency) * time.Minute)
|
time.Sleep(10 * time.Minute)
|
||||||
common.SysLog("testing all channels")
|
continue
|
||||||
_ = testAllChannels(false)
|
}
|
||||||
common.SysLog("channel test finished")
|
frequency := operation_setting.GetMonitorSetting().AutoTestChannelMinutes
|
||||||
}
|
common.SysLog(fmt.Sprintf("automatically test channels with interval %d minutes", frequency))
|
||||||
|
for {
|
||||||
|
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||||
|
common.SysLog("automatically testing all channels")
|
||||||
|
_ = testAllChannels(false)
|
||||||
|
common.SysLog("automatically channel test finished")
|
||||||
|
if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -560,7 +561,7 @@ func AddChannel(c *gin.Context) {
|
|||||||
case "multi_to_single":
|
case "multi_to_single":
|
||||||
addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
|
addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
|
||||||
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
|
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
|
||||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||||||
array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
|
array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@ -585,7 +586,7 @@ func AddChannel(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
keys = []string{addChannelRequest.Channel.Key}
|
keys = []string{addChannelRequest.Channel.Key}
|
||||||
case "batch":
|
case "batch":
|
||||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||||||
// multi json
|
// multi json
|
||||||
keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
|
keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -840,7 +841,7 @@ func UpdateChannel(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 处理 Vertex AI 的特殊情况
|
// 处理 Vertex AI 的特殊情况
|
||||||
if channel.Type == constant.ChannelTypeVertexAi {
|
if channel.Type == constant.ChannelTypeVertexAi && channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||||||
// 尝试解析新密钥为JSON数组
|
// 尝试解析新密钥为JSON数组
|
||||||
if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
|
if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
|
||||||
array, err := getVertexArrayKeys(channel.Key)
|
array, err := getVertexArrayKeys(channel.Key)
|
||||||
|
|||||||
@ -13,6 +13,7 @@ import (
|
|||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
|
"one-api/setting/system_setting"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@ -259,7 +260,7 @@ func GetAllMidjourney(c *gin.Context) {
|
|||||||
|
|
||||||
if setting.MjForwardUrlEnabled {
|
if setting.MjForwardUrlEnabled {
|
||||||
for i, midjourney := range items {
|
for i, midjourney := range items {
|
||||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
midjourney.ImageUrl = system_setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||||
items[i] = midjourney
|
items[i] = midjourney
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -284,7 +285,7 @@ func GetUserMidjourney(c *gin.Context) {
|
|||||||
|
|
||||||
if setting.MjForwardUrlEnabled {
|
if setting.MjForwardUrlEnabled {
|
||||||
for i, midjourney := range items {
|
for i, midjourney := range items {
|
||||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
midjourney.ImageUrl = system_setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||||
items[i] = midjourney
|
items[i] = midjourney
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -39,6 +39,8 @@ func TestStatus(c *gin.Context) {
|
|||||||
func GetStatus(c *gin.Context) {
|
func GetStatus(c *gin.Context) {
|
||||||
|
|
||||||
cs := console_setting.GetConsoleSetting()
|
cs := console_setting.GetConsoleSetting()
|
||||||
|
common.OptionMapRWMutex.RLock()
|
||||||
|
defer common.OptionMapRWMutex.RUnlock()
|
||||||
|
|
||||||
data := gin.H{
|
data := gin.H{
|
||||||
"version": common.Version,
|
"version": common.Version,
|
||||||
@ -56,11 +58,7 @@ func GetStatus(c *gin.Context) {
|
|||||||
"footer_html": common.Footer,
|
"footer_html": common.Footer,
|
||||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||||
"wechat_login": common.WeChatAuthEnabled,
|
"wechat_login": common.WeChatAuthEnabled,
|
||||||
"server_address": setting.ServerAddress,
|
"server_address": system_setting.ServerAddress,
|
||||||
"price": setting.Price,
|
|
||||||
"stripe_unit_price": setting.StripeUnitPrice,
|
|
||||||
"min_topup": setting.MinTopUp,
|
|
||||||
"stripe_min_topup": setting.StripeMinTopUp,
|
|
||||||
"turnstile_check": common.TurnstileCheckEnabled,
|
"turnstile_check": common.TurnstileCheckEnabled,
|
||||||
"turnstile_site_key": common.TurnstileSiteKey,
|
"turnstile_site_key": common.TurnstileSiteKey,
|
||||||
"top_up_link": common.TopUpLink,
|
"top_up_link": common.TopUpLink,
|
||||||
@ -73,15 +71,15 @@ func GetStatus(c *gin.Context) {
|
|||||||
"enable_data_export": common.DataExportEnabled,
|
"enable_data_export": common.DataExportEnabled,
|
||||||
"data_export_default_time": common.DataExportDefaultTime,
|
"data_export_default_time": common.DataExportDefaultTime,
|
||||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
|
||||||
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
|
|
||||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||||
"chats": setting.Chats,
|
"chats": setting.Chats,
|
||||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||||
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
||||||
"pay_methods": setting.PayMethods,
|
|
||||||
"usd_exchange_rate": setting.USDExchangeRate,
|
"usd_exchange_rate": operation_setting.USDExchangeRate,
|
||||||
|
"price": operation_setting.Price,
|
||||||
|
"stripe_unit_price": setting.StripeUnitPrice,
|
||||||
|
|
||||||
// 面板启用开关
|
// 面板启用开关
|
||||||
"api_info_enabled": cs.ApiInfoEnabled,
|
"api_info_enabled": cs.ApiInfoEnabled,
|
||||||
@ -89,6 +87,10 @@ func GetStatus(c *gin.Context) {
|
|||||||
"announcements_enabled": cs.AnnouncementsEnabled,
|
"announcements_enabled": cs.AnnouncementsEnabled,
|
||||||
"faq_enabled": cs.FAQEnabled,
|
"faq_enabled": cs.FAQEnabled,
|
||||||
|
|
||||||
|
// 模块管理配置
|
||||||
|
"HeaderNavModules": common.OptionMap["HeaderNavModules"],
|
||||||
|
"SidebarModulesAdmin": common.OptionMap["SidebarModulesAdmin"],
|
||||||
|
|
||||||
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
|
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
|
||||||
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
|
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
|
||||||
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
|
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
|
||||||
@ -247,7 +249,7 @@ func SendPasswordResetEmail(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
code := common.GenerateVerificationCode(0)
|
code := common.GenerateVerificationCode(0)
|
||||||
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
|
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
|
||||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", setting.ServerAddress, email, code)
|
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", system_setting.ServerAddress, email, code)
|
||||||
subject := fmt.Sprintf("%s密码重置", common.SystemName)
|
subject := fmt.Sprintf("%s密码重置", common.SystemName)
|
||||||
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
|
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
|
||||||
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
|
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
|
||||||
|
|||||||
@ -207,6 +207,7 @@ func ListModels(c *gin.Context, modelType int) {
|
|||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"data": userOpenAiModels,
|
"data": userOpenAiModels,
|
||||||
|
"object": "list",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
604
controller/model_sync.go
Normal file
604
controller/model_sync.go
Normal file
@ -0,0 +1,604 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 上游地址
|
||||||
|
const (
|
||||||
|
upstreamModelsURL = "https://basellm.github.io/llm-metadata/api/newapi/models.json"
|
||||||
|
upstreamVendorsURL = "https://basellm.github.io/llm-metadata/api/newapi/vendors.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
func normalizeLocale(locale string) (string, bool) {
|
||||||
|
l := strings.ToLower(strings.TrimSpace(locale))
|
||||||
|
switch l {
|
||||||
|
case "en", "zh", "ja":
|
||||||
|
return l, true
|
||||||
|
default:
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getUpstreamBase() string {
|
||||||
|
return common.GetEnvOrDefaultString("SYNC_UPSTREAM_BASE", "https://basellm.github.io/llm-metadata")
|
||||||
|
}
|
||||||
|
|
||||||
|
func getUpstreamURLs(locale string) (modelsURL, vendorsURL string) {
|
||||||
|
base := strings.TrimRight(getUpstreamBase(), "/")
|
||||||
|
if l, ok := normalizeLocale(locale); ok && l != "" {
|
||||||
|
return fmt.Sprintf("%s/api/i18n/%s/newapi/models.json", base, l),
|
||||||
|
fmt.Sprintf("%s/api/i18n/%s/newapi/vendors.json", base, l)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s/api/newapi/models.json", base), fmt.Sprintf("%s/api/newapi/vendors.json", base)
|
||||||
|
}
|
||||||
|
|
||||||
|
type upstreamEnvelope[T any] struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Data []T `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type upstreamModel struct {
|
||||||
|
Description string `json:"description"`
|
||||||
|
Endpoints json.RawMessage `json:"endpoints"`
|
||||||
|
Icon string `json:"icon"`
|
||||||
|
ModelName string `json:"model_name"`
|
||||||
|
NameRule int `json:"name_rule"`
|
||||||
|
Status int `json:"status"`
|
||||||
|
Tags string `json:"tags"`
|
||||||
|
VendorName string `json:"vendor_name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type upstreamVendor struct {
|
||||||
|
Description string `json:"description"`
|
||||||
|
Icon string `json:"icon"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Status int `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
etagCache = make(map[string]string)
|
||||||
|
bodyCache = make(map[string][]byte)
|
||||||
|
cacheMutex sync.RWMutex
|
||||||
|
)
|
||||||
|
|
||||||
|
type overwriteField struct {
|
||||||
|
ModelName string `json:"model_name"`
|
||||||
|
Fields []string `json:"fields"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type syncRequest struct {
|
||||||
|
Overwrite []overwriteField `json:"overwrite"`
|
||||||
|
Locale string `json:"locale"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTPClient() *http.Client {
|
||||||
|
timeoutSec := common.GetEnvOrDefault("SYNC_HTTP_TIMEOUT_SECONDS", 10)
|
||||||
|
dialer := &net.Dialer{Timeout: time.Duration(timeoutSec) * time.Second}
|
||||||
|
transport := &http.Transport{
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
TLSHandshakeTimeout: time.Duration(timeoutSec) * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
ResponseHeaderTimeout: time.Duration(timeoutSec) * time.Second,
|
||||||
|
}
|
||||||
|
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
host, _, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
host = addr
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(host, "github.io") {
|
||||||
|
if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil {
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
return dialer.DialContext(ctx, "tcp6", addr)
|
||||||
|
}
|
||||||
|
return dialer.DialContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
return &http.Client{Transport: transport}
|
||||||
|
}
|
||||||
|
|
||||||
|
var httpClient = newHTTPClient()
|
||||||
|
|
||||||
|
func fetchJSON[T any](ctx context.Context, url string, out *upstreamEnvelope[T]) error {
|
||||||
|
var lastErr error
|
||||||
|
attempts := common.GetEnvOrDefault("SYNC_HTTP_RETRY", 3)
|
||||||
|
if attempts < 1 {
|
||||||
|
attempts = 1
|
||||||
|
}
|
||||||
|
baseDelay := 200 * time.Millisecond
|
||||||
|
maxMB := common.GetEnvOrDefault("SYNC_HTTP_MAX_MB", 10)
|
||||||
|
maxBytes := int64(maxMB) << 20
|
||||||
|
for attempt := 0; attempt < attempts; attempt++ {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// ETag conditional request
|
||||||
|
cacheMutex.RLock()
|
||||||
|
if et := etagCache[url]; et != "" {
|
||||||
|
req.Header.Set("If-None-Match", et)
|
||||||
|
}
|
||||||
|
cacheMutex.RUnlock()
|
||||||
|
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
// backoff with jitter
|
||||||
|
sleep := baseDelay * time.Duration(1<<attempt)
|
||||||
|
jitter := time.Duration(rand.Intn(150)) * time.Millisecond
|
||||||
|
time.Sleep(sleep + jitter)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
func() {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
switch resp.StatusCode {
|
||||||
|
case http.StatusOK:
|
||||||
|
// read body into buffer for caching and flexible decode
|
||||||
|
limited := io.LimitReader(resp.Body, maxBytes)
|
||||||
|
buf, err := io.ReadAll(limited)
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// cache body and ETag
|
||||||
|
cacheMutex.Lock()
|
||||||
|
if et := resp.Header.Get("ETag"); et != "" {
|
||||||
|
etagCache[url] = et
|
||||||
|
}
|
||||||
|
bodyCache[url] = buf
|
||||||
|
cacheMutex.Unlock()
|
||||||
|
|
||||||
|
// Try decode as envelope first
|
||||||
|
if err := json.Unmarshal(buf, out); err != nil {
|
||||||
|
// Try decode as pure array
|
||||||
|
var arr []T
|
||||||
|
if err2 := json.Unmarshal(buf, &arr); err2 != nil {
|
||||||
|
lastErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out.Success = true
|
||||||
|
out.Data = arr
|
||||||
|
out.Message = ""
|
||||||
|
} else {
|
||||||
|
if !out.Success && len(out.Data) == 0 && out.Message == "" {
|
||||||
|
out.Success = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lastErr = nil
|
||||||
|
case http.StatusNotModified:
|
||||||
|
// use cache
|
||||||
|
cacheMutex.RLock()
|
||||||
|
buf := bodyCache[url]
|
||||||
|
cacheMutex.RUnlock()
|
||||||
|
if len(buf) == 0 {
|
||||||
|
lastErr = errors.New("cache miss for 304 response")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(buf, out); err != nil {
|
||||||
|
var arr []T
|
||||||
|
if err2 := json.Unmarshal(buf, &arr); err2 != nil {
|
||||||
|
lastErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out.Success = true
|
||||||
|
out.Data = arr
|
||||||
|
out.Message = ""
|
||||||
|
} else {
|
||||||
|
if !out.Success && len(out.Data) == 0 && out.Message == "" {
|
||||||
|
out.Success = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lastErr = nil
|
||||||
|
default:
|
||||||
|
lastErr = errors.New(resp.Status)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if lastErr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
sleep := baseDelay * time.Duration(1<<attempt)
|
||||||
|
jitter := time.Duration(rand.Intn(150)) * time.Millisecond
|
||||||
|
time.Sleep(sleep + jitter)
|
||||||
|
}
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureVendorID(vendorName string, vendorByName map[string]upstreamVendor, vendorIDCache map[string]int, createdVendors *int) int {
|
||||||
|
if vendorName == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if id, ok := vendorIDCache[vendorName]; ok {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
var existing model.Vendor
|
||||||
|
if err := model.DB.Where("name = ?", vendorName).First(&existing).Error; err == nil {
|
||||||
|
vendorIDCache[vendorName] = existing.Id
|
||||||
|
return existing.Id
|
||||||
|
}
|
||||||
|
uv := vendorByName[vendorName]
|
||||||
|
v := &model.Vendor{
|
||||||
|
Name: vendorName,
|
||||||
|
Description: uv.Description,
|
||||||
|
Icon: coalesce(uv.Icon, ""),
|
||||||
|
Status: chooseStatus(uv.Status, 1),
|
||||||
|
}
|
||||||
|
if err := v.Insert(); err == nil {
|
||||||
|
*createdVendors++
|
||||||
|
vendorIDCache[vendorName] = v.Id
|
||||||
|
return v.Id
|
||||||
|
}
|
||||||
|
vendorIDCache[vendorName] = 0
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyncUpstreamModels 同步上游模型与供应商,仅对「未配置模型」生效
|
||||||
|
func SyncUpstreamModels(c *gin.Context) {
|
||||||
|
var req syncRequest
|
||||||
|
// 允许空体
|
||||||
|
_ = c.ShouldBindJSON(&req)
|
||||||
|
// 1) 获取未配置模型列表
|
||||||
|
missing, err := model.GetMissingModels()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(missing) == 0 {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": true, "data": gin.H{
|
||||||
|
"created_models": 0,
|
||||||
|
"created_vendors": 0,
|
||||||
|
"skipped_models": []string{},
|
||||||
|
}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) 拉取上游 vendors 与 models
|
||||||
|
timeoutSec := common.GetEnvOrDefault("SYNC_HTTP_TIMEOUT_SECONDS", 15)
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(timeoutSec)*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
modelsURL, vendorsURL := getUpstreamURLs(req.Locale)
|
||||||
|
var vendorsEnv upstreamEnvelope[upstreamVendor]
|
||||||
|
var modelsEnv upstreamEnvelope[upstreamModel]
|
||||||
|
var fetchErr error
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
// vendor 失败不拦截
|
||||||
|
_ = fetchJSON(ctx, vendorsURL, &vendorsEnv)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := fetchJSON(ctx, modelsURL, &modelsEnv); err != nil {
|
||||||
|
fetchErr = err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
if fetchErr != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取上游模型失败: " + fetchErr.Error(), "locale": req.Locale, "source_urls": gin.H{"models_url": modelsURL, "vendors_url": vendorsURL}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 建立映射
|
||||||
|
vendorByName := make(map[string]upstreamVendor)
|
||||||
|
for _, v := range vendorsEnv.Data {
|
||||||
|
if v.Name != "" {
|
||||||
|
vendorByName[v.Name] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
modelByName := make(map[string]upstreamModel)
|
||||||
|
for _, m := range modelsEnv.Data {
|
||||||
|
if m.ModelName != "" {
|
||||||
|
modelByName[m.ModelName] = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3) 执行同步:仅创建缺失模型;若上游缺失该模型则跳过
|
||||||
|
createdModels := 0
|
||||||
|
createdVendors := 0
|
||||||
|
updatedModels := 0
|
||||||
|
var skipped []string
|
||||||
|
var createdList []string
|
||||||
|
var updatedList []string
|
||||||
|
|
||||||
|
// 本地缓存:vendorName -> id
|
||||||
|
vendorIDCache := make(map[string]int)
|
||||||
|
|
||||||
|
for _, name := range missing {
|
||||||
|
up, ok := modelByName[name]
|
||||||
|
if !ok {
|
||||||
|
skipped = append(skipped, name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 若本地已存在且设置为不同步,则跳过(极端情况:缺失列表与本地状态不同步时)
|
||||||
|
var existing model.Model
|
||||||
|
if err := model.DB.Where("model_name = ?", name).First(&existing).Error; err == nil {
|
||||||
|
if existing.SyncOfficial == 0 {
|
||||||
|
skipped = append(skipped, name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确保 vendor 存在
|
||||||
|
vendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors)
|
||||||
|
|
||||||
|
// 创建模型
|
||||||
|
mi := &model.Model{
|
||||||
|
ModelName: name,
|
||||||
|
Description: up.Description,
|
||||||
|
Icon: up.Icon,
|
||||||
|
Tags: up.Tags,
|
||||||
|
VendorID: vendorID,
|
||||||
|
Status: chooseStatus(up.Status, 1),
|
||||||
|
NameRule: up.NameRule,
|
||||||
|
}
|
||||||
|
if err := mi.Insert(); err == nil {
|
||||||
|
createdModels++
|
||||||
|
createdList = append(createdList, name)
|
||||||
|
} else {
|
||||||
|
skipped = append(skipped, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4) 处理可选覆盖(更新本地已有模型的差异字段)
|
||||||
|
if len(req.Overwrite) > 0 {
|
||||||
|
// vendorIDCache 已用于创建阶段,可复用
|
||||||
|
for _, ow := range req.Overwrite {
|
||||||
|
up, ok := modelByName[ow.ModelName]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var local model.Model
|
||||||
|
if err := model.DB.Where("model_name = ?", ow.ModelName).First(&local).Error; err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 跳过被禁用官方同步的模型
|
||||||
|
if local.SyncOfficial == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 映射 vendor
|
||||||
|
newVendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors)
|
||||||
|
|
||||||
|
// 应用字段覆盖(事务)
|
||||||
|
_ = model.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
needUpdate := false
|
||||||
|
if containsField(ow.Fields, "description") {
|
||||||
|
local.Description = up.Description
|
||||||
|
needUpdate = true
|
||||||
|
}
|
||||||
|
if containsField(ow.Fields, "icon") {
|
||||||
|
local.Icon = up.Icon
|
||||||
|
needUpdate = true
|
||||||
|
}
|
||||||
|
if containsField(ow.Fields, "tags") {
|
||||||
|
local.Tags = up.Tags
|
||||||
|
needUpdate = true
|
||||||
|
}
|
||||||
|
if containsField(ow.Fields, "vendor") {
|
||||||
|
local.VendorID = newVendorID
|
||||||
|
needUpdate = true
|
||||||
|
}
|
||||||
|
if containsField(ow.Fields, "name_rule") {
|
||||||
|
local.NameRule = up.NameRule
|
||||||
|
needUpdate = true
|
||||||
|
}
|
||||||
|
if containsField(ow.Fields, "status") {
|
||||||
|
local.Status = chooseStatus(up.Status, local.Status)
|
||||||
|
needUpdate = true
|
||||||
|
}
|
||||||
|
if !needUpdate {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := tx.Save(&local).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
updatedModels++
|
||||||
|
updatedList = append(updatedList, ow.ModelName)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"data": gin.H{
|
||||||
|
"created_models": createdModels,
|
||||||
|
"created_vendors": createdVendors,
|
||||||
|
"updated_models": updatedModels,
|
||||||
|
"skipped_models": skipped,
|
||||||
|
"created_list": createdList,
|
||||||
|
"updated_list": updatedList,
|
||||||
|
"source": gin.H{
|
||||||
|
"locale": req.Locale,
|
||||||
|
"models_url": modelsURL,
|
||||||
|
"vendors_url": vendorsURL,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsField(fields []string, key string) bool {
|
||||||
|
key = strings.ToLower(strings.TrimSpace(key))
|
||||||
|
for _, f := range fields {
|
||||||
|
if strings.ToLower(strings.TrimSpace(f)) == key {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func coalesce(a, b string) string {
|
||||||
|
if strings.TrimSpace(a) != "" {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func chooseStatus(primary, fallback int) int {
|
||||||
|
if primary == 0 && fallback != 0 {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
if primary != 0 {
|
||||||
|
return primary
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyncUpstreamPreview 预览上游与本地的差异(仅用于弹窗选择)
|
||||||
|
func SyncUpstreamPreview(c *gin.Context) {
|
||||||
|
// 1) 拉取上游数据
|
||||||
|
timeoutSec := common.GetEnvOrDefault("SYNC_HTTP_TIMEOUT_SECONDS", 15)
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(timeoutSec)*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
locale := c.Query("locale")
|
||||||
|
modelsURL, vendorsURL := getUpstreamURLs(locale)
|
||||||
|
|
||||||
|
var vendorsEnv upstreamEnvelope[upstreamVendor]
|
||||||
|
var modelsEnv upstreamEnvelope[upstreamModel]
|
||||||
|
var fetchErr error
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_ = fetchJSON(ctx, vendorsURL, &vendorsEnv)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := fetchJSON(ctx, modelsURL, &modelsEnv); err != nil {
|
||||||
|
fetchErr = err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
if fetchErr != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取上游模型失败: " + fetchErr.Error(), "locale": locale, "source_urls": gin.H{"models_url": modelsURL, "vendors_url": vendorsURL}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
vendorByName := make(map[string]upstreamVendor)
|
||||||
|
for _, v := range vendorsEnv.Data {
|
||||||
|
if v.Name != "" {
|
||||||
|
vendorByName[v.Name] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
modelByName := make(map[string]upstreamModel)
|
||||||
|
upstreamNames := make([]string, 0, len(modelsEnv.Data))
|
||||||
|
for _, m := range modelsEnv.Data {
|
||||||
|
if m.ModelName != "" {
|
||||||
|
modelByName[m.ModelName] = m
|
||||||
|
upstreamNames = append(upstreamNames, m.ModelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) 本地已有模型
|
||||||
|
var locals []model.Model
|
||||||
|
if len(upstreamNames) > 0 {
|
||||||
|
_ = model.DB.Where("model_name IN ? AND sync_official <> 0", upstreamNames).Find(&locals).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// 本地 vendor 名称映射
|
||||||
|
vendorIdSet := make(map[int]struct{})
|
||||||
|
for _, m := range locals {
|
||||||
|
if m.VendorID != 0 {
|
||||||
|
vendorIdSet[m.VendorID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
vendorIDs := make([]int, 0, len(vendorIdSet))
|
||||||
|
for id := range vendorIdSet {
|
||||||
|
vendorIDs = append(vendorIDs, id)
|
||||||
|
}
|
||||||
|
idToVendorName := make(map[int]string)
|
||||||
|
if len(vendorIDs) > 0 {
|
||||||
|
var dbVendors []model.Vendor
|
||||||
|
_ = model.DB.Where("id IN ?", vendorIDs).Find(&dbVendors).Error
|
||||||
|
for _, v := range dbVendors {
|
||||||
|
idToVendorName[v.Id] = v.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3) 缺失且上游存在的模型
|
||||||
|
missingList, _ := model.GetMissingModels()
|
||||||
|
var missing []string
|
||||||
|
for _, name := range missingList {
|
||||||
|
if _, ok := modelByName[name]; ok {
|
||||||
|
missing = append(missing, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4) 计算冲突字段
|
||||||
|
type conflictField struct {
|
||||||
|
Field string `json:"field"`
|
||||||
|
Local interface{} `json:"local"`
|
||||||
|
Upstream interface{} `json:"upstream"`
|
||||||
|
}
|
||||||
|
type conflictItem struct {
|
||||||
|
ModelName string `json:"model_name"`
|
||||||
|
Fields []conflictField `json:"fields"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var conflicts []conflictItem
|
||||||
|
for _, local := range locals {
|
||||||
|
up, ok := modelByName[local.ModelName]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fields := make([]conflictField, 0, 6)
|
||||||
|
if strings.TrimSpace(local.Description) != strings.TrimSpace(up.Description) {
|
||||||
|
fields = append(fields, conflictField{Field: "description", Local: local.Description, Upstream: up.Description})
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(local.Icon) != strings.TrimSpace(up.Icon) {
|
||||||
|
fields = append(fields, conflictField{Field: "icon", Local: local.Icon, Upstream: up.Icon})
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(local.Tags) != strings.TrimSpace(up.Tags) {
|
||||||
|
fields = append(fields, conflictField{Field: "tags", Local: local.Tags, Upstream: up.Tags})
|
||||||
|
}
|
||||||
|
// vendor 对比使用名称
|
||||||
|
localVendor := idToVendorName[local.VendorID]
|
||||||
|
if strings.TrimSpace(localVendor) != strings.TrimSpace(up.VendorName) {
|
||||||
|
fields = append(fields, conflictField{Field: "vendor", Local: localVendor, Upstream: up.VendorName})
|
||||||
|
}
|
||||||
|
if local.NameRule != up.NameRule {
|
||||||
|
fields = append(fields, conflictField{Field: "name_rule", Local: local.NameRule, Upstream: up.NameRule})
|
||||||
|
}
|
||||||
|
if local.Status != chooseStatus(up.Status, local.Status) {
|
||||||
|
fields = append(fields, conflictField{Field: "status", Local: local.Status, Upstream: up.Status})
|
||||||
|
}
|
||||||
|
if len(fields) > 0 {
|
||||||
|
conflicts = append(conflicts, conflictItem{ModelName: local.ModelName, Fields: fields})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"data": gin.H{
|
||||||
|
"missing": missing,
|
||||||
|
"conflicts": conflicts,
|
||||||
|
"source": gin.H{
|
||||||
|
"locale": locale,
|
||||||
|
"models_url": modelsURL,
|
||||||
|
"vendors_url": vendorsURL,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
@ -8,7 +8,6 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/setting"
|
|
||||||
"one-api/setting/system_setting"
|
"one-api/setting/system_setting"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -45,7 +44,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
|||||||
values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
|
values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
|
||||||
values.Set("code", code)
|
values.Set("code", code)
|
||||||
values.Set("grant_type", "authorization_code")
|
values.Set("grant_type", "authorization_code")
|
||||||
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", setting.ServerAddress))
|
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress))
|
||||||
formData := values.Encode()
|
formData := values.Encode()
|
||||||
req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
|
req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -2,6 +2,7 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
@ -35,8 +36,13 @@ func GetOptions(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OptionUpdateRequest struct {
|
||||||
|
Key string `json:"key"`
|
||||||
|
Value any `json:"value"`
|
||||||
|
}
|
||||||
|
|
||||||
func UpdateOption(c *gin.Context) {
|
func UpdateOption(c *gin.Context) {
|
||||||
var option model.Option
|
var option OptionUpdateRequest
|
||||||
err := json.NewDecoder(c.Request.Body).Decode(&option)
|
err := json.NewDecoder(c.Request.Body).Decode(&option)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
@ -45,6 +51,16 @@ func UpdateOption(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
switch option.Value.(type) {
|
||||||
|
case bool:
|
||||||
|
option.Value = common.Interface2String(option.Value.(bool))
|
||||||
|
case float64:
|
||||||
|
option.Value = common.Interface2String(option.Value.(float64))
|
||||||
|
case int:
|
||||||
|
option.Value = common.Interface2String(option.Value.(int))
|
||||||
|
default:
|
||||||
|
option.Value = fmt.Sprintf("%v", option.Value)
|
||||||
|
}
|
||||||
switch option.Key {
|
switch option.Key {
|
||||||
case "GitHubOAuthEnabled":
|
case "GitHubOAuthEnabled":
|
||||||
if option.Value == "true" && common.GitHubClientId == "" {
|
if option.Value == "true" && common.GitHubClientId == "" {
|
||||||
@ -104,7 +120,7 @@ func UpdateOption(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "GroupRatio":
|
case "GroupRatio":
|
||||||
err = ratio_setting.CheckGroupRatio(option.Value)
|
err = ratio_setting.CheckGroupRatio(option.Value.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -140,7 +156,7 @@ func UpdateOption(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "ModelRequestRateLimitGroup":
|
case "ModelRequestRateLimitGroup":
|
||||||
err = setting.CheckModelRequestRateLimitGroup(option.Value)
|
err = setting.CheckModelRequestRateLimitGroup(option.Value.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -149,7 +165,7 @@ func UpdateOption(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "console_setting.api_info":
|
case "console_setting.api_info":
|
||||||
err = console_setting.ValidateConsoleSettings(option.Value, "ApiInfo")
|
err = console_setting.ValidateConsoleSettings(option.Value.(string), "ApiInfo")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -158,7 +174,7 @@ func UpdateOption(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "console_setting.announcements":
|
case "console_setting.announcements":
|
||||||
err = console_setting.ValidateConsoleSettings(option.Value, "Announcements")
|
err = console_setting.ValidateConsoleSettings(option.Value.(string), "Announcements")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -167,7 +183,7 @@ func UpdateOption(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "console_setting.faq":
|
case "console_setting.faq":
|
||||||
err = console_setting.ValidateConsoleSettings(option.Value, "FAQ")
|
err = console_setting.ValidateConsoleSettings(option.Value.(string), "FAQ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -176,7 +192,7 @@ func UpdateOption(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "console_setting.uptime_kuma_groups":
|
case "console_setting.uptime_kuma_groups":
|
||||||
err = console_setting.ValidateConsoleSettings(option.Value, "UptimeKumaGroups")
|
err = console_setting.ValidateConsoleSettings(option.Value.(string), "UptimeKumaGroups")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -185,7 +201,7 @@ func UpdateOption(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = model.UpdateOption(option.Key, option.Value)
|
err = model.UpdateOption(option.Key, option.Value.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@ -1,24 +1,24 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/setting/ratio_setting"
|
"one-api/setting/ratio_setting"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetRatioConfig(c *gin.Context) {
|
func GetRatioConfig(c *gin.Context) {
|
||||||
if !ratio_setting.IsExposeRatioEnabled() {
|
if !ratio_setting.IsExposeRatioEnabled() {
|
||||||
c.JSON(http.StatusForbidden, gin.H{
|
c.JSON(http.StatusForbidden, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "倍率配置接口未启用",
|
"message": "倍率配置接口未启用",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": ratio_setting.GetExposedData(),
|
"data": ratio_setting.GetExposedData(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,6 +4,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/logger"
|
"one-api/logger"
|
||||||
"strings"
|
"strings"
|
||||||
@ -21,8 +23,26 @@ const (
|
|||||||
defaultTimeoutSeconds = 10
|
defaultTimeoutSeconds = 10
|
||||||
defaultEndpoint = "/api/ratio_config"
|
defaultEndpoint = "/api/ratio_config"
|
||||||
maxConcurrentFetches = 8
|
maxConcurrentFetches = 8
|
||||||
|
maxRatioConfigBytes = 10 << 20 // 10MB
|
||||||
|
floatEpsilon = 1e-9
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func nearlyEqual(a, b float64) bool {
|
||||||
|
if a > b {
|
||||||
|
return a-b < floatEpsilon
|
||||||
|
}
|
||||||
|
return b-a < floatEpsilon
|
||||||
|
}
|
||||||
|
|
||||||
|
func valuesEqual(a, b interface{}) bool {
|
||||||
|
af, aok := a.(float64)
|
||||||
|
bf, bok := b.(float64)
|
||||||
|
if aok && bok {
|
||||||
|
return nearlyEqual(af, bf)
|
||||||
|
}
|
||||||
|
return a == b
|
||||||
|
}
|
||||||
|
|
||||||
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
||||||
|
|
||||||
type upstreamResult struct {
|
type upstreamResult struct {
|
||||||
@ -87,7 +107,23 @@ func FetchUpstreamRatios(c *gin.Context) {
|
|||||||
|
|
||||||
sem := make(chan struct{}, maxConcurrentFetches)
|
sem := make(chan struct{}, maxConcurrentFetches)
|
||||||
|
|
||||||
client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
|
dialer := &net.Dialer{Timeout: 10 * time.Second}
|
||||||
|
transport := &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, ResponseHeaderTimeout: 10 * time.Second}
|
||||||
|
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
host, _, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
host = addr
|
||||||
|
}
|
||||||
|
// 对 github.io 优先尝试 IPv4,失败则回退 IPv6
|
||||||
|
if strings.HasSuffix(host, "github.io") {
|
||||||
|
if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil {
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
return dialer.DialContext(ctx, "tcp6", addr)
|
||||||
|
}
|
||||||
|
return dialer.DialContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
client := &http.Client{Transport: transport}
|
||||||
|
|
||||||
for _, chn := range upstreams {
|
for _, chn := range upstreams {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
@ -98,12 +134,17 @@ func FetchUpstreamRatios(c *gin.Context) {
|
|||||||
defer func() { <-sem }()
|
defer func() { <-sem }()
|
||||||
|
|
||||||
endpoint := chItem.Endpoint
|
endpoint := chItem.Endpoint
|
||||||
if endpoint == "" {
|
var fullURL string
|
||||||
endpoint = defaultEndpoint
|
if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") {
|
||||||
} else if !strings.HasPrefix(endpoint, "/") {
|
fullURL = endpoint
|
||||||
endpoint = "/" + endpoint
|
} else {
|
||||||
|
if endpoint == "" {
|
||||||
|
endpoint = defaultEndpoint
|
||||||
|
} else if !strings.HasPrefix(endpoint, "/") {
|
||||||
|
endpoint = "/" + endpoint
|
||||||
|
}
|
||||||
|
fullURL = chItem.BaseURL + endpoint
|
||||||
}
|
}
|
||||||
fullURL := chItem.BaseURL + endpoint
|
|
||||||
|
|
||||||
uniqueName := chItem.Name
|
uniqueName := chItem.Name
|
||||||
if chItem.ID != 0 {
|
if chItem.ID != 0 {
|
||||||
@ -120,10 +161,19 @@ func FetchUpstreamRatios(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := client.Do(httpReq)
|
// 简单重试:最多 3 次,指数退避
|
||||||
if err != nil {
|
var resp *http.Response
|
||||||
logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
|
var lastErr error
|
||||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
for attempt := 0; attempt < 3; attempt++ {
|
||||||
|
resp, lastErr = client.Do(httpReq)
|
||||||
|
if lastErr == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(time.Duration(200*(1<<attempt)) * time.Millisecond)
|
||||||
|
}
|
||||||
|
if lastErr != nil {
|
||||||
|
logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+lastErr.Error())
|
||||||
|
ch <- upstreamResult{Name: uniqueName, Err: lastErr.Error()}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
@ -132,6 +182,12 @@ func FetchUpstreamRatios(c *gin.Context) {
|
|||||||
ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
|
ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Content-Type 和响应体大小校验
|
||||||
|
if ct := resp.Header.Get("Content-Type"); ct != "" && !strings.Contains(strings.ToLower(ct), "application/json") {
|
||||||
|
logger.LogWarn(c.Request.Context(), "unexpected content-type from "+chItem.Name+": "+ct)
|
||||||
|
}
|
||||||
|
limited := io.LimitReader(resp.Body, maxRatioConfigBytes)
|
||||||
// 兼容两种上游接口格式:
|
// 兼容两种上游接口格式:
|
||||||
// type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
|
// type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
|
||||||
// type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
|
// type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
|
||||||
@ -141,7 +197,7 @@ func FetchUpstreamRatios(c *gin.Context) {
|
|||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
if err := json.NewDecoder(limited).Decode(&body); err != nil {
|
||||||
logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
|
logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
|
||||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||||
return
|
return
|
||||||
@ -152,6 +208,8 @@ func FetchUpstreamRatios(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 若 Data 为空,将继续按 type1 尝试解析(与多数静态 ratio_config 兼容)
|
||||||
|
|
||||||
// 尝试按 type1 解析
|
// 尝试按 type1 解析
|
||||||
var type1Data map[string]any
|
var type1Data map[string]any
|
||||||
if err := json.Unmarshal(body.Data, &type1Data); err == nil {
|
if err := json.Unmarshal(body.Data, &type1Data); err == nil {
|
||||||
@ -357,9 +415,9 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
|
|||||||
upstreamValue = val
|
upstreamValue = val
|
||||||
hasUpstreamValue = true
|
hasUpstreamValue = true
|
||||||
|
|
||||||
if localValue != nil && localValue != val {
|
if localValue != nil && !valuesEqual(localValue, val) {
|
||||||
hasDifference = true
|
hasDifference = true
|
||||||
} else if localValue == val {
|
} else if valuesEqual(localValue, val) {
|
||||||
upstreamValue = "same"
|
upstreamValue = "same"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -466,6 +524,13 @@ func GetSyncableChannels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
syncableChannels = append(syncableChannels, dto.SyncableChannel{
|
||||||
|
ID: -100,
|
||||||
|
Name: "官方倍率预设",
|
||||||
|
BaseURL: "https://basellm.github.io",
|
||||||
|
Status: 1,
|
||||||
|
})
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
|
|||||||
@ -139,15 +139,15 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
|||||||
|
|
||||||
// common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
|
// common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
|
||||||
|
|
||||||
preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
newAPIError = service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||||
if newAPIError != nil {
|
if newAPIError != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
// Only return quota if downstream failed and quota was actually pre-consumed
|
// Only return quota if downstream failed and quota was actually pre-consumed
|
||||||
if newAPIError != nil && preConsumedQuota != 0 {
|
if newAPIError != nil && relayInfo.FinalPreConsumedQuota != 0 {
|
||||||
service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota)
|
service.ReturnPreConsumedQuota(c, relayInfo)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -277,14 +277,13 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
|||||||
|
|
||||||
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
||||||
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||||||
|
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||||
gopool.Go(func() {
|
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
gopool.Go(func() {
|
||||||
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
|
||||||
service.DisableChannel(channelError, err.Error())
|
service.DisableChannel(channelError, err.Error())
|
||||||
}
|
})
|
||||||
})
|
}
|
||||||
|
|
||||||
if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) {
|
if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) {
|
||||||
// 保存错误日志到mysql中
|
// 保存错误日志到mysql中
|
||||||
|
|||||||
@ -178,4 +178,4 @@ func boolToString(b bool) string {
|
|||||||
return "true"
|
return "true"
|
||||||
}
|
}
|
||||||
return "false"
|
return "false"
|
||||||
}
|
}
|
||||||
@ -94,7 +94,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
|||||||
} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
|
} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
|
||||||
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
||||||
} else {
|
} else {
|
||||||
task.Data = responseBody
|
task.Data = redactVideoResponseBody(responseBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
@ -113,11 +113,13 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
|||||||
task.StartTime = now
|
task.StartTime = now
|
||||||
}
|
}
|
||||||
case model.TaskStatusSuccess:
|
case model.TaskStatusSuccess:
|
||||||
task.Progress = "100%"
|
task.Progress = "100%"
|
||||||
if task.FinishTime == 0 {
|
if task.FinishTime == 0 {
|
||||||
task.FinishTime = now
|
task.FinishTime = now
|
||||||
}
|
}
|
||||||
task.FailReason = taskResult.Url
|
if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
|
||||||
|
task.FailReason = taskResult.Url
|
||||||
|
}
|
||||||
case model.TaskStatusFailure:
|
case model.TaskStatusFailure:
|
||||||
task.Status = model.TaskStatusFailure
|
task.Status = model.TaskStatusFailure
|
||||||
task.Progress = "100%"
|
task.Progress = "100%"
|
||||||
@ -146,3 +148,37 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func redactVideoResponseBody(body []byte) []byte {
|
||||||
|
var m map[string]any
|
||||||
|
if err := json.Unmarshal(body, &m); err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
resp, _ := m["response"].(map[string]any)
|
||||||
|
if resp != nil {
|
||||||
|
delete(resp, "bytesBase64Encoded")
|
||||||
|
if v, ok := resp["video"].(string); ok {
|
||||||
|
resp["video"] = truncateBase64(v)
|
||||||
|
}
|
||||||
|
if vs, ok := resp["videos"].([]any); ok {
|
||||||
|
for i := range vs {
|
||||||
|
if vm, ok := vs[i].(map[string]any); ok {
|
||||||
|
delete(vm, "bytesBase64Encoded")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateBase64(s string) string {
|
||||||
|
const maxKeep = 256
|
||||||
|
if len(s) <= maxKeep {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s[:maxKeep] + "..."
|
||||||
|
}
|
||||||
|
|||||||
@ -9,6 +9,8 @@ import (
|
|||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
|
"one-api/setting/operation_setting"
|
||||||
|
"one-api/setting/system_setting"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -19,6 +21,44 @@ import (
|
|||||||
"github.com/shopspring/decimal"
|
"github.com/shopspring/decimal"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func GetTopUpInfo(c *gin.Context) {
|
||||||
|
// 获取支付方式
|
||||||
|
payMethods := operation_setting.PayMethods
|
||||||
|
|
||||||
|
// 如果启用了 Stripe 支付,添加到支付方法列表
|
||||||
|
if setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "" {
|
||||||
|
// 检查是否已经包含 Stripe
|
||||||
|
hasStripe := false
|
||||||
|
for _, method := range payMethods {
|
||||||
|
if method["type"] == "stripe" {
|
||||||
|
hasStripe = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasStripe {
|
||||||
|
stripeMethod := map[string]string{
|
||||||
|
"name": "Stripe",
|
||||||
|
"type": "stripe",
|
||||||
|
"color": "rgba(var(--semi-purple-5), 1)",
|
||||||
|
"min_topup": strconv.Itoa(setting.StripeMinTopUp),
|
||||||
|
}
|
||||||
|
payMethods = append(payMethods, stripeMethod)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data := gin.H{
|
||||||
|
"enable_online_topup": operation_setting.PayAddress != "" && operation_setting.EpayId != "" && operation_setting.EpayKey != "",
|
||||||
|
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
|
||||||
|
"pay_methods": payMethods,
|
||||||
|
"min_topup": operation_setting.MinTopUp,
|
||||||
|
"stripe_min_topup": setting.StripeMinTopUp,
|
||||||
|
"amount_options": operation_setting.GetPaymentSetting().AmountOptions,
|
||||||
|
"discount": operation_setting.GetPaymentSetting().AmountDiscount,
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, data)
|
||||||
|
}
|
||||||
|
|
||||||
type EpayRequest struct {
|
type EpayRequest struct {
|
||||||
Amount int64 `json:"amount"`
|
Amount int64 `json:"amount"`
|
||||||
PaymentMethod string `json:"payment_method"`
|
PaymentMethod string `json:"payment_method"`
|
||||||
@ -31,13 +71,13 @@ type AmountRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetEpayClient() *epay.Client {
|
func GetEpayClient() *epay.Client {
|
||||||
if setting.PayAddress == "" || setting.EpayId == "" || setting.EpayKey == "" {
|
if operation_setting.PayAddress == "" || operation_setting.EpayId == "" || operation_setting.EpayKey == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
withUrl, err := epay.NewClient(&epay.Config{
|
withUrl, err := epay.NewClient(&epay.Config{
|
||||||
PartnerID: setting.EpayId,
|
PartnerID: operation_setting.EpayId,
|
||||||
Key: setting.EpayKey,
|
Key: operation_setting.EpayKey,
|
||||||
}, setting.PayAddress)
|
}, operation_setting.PayAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -58,15 +98,23 @@ func getPayMoney(amount int64, group string) float64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dTopupGroupRatio := decimal.NewFromFloat(topupGroupRatio)
|
dTopupGroupRatio := decimal.NewFromFloat(topupGroupRatio)
|
||||||
dPrice := decimal.NewFromFloat(setting.Price)
|
dPrice := decimal.NewFromFloat(operation_setting.Price)
|
||||||
|
// apply optional preset discount by the original request amount (if configured), default 1.0
|
||||||
|
discount := 1.0
|
||||||
|
if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(amount)]; ok {
|
||||||
|
if ds > 0 {
|
||||||
|
discount = ds
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dDiscount := decimal.NewFromFloat(discount)
|
||||||
|
|
||||||
payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio)
|
payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio).Mul(dDiscount)
|
||||||
|
|
||||||
return payMoney.InexactFloat64()
|
return payMoney.InexactFloat64()
|
||||||
}
|
}
|
||||||
|
|
||||||
func getMinTopup() int64 {
|
func getMinTopup() int64 {
|
||||||
minTopup := setting.MinTopUp
|
minTopup := operation_setting.MinTopUp
|
||||||
if !common.DisplayInCurrencyEnabled {
|
if !common.DisplayInCurrencyEnabled {
|
||||||
dMinTopup := decimal.NewFromInt(int64(minTopup))
|
dMinTopup := decimal.NewFromInt(int64(minTopup))
|
||||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||||
@ -99,13 +147,13 @@ func RequestEpay(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !setting.ContainsPayMethod(req.PaymentMethod) {
|
if !operation_setting.ContainsPayMethod(req.PaymentMethod) {
|
||||||
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
|
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
callBackAddress := service.GetCallbackAddress()
|
callBackAddress := service.GetCallbackAddress()
|
||||||
returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
|
returnUrl, _ := url.Parse(system_setting.ServerAddress + "/console/log")
|
||||||
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
||||||
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
|
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
|
||||||
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
|
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
|
||||||
|
|||||||
@ -8,6 +8,8 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
|
"one-api/setting/operation_setting"
|
||||||
|
"one-api/setting/system_setting"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -215,8 +217,8 @@ func genStripeLink(referenceId string, customerId string, email string, amount i
|
|||||||
|
|
||||||
params := &stripe.CheckoutSessionParams{
|
params := &stripe.CheckoutSessionParams{
|
||||||
ClientReferenceID: stripe.String(referenceId),
|
ClientReferenceID: stripe.String(referenceId),
|
||||||
SuccessURL: stripe.String(setting.ServerAddress + "/log"),
|
SuccessURL: stripe.String(system_setting.ServerAddress + "/log"),
|
||||||
CancelURL: stripe.String(setting.ServerAddress + "/topup"),
|
CancelURL: stripe.String(system_setting.ServerAddress + "/topup"),
|
||||||
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
||||||
{
|
{
|
||||||
Price: stripe.String(setting.StripePriceId),
|
Price: stripe.String(setting.StripePriceId),
|
||||||
@ -254,6 +256,7 @@ func GetChargedAmount(count float64, user model.User) float64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getStripePayMoney(amount float64, group string) float64 {
|
func getStripePayMoney(amount float64, group string) float64 {
|
||||||
|
originalAmount := amount
|
||||||
if !common.DisplayInCurrencyEnabled {
|
if !common.DisplayInCurrencyEnabled {
|
||||||
amount = amount / common.QuotaPerUnit
|
amount = amount / common.QuotaPerUnit
|
||||||
}
|
}
|
||||||
@ -262,7 +265,14 @@ func getStripePayMoney(amount float64, group string) float64 {
|
|||||||
if topupGroupRatio == 0 {
|
if topupGroupRatio == 0 {
|
||||||
topupGroupRatio = 1
|
topupGroupRatio = 1
|
||||||
}
|
}
|
||||||
payMoney := amount * setting.StripeUnitPrice * topupGroupRatio
|
// apply optional preset discount by the original request amount (if configured), default 1.0
|
||||||
|
discount := 1.0
|
||||||
|
if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(originalAmount)]; ok {
|
||||||
|
if ds > 0 {
|
||||||
|
discount = ds
|
||||||
|
}
|
||||||
|
}
|
||||||
|
payMoney := amount * setting.StripeUnitPrice * topupGroupRatio * discount
|
||||||
return payMoney
|
return payMoney
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -31,7 +31,7 @@ type Monitor struct {
|
|||||||
|
|
||||||
type UptimeGroupResult struct {
|
type UptimeGroupResult struct {
|
||||||
CategoryName string `json:"categoryName"`
|
CategoryName string `json:"categoryName"`
|
||||||
Monitors []Monitor `json:"monitors"`
|
Monitors []Monitor `json:"monitors"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error {
|
func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error {
|
||||||
@ -57,29 +57,29 @@ func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[st
|
|||||||
url, _ := groupConfig["url"].(string)
|
url, _ := groupConfig["url"].(string)
|
||||||
slug, _ := groupConfig["slug"].(string)
|
slug, _ := groupConfig["slug"].(string)
|
||||||
categoryName, _ := groupConfig["categoryName"].(string)
|
categoryName, _ := groupConfig["categoryName"].(string)
|
||||||
|
|
||||||
result := UptimeGroupResult{
|
result := UptimeGroupResult{
|
||||||
CategoryName: categoryName,
|
CategoryName: categoryName,
|
||||||
Monitors: []Monitor{},
|
Monitors: []Monitor{},
|
||||||
}
|
}
|
||||||
|
|
||||||
if url == "" || slug == "" {
|
if url == "" || slug == "" {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
baseURL := strings.TrimSuffix(url, "/")
|
baseURL := strings.TrimSuffix(url, "/")
|
||||||
|
|
||||||
var statusData struct {
|
var statusData struct {
|
||||||
PublicGroupList []struct {
|
PublicGroupList []struct {
|
||||||
ID int `json:"id"`
|
ID int `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
MonitorList []struct {
|
MonitorList []struct {
|
||||||
ID int `json:"id"`
|
ID int `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
} `json:"monitorList"`
|
} `json:"monitorList"`
|
||||||
} `json:"publicGroupList"`
|
} `json:"publicGroupList"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var heartbeatData struct {
|
var heartbeatData struct {
|
||||||
HeartbeatList map[string][]struct {
|
HeartbeatList map[string][]struct {
|
||||||
Status int `json:"status"`
|
Status int `json:"status"`
|
||||||
@ -88,11 +88,11 @@ func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[st
|
|||||||
}
|
}
|
||||||
|
|
||||||
g, gCtx := errgroup.WithContext(ctx)
|
g, gCtx := errgroup.WithContext(ctx)
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
|
return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
|
||||||
})
|
})
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
|
return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
|
||||||
})
|
})
|
||||||
|
|
||||||
if g.Wait() != nil {
|
if g.Wait() != nil {
|
||||||
@ -139,7 +139,7 @@ func GetUptimeKumaStatus(c *gin.Context) {
|
|||||||
|
|
||||||
client := &http.Client{Timeout: httpTimeout}
|
client := &http.Client{Timeout: httpTimeout}
|
||||||
results := make([]UptimeGroupResult, len(groups))
|
results := make([]UptimeGroupResult, len(groups))
|
||||||
|
|
||||||
g, gCtx := errgroup.WithContext(ctx)
|
g, gCtx := errgroup.WithContext(ctx)
|
||||||
for i, group := range groups {
|
for i, group := range groups {
|
||||||
i, group := i, group
|
i, group := i, group
|
||||||
@ -148,7 +148,7 @@ func GetUptimeKumaStatus(c *gin.Context) {
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
g.Wait()
|
g.Wait()
|
||||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results})
|
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -210,6 +210,7 @@ func Register(c *gin.Context) {
|
|||||||
Password: user.Password,
|
Password: user.Password,
|
||||||
DisplayName: user.Username,
|
DisplayName: user.Username,
|
||||||
InviterId: inviterId,
|
InviterId: inviterId,
|
||||||
|
Role: common.RoleCommonUser, // 明确设置角色为普通用户
|
||||||
}
|
}
|
||||||
if common.EmailVerificationEnabled {
|
if common.EmailVerificationEnabled {
|
||||||
cleanUser.Email = user.Email
|
cleanUser.Email = user.Email
|
||||||
@ -426,6 +427,7 @@ func GetAffCode(c *gin.Context) {
|
|||||||
|
|
||||||
func GetSelf(c *gin.Context) {
|
func GetSelf(c *gin.Context) {
|
||||||
id := c.GetInt("id")
|
id := c.GetInt("id")
|
||||||
|
userRole := c.GetInt("role")
|
||||||
user, err := model.GetUserById(id, false)
|
user, err := model.GetUserById(id, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
@ -434,14 +436,134 @@ func GetSelf(c *gin.Context) {
|
|||||||
// Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
|
// Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
|
||||||
user.Remark = ""
|
user.Remark = ""
|
||||||
|
|
||||||
|
// 计算用户权限信息
|
||||||
|
permissions := calculateUserPermissions(userRole)
|
||||||
|
|
||||||
|
// 获取用户设置并提取sidebar_modules
|
||||||
|
userSetting := user.GetSetting()
|
||||||
|
|
||||||
|
// 构建响应数据,包含用户信息和权限
|
||||||
|
responseData := map[string]interface{}{
|
||||||
|
"id": user.Id,
|
||||||
|
"username": user.Username,
|
||||||
|
"display_name": user.DisplayName,
|
||||||
|
"role": user.Role,
|
||||||
|
"status": user.Status,
|
||||||
|
"email": user.Email,
|
||||||
|
"group": user.Group,
|
||||||
|
"quota": user.Quota,
|
||||||
|
"used_quota": user.UsedQuota,
|
||||||
|
"request_count": user.RequestCount,
|
||||||
|
"aff_code": user.AffCode,
|
||||||
|
"aff_count": user.AffCount,
|
||||||
|
"aff_quota": user.AffQuota,
|
||||||
|
"aff_history_quota": user.AffHistoryQuota,
|
||||||
|
"inviter_id": user.InviterId,
|
||||||
|
"linux_do_id": user.LinuxDOId,
|
||||||
|
"setting": user.Setting,
|
||||||
|
"stripe_customer": user.StripeCustomer,
|
||||||
|
"sidebar_modules": userSetting.SidebarModules, // 正确提取sidebar_modules字段
|
||||||
|
"permissions": permissions, // 新增权限字段
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": user,
|
"data": responseData,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 计算用户权限的辅助函数
|
||||||
|
func calculateUserPermissions(userRole int) map[string]interface{} {
|
||||||
|
permissions := map[string]interface{}{}
|
||||||
|
|
||||||
|
// 根据用户角色计算权限
|
||||||
|
if userRole == common.RoleRootUser {
|
||||||
|
// 超级管理员不需要边栏设置功能
|
||||||
|
permissions["sidebar_settings"] = false
|
||||||
|
permissions["sidebar_modules"] = map[string]interface{}{}
|
||||||
|
} else if userRole == common.RoleAdminUser {
|
||||||
|
// 管理员可以设置边栏,但不包含系统设置功能
|
||||||
|
permissions["sidebar_settings"] = true
|
||||||
|
permissions["sidebar_modules"] = map[string]interface{}{
|
||||||
|
"admin": map[string]interface{}{
|
||||||
|
"setting": false, // 管理员不能访问系统设置
|
||||||
|
},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 普通用户只能设置个人功能,不包含管理员区域
|
||||||
|
permissions["sidebar_settings"] = true
|
||||||
|
permissions["sidebar_modules"] = map[string]interface{}{
|
||||||
|
"admin": false, // 普通用户不能访问管理员区域
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return permissions
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据用户角色生成默认的边栏配置
|
||||||
|
func generateDefaultSidebarConfig(userRole int) string {
|
||||||
|
defaultConfig := map[string]interface{}{}
|
||||||
|
|
||||||
|
// 聊天区域 - 所有用户都可以访问
|
||||||
|
defaultConfig["chat"] = map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"playground": true,
|
||||||
|
"chat": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 控制台区域 - 所有用户都可以访问
|
||||||
|
defaultConfig["console"] = map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"detail": true,
|
||||||
|
"token": true,
|
||||||
|
"log": true,
|
||||||
|
"midjourney": true,
|
||||||
|
"task": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 个人中心区域 - 所有用户都可以访问
|
||||||
|
defaultConfig["personal"] = map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"topup": true,
|
||||||
|
"personal": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 管理员区域 - 根据角色决定
|
||||||
|
if userRole == common.RoleAdminUser {
|
||||||
|
// 管理员可以访问管理员区域,但不能访问系统设置
|
||||||
|
defaultConfig["admin"] = map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"channel": true,
|
||||||
|
"models": true,
|
||||||
|
"redemption": true,
|
||||||
|
"user": true,
|
||||||
|
"setting": false, // 管理员不能访问系统设置
|
||||||
|
}
|
||||||
|
} else if userRole == common.RoleRootUser {
|
||||||
|
// 超级管理员可以访问所有功能
|
||||||
|
defaultConfig["admin"] = map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"channel": true,
|
||||||
|
"models": true,
|
||||||
|
"redemption": true,
|
||||||
|
"user": true,
|
||||||
|
"setting": true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 普通用户不包含admin区域
|
||||||
|
|
||||||
|
// 转换为JSON字符串
|
||||||
|
configBytes, err := json.Marshal(defaultConfig)
|
||||||
|
if err != nil {
|
||||||
|
common.SysLog("生成默认边栏配置失败: " + err.Error())
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(configBytes)
|
||||||
|
}
|
||||||
|
|
||||||
func GetUserModels(c *gin.Context) {
|
func GetUserModels(c *gin.Context) {
|
||||||
id, err := strconv.Atoi(c.Param("id"))
|
id, err := strconv.Atoi(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -528,8 +650,8 @@ func UpdateUser(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UpdateSelf(c *gin.Context) {
|
func UpdateSelf(c *gin.Context) {
|
||||||
var user model.User
|
var requestData map[string]interface{}
|
||||||
err := json.NewDecoder(c.Request.Body).Decode(&user)
|
err := json.NewDecoder(c.Request.Body).Decode(&requestData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -537,6 +659,60 @@ func UpdateSelf(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查是否是sidebar_modules更新请求
|
||||||
|
if sidebarModules, exists := requestData["sidebar_modules"]; exists {
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
user, err := model.GetUserById(userId, false)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取当前用户设置
|
||||||
|
currentSetting := user.GetSetting()
|
||||||
|
|
||||||
|
// 更新sidebar_modules字段
|
||||||
|
if sidebarModulesStr, ok := sidebarModules.(string); ok {
|
||||||
|
currentSetting.SidebarModules = sidebarModulesStr
|
||||||
|
}
|
||||||
|
|
||||||
|
// 保存更新后的设置
|
||||||
|
user.SetSetting(currentSetting)
|
||||||
|
if err := user.Update(false); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "更新设置失败: " + err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "设置更新成功",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 原有的用户信息更新逻辑
|
||||||
|
var user model.User
|
||||||
|
requestDataBytes, err := json.Marshal(requestData)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "无效的参数",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(requestDataBytes, &user)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "无效的参数",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if user.Password == "" {
|
if user.Password == "" {
|
||||||
user.Password = "$I_LOVE_U" // make Validator happy :)
|
user.Password = "$I_LOVE_U" // make Validator happy :)
|
||||||
}
|
}
|
||||||
@ -679,6 +855,7 @@ func CreateUser(c *gin.Context) {
|
|||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
Password: user.Password,
|
Password: user.Password,
|
||||||
DisplayName: user.DisplayName,
|
DisplayName: user.DisplayName,
|
||||||
|
Role: user.Role, // 保持管理员设置的角色
|
||||||
}
|
}
|
||||||
if err := cleanUser.Insert(0); err != nil {
|
if err := cleanUser.Insert(0); err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
@ -920,6 +1097,7 @@ type UpdateUserSettingRequest struct {
|
|||||||
WebhookUrl string `json:"webhook_url,omitempty"`
|
WebhookUrl string `json:"webhook_url,omitempty"`
|
||||||
WebhookSecret string `json:"webhook_secret,omitempty"`
|
WebhookSecret string `json:"webhook_secret,omitempty"`
|
||||||
NotificationEmail string `json:"notification_email,omitempty"`
|
NotificationEmail string `json:"notification_email,omitempty"`
|
||||||
|
BarkUrl string `json:"bark_url,omitempty"`
|
||||||
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
|
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
|
||||||
RecordIpLog bool `json:"record_ip_log"`
|
RecordIpLog bool `json:"record_ip_log"`
|
||||||
}
|
}
|
||||||
@ -935,7 +1113,7 @@ func UpdateUserSetting(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 验证预警类型
|
// 验证预警类型
|
||||||
if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook {
|
if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook && req.QuotaWarningType != dto.NotifyTypeBark {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "无效的预警类型",
|
"message": "无效的预警类型",
|
||||||
@ -983,6 +1161,33 @@ func UpdateUserSetting(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 如果是Bark类型,验证Bark URL
|
||||||
|
if req.QuotaWarningType == dto.NotifyTypeBark {
|
||||||
|
if req.BarkUrl == "" {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "Bark推送URL不能为空",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 验证URL格式
|
||||||
|
if _, err := url.ParseRequestURI(req.BarkUrl); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "无效的Bark推送URL",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 检查是否是HTTP或HTTPS
|
||||||
|
if !strings.HasPrefix(req.BarkUrl, "https://") && !strings.HasPrefix(req.BarkUrl, "http://") {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "Bark推送URL必须以http://或https://开头",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
user, err := model.GetUserById(userId, true)
|
user, err := model.GetUserById(userId, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1011,6 +1216,11 @@ func UpdateUserSetting(c *gin.Context) {
|
|||||||
settings.NotificationEmail = req.NotificationEmail
|
settings.NotificationEmail = req.NotificationEmail
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 如果是Bark类型,添加Bark URL到设置中
|
||||||
|
if req.QuotaWarningType == dto.NotifyTypeBark {
|
||||||
|
settings.BarkUrl = req.BarkUrl
|
||||||
|
}
|
||||||
|
|
||||||
// 更新用户设置
|
// 更新用户设置
|
||||||
user.SetSetting(settings)
|
user.SetSetting(settings)
|
||||||
if err := user.Update(false); err != nil {
|
if err := user.Update(false); err != nil {
|
||||||
|
|||||||
@ -9,6 +9,14 @@ type ChannelSettings struct {
|
|||||||
SystemPromptOverride bool `json:"system_prompt_override,omitempty"`
|
SystemPromptOverride bool `json:"system_prompt_override,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type VertexKeyType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
VertexKeyTypeJSON VertexKeyType = "json"
|
||||||
|
VertexKeyTypeAPIKey VertexKeyType = "api_key"
|
||||||
|
)
|
||||||
|
|
||||||
type ChannelOtherSettings struct {
|
type ChannelOtherSettings struct {
|
||||||
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
|
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
|
||||||
|
VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
|
||||||
}
|
}
|
||||||
|
|||||||
@ -59,6 +59,31 @@ func (i *ImageRequest) UnmarshalJSON(data []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 序列化时需要重新把字段平铺
|
||||||
|
func (r ImageRequest) MarshalJSON() ([]byte, error) {
|
||||||
|
// 将已定义字段转为 map
|
||||||
|
type Alias ImageRequest
|
||||||
|
alias := Alias(r)
|
||||||
|
base, err := common.Marshal(alias)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var baseMap map[string]json.RawMessage
|
||||||
|
if err := common.Unmarshal(base, &baseMap); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 合并 ExtraFields
|
||||||
|
for k, v := range r.Extra {
|
||||||
|
if _, exists := baseMap[k]; !exists {
|
||||||
|
baseMap[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Marshal(baseMap)
|
||||||
|
}
|
||||||
|
|
||||||
func GetJSONFieldNames(t reflect.Type) map[string]struct{} {
|
func GetJSONFieldNames(t reflect.Type) map[string]struct{} {
|
||||||
fields := make(map[string]struct{})
|
fields := make(map[string]struct{})
|
||||||
for i := 0; i < t.NumField(); i++ {
|
for i := 0; i < t.NumField(); i++ {
|
||||||
|
|||||||
@ -1,23 +1,23 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
type UpstreamDTO struct {
|
type UpstreamDTO struct {
|
||||||
ID int `json:"id,omitempty"`
|
ID int `json:"id,omitempty"`
|
||||||
Name string `json:"name" binding:"required"`
|
Name string `json:"name" binding:"required"`
|
||||||
BaseURL string `json:"base_url" binding:"required"`
|
BaseURL string `json:"base_url" binding:"required"`
|
||||||
Endpoint string `json:"endpoint"`
|
Endpoint string `json:"endpoint"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpstreamRequest struct {
|
type UpstreamRequest struct {
|
||||||
ChannelIDs []int64 `json:"channel_ids"`
|
ChannelIDs []int64 `json:"channel_ids"`
|
||||||
Upstreams []UpstreamDTO `json:"upstreams"`
|
Upstreams []UpstreamDTO `json:"upstreams"`
|
||||||
Timeout int `json:"timeout"`
|
Timeout int `json:"timeout"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestResult 上游测试连通性结果
|
// TestResult 上游测试连通性结果
|
||||||
type TestResult struct {
|
type TestResult struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Error string `json:"error,omitempty"`
|
Error string `json:"error,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// DifferenceItem 差异项
|
// DifferenceItem 差异项
|
||||||
@ -25,14 +25,14 @@ type TestResult struct {
|
|||||||
// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
|
// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
|
||||||
|
|
||||||
type DifferenceItem struct {
|
type DifferenceItem struct {
|
||||||
Current interface{} `json:"current"`
|
Current interface{} `json:"current"`
|
||||||
Upstreams map[string]interface{} `json:"upstreams"`
|
Upstreams map[string]interface{} `json:"upstreams"`
|
||||||
Confidence map[string]bool `json:"confidence"`
|
Confidence map[string]bool `json:"confidence"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type SyncableChannel struct {
|
type SyncableChannel struct {
|
||||||
ID int `json:"id"`
|
ID int `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
BaseURL string `json:"base_url"`
|
BaseURL string `json:"base_url"`
|
||||||
Status int `json:"status"`
|
Status int `json:"status"`
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,11 +6,14 @@ type UserSetting struct {
|
|||||||
WebhookUrl string `json:"webhook_url,omitempty"` // WebhookUrl webhook地址
|
WebhookUrl string `json:"webhook_url,omitempty"` // WebhookUrl webhook地址
|
||||||
WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥
|
WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥
|
||||||
NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址
|
NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址
|
||||||
|
BarkUrl string `json:"bark_url,omitempty"` // BarkUrl Bark推送URL
|
||||||
AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型
|
AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型
|
||||||
RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP
|
RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP
|
||||||
|
SidebarModules string `json:"sidebar_modules,omitempty"` // SidebarModules 左侧边栏模块配置
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
NotifyTypeEmail = "email" // Email 邮件
|
NotifyTypeEmail = "email" // Email 邮件
|
||||||
NotifyTypeWebhook = "webhook" // Webhook
|
NotifyTypeWebhook = "webhook" // Webhook
|
||||||
|
NotifyTypeBark = "bark" // Bark 推送
|
||||||
)
|
)
|
||||||
|
|||||||
12
main.go
12
main.go
@ -94,13 +94,9 @@ func main() {
|
|||||||
}
|
}
|
||||||
go controller.AutomaticallyUpdateChannels(frequency)
|
go controller.AutomaticallyUpdateChannels(frequency)
|
||||||
}
|
}
|
||||||
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
|
|
||||||
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
|
go controller.AutomaticallyTestChannels()
|
||||||
if err != nil {
|
|
||||||
common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
|
|
||||||
}
|
|
||||||
go controller.AutomaticallyTestChannels(frequency)
|
|
||||||
}
|
|
||||||
if common.IsMasterNode && constant.UpdateTask {
|
if common.IsMasterNode && constant.UpdateTask {
|
||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
controller.UpdateMidjourneyTaskBulk()
|
controller.UpdateMidjourneyTaskBulk()
|
||||||
@ -208,4 +204,4 @@ func InitResources() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -166,9 +166,9 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
c.Set("platform", string(constant.TaskPlatformSuno))
|
c.Set("platform", string(constant.TaskPlatformSuno))
|
||||||
c.Set("relay_mode", relayMode)
|
c.Set("relay_mode", relayMode)
|
||||||
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
|
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
|
||||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
|
||||||
relayMode := relayconstant.RelayModeUnknown
|
relayMode := relayconstant.RelayModeUnknown
|
||||||
if c.Request.Method == http.MethodPost {
|
if c.Request.Method == http.MethodPost {
|
||||||
|
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
relayMode = relayconstant.RelayModeVideoSubmit
|
relayMode = relayconstant.RelayModeVideoSubmit
|
||||||
} else if c.Request.Method == http.MethodGet {
|
} else if c.Request.Method == http.MethodGet {
|
||||||
relayMode = relayconstant.RelayModeVideoFetchByID
|
relayMode = relayconstant.RelayModeVideoFetchByID
|
||||||
|
|||||||
@ -18,12 +18,12 @@ func StatsMiddleware() gin.HandlerFunc {
|
|||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
// 增加活跃连接数
|
// 增加活跃连接数
|
||||||
atomic.AddInt64(&globalStats.activeConnections, 1)
|
atomic.AddInt64(&globalStats.activeConnections, 1)
|
||||||
|
|
||||||
// 确保在请求结束时减少连接数
|
// 确保在请求结束时减少连接数
|
||||||
defer func() {
|
defer func() {
|
||||||
atomic.AddInt64(&globalStats.activeConnections, -1)
|
atomic.AddInt64(&globalStats.activeConnections, -1)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -38,4 +38,4 @@ func GetStats() StatsInfo {
|
|||||||
return StatsInfo{
|
return StatsInfo{
|
||||||
ActiveConnections: atomic.LoadInt64(&globalStats.activeConnections),
|
ActiveConnections: atomic.LoadInt64(&globalStats.activeConnections),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -42,14 +42,16 @@ type Channel struct {
|
|||||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||||
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
||||||
OtherInfo string `json:"other_info"`
|
OtherInfo string `json:"other_info"`
|
||||||
OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置
|
|
||||||
Tag *string `json:"tag" gorm:"index"`
|
Tag *string `json:"tag" gorm:"index"`
|
||||||
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
|
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
|
||||||
ParamOverride *string `json:"param_override" gorm:"type:text"`
|
ParamOverride *string `json:"param_override" gorm:"type:text"`
|
||||||
HeaderOverride *string `json:"header_override" gorm:"type:text"`
|
HeaderOverride *string `json:"header_override" gorm:"type:text"`
|
||||||
|
Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"`
|
||||||
// add after v0.8.5
|
// add after v0.8.5
|
||||||
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
|
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
|
||||||
|
|
||||||
|
OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置,存储azure版本等不需要检索的信息,详见dto.ChannelOtherSettings
|
||||||
|
|
||||||
// cache info
|
// cache info
|
||||||
Keys []string `json:"-" gorm:"-"`
|
Keys []string `json:"-" gorm:"-"`
|
||||||
}
|
}
|
||||||
@ -606,8 +608,12 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if channelCache.ChannelInfo.IsMultiKey {
|
if channelCache.ChannelInfo.IsMultiKey {
|
||||||
|
// Use per-channel lock to prevent concurrent map read/write with GetNextEnabledKey
|
||||||
|
pollingLock := GetChannelPollingLock(channelId)
|
||||||
|
pollingLock.Lock()
|
||||||
// 如果是多Key模式,更新缓存中的状态
|
// 如果是多Key模式,更新缓存中的状态
|
||||||
handlerMultiKeyUpdate(channelCache, usingKey, status, reason)
|
handlerMultiKeyUpdate(channelCache, usingKey, status, reason)
|
||||||
|
pollingLock.Unlock()
|
||||||
//CacheUpdateChannel(channelCache)
|
//CacheUpdateChannel(channelCache)
|
||||||
//return true
|
//return true
|
||||||
} else {
|
} else {
|
||||||
@ -638,7 +644,11 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
|||||||
|
|
||||||
if channel.ChannelInfo.IsMultiKey {
|
if channel.ChannelInfo.IsMultiKey {
|
||||||
beforeStatus := channel.Status
|
beforeStatus := channel.Status
|
||||||
|
// Protect map writes with the same per-channel lock used by readers
|
||||||
|
pollingLock := GetChannelPollingLock(channelId)
|
||||||
|
pollingLock.Lock()
|
||||||
handlerMultiKeyUpdate(channel, usingKey, status, reason)
|
handlerMultiKeyUpdate(channel, usingKey, status, reason)
|
||||||
|
pollingLock.Unlock()
|
||||||
if beforeStatus != channel.Status {
|
if beforeStatus != channel.Status {
|
||||||
shouldUpdateAbilities = true
|
shouldUpdateAbilities = true
|
||||||
}
|
}
|
||||||
|
|||||||
@ -64,22 +64,6 @@ var DB *gorm.DB
|
|||||||
|
|
||||||
var LOG_DB *gorm.DB
|
var LOG_DB *gorm.DB
|
||||||
|
|
||||||
// dropIndexIfExists drops a MySQL index only if it exists to avoid noisy 1091 errors
|
|
||||||
func dropIndexIfExists(tableName string, indexName string) {
|
|
||||||
if !common.UsingMySQL {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var count int64
|
|
||||||
// Check index existence via information_schema
|
|
||||||
err := DB.Raw(
|
|
||||||
"SELECT COUNT(1) FROM information_schema.statistics WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?",
|
|
||||||
tableName, indexName,
|
|
||||||
).Scan(&count).Error
|
|
||||||
if err == nil && count > 0 {
|
|
||||||
_ = DB.Exec("ALTER TABLE " + tableName + " DROP INDEX " + indexName + ";").Error
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func createRootAccountIfNeed() error {
|
func createRootAccountIfNeed() error {
|
||||||
var user User
|
var user User
|
||||||
//if user.Status != common.UserStatusEnabled {
|
//if user.Status != common.UserStatusEnabled {
|
||||||
@ -263,16 +247,6 @@ func InitLogDB() (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func migrateDB() error {
|
func migrateDB() error {
|
||||||
// 修复旧版本留下的唯一索引,允许软删除后重新插入同名记录
|
|
||||||
// 删除单列唯一索引(列级 UNIQUE)及早期命名方式,防止与新复合唯一索引 (model_name, deleted_at) 冲突
|
|
||||||
dropIndexIfExists("models", "uk_model_name") // 新版复合索引名称(若已存在)
|
|
||||||
dropIndexIfExists("models", "model_name") // 旧版列级唯一索引名称
|
|
||||||
|
|
||||||
dropIndexIfExists("vendors", "uk_vendor_name") // 新版复合索引名称(若已存在)
|
|
||||||
dropIndexIfExists("vendors", "name") // 旧版列级唯一索引名称
|
|
||||||
//if !common.UsingPostgreSQL {
|
|
||||||
// return migrateDBFast()
|
|
||||||
//}
|
|
||||||
err := DB.AutoMigrate(
|
err := DB.AutoMigrate(
|
||||||
&Channel{},
|
&Channel{},
|
||||||
&Token{},
|
&Token{},
|
||||||
@ -299,13 +273,6 @@ func migrateDB() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func migrateDBFast() error {
|
func migrateDBFast() error {
|
||||||
// 修复旧版本留下的唯一索引,允许软删除后重新插入同名记录
|
|
||||||
// 删除单列唯一索引(列级 UNIQUE)及早期命名方式,防止与新复合唯一索引冲突
|
|
||||||
dropIndexIfExists("models", "uk_model_name")
|
|
||||||
dropIndexIfExists("models", "model_name")
|
|
||||||
|
|
||||||
dropIndexIfExists("vendors", "uk_vendor_name")
|
|
||||||
dropIndexIfExists("vendors", "name")
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
|||||||
@ -20,17 +20,18 @@ type BoundChannel struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name,priority:1"`
|
ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name_delete_at,priority:1"`
|
||||||
Description string `json:"description,omitempty" gorm:"type:text"`
|
Description string `json:"description,omitempty" gorm:"type:text"`
|
||||||
Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
|
Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
|
||||||
Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"`
|
Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"`
|
||||||
VendorID int `json:"vendor_id,omitempty" gorm:"index"`
|
VendorID int `json:"vendor_id,omitempty" gorm:"index"`
|
||||||
Endpoints string `json:"endpoints,omitempty" gorm:"type:text"`
|
Endpoints string `json:"endpoints,omitempty" gorm:"type:text"`
|
||||||
Status int `json:"status" gorm:"default:1"`
|
Status int `json:"status" gorm:"default:1"`
|
||||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
SyncOfficial int `json:"sync_official" gorm:"default:1"`
|
||||||
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
|
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name,priority:2"`
|
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
|
||||||
|
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name_delete_at,priority:2"`
|
||||||
|
|
||||||
BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"`
|
BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"`
|
||||||
EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"`
|
EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"`
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import (
|
|||||||
"one-api/setting/config"
|
"one-api/setting/config"
|
||||||
"one-api/setting/operation_setting"
|
"one-api/setting/operation_setting"
|
||||||
"one-api/setting/ratio_setting"
|
"one-api/setting/ratio_setting"
|
||||||
|
"one-api/setting/system_setting"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -66,16 +67,16 @@ func InitOptionMap() {
|
|||||||
common.OptionMap["SystemName"] = common.SystemName
|
common.OptionMap["SystemName"] = common.SystemName
|
||||||
common.OptionMap["Logo"] = common.Logo
|
common.OptionMap["Logo"] = common.Logo
|
||||||
common.OptionMap["ServerAddress"] = ""
|
common.OptionMap["ServerAddress"] = ""
|
||||||
common.OptionMap["WorkerUrl"] = setting.WorkerUrl
|
common.OptionMap["WorkerUrl"] = system_setting.WorkerUrl
|
||||||
common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey
|
common.OptionMap["WorkerValidKey"] = system_setting.WorkerValidKey
|
||||||
common.OptionMap["WorkerAllowHttpImageRequestEnabled"] = strconv.FormatBool(setting.WorkerAllowHttpImageRequestEnabled)
|
common.OptionMap["WorkerAllowHttpImageRequestEnabled"] = strconv.FormatBool(system_setting.WorkerAllowHttpImageRequestEnabled)
|
||||||
common.OptionMap["PayAddress"] = ""
|
common.OptionMap["PayAddress"] = ""
|
||||||
common.OptionMap["CustomCallbackAddress"] = ""
|
common.OptionMap["CustomCallbackAddress"] = ""
|
||||||
common.OptionMap["EpayId"] = ""
|
common.OptionMap["EpayId"] = ""
|
||||||
common.OptionMap["EpayKey"] = ""
|
common.OptionMap["EpayKey"] = ""
|
||||||
common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64)
|
common.OptionMap["Price"] = strconv.FormatFloat(operation_setting.Price, 'f', -1, 64)
|
||||||
common.OptionMap["USDExchangeRate"] = strconv.FormatFloat(setting.USDExchangeRate, 'f', -1, 64)
|
common.OptionMap["USDExchangeRate"] = strconv.FormatFloat(operation_setting.USDExchangeRate, 'f', -1, 64)
|
||||||
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
|
common.OptionMap["MinTopUp"] = strconv.Itoa(operation_setting.MinTopUp)
|
||||||
common.OptionMap["StripeMinTopUp"] = strconv.Itoa(setting.StripeMinTopUp)
|
common.OptionMap["StripeMinTopUp"] = strconv.Itoa(setting.StripeMinTopUp)
|
||||||
common.OptionMap["StripeApiSecret"] = setting.StripeApiSecret
|
common.OptionMap["StripeApiSecret"] = setting.StripeApiSecret
|
||||||
common.OptionMap["StripeWebhookSecret"] = setting.StripeWebhookSecret
|
common.OptionMap["StripeWebhookSecret"] = setting.StripeWebhookSecret
|
||||||
@ -85,7 +86,7 @@ func InitOptionMap() {
|
|||||||
common.OptionMap["Chats"] = setting.Chats2JsonString()
|
common.OptionMap["Chats"] = setting.Chats2JsonString()
|
||||||
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
|
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
|
||||||
common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
|
common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
|
||||||
common.OptionMap["PayMethods"] = setting.PayMethods2JsonString()
|
common.OptionMap["PayMethods"] = operation_setting.PayMethods2JsonString()
|
||||||
common.OptionMap["GitHubClientId"] = ""
|
common.OptionMap["GitHubClientId"] = ""
|
||||||
common.OptionMap["GitHubClientSecret"] = ""
|
common.OptionMap["GitHubClientSecret"] = ""
|
||||||
common.OptionMap["TelegramBotToken"] = ""
|
common.OptionMap["TelegramBotToken"] = ""
|
||||||
@ -274,7 +275,7 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
case "SMTPSSLEnabled":
|
case "SMTPSSLEnabled":
|
||||||
common.SMTPSSLEnabled = boolValue
|
common.SMTPSSLEnabled = boolValue
|
||||||
case "WorkerAllowHttpImageRequestEnabled":
|
case "WorkerAllowHttpImageRequestEnabled":
|
||||||
setting.WorkerAllowHttpImageRequestEnabled = boolValue
|
system_setting.WorkerAllowHttpImageRequestEnabled = boolValue
|
||||||
case "DefaultUseAutoGroup":
|
case "DefaultUseAutoGroup":
|
||||||
setting.DefaultUseAutoGroup = boolValue
|
setting.DefaultUseAutoGroup = boolValue
|
||||||
case "ExposeRatioEnabled":
|
case "ExposeRatioEnabled":
|
||||||
@ -296,29 +297,29 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
case "SMTPToken":
|
case "SMTPToken":
|
||||||
common.SMTPToken = value
|
common.SMTPToken = value
|
||||||
case "ServerAddress":
|
case "ServerAddress":
|
||||||
setting.ServerAddress = value
|
system_setting.ServerAddress = value
|
||||||
case "WorkerUrl":
|
case "WorkerUrl":
|
||||||
setting.WorkerUrl = value
|
system_setting.WorkerUrl = value
|
||||||
case "WorkerValidKey":
|
case "WorkerValidKey":
|
||||||
setting.WorkerValidKey = value
|
system_setting.WorkerValidKey = value
|
||||||
case "PayAddress":
|
case "PayAddress":
|
||||||
setting.PayAddress = value
|
operation_setting.PayAddress = value
|
||||||
case "Chats":
|
case "Chats":
|
||||||
err = setting.UpdateChatsByJsonString(value)
|
err = setting.UpdateChatsByJsonString(value)
|
||||||
case "AutoGroups":
|
case "AutoGroups":
|
||||||
err = setting.UpdateAutoGroupsByJsonString(value)
|
err = setting.UpdateAutoGroupsByJsonString(value)
|
||||||
case "CustomCallbackAddress":
|
case "CustomCallbackAddress":
|
||||||
setting.CustomCallbackAddress = value
|
operation_setting.CustomCallbackAddress = value
|
||||||
case "EpayId":
|
case "EpayId":
|
||||||
setting.EpayId = value
|
operation_setting.EpayId = value
|
||||||
case "EpayKey":
|
case "EpayKey":
|
||||||
setting.EpayKey = value
|
operation_setting.EpayKey = value
|
||||||
case "Price":
|
case "Price":
|
||||||
setting.Price, _ = strconv.ParseFloat(value, 64)
|
operation_setting.Price, _ = strconv.ParseFloat(value, 64)
|
||||||
case "USDExchangeRate":
|
case "USDExchangeRate":
|
||||||
setting.USDExchangeRate, _ = strconv.ParseFloat(value, 64)
|
operation_setting.USDExchangeRate, _ = strconv.ParseFloat(value, 64)
|
||||||
case "MinTopUp":
|
case "MinTopUp":
|
||||||
setting.MinTopUp, _ = strconv.Atoi(value)
|
operation_setting.MinTopUp, _ = strconv.Atoi(value)
|
||||||
case "StripeApiSecret":
|
case "StripeApiSecret":
|
||||||
setting.StripeApiSecret = value
|
setting.StripeApiSecret = value
|
||||||
case "StripeWebhookSecret":
|
case "StripeWebhookSecret":
|
||||||
@ -422,7 +423,7 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
case "StreamCacheQueueLength":
|
case "StreamCacheQueueLength":
|
||||||
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
||||||
case "PayMethods":
|
case "PayMethods":
|
||||||
err = setting.UpdatePayMethodsByJsonString(value)
|
err = operation_setting.UpdatePayMethodsByJsonString(value)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,7 +16,7 @@ type TwoFA struct {
|
|||||||
Id int `json:"id" gorm:"primaryKey"`
|
Id int `json:"id" gorm:"primaryKey"`
|
||||||
UserId int `json:"user_id" gorm:"unique;not null;index"`
|
UserId int `json:"user_id" gorm:"unique;not null;index"`
|
||||||
Secret string `json:"-" gorm:"type:varchar(255);not null"` // TOTP密钥,不返回给前端
|
Secret string `json:"-" gorm:"type:varchar(255);not null"` // TOTP密钥,不返回给前端
|
||||||
IsEnabled bool `json:"is_enabled" gorm:"default:false"`
|
IsEnabled bool `json:"is_enabled"`
|
||||||
FailedAttempts int `json:"failed_attempts" gorm:"default:0"`
|
FailedAttempts int `json:"failed_attempts" gorm:"default:0"`
|
||||||
LockedUntil *time.Time `json:"locked_until,omitempty"`
|
LockedUntil *time.Time `json:"locked_until,omitempty"`
|
||||||
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
|
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
|
||||||
@ -30,7 +30,7 @@ type TwoFABackupCode struct {
|
|||||||
Id int `json:"id" gorm:"primaryKey"`
|
Id int `json:"id" gorm:"primaryKey"`
|
||||||
UserId int `json:"user_id" gorm:"not null;index"`
|
UserId int `json:"user_id" gorm:"not null;index"`
|
||||||
CodeHash string `json:"-" gorm:"type:varchar(255);not null"` // 备用码哈希
|
CodeHash string `json:"-" gorm:"type:varchar(255);not null"` // 备用码哈希
|
||||||
IsUsed bool `json:"is_used" gorm:"default:false"`
|
IsUsed bool `json:"is_used"`
|
||||||
UsedAt *time.Time `json:"used_at,omitempty"`
|
UsedAt *time.Time `json:"used_at,omitempty"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||||
|
|||||||
@ -91,6 +91,68 @@ func (user *User) SetSetting(setting dto.UserSetting) {
|
|||||||
user.Setting = string(settingBytes)
|
user.Setting = string(settingBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 根据用户角色生成默认的边栏配置
|
||||||
|
func generateDefaultSidebarConfigForRole(userRole int) string {
|
||||||
|
defaultConfig := map[string]interface{}{}
|
||||||
|
|
||||||
|
// 聊天区域 - 所有用户都可以访问
|
||||||
|
defaultConfig["chat"] = map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"playground": true,
|
||||||
|
"chat": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 控制台区域 - 所有用户都可以访问
|
||||||
|
defaultConfig["console"] = map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"detail": true,
|
||||||
|
"token": true,
|
||||||
|
"log": true,
|
||||||
|
"midjourney": true,
|
||||||
|
"task": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 个人中心区域 - 所有用户都可以访问
|
||||||
|
defaultConfig["personal"] = map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"topup": true,
|
||||||
|
"personal": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 管理员区域 - 根据角色决定
|
||||||
|
if userRole == common.RoleAdminUser {
|
||||||
|
// 管理员可以访问管理员区域,但不能访问系统设置
|
||||||
|
defaultConfig["admin"] = map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"channel": true,
|
||||||
|
"models": true,
|
||||||
|
"redemption": true,
|
||||||
|
"user": true,
|
||||||
|
"setting": false, // 管理员不能访问系统设置
|
||||||
|
}
|
||||||
|
} else if userRole == common.RoleRootUser {
|
||||||
|
// 超级管理员可以访问所有功能
|
||||||
|
defaultConfig["admin"] = map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"channel": true,
|
||||||
|
"models": true,
|
||||||
|
"redemption": true,
|
||||||
|
"user": true,
|
||||||
|
"setting": true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 普通用户不包含admin区域
|
||||||
|
|
||||||
|
// 转换为JSON字符串
|
||||||
|
configBytes, err := json.Marshal(defaultConfig)
|
||||||
|
if err != nil {
|
||||||
|
common.SysLog("生成默认边栏配置失败: " + err.Error())
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(configBytes)
|
||||||
|
}
|
||||||
|
|
||||||
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
|
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
|
||||||
func CheckUserExistOrDeleted(username string, email string) (bool, error) {
|
func CheckUserExistOrDeleted(username string, email string) (bool, error) {
|
||||||
var user User
|
var user User
|
||||||
@ -320,10 +382,34 @@ func (user *User) Insert(inviterId int) error {
|
|||||||
user.Quota = common.QuotaForNewUser
|
user.Quota = common.QuotaForNewUser
|
||||||
//user.SetAccessToken(common.GetUUID())
|
//user.SetAccessToken(common.GetUUID())
|
||||||
user.AffCode = common.GetRandomString(4)
|
user.AffCode = common.GetRandomString(4)
|
||||||
|
|
||||||
|
// 初始化用户设置,包括默认的边栏配置
|
||||||
|
if user.Setting == "" {
|
||||||
|
defaultSetting := dto.UserSetting{}
|
||||||
|
// 这里暂时不设置SidebarModules,因为需要在用户创建后根据角色设置
|
||||||
|
user.SetSetting(defaultSetting)
|
||||||
|
}
|
||||||
|
|
||||||
result := DB.Create(user)
|
result := DB.Create(user)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 用户创建成功后,根据角色初始化边栏配置
|
||||||
|
// 需要重新获取用户以确保有正确的ID和Role
|
||||||
|
var createdUser User
|
||||||
|
if err := DB.Where("username = ?", user.Username).First(&createdUser).Error; err == nil {
|
||||||
|
// 生成基于角色的默认边栏配置
|
||||||
|
defaultSidebarConfig := generateDefaultSidebarConfigForRole(createdUser.Role)
|
||||||
|
if defaultSidebarConfig != "" {
|
||||||
|
currentSetting := createdUser.GetSetting()
|
||||||
|
currentSetting.SidebarModules = defaultSidebarConfig
|
||||||
|
createdUser.SetSetting(currentSetting)
|
||||||
|
createdUser.Update(false)
|
||||||
|
common.SysLog(fmt.Sprintf("为新用户 %s (角色: %d) 初始化边栏配置", createdUser.Username, createdUser.Role))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if common.QuotaForNewUser > 0 {
|
if common.QuotaForNewUser > 0 {
|
||||||
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
|
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,13 +14,13 @@ import (
|
|||||||
|
|
||||||
type Vendor struct {
|
type Vendor struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
Name string `json:"name" gorm:"size:128;not null;uniqueIndex:uk_vendor_name,priority:1"`
|
Name string `json:"name" gorm:"size:128;not null;uniqueIndex:uk_vendor_name_delete_at,priority:1"`
|
||||||
Description string `json:"description,omitempty" gorm:"type:text"`
|
Description string `json:"description,omitempty" gorm:"type:text"`
|
||||||
Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
|
Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
|
||||||
Status int `json:"status" gorm:"default:1"`
|
Status int `json:"status" gorm:"default:1"`
|
||||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||||
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
|
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
|
||||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_vendor_name,priority:2"`
|
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_vendor_name_delete_at,priority:2"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert 创建新的供应商记录
|
// Insert 创建新的供应商记录
|
||||||
|
|||||||
@ -53,7 +53,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
|||||||
if resp != nil {
|
if resp != nil {
|
||||||
httpResp = resp.(*http.Response)
|
httpResp = resp.(*http.Response)
|
||||||
if httpResp.StatusCode != http.StatusOK {
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return newAPIError
|
return newAPIError
|
||||||
|
|||||||
@ -264,9 +264,8 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
|
|||||||
}
|
}
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, types.NewError(err, types.ErrorCodeDoRequestFailed, types.ErrOptionWithHideErrMsg("upstream error: do request failed"))
|
||||||
}
|
}
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
return nil, errors.New("resp is nil")
|
return nil, errors.New("resp is nil")
|
||||||
|
|||||||
@ -60,7 +60,16 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
// 检查是否为Nova模型
|
||||||
|
if isNovaModel(request.Model) {
|
||||||
|
novaReq := convertToNovaRequest(request)
|
||||||
|
c.Set("request_model", request.Model)
|
||||||
|
c.Set("converted_request", novaReq)
|
||||||
|
c.Set("is_nova_model", true)
|
||||||
|
return novaReq, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 原有的Claude模型处理逻辑
|
||||||
var claudeReq *dto.ClaudeRequest
|
var claudeReq *dto.ClaudeRequest
|
||||||
var err error
|
var err error
|
||||||
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
|
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
|
||||||
@ -69,6 +78,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
}
|
}
|
||||||
c.Set("request_model", claudeReq.Model)
|
c.Set("request_model", claudeReq.Model)
|
||||||
c.Set("converted_request", claudeReq)
|
c.Set("converted_request", claudeReq)
|
||||||
|
c.Set("is_nova_model", false)
|
||||||
return claudeReq, err
|
return claudeReq, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
package aws
|
package aws
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
var awsModelIDMap = map[string]string{
|
var awsModelIDMap = map[string]string{
|
||||||
"claude-instant-1.2": "anthropic.claude-instant-v1",
|
"claude-instant-1.2": "anthropic.claude-instant-v1",
|
||||||
"claude-2.0": "anthropic.claude-v2",
|
"claude-2.0": "anthropic.claude-v2",
|
||||||
@ -14,6 +16,11 @@ var awsModelIDMap = map[string]string{
|
|||||||
"claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0",
|
"claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0",
|
||||||
"claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0",
|
"claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0",
|
||||||
"claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0",
|
"claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0",
|
||||||
|
// Nova models
|
||||||
|
"nova-micro-v1:0": "amazon.nova-micro-v1:0",
|
||||||
|
"nova-lite-v1:0": "amazon.nova-lite-v1:0",
|
||||||
|
"nova-pro-v1:0": "amazon.nova-pro-v1:0",
|
||||||
|
"nova-premier-v1:0": "amazon.nova-premier-v1:0",
|
||||||
}
|
}
|
||||||
|
|
||||||
var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||||
@ -58,7 +65,27 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
|||||||
"anthropic.claude-opus-4-1-20250805-v1:0": {
|
"anthropic.claude-opus-4-1-20250805-v1:0": {
|
||||||
"us": true,
|
"us": true,
|
||||||
},
|
},
|
||||||
}
|
// Nova models - all support three major regions
|
||||||
|
"amazon.nova-micro-v1:0": {
|
||||||
|
"us": true,
|
||||||
|
"eu": true,
|
||||||
|
"apac": true,
|
||||||
|
},
|
||||||
|
"amazon.nova-lite-v1:0": {
|
||||||
|
"us": true,
|
||||||
|
"eu": true,
|
||||||
|
"apac": true,
|
||||||
|
},
|
||||||
|
"amazon.nova-pro-v1:0": {
|
||||||
|
"us": true,
|
||||||
|
"eu": true,
|
||||||
|
"apac": true,
|
||||||
|
},
|
||||||
|
"amazon.nova-premier-v1:0": {
|
||||||
|
"us": true,
|
||||||
|
"eu": true,
|
||||||
|
"apac": true,
|
||||||
|
}}
|
||||||
|
|
||||||
var awsRegionCrossModelPrefixMap = map[string]string{
|
var awsRegionCrossModelPrefixMap = map[string]string{
|
||||||
"us": "us",
|
"us": "us",
|
||||||
@ -67,3 +94,8 @@ var awsRegionCrossModelPrefixMap = map[string]string{
|
|||||||
}
|
}
|
||||||
|
|
||||||
var ChannelName = "aws"
|
var ChannelName = "aws"
|
||||||
|
|
||||||
|
// 判断是否为Nova模型
|
||||||
|
func isNovaModel(modelId string) bool {
|
||||||
|
return strings.HasPrefix(modelId, "nova-")
|
||||||
|
}
|
||||||
|
|||||||
@ -34,3 +34,92 @@ func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
|
|||||||
Thinking: req.Thinking,
|
Thinking: req.Thinking,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NovaMessage Nova模型使用messages-v1格式
|
||||||
|
type NovaMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content []NovaContent `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type NovaContent struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type NovaRequest struct {
|
||||||
|
SchemaVersion string `json:"schemaVersion"` // 请求版本,例如 "1.0"
|
||||||
|
Messages []NovaMessage `json:"messages"` // 对话消息列表
|
||||||
|
InferenceConfig *NovaInferenceConfig `json:"inferenceConfig,omitempty"` // 推理配置,可选
|
||||||
|
}
|
||||||
|
|
||||||
|
type NovaInferenceConfig struct {
|
||||||
|
MaxTokens int `json:"maxTokens,omitempty"` // 最大生成的 token 数
|
||||||
|
Temperature float64 `json:"temperature,omitempty"` // 随机性 (默认 0.7, 范围 0-1)
|
||||||
|
TopP float64 `json:"topP,omitempty"` // nucleus sampling (默认 0.9, 范围 0-1)
|
||||||
|
TopK int `json:"topK,omitempty"` // 限制候选 token 数 (默认 50, 范围 0-128)
|
||||||
|
StopSequences []string `json:"stopSequences,omitempty"` // 停止生成的序列
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换OpenAI请求为Nova格式
|
||||||
|
func convertToNovaRequest(req *dto.GeneralOpenAIRequest) *NovaRequest {
|
||||||
|
novaMessages := make([]NovaMessage, len(req.Messages))
|
||||||
|
for i, msg := range req.Messages {
|
||||||
|
novaMessages[i] = NovaMessage{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: []NovaContent{{Text: msg.StringContent()}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
novaReq := &NovaRequest{
|
||||||
|
SchemaVersion: "messages-v1",
|
||||||
|
Messages: novaMessages,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置推理配置
|
||||||
|
if req.MaxTokens != 0 || (req.Temperature != nil && *req.Temperature != 0) || req.TopP != 0 || req.TopK != 0 || req.Stop != nil {
|
||||||
|
novaReq.InferenceConfig = &NovaInferenceConfig{}
|
||||||
|
if req.MaxTokens != 0 {
|
||||||
|
novaReq.InferenceConfig.MaxTokens = int(req.MaxTokens)
|
||||||
|
}
|
||||||
|
if req.Temperature != nil && *req.Temperature != 0 {
|
||||||
|
novaReq.InferenceConfig.Temperature = *req.Temperature
|
||||||
|
}
|
||||||
|
if req.TopP != 0 {
|
||||||
|
novaReq.InferenceConfig.TopP = req.TopP
|
||||||
|
}
|
||||||
|
if req.TopK != 0 {
|
||||||
|
novaReq.InferenceConfig.TopK = req.TopK
|
||||||
|
}
|
||||||
|
if req.Stop != nil {
|
||||||
|
if stopSequences := parseStopSequences(req.Stop); len(stopSequences) > 0 {
|
||||||
|
novaReq.InferenceConfig.StopSequences = stopSequences
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return novaReq
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseStopSequences 解析停止序列,支持字符串或字符串数组
|
||||||
|
func parseStopSequences(stop any) []string {
|
||||||
|
if stop == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := stop.(type) {
|
||||||
|
case string:
|
||||||
|
if v != "" {
|
||||||
|
return []string{v}
|
||||||
|
}
|
||||||
|
case []string:
|
||||||
|
return v
|
||||||
|
case []interface{}:
|
||||||
|
var sequences []string
|
||||||
|
for _, item := range v {
|
||||||
|
if str, ok := item.(string); ok && str != "" {
|
||||||
|
sequences = append(sequences, str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sequences
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
package aws
|
package aws
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
@ -93,7 +94,19 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
|||||||
}
|
}
|
||||||
|
|
||||||
awsModelId := awsModelID(c.GetString("request_model"))
|
awsModelId := awsModelID(c.GetString("request_model"))
|
||||||
|
// 检查是否为Nova模型
|
||||||
|
isNova, _ := c.Get("is_nova_model")
|
||||||
|
if isNova == true {
|
||||||
|
// Nova模型也支持跨区域
|
||||||
|
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||||
|
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||||
|
if canCrossRegion {
|
||||||
|
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||||
|
}
|
||||||
|
return handleNovaRequest(c, awsCli, info, awsModelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 原有的Claude处理逻辑
|
||||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||||
if canCrossRegion {
|
if canCrossRegion {
|
||||||
@ -130,7 +143,12 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
|||||||
Usage: &dto.Usage{},
|
Usage: &dto.Usage{},
|
||||||
}
|
}
|
||||||
|
|
||||||
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
|
// 复制上游 Content-Type 到客户端响应头
|
||||||
|
if awsResp.ContentType != nil && *awsResp.ContentType != "" {
|
||||||
|
c.Writer.Header().Set("Content-Type", *awsResp.ContentType)
|
||||||
|
}
|
||||||
|
|
||||||
|
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, RequestModeMessage)
|
||||||
if handlerErr != nil {
|
if handlerErr != nil {
|
||||||
return handlerErr, nil
|
return handlerErr, nil
|
||||||
}
|
}
|
||||||
@ -204,3 +222,74 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
|
claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
|
||||||
return nil, claudeInfo.Usage
|
return nil, claudeInfo.Usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Nova模型处理函数
|
||||||
|
func handleNovaRequest(c *gin.Context, awsCli *bedrockruntime.Client, info *relaycommon.RelayInfo, awsModelId string) (*types.NewAPIError, *dto.Usage) {
|
||||||
|
novaReq_, ok := c.Get("converted_request")
|
||||||
|
if !ok {
|
||||||
|
return types.NewError(errors.New("nova request not found"), types.ErrorCodeInvalidRequest), nil
|
||||||
|
}
|
||||||
|
novaReq := novaReq_.(*NovaRequest)
|
||||||
|
|
||||||
|
// 使用InvokeModel API,但使用Nova格式的请求体
|
||||||
|
awsReq := &bedrockruntime.InvokeModelInput{
|
||||||
|
ModelId: aws.String(awsModelId),
|
||||||
|
Accept: aws.String("application/json"),
|
||||||
|
ContentType: aws.String("application/json"),
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBody, err := json.Marshal(novaReq)
|
||||||
|
if err != nil {
|
||||||
|
return types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody), nil
|
||||||
|
}
|
||||||
|
awsReq.Body = reqBody
|
||||||
|
|
||||||
|
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
||||||
|
if err != nil {
|
||||||
|
return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析Nova响应
|
||||||
|
var novaResp struct {
|
||||||
|
Output struct {
|
||||||
|
Message struct {
|
||||||
|
Content []struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"content"`
|
||||||
|
} `json:"message"`
|
||||||
|
} `json:"output"`
|
||||||
|
Usage struct {
|
||||||
|
InputTokens int `json:"inputTokens"`
|
||||||
|
OutputTokens int `json:"outputTokens"`
|
||||||
|
TotalTokens int `json:"totalTokens"`
|
||||||
|
} `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(awsResp.Body, &novaResp); err != nil {
|
||||||
|
return types.NewError(errors.Wrap(err, "unmarshal nova response"), types.ErrorCodeBadResponseBody), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构造OpenAI格式响应
|
||||||
|
response := dto.OpenAITextResponse{
|
||||||
|
Id: helper.GetResponseID(c),
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Model: info.UpstreamModelName,
|
||||||
|
Choices: []dto.OpenAITextResponseChoice{{
|
||||||
|
Index: 0,
|
||||||
|
Message: dto.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: novaResp.Output.Message.Content[0].Text,
|
||||||
|
},
|
||||||
|
FinishReason: "stop",
|
||||||
|
}},
|
||||||
|
Usage: dto.Usage{
|
||||||
|
PromptTokens: novaResp.Usage.InputTokens,
|
||||||
|
CompletionTokens: novaResp.Usage.OutputTokens,
|
||||||
|
TotalTokens: novaResp.Usage.TotalTokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, response)
|
||||||
|
return nil, &response.Usage
|
||||||
|
}
|
||||||
|
|||||||
@ -32,7 +32,7 @@ func stopReasonClaude2OpenAI(reason string) string {
|
|||||||
case "end_turn":
|
case "end_turn":
|
||||||
return "stop"
|
return "stop"
|
||||||
case "max_tokens":
|
case "max_tokens":
|
||||||
return "max_tokens"
|
return "length"
|
||||||
case "tool_use":
|
case "tool_use":
|
||||||
return "tool_calls"
|
return "tool_calls"
|
||||||
default:
|
default:
|
||||||
@ -274,19 +274,28 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
|
|||||||
|
|
||||||
claudeMessages := make([]dto.ClaudeMessage, 0)
|
claudeMessages := make([]dto.ClaudeMessage, 0)
|
||||||
isFirstMessage := true
|
isFirstMessage := true
|
||||||
|
// 初始化system消息数组,用于累积多个system消息
|
||||||
|
var systemMessages []dto.ClaudeMediaMessage
|
||||||
|
|
||||||
for _, message := range formatMessages {
|
for _, message := range formatMessages {
|
||||||
if message.Role == "system" {
|
if message.Role == "system" {
|
||||||
|
// 根据Claude API规范,system字段使用数组格式更有通用性
|
||||||
if message.IsStringContent() {
|
if message.IsStringContent() {
|
||||||
claudeRequest.System = message.StringContent()
|
systemMessages = append(systemMessages, dto.ClaudeMediaMessage{
|
||||||
|
Type: "text",
|
||||||
|
Text: common.GetPointer[string](message.StringContent()),
|
||||||
|
})
|
||||||
} else {
|
} else {
|
||||||
contents := message.ParseContent()
|
// 支持复合内容的system消息(虽然不常见,但需要考虑完整性)
|
||||||
content := ""
|
for _, ctx := range message.ParseContent() {
|
||||||
for _, ctx := range contents {
|
|
||||||
if ctx.Type == "text" {
|
if ctx.Type == "text" {
|
||||||
content += ctx.Text
|
systemMessages = append(systemMessages, dto.ClaudeMediaMessage{
|
||||||
|
Type: "text",
|
||||||
|
Text: common.GetPointer[string](ctx.Text),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
// 未来可以在这里扩展对图片等其他类型的支持
|
||||||
}
|
}
|
||||||
claudeRequest.System = content
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if isFirstMessage {
|
if isFirstMessage {
|
||||||
@ -392,6 +401,12 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
|
|||||||
claudeMessages = append(claudeMessages, claudeMessage)
|
claudeMessages = append(claudeMessages, claudeMessage)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 设置累积的system消息
|
||||||
|
if len(systemMessages) > 0 {
|
||||||
|
claudeRequest.System = systemMessages
|
||||||
|
}
|
||||||
|
|
||||||
claudeRequest.Prompt = ""
|
claudeRequest.Prompt = ""
|
||||||
claudeRequest.Messages = claudeMessages
|
claudeRequest.Messages = claudeMessages
|
||||||
return &claudeRequest, nil
|
return &claudeRequest, nil
|
||||||
@ -426,7 +441,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse
|
|||||||
choice.Delta.Role = "assistant"
|
choice.Delta.Role = "assistant"
|
||||||
} else if claudeResponse.Type == "content_block_start" {
|
} else if claudeResponse.Type == "content_block_start" {
|
||||||
if claudeResponse.ContentBlock != nil {
|
if claudeResponse.ContentBlock != nil {
|
||||||
//choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
|
// 如果是文本块,尽可能发送首段文本(若存在)
|
||||||
|
if claudeResponse.ContentBlock.Type == "text" && claudeResponse.ContentBlock.Text != nil {
|
||||||
|
choice.Delta.SetContentString(*claudeResponse.ContentBlock.Text)
|
||||||
|
}
|
||||||
if claudeResponse.ContentBlock.Type == "tool_use" {
|
if claudeResponse.ContentBlock.Type == "tool_use" {
|
||||||
tools = append(tools, dto.ToolCallResponse{
|
tools = append(tools, dto.ToolCallResponse{
|
||||||
Index: common.GetPointer(fcIdx),
|
Index: common.GetPointer(fcIdx),
|
||||||
@ -698,7 +716,7 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
return claudeInfo.Usage, nil
|
return claudeInfo.Usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *types.NewAPIError {
|
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, httpResp *http.Response, data []byte, requestMode int) *types.NewAPIError {
|
||||||
var claudeResponse dto.ClaudeResponse
|
var claudeResponse dto.ClaudeResponse
|
||||||
err := common.Unmarshal(data, &claudeResponse)
|
err := common.Unmarshal(data, &claudeResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -736,7 +754,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
|||||||
c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests)
|
c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests)
|
||||||
}
|
}
|
||||||
|
|
||||||
service.IOCopyBytesGracefully(c, nil, responseData)
|
service.IOCopyBytesGracefully(c, httpResp, responseData)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -757,7 +775,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
|||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
println("responseBody: ", string(responseBody))
|
println("responseBody: ", string(responseBody))
|
||||||
}
|
}
|
||||||
handleErr := HandleClaudeResponseData(c, info, claudeInfo, responseBody, requestMode)
|
handleErr := HandleClaudeResponseData(c, info, claudeInfo, resp, responseBody, requestMode)
|
||||||
if handleErr != nil {
|
if handleErr != nil {
|
||||||
return nil, handleErr
|
return nil, handleErr
|
||||||
}
|
}
|
||||||
|
|||||||
@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference?hl=zh-cn#blob
|
||||||
var geminiSupportedMimeTypes = map[string]bool{
|
var geminiSupportedMimeTypes = map[string]bool{
|
||||||
"application/pdf": true,
|
"application/pdf": true,
|
||||||
"audio/mpeg": true,
|
"audio/mpeg": true,
|
||||||
@ -30,6 +31,7 @@ var geminiSupportedMimeTypes = map[string]bool{
|
|||||||
"audio/wav": true,
|
"audio/wav": true,
|
||||||
"image/png": true,
|
"image/png": true,
|
||||||
"image/jpeg": true,
|
"image/jpeg": true,
|
||||||
|
"image/webp": true,
|
||||||
"text/plain": true,
|
"text/plain": true,
|
||||||
"video/mov": true,
|
"video/mov": true,
|
||||||
"video/mpeg": true,
|
"video/mpeg": true,
|
||||||
|
|||||||
@ -6,4 +6,4 @@ var ModelList = []string{
|
|||||||
"m3e-small",
|
"m3e-small",
|
||||||
}
|
}
|
||||||
|
|
||||||
var ChannelName = "mokaai"
|
var ChannelName = "mokaai"
|
||||||
|
|||||||
@ -2,6 +2,7 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
@ -280,11 +281,6 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
|
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
|
||||||
defer service.CloseResponseBodyGracefully(resp)
|
defer service.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
// count tokens by audio file duration
|
|
||||||
audioTokens, err := countAudioTokens(c)
|
|
||||||
if err != nil {
|
|
||||||
return types.NewError(err, types.ErrorCodeCountTokenFailed), nil
|
|
||||||
}
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||||
@ -292,6 +288,26 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
// 写入新的 response body
|
// 写入新的 response body
|
||||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||||
|
|
||||||
|
var responseData struct {
|
||||||
|
Usage *dto.Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil {
|
||||||
|
if responseData.Usage.TotalTokens > 0 {
|
||||||
|
usage := responseData.Usage
|
||||||
|
if usage.PromptTokens == 0 {
|
||||||
|
usage.PromptTokens = usage.InputTokens
|
||||||
|
}
|
||||||
|
if usage.CompletionTokens == 0 {
|
||||||
|
usage.CompletionTokens = usage.OutputTokens
|
||||||
|
}
|
||||||
|
return nil, usage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
audioTokens, err := countAudioTokens(c)
|
||||||
|
if err != nil {
|
||||||
|
return types.NewError(err, types.ErrorCodeCountTokenFailed), nil
|
||||||
|
}
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
usage.PromptTokens = audioTokens
|
usage.PromptTokens = audioTokens
|
||||||
usage.CompletionTokens = 0
|
usage.CompletionTokens = 0
|
||||||
|
|||||||
@ -46,9 +46,17 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
|||||||
usage.PromptTokensDetails.CachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
|
usage.PromptTokensDetails.CachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if info == nil || info.ResponsesUsageInfo == nil || info.ResponsesUsageInfo.BuiltInTools == nil {
|
||||||
|
return &usage, nil
|
||||||
|
}
|
||||||
// 解析 Tools 用量
|
// 解析 Tools 用量
|
||||||
for _, tool := range responsesResponse.Tools {
|
for _, tool := range responsesResponse.Tools {
|
||||||
info.ResponsesUsageInfo.BuiltInTools[common.Interface2String(tool["type"])].CallCount++
|
buildToolinfo, ok := info.ResponsesUsageInfo.BuiltInTools[common.Interface2String(tool["type"])]
|
||||||
|
if !ok || buildToolinfo == nil {
|
||||||
|
logger.LogError(c, fmt.Sprintf("BuiltInTools not found for tool type: %v", tool["type"]))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
buildToolinfo.CallCount++
|
||||||
}
|
}
|
||||||
return &usage, nil
|
return &usage, nil
|
||||||
}
|
}
|
||||||
@ -72,7 +80,7 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
|
|||||||
sendResponsesStreamData(c, streamResponse, data)
|
sendResponsesStreamData(c, streamResponse, data)
|
||||||
switch streamResponse.Type {
|
switch streamResponse.Type {
|
||||||
case "response.completed":
|
case "response.completed":
|
||||||
if streamResponse.Response.Usage != nil {
|
if streamResponse.Response != nil && streamResponse.Response.Usage != nil {
|
||||||
if streamResponse.Response.Usage.InputTokens != 0 {
|
if streamResponse.Response.Usage.InputTokens != 0 {
|
||||||
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
|
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,7 +18,6 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
@ -89,22 +88,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||||
// Accept only POST /v1/video/generations as "generate" action.
|
// Accept only POST /v1/video/generations as "generate" action.
|
||||||
action := constant.TaskActionGenerate
|
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
|
||||||
info.Action = action
|
|
||||||
|
|
||||||
req := relaycommon.TaskSubmitReq{}
|
|
||||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
|
||||||
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(req.Prompt) == "" {
|
|
||||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store into context for later usage
|
|
||||||
c.Set("task_request", req)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildRequestURL constructs the upstream URL.
|
// BuildRequestURL constructs the upstream URL.
|
||||||
@ -334,11 +318,11 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Handle one-of image_urls or binary_data_base64
|
// Handle one-of image_urls or binary_data_base64
|
||||||
if req.Image != "" {
|
if req.HasImage() {
|
||||||
if strings.HasPrefix(req.Image, "http") {
|
if strings.HasPrefix(req.Images[0], "http") {
|
||||||
r.ImageUrls = []string{req.Image}
|
r.ImageUrls = req.Images
|
||||||
} else {
|
} else {
|
||||||
r.BinaryDataBase64 = []string{req.Image}
|
r.BinaryDataBase64 = req.Images
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
metadata := req.Metadata
|
metadata := req.Metadata
|
||||||
|
|||||||
@ -16,7 +16,6 @@ import (
|
|||||||
"github.com/golang-jwt/jwt"
|
"github.com/golang-jwt/jwt"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
@ -28,16 +27,6 @@ import (
|
|||||||
// Request / Response structures
|
// Request / Response structures
|
||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
type SubmitReq struct {
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
Model string `json:"model,omitempty"`
|
|
||||||
Mode string `json:"mode,omitempty"`
|
|
||||||
Image string `json:"image,omitempty"`
|
|
||||||
Size string `json:"size,omitempty"`
|
|
||||||
Duration int `json:"duration,omitempty"`
|
|
||||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TrajectoryPoint struct {
|
type TrajectoryPoint struct {
|
||||||
X int `json:"x"`
|
X int `json:"x"`
|
||||||
Y int `json:"y"`
|
Y int `json:"y"`
|
||||||
@ -121,23 +110,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
|
|
||||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||||
// Accept only POST /v1/video/generations as "generate" action.
|
// Use the standard validation method for TaskSubmitReq
|
||||||
action := constant.TaskActionGenerate
|
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
|
||||||
info.Action = action
|
|
||||||
|
|
||||||
var req SubmitReq
|
|
||||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
|
||||||
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(req.Prompt) == "" {
|
|
||||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store into context for later usage
|
|
||||||
c.Set("task_request", req)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildRequestURL constructs the upstream URL.
|
// BuildRequestURL constructs the upstream URL.
|
||||||
@ -166,7 +140,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
if !exists {
|
if !exists {
|
||||||
return nil, fmt.Errorf("request not found in context")
|
return nil, fmt.Errorf("request not found in context")
|
||||||
}
|
}
|
||||||
req := v.(SubmitReq)
|
req := v.(relaycommon.TaskSubmitReq)
|
||||||
|
|
||||||
body, err := a.convertToRequestPayload(&req)
|
body, err := a.convertToRequestPayload(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -255,7 +229,7 @@ func (a *TaskAdaptor) GetChannelName() string {
|
|||||||
// helpers
|
// helpers
|
||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
|
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||||
r := requestPayload{
|
r := requestPayload{
|
||||||
Prompt: req.Prompt,
|
Prompt: req.Prompt,
|
||||||
Image: req.Image,
|
Image: req.Image,
|
||||||
|
|||||||
355
relay/channel/task/vertex/adaptor.go
Normal file
355
relay/channel/task/vertex/adaptor.go
Normal file
@ -0,0 +1,355 @@
|
|||||||
|
package vertex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/model"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"one-api/constant"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/relay/channel"
|
||||||
|
vertexcore "one-api/relay/channel/vertex"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ============================
|
||||||
|
// Request / Response structures
|
||||||
|
// ============================
|
||||||
|
|
||||||
|
type requestPayload struct {
|
||||||
|
Instances []map[string]any `json:"instances"`
|
||||||
|
Parameters map[string]any `json:"parameters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type submitResponse struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type operationVideo struct {
|
||||||
|
MimeType string `json:"mimeType"`
|
||||||
|
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
||||||
|
Encoding string `json:"encoding"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type operationResponse struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Done bool `json:"done"`
|
||||||
|
Response struct {
|
||||||
|
Type string `json:"@type"`
|
||||||
|
RaiMediaFilteredCount int `json:"raiMediaFilteredCount"`
|
||||||
|
Videos []operationVideo `json:"videos"`
|
||||||
|
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
||||||
|
Encoding string `json:"encoding"`
|
||||||
|
Video string `json:"video"`
|
||||||
|
} `json:"response"`
|
||||||
|
Error struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
} `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================
|
||||||
|
// Adaptor implementation
|
||||||
|
// ============================
|
||||||
|
|
||||||
|
type TaskAdaptor struct {
|
||||||
|
ChannelType int
|
||||||
|
apiKey string
|
||||||
|
baseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
|
a.ChannelType = info.ChannelType
|
||||||
|
a.baseURL = info.ChannelBaseUrl
|
||||||
|
a.apiKey = info.ApiKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||||
|
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||||
|
// Use the standard validation method for TaskSubmitReq
|
||||||
|
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionTextGenerate)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildRequestURL constructs the upstream URL.
|
||||||
|
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
adc := &vertexcore.Credentials{}
|
||||||
|
if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to decode credentials: %w", err)
|
||||||
|
}
|
||||||
|
modelName := info.OriginModelName
|
||||||
|
if modelName == "" {
|
||||||
|
modelName = "veo-3.0-generate-001"
|
||||||
|
}
|
||||||
|
|
||||||
|
region := vertexcore.GetModelRegion(info.ApiVersion, modelName)
|
||||||
|
if strings.TrimSpace(region) == "" {
|
||||||
|
region = "global"
|
||||||
|
}
|
||||||
|
if region == "global" {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predictLongRunning",
|
||||||
|
adc.ProjectID,
|
||||||
|
modelName,
|
||||||
|
), nil
|
||||||
|
}
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predictLongRunning",
|
||||||
|
region,
|
||||||
|
adc.ProjectID,
|
||||||
|
region,
|
||||||
|
modelName,
|
||||||
|
), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildRequestHeader sets required headers.
|
||||||
|
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
adc := &vertexcore.Credentials{}
|
||||||
|
if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
|
||||||
|
return fmt.Errorf("failed to decode credentials: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := vertexcore.AcquireAccessToken(*adc, "")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to acquire access token: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
req.Header.Set("x-goog-user-project", adc.ProjectID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildRequestBody converts request into Vertex specific format.
|
||||||
|
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||||
|
v, ok := c.Get("task_request")
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("request not found in context")
|
||||||
|
}
|
||||||
|
req := v.(relaycommon.TaskSubmitReq)
|
||||||
|
|
||||||
|
body := requestPayload{
|
||||||
|
Instances: []map[string]any{{"prompt": req.Prompt}},
|
||||||
|
Parameters: map[string]any{},
|
||||||
|
}
|
||||||
|
if req.Metadata != nil {
|
||||||
|
if v, ok := req.Metadata["storageUri"]; ok {
|
||||||
|
body.Parameters["storageUri"] = v
|
||||||
|
}
|
||||||
|
if v, ok := req.Metadata["sampleCount"]; ok {
|
||||||
|
body.Parameters["sampleCount"] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, ok := body.Parameters["sampleCount"]; !ok {
|
||||||
|
body.Parameters["sampleCount"] = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return bytes.NewReader(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoRequest delegates to common helper.
|
||||||
|
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoResponse handles upstream response, returns taskID etc.
|
||||||
|
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
|
var s submitResponse
|
||||||
|
if err := json.Unmarshal(responseBody, &s); err != nil {
|
||||||
|
return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(s.Name) == "" {
|
||||||
|
return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
localID := encodeLocalTaskID(s.Name)
|
||||||
|
c.JSON(http.StatusOK, gin.H{"task_id": localID})
|
||||||
|
return localID, responseBody, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) GetModelList() []string { return []string{"veo-3.0-generate-001"} }
|
||||||
|
func (a *TaskAdaptor) GetChannelName() string { return "vertex" }
|
||||||
|
|
||||||
|
// FetchTask fetch task status
|
||||||
|
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||||
|
taskID, ok := body["task_id"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid task_id")
|
||||||
|
}
|
||||||
|
upstreamName, err := decodeLocalTaskID(taskID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decode task_id failed: %w", err)
|
||||||
|
}
|
||||||
|
region := extractRegionFromOperationName(upstreamName)
|
||||||
|
if region == "" {
|
||||||
|
region = "us-central1"
|
||||||
|
}
|
||||||
|
project := extractProjectFromOperationName(upstreamName)
|
||||||
|
modelName := extractModelFromOperationName(upstreamName)
|
||||||
|
if project == "" || modelName == "" {
|
||||||
|
return nil, fmt.Errorf("cannot extract project/model from operation name")
|
||||||
|
}
|
||||||
|
var url string
|
||||||
|
if region == "global" {
|
||||||
|
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:fetchPredictOperation", project, modelName)
|
||||||
|
} else {
|
||||||
|
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName)
|
||||||
|
}
|
||||||
|
payload := map[string]string{"operationName": upstreamName}
|
||||||
|
data, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
adc := &vertexcore.Credentials{}
|
||||||
|
if err := json.Unmarshal([]byte(key), adc); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode credentials: %w", err)
|
||||||
|
}
|
||||||
|
token, err := vertexcore.AcquireAccessToken(*adc, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to acquire access token: %w", err)
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(data))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
req.Header.Set("x-goog-user-project", adc.ProjectID)
|
||||||
|
return service.GetHttpClient().Do(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||||
|
var op operationResponse
|
||||||
|
if err := json.Unmarshal(respBody, &op); err != nil {
|
||||||
|
return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
|
||||||
|
}
|
||||||
|
ti := &relaycommon.TaskInfo{}
|
||||||
|
if op.Error.Message != "" {
|
||||||
|
ti.Status = model.TaskStatusFailure
|
||||||
|
ti.Reason = op.Error.Message
|
||||||
|
ti.Progress = "100%"
|
||||||
|
return ti, nil
|
||||||
|
}
|
||||||
|
if !op.Done {
|
||||||
|
ti.Status = model.TaskStatusInProgress
|
||||||
|
ti.Progress = "50%"
|
||||||
|
return ti, nil
|
||||||
|
}
|
||||||
|
ti.Status = model.TaskStatusSuccess
|
||||||
|
ti.Progress = "100%"
|
||||||
|
if len(op.Response.Videos) > 0 {
|
||||||
|
v0 := op.Response.Videos[0]
|
||||||
|
if v0.BytesBase64Encoded != "" {
|
||||||
|
mime := strings.TrimSpace(v0.MimeType)
|
||||||
|
if mime == "" {
|
||||||
|
enc := strings.TrimSpace(v0.Encoding)
|
||||||
|
if enc == "" {
|
||||||
|
enc = "mp4"
|
||||||
|
}
|
||||||
|
if strings.Contains(enc, "/") {
|
||||||
|
mime = enc
|
||||||
|
} else {
|
||||||
|
mime = "video/" + enc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ti.Url = "data:" + mime + ";base64," + v0.BytesBase64Encoded
|
||||||
|
return ti, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if op.Response.BytesBase64Encoded != "" {
|
||||||
|
enc := strings.TrimSpace(op.Response.Encoding)
|
||||||
|
if enc == "" {
|
||||||
|
enc = "mp4"
|
||||||
|
}
|
||||||
|
mime := enc
|
||||||
|
if !strings.Contains(enc, "/") {
|
||||||
|
mime = "video/" + enc
|
||||||
|
}
|
||||||
|
ti.Url = "data:" + mime + ";base64," + op.Response.BytesBase64Encoded
|
||||||
|
return ti, nil
|
||||||
|
}
|
||||||
|
if op.Response.Video != "" { // some variants use `video` as base64
|
||||||
|
enc := strings.TrimSpace(op.Response.Encoding)
|
||||||
|
if enc == "" {
|
||||||
|
enc = "mp4"
|
||||||
|
}
|
||||||
|
mime := enc
|
||||||
|
if !strings.Contains(enc, "/") {
|
||||||
|
mime = "video/" + enc
|
||||||
|
}
|
||||||
|
ti.Url = "data:" + mime + ";base64," + op.Response.Video
|
||||||
|
return ti, nil
|
||||||
|
}
|
||||||
|
return ti, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================
|
||||||
|
// helpers
|
||||||
|
// ============================
|
||||||
|
|
||||||
|
func encodeLocalTaskID(name string) string {
|
||||||
|
return base64.RawURLEncoding.EncodeToString([]byte(name))
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeLocalTaskID(local string) (string, error) {
|
||||||
|
b, err := base64.RawURLEncoding.DecodeString(local)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`)
|
||||||
|
|
||||||
|
func extractRegionFromOperationName(name string) string {
|
||||||
|
m := regionRe.FindStringSubmatch(name)
|
||||||
|
if len(m) == 2 {
|
||||||
|
return m[1]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`)
|
||||||
|
|
||||||
|
func extractModelFromOperationName(name string) string {
|
||||||
|
m := modelRe.FindStringSubmatch(name)
|
||||||
|
if len(m) == 2 {
|
||||||
|
return m[1]
|
||||||
|
}
|
||||||
|
idx := strings.Index(name, "models/")
|
||||||
|
if idx >= 0 {
|
||||||
|
s := name[idx+len("models/"):]
|
||||||
|
if p := strings.Index(s, "/operations/"); p > 0 {
|
||||||
|
return s[:p]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var projectRe = regexp.MustCompile(`projects/([^/]+)/locations/`)
|
||||||
|
|
||||||
|
func extractProjectFromOperationName(name string) string {
|
||||||
|
m := projectRe.FindStringSubmatch(name)
|
||||||
|
if len(m) == 2 {
|
||||||
|
return m[1]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
@ -23,16 +23,6 @@ import (
|
|||||||
// Request / Response structures
|
// Request / Response structures
|
||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
type SubmitReq struct {
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
Model string `json:"model,omitempty"`
|
|
||||||
Mode string `json:"mode,omitempty"`
|
|
||||||
Image string `json:"image,omitempty"`
|
|
||||||
Size string `json:"size,omitempty"`
|
|
||||||
Duration int `json:"duration,omitempty"`
|
|
||||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type requestPayload struct {
|
type requestPayload struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Images []string `json:"images"`
|
Images []string `json:"images"`
|
||||||
@ -90,23 +80,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
|
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
|
||||||
var req SubmitReq
|
// Use the unified validation method for TaskSubmitReq with image-based action determination
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
return relaycommon.ValidateTaskRequestWithImageBinding(c, info)
|
||||||
return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Prompt == "" {
|
|
||||||
return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Image != "" {
|
|
||||||
info.Action = constant.TaskActionGenerate
|
|
||||||
} else {
|
|
||||||
info.Action = constant.TaskActionTextGenerate
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Set("task_request", req)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) {
|
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) {
|
||||||
@ -114,7 +89,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo)
|
|||||||
if !exists {
|
if !exists {
|
||||||
return nil, fmt.Errorf("request not found in context")
|
return nil, fmt.Errorf("request not found in context")
|
||||||
}
|
}
|
||||||
req := v.(SubmitReq)
|
req := v.(relaycommon.TaskSubmitReq)
|
||||||
|
|
||||||
body, err := a.convertToRequestPayload(&req)
|
body, err := a.convertToRequestPayload(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -211,7 +186,7 @@ func (a *TaskAdaptor) GetChannelName() string {
|
|||||||
// helpers
|
// helpers
|
||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
|
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||||
var images []string
|
var images []string
|
||||||
if req.Image != "" {
|
if req.Image != "" {
|
||||||
images = []string{req.Image}
|
images = []string{req.Image}
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
"one-api/relay/channel/claude"
|
"one-api/relay/channel/claude"
|
||||||
@ -80,16 +81,64 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix string) (string, error) {
|
||||||
adc := &Credentials{}
|
|
||||||
if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil {
|
|
||||||
return "", fmt.Errorf("failed to decode credentials file: %w", err)
|
|
||||||
}
|
|
||||||
region := GetModelRegion(info.ApiVersion, info.OriginModelName)
|
region := GetModelRegion(info.ApiVersion, info.OriginModelName)
|
||||||
a.AccountCredentials = *adc
|
if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||||||
|
adc := &Credentials{}
|
||||||
|
if err := common.Unmarshal([]byte(info.ApiKey), adc); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to decode credentials file: %w", err)
|
||||||
|
}
|
||||||
|
a.AccountCredentials = *adc
|
||||||
|
|
||||||
|
if a.RequestMode == RequestModeLlama {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
|
||||||
|
region,
|
||||||
|
adc.ProjectID,
|
||||||
|
region,
|
||||||
|
), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if region == "global" {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
|
||||||
|
adc.ProjectID,
|
||||||
|
modelName,
|
||||||
|
suffix,
|
||||||
|
), nil
|
||||||
|
} else {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
||||||
|
region,
|
||||||
|
adc.ProjectID,
|
||||||
|
region,
|
||||||
|
modelName,
|
||||||
|
suffix,
|
||||||
|
), nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if region == "global" {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"https://aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s",
|
||||||
|
modelName,
|
||||||
|
suffix,
|
||||||
|
info.ApiKey,
|
||||||
|
), nil
|
||||||
|
} else {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"https://%s-aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s",
|
||||||
|
region,
|
||||||
|
modelName,
|
||||||
|
suffix,
|
||||||
|
info.ApiKey,
|
||||||
|
), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
suffix := ""
|
suffix := ""
|
||||||
if a.RequestMode == RequestModeGemini {
|
if a.RequestMode == RequestModeGemini {
|
||||||
|
|
||||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||||
// 新增逻辑:处理 -thinking-<budget> 格式
|
// 新增逻辑:处理 -thinking-<budget> 格式
|
||||||
if strings.Contains(info.UpstreamModelName, "-thinking-") {
|
if strings.Contains(info.UpstreamModelName, "-thinking-") {
|
||||||
@ -111,24 +160,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||||
suffix = "predict"
|
suffix = "predict"
|
||||||
}
|
}
|
||||||
|
return a.getRequestUrl(info, info.UpstreamModelName, suffix)
|
||||||
if region == "global" {
|
|
||||||
return fmt.Sprintf(
|
|
||||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
|
|
||||||
adc.ProjectID,
|
|
||||||
info.UpstreamModelName,
|
|
||||||
suffix,
|
|
||||||
), nil
|
|
||||||
} else {
|
|
||||||
return fmt.Sprintf(
|
|
||||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
|
||||||
region,
|
|
||||||
adc.ProjectID,
|
|
||||||
region,
|
|
||||||
info.UpstreamModelName,
|
|
||||||
suffix,
|
|
||||||
), nil
|
|
||||||
}
|
|
||||||
} else if a.RequestMode == RequestModeClaude {
|
} else if a.RequestMode == RequestModeClaude {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
suffix = "streamRawPredict?alt=sse"
|
suffix = "streamRawPredict?alt=sse"
|
||||||
@ -139,41 +171,25 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
||||||
model = v
|
model = v
|
||||||
}
|
}
|
||||||
if region == "global" {
|
return a.getRequestUrl(info, model, suffix)
|
||||||
return fmt.Sprintf(
|
|
||||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
|
|
||||||
adc.ProjectID,
|
|
||||||
model,
|
|
||||||
suffix,
|
|
||||||
), nil
|
|
||||||
} else {
|
|
||||||
return fmt.Sprintf(
|
|
||||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
|
|
||||||
region,
|
|
||||||
adc.ProjectID,
|
|
||||||
region,
|
|
||||||
model,
|
|
||||||
suffix,
|
|
||||||
), nil
|
|
||||||
}
|
|
||||||
} else if a.RequestMode == RequestModeLlama {
|
} else if a.RequestMode == RequestModeLlama {
|
||||||
return fmt.Sprintf(
|
return a.getRequestUrl(info, "", "")
|
||||||
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
|
|
||||||
region,
|
|
||||||
adc.ProjectID,
|
|
||||||
region,
|
|
||||||
), nil
|
|
||||||
}
|
}
|
||||||
return "", errors.New("unsupported request mode")
|
return "", errors.New("unsupported request mode")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
accessToken, err := getAccessToken(a, info)
|
if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||||||
if err != nil {
|
accessToken, err := getAccessToken(a, info)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
}
|
||||||
|
if a.AccountCredentials.ProjectID != "" {
|
||||||
|
req.Set("x-goog-user-project", a.AccountCredentials.ProjectID)
|
||||||
}
|
}
|
||||||
req.Set("Authorization", "Bearer "+accessToken)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -12,7 +12,10 @@ func GetModelRegion(other string, localModelName string) string {
|
|||||||
if m[localModelName] != nil {
|
if m[localModelName] != nil {
|
||||||
return m[localModelName].(string)
|
return m[localModelName].(string)
|
||||||
} else {
|
} else {
|
||||||
return m["default"].(string)
|
if v, ok := m["default"]; ok {
|
||||||
|
return v.(string)
|
||||||
|
}
|
||||||
|
return "global"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return other
|
return other
|
||||||
|
|||||||
@ -6,14 +6,15 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/bytedance/gopkg/cache/asynccache"
|
|
||||||
"github.com/golang-jwt/jwt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/bytedance/gopkg/cache/asynccache"
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -137,3 +138,45 @@ func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (s
|
|||||||
|
|
||||||
return "", fmt.Errorf("failed to get access token: %v", result)
|
return "", fmt.Errorf("failed to get access token: %v", result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func AcquireAccessToken(creds Credentials, proxy string) (string, error) {
|
||||||
|
signedJWT, err := createSignedJWT(creds.ClientEmail, creds.PrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to create signed JWT: %w", err)
|
||||||
|
}
|
||||||
|
return exchangeJwtForAccessTokenWithProxy(signedJWT, proxy)
|
||||||
|
}
|
||||||
|
|
||||||
|
func exchangeJwtForAccessTokenWithProxy(signedJWT string, proxy string) (string, error) {
|
||||||
|
authURL := "https://www.googleapis.com/oauth2/v4/token"
|
||||||
|
data := url.Values{}
|
||||||
|
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
|
||||||
|
data.Set("assertion", signedJWT)
|
||||||
|
|
||||||
|
var client *http.Client
|
||||||
|
var err error
|
||||||
|
if proxy != "" {
|
||||||
|
client, err = service.NewProxyHttpClient(proxy)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("new proxy http client failed: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
client = service.GetHttpClient()
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.PostForm(authURL, data)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if accessToken, ok := result["access_token"].(string); ok {
|
||||||
|
return accessToken, nil
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("failed to get access token: %v", result)
|
||||||
|
}
|
||||||
|
|||||||
@ -111,7 +111,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
|||||||
httpResp = resp.(*http.Response)
|
httpResp = resp.(*http.Response)
|
||||||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||||
if httpResp.StatusCode != http.StatusOK {
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return newAPIError
|
return newAPIError
|
||||||
|
|||||||
@ -5,6 +5,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -151,7 +153,9 @@ func checkConditions(jsonStr string, conditions []ConditionOperation, logic stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) {
|
func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) {
|
||||||
value := gjson.Get(jsonStr, condition.Path)
|
// 处理负数索引
|
||||||
|
path := processNegativeIndex(jsonStr, condition.Path)
|
||||||
|
value := gjson.Get(jsonStr, path)
|
||||||
if !value.Exists() {
|
if !value.Exists() {
|
||||||
if condition.PassMissingKey {
|
if condition.PassMissingKey {
|
||||||
return true, nil
|
return true, nil
|
||||||
@ -177,6 +181,37 @@ func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, e
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func processNegativeIndex(jsonStr string, path string) string {
|
||||||
|
re := regexp.MustCompile(`\.(-\d+)`)
|
||||||
|
matches := re.FindAllStringSubmatch(path, -1)
|
||||||
|
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
result := path
|
||||||
|
for _, match := range matches {
|
||||||
|
negIndex := match[1]
|
||||||
|
index, _ := strconv.Atoi(negIndex)
|
||||||
|
|
||||||
|
arrayPath := strings.Split(path, negIndex)[0]
|
||||||
|
if strings.HasSuffix(arrayPath, ".") {
|
||||||
|
arrayPath = arrayPath[:len(arrayPath)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
array := gjson.Get(jsonStr, arrayPath)
|
||||||
|
if array.IsArray() {
|
||||||
|
length := len(array.Array())
|
||||||
|
actualIndex := length + index
|
||||||
|
if actualIndex >= 0 && actualIndex < length {
|
||||||
|
result = strings.Replace(result, match[0], "."+strconv.Itoa(actualIndex), 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
// compareGjsonValues 直接比较两个gjson.Result,支持所有比较模式
|
// compareGjsonValues 直接比较两个gjson.Result,支持所有比较模式
|
||||||
func compareGjsonValues(jsonValue, targetValue gjson.Result, mode string) (bool, error) {
|
func compareGjsonValues(jsonValue, targetValue gjson.Result, mode string) (bool, error) {
|
||||||
switch mode {
|
switch mode {
|
||||||
@ -274,21 +309,25 @@ func applyOperations(jsonStr string, operations []ParamOperation) (string, error
|
|||||||
if !ok {
|
if !ok {
|
||||||
continue // 条件不满足,跳过当前操作
|
continue // 条件不满足,跳过当前操作
|
||||||
}
|
}
|
||||||
|
// 处理路径中的负数索引
|
||||||
|
opPath := processNegativeIndex(result, op.Path)
|
||||||
|
opFrom := processNegativeIndex(result, op.From)
|
||||||
|
opTo := processNegativeIndex(result, op.To)
|
||||||
|
|
||||||
switch op.Mode {
|
switch op.Mode {
|
||||||
case "delete":
|
case "delete":
|
||||||
result, err = sjson.Delete(result, op.Path)
|
result, err = sjson.Delete(result, opPath)
|
||||||
case "set":
|
case "set":
|
||||||
if op.KeepOrigin && gjson.Get(result, op.Path).Exists() {
|
if op.KeepOrigin && gjson.Get(result, opPath).Exists() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
result, err = sjson.Set(result, op.Path, op.Value)
|
result, err = sjson.Set(result, opPath, op.Value)
|
||||||
case "move":
|
case "move":
|
||||||
result, err = moveValue(result, op.From, op.To)
|
result, err = moveValue(result, opFrom, opTo)
|
||||||
case "prepend":
|
case "prepend":
|
||||||
result, err = modifyValue(result, op.Path, op.Value, op.KeepOrigin, true)
|
result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, true)
|
||||||
case "append":
|
case "append":
|
||||||
result, err = modifyValue(result, op.Path, op.Value, op.KeepOrigin, false)
|
result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, false)
|
||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("unknown operation: %s", op.Mode)
|
return "", fmt.Errorf("unknown operation: %s", op.Mode)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -481,11 +481,20 @@ type TaskSubmitReq struct {
|
|||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
Mode string `json:"mode,omitempty"`
|
Mode string `json:"mode,omitempty"`
|
||||||
Image string `json:"image,omitempty"`
|
Image string `json:"image,omitempty"`
|
||||||
|
Images []string `json:"images,omitempty"`
|
||||||
Size string `json:"size,omitempty"`
|
Size string `json:"size,omitempty"`
|
||||||
Duration int `json:"duration,omitempty"`
|
Duration int `json:"duration,omitempty"`
|
||||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t TaskSubmitReq) GetPrompt() string {
|
||||||
|
return t.Prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t TaskSubmitReq) HasImage() bool {
|
||||||
|
return len(t.Images) > 0
|
||||||
|
}
|
||||||
|
|
||||||
type TaskInfo struct {
|
type TaskInfo struct {
|
||||||
Code int `json:"code"`
|
Code int `json:"code"`
|
||||||
TaskID string `json:"task_id"`
|
TaskID string `json:"task_id"`
|
||||||
|
|||||||
@ -2,14 +2,23 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"net/http"
|
||||||
_ "image/gif"
|
"one-api/common"
|
||||||
_ "image/jpeg"
|
|
||||||
_ "image/png"
|
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
|
"one-api/dto"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type HasPrompt interface {
|
||||||
|
GetPrompt() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type HasImage interface {
|
||||||
|
HasImage() bool
|
||||||
|
}
|
||||||
|
|
||||||
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
||||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||||
|
|
||||||
@ -32,3 +41,72 @@ func GetAPIVersion(c *gin.Context) string {
|
|||||||
}
|
}
|
||||||
return apiVersion
|
return apiVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError {
|
||||||
|
return &dto.TaskError{
|
||||||
|
Code: code,
|
||||||
|
Message: err.Error(),
|
||||||
|
StatusCode: statusCode,
|
||||||
|
LocalError: localError,
|
||||||
|
Error: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj interface{}) {
|
||||||
|
info.Action = action
|
||||||
|
c.Set("task_request", requestObj)
|
||||||
|
}
|
||||||
|
|
||||||
|
func validatePrompt(prompt string) *dto.TaskError {
|
||||||
|
if strings.TrimSpace(prompt) == "" {
|
||||||
|
return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
|
||||||
|
var req TaskSubmitReq
|
||||||
|
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||||
|
return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
if taskErr := validatePrompt(req.Prompt); taskErr != nil {
|
||||||
|
return taskErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" {
|
||||||
|
// 兼容单图上传
|
||||||
|
req.Images = []string{req.Image}
|
||||||
|
}
|
||||||
|
|
||||||
|
storeTaskRequest(c, info, action, req)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateTaskRequestWithImage(c *gin.Context, info *RelayInfo, requestObj interface{}) *dto.TaskError {
|
||||||
|
hasPrompt, ok := requestObj.(HasPrompt)
|
||||||
|
if !ok {
|
||||||
|
return createTaskError(fmt.Errorf("request must have prompt"), "invalid_request", http.StatusBadRequest, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
if taskErr := validatePrompt(hasPrompt.GetPrompt()); taskErr != nil {
|
||||||
|
return taskErr
|
||||||
|
}
|
||||||
|
|
||||||
|
action := constant.TaskActionTextGenerate
|
||||||
|
if hasImage, ok := requestObj.(HasImage); ok && hasImage.HasImage() {
|
||||||
|
action = constant.TaskActionGenerate
|
||||||
|
}
|
||||||
|
|
||||||
|
storeTaskRequest(c, info, action, requestObj)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateTaskRequestWithImageBinding(c *gin.Context, info *RelayInfo) *dto.TaskError {
|
||||||
|
var req TaskSubmitReq
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
return createTaskError(err, "invalid_request_body", http.StatusBadRequest, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ValidateTaskRequestWithImage(c, info, req)
|
||||||
|
}
|
||||||
|
|||||||
@ -158,7 +158,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
|||||||
httpResp = resp.(*http.Response)
|
httpResp = resp.(*http.Response)
|
||||||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||||
if httpResp.StatusCode != http.StatusOK {
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
newApiErr := service.RelayErrorHandler(httpResp, false)
|
newApiErr := service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(newApiErr, statusCodeMappingStr)
|
service.ResetStatusCode(newApiErr, statusCodeMappingStr)
|
||||||
return newApiErr
|
return newApiErr
|
||||||
@ -195,6 +195,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
|||||||
imageTokens := usage.PromptTokensDetails.ImageTokens
|
imageTokens := usage.PromptTokensDetails.ImageTokens
|
||||||
audioTokens := usage.PromptTokensDetails.AudioTokens
|
audioTokens := usage.PromptTokensDetails.AudioTokens
|
||||||
completionTokens := usage.CompletionTokens
|
completionTokens := usage.CompletionTokens
|
||||||
|
cachedCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
|
||||||
|
|
||||||
modelName := relayInfo.OriginModelName
|
modelName := relayInfo.OriginModelName
|
||||||
|
|
||||||
tokenName := ctx.GetString("token_name")
|
tokenName := ctx.GetString("token_name")
|
||||||
@ -204,6 +206,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
|||||||
modelRatio := relayInfo.PriceData.ModelRatio
|
modelRatio := relayInfo.PriceData.ModelRatio
|
||||||
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
|
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
|
||||||
modelPrice := relayInfo.PriceData.ModelPrice
|
modelPrice := relayInfo.PriceData.ModelPrice
|
||||||
|
cachedCreationRatio := relayInfo.PriceData.CacheCreationRatio
|
||||||
|
|
||||||
// Convert values to decimal for precise calculation
|
// Convert values to decimal for precise calculation
|
||||||
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
|
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
|
||||||
@ -211,12 +214,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
|||||||
dImageTokens := decimal.NewFromInt(int64(imageTokens))
|
dImageTokens := decimal.NewFromInt(int64(imageTokens))
|
||||||
dAudioTokens := decimal.NewFromInt(int64(audioTokens))
|
dAudioTokens := decimal.NewFromInt(int64(audioTokens))
|
||||||
dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
|
dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
|
||||||
|
dCachedCreationTokens := decimal.NewFromInt(int64(cachedCreationTokens))
|
||||||
dCompletionRatio := decimal.NewFromFloat(completionRatio)
|
dCompletionRatio := decimal.NewFromFloat(completionRatio)
|
||||||
dCacheRatio := decimal.NewFromFloat(cacheRatio)
|
dCacheRatio := decimal.NewFromFloat(cacheRatio)
|
||||||
dImageRatio := decimal.NewFromFloat(imageRatio)
|
dImageRatio := decimal.NewFromFloat(imageRatio)
|
||||||
dModelRatio := decimal.NewFromFloat(modelRatio)
|
dModelRatio := decimal.NewFromFloat(modelRatio)
|
||||||
dGroupRatio := decimal.NewFromFloat(groupRatio)
|
dGroupRatio := decimal.NewFromFloat(groupRatio)
|
||||||
dModelPrice := decimal.NewFromFloat(modelPrice)
|
dModelPrice := decimal.NewFromFloat(modelPrice)
|
||||||
|
dCachedCreationRatio := decimal.NewFromFloat(cachedCreationRatio)
|
||||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||||
|
|
||||||
ratio := dModelRatio.Mul(dGroupRatio)
|
ratio := dModelRatio.Mul(dGroupRatio)
|
||||||
@ -284,6 +289,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
|||||||
baseTokens = baseTokens.Sub(dCacheTokens)
|
baseTokens = baseTokens.Sub(dCacheTokens)
|
||||||
cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
|
cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
|
||||||
}
|
}
|
||||||
|
var dCachedCreationTokensWithRatio decimal.Decimal
|
||||||
|
if !dCachedCreationTokens.IsZero() {
|
||||||
|
baseTokens = baseTokens.Sub(dCachedCreationTokens)
|
||||||
|
dCachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCachedCreationRatio)
|
||||||
|
}
|
||||||
|
|
||||||
// 减去 image tokens
|
// 减去 image tokens
|
||||||
var imageTokensWithRatio decimal.Decimal
|
var imageTokensWithRatio decimal.Decimal
|
||||||
@ -302,7 +312,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
|||||||
extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String())
|
extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio)
|
promptQuota := baseTokens.Add(cachedTokensWithRatio).
|
||||||
|
Add(imageTokensWithRatio).
|
||||||
|
Add(dCachedCreationTokensWithRatio)
|
||||||
|
|
||||||
completionQuota := dCompletionTokens.Mul(dCompletionRatio)
|
completionQuota := dCompletionTokens.Mul(dCompletionRatio)
|
||||||
|
|
||||||
@ -384,6 +396,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
|||||||
other["image_ratio"] = imageRatio
|
other["image_ratio"] = imageRatio
|
||||||
other["image_output"] = imageTokens
|
other["image_output"] = imageTokens
|
||||||
}
|
}
|
||||||
|
if cachedCreationTokens != 0 {
|
||||||
|
other["cache_creation_tokens"] = cachedCreationTokens
|
||||||
|
other["cache_creation_ratio"] = cachedCreationRatio
|
||||||
|
}
|
||||||
if !dWebSearchQuota.IsZero() {
|
if !dWebSearchQuota.IsZero() {
|
||||||
if relayInfo.ResponsesUsageInfo != nil {
|
if relayInfo.ResponsesUsageInfo != nil {
|
||||||
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists {
|
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists {
|
||||||
|
|||||||
@ -58,7 +58,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
|||||||
if resp != nil {
|
if resp != nil {
|
||||||
httpResp = resp.(*http.Response)
|
httpResp = resp.(*http.Response)
|
||||||
if httpResp.StatusCode != http.StatusOK {
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return newAPIError
|
return newAPIError
|
||||||
|
|||||||
@ -152,7 +152,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
|||||||
httpResp = resp.(*http.Response)
|
httpResp = resp.(*http.Response)
|
||||||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||||
if httpResp.StatusCode != http.StatusOK {
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return newAPIError
|
return newAPIError
|
||||||
@ -249,7 +249,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
|
|||||||
if resp != nil {
|
if resp != nil {
|
||||||
httpResp = resp.(*http.Response)
|
httpResp = resp.(*http.Response)
|
||||||
if httpResp.StatusCode != http.StatusOK {
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return newAPIError
|
return newAPIError
|
||||||
}
|
}
|
||||||
|
|||||||
@ -91,7 +91,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
|||||||
httpResp = resp.(*http.Response)
|
httpResp = resp.(*http.Response)
|
||||||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||||
if httpResp.StatusCode != http.StatusOK {
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return newAPIError
|
return newAPIError
|
||||||
@ -120,7 +120,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
|||||||
var logContent string
|
var logContent string
|
||||||
|
|
||||||
if len(request.Size) > 0 {
|
if len(request.Size) > 0 {
|
||||||
logContent = fmt.Sprintf("大小 %s, 品质 %s", request.Size, quality)
|
logContent = fmt.Sprintf("大小 %s, 品质 %s, 张数 %d", request.Size, quality, request.N)
|
||||||
}
|
}
|
||||||
|
|
||||||
postConsumeQuota(c, info, usage.(*dto.Usage), logContent)
|
postConsumeQuota(c, info, usage.(*dto.Usage), logContent)
|
||||||
|
|||||||
@ -16,6 +16,7 @@ import (
|
|||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
|
"one-api/setting/system_setting"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -131,7 +132,7 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
|
|||||||
midjourneyTask.FinishTime = originTask.FinishTime
|
midjourneyTask.FinishTime = originTask.FinishTime
|
||||||
midjourneyTask.ImageUrl = ""
|
midjourneyTask.ImageUrl = ""
|
||||||
if originTask.ImageUrl != "" && setting.MjForwardUrlEnabled {
|
if originTask.ImageUrl != "" && setting.MjForwardUrlEnabled {
|
||||||
midjourneyTask.ImageUrl = setting.ServerAddress + "/mj/image/" + originTask.MjId
|
midjourneyTask.ImageUrl = system_setting.ServerAddress + "/mj/image/" + originTask.MjId
|
||||||
if originTask.Status != "SUCCESS" {
|
if originTask.Status != "SUCCESS" {
|
||||||
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
|
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
package relay
|
package relay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
"one-api/relay/channel/ali"
|
"one-api/relay/channel/ali"
|
||||||
@ -28,6 +27,7 @@ import (
|
|||||||
taskjimeng "one-api/relay/channel/task/jimeng"
|
taskjimeng "one-api/relay/channel/task/jimeng"
|
||||||
"one-api/relay/channel/task/kling"
|
"one-api/relay/channel/task/kling"
|
||||||
"one-api/relay/channel/task/suno"
|
"one-api/relay/channel/task/suno"
|
||||||
|
taskvertex "one-api/relay/channel/task/vertex"
|
||||||
taskVidu "one-api/relay/channel/task/vidu"
|
taskVidu "one-api/relay/channel/task/vidu"
|
||||||
"one-api/relay/channel/tencent"
|
"one-api/relay/channel/tencent"
|
||||||
"one-api/relay/channel/vertex"
|
"one-api/relay/channel/vertex"
|
||||||
@ -37,6 +37,8 @@ import (
|
|||||||
"one-api/relay/channel/zhipu"
|
"one-api/relay/channel/zhipu"
|
||||||
"one-api/relay/channel/zhipu_4v"
|
"one-api/relay/channel/zhipu_4v"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetAdaptor(apiType int) channel.Adaptor {
|
func GetAdaptor(apiType int) channel.Adaptor {
|
||||||
@ -126,6 +128,8 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
|
|||||||
return &kling.TaskAdaptor{}
|
return &kling.TaskAdaptor{}
|
||||||
case constant.ChannelTypeJimeng:
|
case constant.ChannelTypeJimeng:
|
||||||
return &taskjimeng.TaskAdaptor{}
|
return &taskjimeng.TaskAdaptor{}
|
||||||
|
case constant.ChannelTypeVertexAi:
|
||||||
|
return &taskvertex.TaskAdaptor{}
|
||||||
case constant.ChannelTypeVidu:
|
case constant.ChannelTypeVidu:
|
||||||
return &taskVidu.TaskAdaptor{}
|
return &taskVidu.TaskAdaptor{}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -15,6 +15,8 @@ import (
|
|||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting/ratio_setting"
|
"one-api/setting/ratio_setting"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@ -33,6 +35,7 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
|||||||
platform = GetTaskPlatform(c)
|
platform = GetTaskPlatform(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
info.InitChannelMeta(c)
|
||||||
adaptor := GetTaskAdaptor(platform)
|
adaptor := GetTaskAdaptor(platform)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
|
return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
|
||||||
@ -197,6 +200,9 @@ func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
|
|||||||
if taskErr != nil {
|
if taskErr != nil {
|
||||||
return taskErr
|
return taskErr
|
||||||
}
|
}
|
||||||
|
if len(respBody) == 0 {
|
||||||
|
respBody = []byte("{\"code\":\"success\",\"data\":null}")
|
||||||
|
}
|
||||||
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
_, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
_, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||||
@ -276,10 +282,92 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
respBody, err = json.Marshal(dto.TaskResponse[any]{
|
func() {
|
||||||
Code: "success",
|
channelModel, err2 := model.GetChannelById(originTask.ChannelId, true)
|
||||||
Data: TaskModel2Dto(originTask),
|
if err2 != nil {
|
||||||
})
|
return
|
||||||
|
}
|
||||||
|
if channelModel.Type != constant.ChannelTypeVertexAi {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
baseURL := constant.ChannelBaseURLs[channelModel.Type]
|
||||||
|
if channelModel.GetBaseURL() != "" {
|
||||||
|
baseURL = channelModel.GetBaseURL()
|
||||||
|
}
|
||||||
|
adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
|
||||||
|
if adaptor == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
|
||||||
|
"task_id": originTask.TaskID,
|
||||||
|
"action": originTask.Action,
|
||||||
|
})
|
||||||
|
if err2 != nil || resp == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
body, err2 := io.ReadAll(resp.Body)
|
||||||
|
if err2 != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ti, err2 := adaptor.ParseTaskResult(body)
|
||||||
|
if err2 == nil && ti != nil {
|
||||||
|
if ti.Status != "" {
|
||||||
|
originTask.Status = model.TaskStatus(ti.Status)
|
||||||
|
}
|
||||||
|
if ti.Progress != "" {
|
||||||
|
originTask.Progress = ti.Progress
|
||||||
|
}
|
||||||
|
if ti.Url != "" {
|
||||||
|
originTask.FailReason = ti.Url
|
||||||
|
}
|
||||||
|
_ = originTask.Update()
|
||||||
|
var raw map[string]any
|
||||||
|
_ = json.Unmarshal(body, &raw)
|
||||||
|
format := "mp4"
|
||||||
|
if respObj, ok := raw["response"].(map[string]any); ok {
|
||||||
|
if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 {
|
||||||
|
if v0, ok := vids[0].(map[string]any); ok {
|
||||||
|
if mt, ok := v0["mimeType"].(string); ok && mt != "" {
|
||||||
|
if strings.Contains(mt, "mp4") {
|
||||||
|
format = "mp4"
|
||||||
|
} else {
|
||||||
|
format = mt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
status := "processing"
|
||||||
|
switch originTask.Status {
|
||||||
|
case model.TaskStatusSuccess:
|
||||||
|
status = "succeeded"
|
||||||
|
case model.TaskStatusFailure:
|
||||||
|
status = "failed"
|
||||||
|
case model.TaskStatusQueued, model.TaskStatusSubmitted:
|
||||||
|
status = "queued"
|
||||||
|
}
|
||||||
|
out := map[string]any{
|
||||||
|
"error": nil,
|
||||||
|
"format": format,
|
||||||
|
"metadata": nil,
|
||||||
|
"status": status,
|
||||||
|
"task_id": originTask.TaskID,
|
||||||
|
"url": originTask.FailReason,
|
||||||
|
}
|
||||||
|
respBody, _ = json.Marshal(dto.TaskResponse[any]{
|
||||||
|
Code: "success",
|
||||||
|
Data: out,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if len(respBody) == 0 {
|
||||||
|
respBody, err = json.Marshal(dto.TaskResponse[any]{
|
||||||
|
Code: "success",
|
||||||
|
Data: TaskModel2Dto(originTask),
|
||||||
|
})
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -81,7 +81,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
|||||||
if resp != nil {
|
if resp != nil {
|
||||||
httpResp = resp.(*http.Response)
|
httpResp = resp.(*http.Response)
|
||||||
if httpResp.StatusCode != http.StatusOK {
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return newAPIError
|
return newAPIError
|
||||||
|
|||||||
@ -41,7 +41,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
|||||||
}
|
}
|
||||||
adaptor.Init(info)
|
adaptor.Init(info)
|
||||||
var requestBody io.Reader
|
var requestBody io.Reader
|
||||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
|
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
|
||||||
body, err := common.GetRequestBody(c)
|
body, err := common.GetRequestBody(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry())
|
return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry())
|
||||||
@ -82,7 +82,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
|||||||
httpResp = resp.(*http.Response)
|
httpResp = resp.(*http.Response)
|
||||||
|
|
||||||
if httpResp.StatusCode != http.StatusOK {
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return newAPIError
|
return newAPIError
|
||||||
|
|||||||
@ -60,6 +60,7 @@ func SetApiRouter(router *gin.Engine) {
|
|||||||
selfRoute.DELETE("/self", controller.DeleteSelf)
|
selfRoute.DELETE("/self", controller.DeleteSelf)
|
||||||
selfRoute.GET("/token", controller.GenerateAccessToken)
|
selfRoute.GET("/token", controller.GenerateAccessToken)
|
||||||
selfRoute.GET("/aff", controller.GetAffCode)
|
selfRoute.GET("/aff", controller.GetAffCode)
|
||||||
|
selfRoute.GET("/topup/info", controller.GetTopUpInfo)
|
||||||
selfRoute.POST("/topup", middleware.CriticalRateLimit(), controller.TopUp)
|
selfRoute.POST("/topup", middleware.CriticalRateLimit(), controller.TopUp)
|
||||||
selfRoute.POST("/pay", middleware.CriticalRateLimit(), controller.RequestEpay)
|
selfRoute.POST("/pay", middleware.CriticalRateLimit(), controller.RequestEpay)
|
||||||
selfRoute.POST("/amount", controller.RequestAmount)
|
selfRoute.POST("/amount", controller.RequestAmount)
|
||||||
@ -224,6 +225,8 @@ func SetApiRouter(router *gin.Engine) {
|
|||||||
modelsRoute := apiRouter.Group("/models")
|
modelsRoute := apiRouter.Group("/models")
|
||||||
modelsRoute.Use(middleware.AdminAuth())
|
modelsRoute.Use(middleware.AdminAuth())
|
||||||
{
|
{
|
||||||
|
modelsRoute.GET("/sync_upstream/preview", controller.SyncUpstreamPreview)
|
||||||
|
modelsRoute.POST("/sync_upstream", controller.SyncUpstreamModels)
|
||||||
modelsRoute.GET("/missing", controller.GetMissingModels)
|
modelsRoute.GET("/missing", controller.GetMissingModels)
|
||||||
modelsRoute.GET("/", controller.GetAllModelsMeta)
|
modelsRoute.GET("/", controller.GetAllModelsMeta)
|
||||||
modelsRoute.GET("/search", controller.SearchModelsMeta)
|
modelsRoute.GET("/search", controller.SearchModelsMeta)
|
||||||
|
|||||||
@ -1,12 +1,13 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"one-api/setting"
|
"one-api/setting/operation_setting"
|
||||||
|
"one-api/setting/system_setting"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetCallbackAddress() string {
|
func GetCallbackAddress() string {
|
||||||
if setting.CustomCallbackAddress == "" {
|
if operation_setting.CustomCallbackAddress == "" {
|
||||||
return setting.ServerAddress
|
return system_setting.ServerAddress
|
||||||
}
|
}
|
||||||
return setting.CustomCallbackAddress
|
return operation_setting.CustomCallbackAddress
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,12 +1,14 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/logger"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -78,7 +80,7 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude
|
|||||||
return claudeErr
|
return claudeErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
|
func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
|
||||||
newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
|
newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
|
||||||
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
@ -94,7 +96,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t
|
|||||||
newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
|
newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
|
||||||
} else {
|
} else {
|
||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
println(fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
|
logger.LogInfo(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
|
||||||
}
|
}
|
||||||
newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
|
newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,6 +5,9 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"image"
|
"image"
|
||||||
|
_ "image/gif"
|
||||||
|
_ "image/jpeg"
|
||||||
|
_ "image/png"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
|||||||
@ -13,13 +13,13 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) {
|
func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
|
||||||
if preConsumedQuota != 0 {
|
if relayInfo.FinalPreConsumedQuota != 0 {
|
||||||
logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota)))
|
logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(relayInfo.FinalPreConsumedQuota)))
|
||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
relayInfoCopy := *relayInfo
|
relayInfoCopy := *relayInfo
|
||||||
|
|
||||||
err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
|
err := PostConsumeQuota(&relayInfoCopy, -relayInfoCopy.FinalPreConsumedQuota, 0, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysLog("error return pre-consumed quota: " + err.Error())
|
common.SysLog("error return pre-consumed quota: " + err.Error())
|
||||||
}
|
}
|
||||||
@ -29,16 +29,16 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, pr
|
|||||||
|
|
||||||
// PreConsumeQuota checks if the user has enough quota to pre-consume.
|
// PreConsumeQuota checks if the user has enough quota to pre-consume.
|
||||||
// It returns the pre-consumed quota if successful, or an error if not.
|
// It returns the pre-consumed quota if successful, or an error if not.
|
||||||
func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *types.NewAPIError) {
|
func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError {
|
||||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||||
}
|
}
|
||||||
if userQuota <= 0 {
|
if userQuota <= 0 {
|
||||||
return 0, types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
return types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||||
}
|
}
|
||||||
if userQuota-preConsumedQuota < 0 {
|
if userQuota-preConsumedQuota < 0 {
|
||||||
return 0, types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
return types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||||
}
|
}
|
||||||
|
|
||||||
trustQuota := common.GetTrustQuota()
|
trustQuota := common.GetTrustQuota()
|
||||||
@ -65,14 +65,14 @@ func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
|||||||
if preConsumedQuota > 0 {
|
if preConsumedQuota > 0 {
|
||||||
err := PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
err := PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||||
}
|
}
|
||||||
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
|
return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
|
||||||
}
|
}
|
||||||
logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota)))
|
logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota)))
|
||||||
}
|
}
|
||||||
relayInfo.FinalPreConsumedQuota = preConsumedQuota
|
relayInfo.FinalPreConsumedQuota = preConsumedQuota
|
||||||
return preConsumedQuota, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,8 +11,8 @@ import (
|
|||||||
"one-api/logger"
|
"one-api/logger"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/setting"
|
|
||||||
"one-api/setting/ratio_setting"
|
"one-api/setting/ratio_setting"
|
||||||
|
"one-api/setting/system_setting"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -534,9 +534,28 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon
|
|||||||
}
|
}
|
||||||
if quotaTooLow {
|
if quotaTooLow {
|
||||||
prompt := "您的额度即将用尽"
|
prompt := "您的额度即将用尽"
|
||||||
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
|
topUpLink := fmt.Sprintf("%s/topup", system_setting.ServerAddress)
|
||||||
content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
|
|
||||||
err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}))
|
// 根据通知方式生成不同的内容格式
|
||||||
|
var content string
|
||||||
|
var values []interface{}
|
||||||
|
|
||||||
|
notifyType := userSetting.NotifyType
|
||||||
|
if notifyType == "" {
|
||||||
|
notifyType = dto.NotifyTypeEmail
|
||||||
|
}
|
||||||
|
|
||||||
|
if notifyType == dto.NotifyTypeBark {
|
||||||
|
// Bark推送使用简短文本,不支持HTML
|
||||||
|
content = "{{value}},剩余额度:{{value}},请及时充值"
|
||||||
|
values = []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota)}
|
||||||
|
} else {
|
||||||
|
// 默认内容格式,适用于Email和Webhook
|
||||||
|
content = "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
|
||||||
|
values = []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}
|
||||||
|
}
|
||||||
|
|
||||||
|
err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, values))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error()))
|
common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error()))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,6 +5,9 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"image"
|
"image"
|
||||||
|
_ "image/gif"
|
||||||
|
_ "image/jpeg"
|
||||||
|
_ "image/png"
|
||||||
"log"
|
"log"
|
||||||
"math"
|
"math"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
@ -357,33 +360,6 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
|||||||
return tkm, nil
|
return tkm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
|
|
||||||
// tkm := 0
|
|
||||||
// msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
|
|
||||||
// if err != nil {
|
|
||||||
// return 0, err
|
|
||||||
// }
|
|
||||||
// tkm += msgTokens
|
|
||||||
// if request.Tools != nil {
|
|
||||||
// openaiTools := request.Tools
|
|
||||||
// countStr := ""
|
|
||||||
// for _, tool := range openaiTools {
|
|
||||||
// countStr = tool.Function.Name
|
|
||||||
// if tool.Function.Description != "" {
|
|
||||||
// countStr += tool.Function.Description
|
|
||||||
// }
|
|
||||||
// if tool.Function.Parameters != nil {
|
|
||||||
// countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// toolTokens := CountTokenInput(countStr, request.Model)
|
|
||||||
// tkm += 8
|
|
||||||
// tkm += toolTokens
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// return tkm, nil
|
|
||||||
//}
|
|
||||||
|
|
||||||
func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) {
|
func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) {
|
||||||
tkm := 0
|
tkm := 0
|
||||||
|
|
||||||
@ -543,56 +519,6 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
|||||||
return textToken, audioToken, nil
|
return textToken, audioToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
|
|
||||||
// //recover when panic
|
|
||||||
// tokenEncoder := getTokenEncoder(model)
|
|
||||||
// // Reference:
|
|
||||||
// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
|
||||||
// // https://github.com/pkoukk/tiktoken-go/issues/6
|
|
||||||
// //
|
|
||||||
// // Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
|
||||||
// var tokensPerMessage int
|
|
||||||
// var tokensPerName int
|
|
||||||
//
|
|
||||||
// tokensPerMessage = 3
|
|
||||||
// tokensPerName = 1
|
|
||||||
//
|
|
||||||
// tokenNum := 0
|
|
||||||
// for _, message := range messages {
|
|
||||||
// tokenNum += tokensPerMessage
|
|
||||||
// tokenNum += getTokenNum(tokenEncoder, message.Role)
|
|
||||||
// if message.Content != nil {
|
|
||||||
// if message.Name != nil {
|
|
||||||
// tokenNum += tokensPerName
|
|
||||||
// tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
|
||||||
// }
|
|
||||||
// arrayContent := message.ParseContent()
|
|
||||||
// for _, m := range arrayContent {
|
|
||||||
// if m.Type == dto.ContentTypeImageURL {
|
|
||||||
// imageUrl := m.GetImageMedia()
|
|
||||||
// imageTokenNum, err := getImageToken(info, imageUrl, model, stream)
|
|
||||||
// if err != nil {
|
|
||||||
// return 0, err
|
|
||||||
// }
|
|
||||||
// tokenNum += imageTokenNum
|
|
||||||
// log.Printf("image token num: %d", imageTokenNum)
|
|
||||||
// } else if m.Type == dto.ContentTypeInputAudio {
|
|
||||||
// // TODO: 音频token数量计算
|
|
||||||
// tokenNum += 100
|
|
||||||
// } else if m.Type == dto.ContentTypeFile {
|
|
||||||
// tokenNum += 5000
|
|
||||||
// } else if m.Type == dto.ContentTypeVideoUrl {
|
|
||||||
// tokenNum += 5000
|
|
||||||
// } else {
|
|
||||||
// tokenNum += getTokenNum(tokenEncoder, m.Text)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
|
||||||
// return tokenNum, nil
|
|
||||||
//}
|
|
||||||
|
|
||||||
func CountTokenInput(input any, model string) int {
|
func CountTokenInput(input any, model string) int {
|
||||||
switch v := input.(type) {
|
switch v := input.(type) {
|
||||||
case string:
|
case string:
|
||||||
|
|||||||
@ -2,9 +2,12 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
"one-api/setting"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -51,6 +54,13 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data
|
|||||||
// 获取 webhook secret
|
// 获取 webhook secret
|
||||||
webhookSecret := userSetting.WebhookSecret
|
webhookSecret := userSetting.WebhookSecret
|
||||||
return SendWebhookNotify(webhookURLStr, webhookSecret, data)
|
return SendWebhookNotify(webhookURLStr, webhookSecret, data)
|
||||||
|
case dto.NotifyTypeBark:
|
||||||
|
barkURL := userSetting.BarkUrl
|
||||||
|
if barkURL == "" {
|
||||||
|
common.SysLog(fmt.Sprintf("user %d has no bark url, skip sending bark", userId))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return sendBarkNotify(barkURL, data)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -64,3 +74,67 @@ func sendEmailNotify(userEmail string, data dto.Notify) error {
|
|||||||
}
|
}
|
||||||
return common.SendEmail(data.Title, userEmail, content)
|
return common.SendEmail(data.Title, userEmail, content)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sendBarkNotify(barkURL string, data dto.Notify) error {
|
||||||
|
// 处理占位符
|
||||||
|
content := data.Content
|
||||||
|
for _, value := range data.Values {
|
||||||
|
content = strings.Replace(content, dto.ContentValueParam, fmt.Sprintf("%v", value), 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 替换模板变量
|
||||||
|
finalURL := strings.ReplaceAll(barkURL, "{{title}}", url.QueryEscape(data.Title))
|
||||||
|
finalURL = strings.ReplaceAll(finalURL, "{{content}}", url.QueryEscape(content))
|
||||||
|
|
||||||
|
// 发送GET请求到Bark
|
||||||
|
var req *http.Request
|
||||||
|
var resp *http.Response
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if setting.EnableWorker() {
|
||||||
|
// 使用worker发送请求
|
||||||
|
workerReq := &WorkerRequest{
|
||||||
|
URL: finalURL,
|
||||||
|
Key: setting.WorkerValidKey,
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Headers: map[string]string{
|
||||||
|
"User-Agent": "OneAPI-Bark-Notify/1.0",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err = DoWorkerRequest(workerReq)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to send bark request through worker: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// 检查响应状态
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return fmt.Errorf("bark request failed with status code: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 直接发送请求
|
||||||
|
req, err = http.NewRequest(http.MethodGet, finalURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create bark request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置User-Agent
|
||||||
|
req.Header.Set("User-Agent", "OneAPI-Bark-Notify/1.0")
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
client := GetHttpClient()
|
||||||
|
resp, err = client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to send bark request: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// 检查响应状态
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return fmt.Errorf("bark request failed with status code: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@ -3,37 +3,37 @@ package console_setting
|
|||||||
import "one-api/setting/config"
|
import "one-api/setting/config"
|
||||||
|
|
||||||
type ConsoleSetting struct {
|
type ConsoleSetting struct {
|
||||||
ApiInfo string `json:"api_info"` // 控制台 API 信息 (JSON 数组字符串)
|
ApiInfo string `json:"api_info"` // 控制台 API 信息 (JSON 数组字符串)
|
||||||
UptimeKumaGroups string `json:"uptime_kuma_groups"` // Uptime Kuma 分组配置 (JSON 数组字符串)
|
UptimeKumaGroups string `json:"uptime_kuma_groups"` // Uptime Kuma 分组配置 (JSON 数组字符串)
|
||||||
Announcements string `json:"announcements"` // 系统公告 (JSON 数组字符串)
|
Announcements string `json:"announcements"` // 系统公告 (JSON 数组字符串)
|
||||||
FAQ string `json:"faq"` // 常见问题 (JSON 数组字符串)
|
FAQ string `json:"faq"` // 常见问题 (JSON 数组字符串)
|
||||||
ApiInfoEnabled bool `json:"api_info_enabled"` // 是否启用 API 信息面板
|
ApiInfoEnabled bool `json:"api_info_enabled"` // 是否启用 API 信息面板
|
||||||
UptimeKumaEnabled bool `json:"uptime_kuma_enabled"` // 是否启用 Uptime Kuma 面板
|
UptimeKumaEnabled bool `json:"uptime_kuma_enabled"` // 是否启用 Uptime Kuma 面板
|
||||||
AnnouncementsEnabled bool `json:"announcements_enabled"` // 是否启用系统公告面板
|
AnnouncementsEnabled bool `json:"announcements_enabled"` // 是否启用系统公告面板
|
||||||
FAQEnabled bool `json:"faq_enabled"` // 是否启用常见问答面板
|
FAQEnabled bool `json:"faq_enabled"` // 是否启用常见问答面板
|
||||||
}
|
}
|
||||||
|
|
||||||
// 默认配置
|
// 默认配置
|
||||||
var defaultConsoleSetting = ConsoleSetting{
|
var defaultConsoleSetting = ConsoleSetting{
|
||||||
ApiInfo: "",
|
ApiInfo: "",
|
||||||
UptimeKumaGroups: "",
|
UptimeKumaGroups: "",
|
||||||
Announcements: "",
|
Announcements: "",
|
||||||
FAQ: "",
|
FAQ: "",
|
||||||
ApiInfoEnabled: true,
|
ApiInfoEnabled: true,
|
||||||
UptimeKumaEnabled: true,
|
UptimeKumaEnabled: true,
|
||||||
AnnouncementsEnabled: true,
|
AnnouncementsEnabled: true,
|
||||||
FAQEnabled: true,
|
FAQEnabled: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 全局实例
|
// 全局实例
|
||||||
var consoleSetting = defaultConsoleSetting
|
var consoleSetting = defaultConsoleSetting
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// 注册到全局配置管理器,键名为 console_setting
|
// 注册到全局配置管理器,键名为 console_setting
|
||||||
config.GlobalConfig.Register("console_setting", &consoleSetting)
|
config.GlobalConfig.Register("console_setting", &consoleSetting)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConsoleSetting 获取 ConsoleSetting 配置实例
|
// GetConsoleSetting 获取 ConsoleSetting 配置实例
|
||||||
func GetConsoleSetting() *ConsoleSetting {
|
func GetConsoleSetting() *ConsoleSetting {
|
||||||
return &consoleSetting
|
return &consoleSetting
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,304 +1,304 @@
|
|||||||
package console_setting
|
package console_setting
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"sort"
|
||||||
"time"
|
"strings"
|
||||||
"sort"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
urlRegex = regexp.MustCompile(`^https?://(?:(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?|(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?))(?:\:[0-9]{1,5})?(?:/.*)?$`)
|
urlRegex = regexp.MustCompile(`^https?://(?:(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?|(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?))(?:\:[0-9]{1,5})?(?:/.*)?$`)
|
||||||
dangerousChars = []string{"<script", "<iframe", "javascript:", "onload=", "onerror=", "onclick="}
|
dangerousChars = []string{"<script", "<iframe", "javascript:", "onload=", "onerror=", "onclick="}
|
||||||
validColors = map[string]bool{
|
validColors = map[string]bool{
|
||||||
"blue": true, "green": true, "cyan": true, "purple": true, "pink": true,
|
"blue": true, "green": true, "cyan": true, "purple": true, "pink": true,
|
||||||
"red": true, "orange": true, "amber": true, "yellow": true, "lime": true,
|
"red": true, "orange": true, "amber": true, "yellow": true, "lime": true,
|
||||||
"light-green": true, "teal": true, "light-blue": true, "indigo": true,
|
"light-green": true, "teal": true, "light-blue": true, "indigo": true,
|
||||||
"violet": true, "grey": true,
|
"violet": true, "grey": true,
|
||||||
}
|
}
|
||||||
slugRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
slugRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||||
)
|
)
|
||||||
|
|
||||||
func parseJSONArray(jsonStr string, typeName string) ([]map[string]interface{}, error) {
|
func parseJSONArray(jsonStr string, typeName string) ([]map[string]interface{}, error) {
|
||||||
var list []map[string]interface{}
|
var list []map[string]interface{}
|
||||||
if err := json.Unmarshal([]byte(jsonStr), &list); err != nil {
|
if err := json.Unmarshal([]byte(jsonStr), &list); err != nil {
|
||||||
return nil, fmt.Errorf("%s格式错误:%s", typeName, err.Error())
|
return nil, fmt.Errorf("%s格式错误:%s", typeName, err.Error())
|
||||||
}
|
}
|
||||||
return list, nil
|
return list, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateURL(urlStr string, index int, itemType string) error {
|
func validateURL(urlStr string, index int, itemType string) error {
|
||||||
if !urlRegex.MatchString(urlStr) {
|
if !urlRegex.MatchString(urlStr) {
|
||||||
return fmt.Errorf("第%d个%s的URL格式不正确", index, itemType)
|
return fmt.Errorf("第%d个%s的URL格式不正确", index, itemType)
|
||||||
}
|
}
|
||||||
if _, err := url.Parse(urlStr); err != nil {
|
if _, err := url.Parse(urlStr); err != nil {
|
||||||
return fmt.Errorf("第%d个%s的URL无法解析:%s", index, itemType, err.Error())
|
return fmt.Errorf("第%d个%s的URL无法解析:%s", index, itemType, err.Error())
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkDangerousContent(content string, index int, itemType string) error {
|
func checkDangerousContent(content string, index int, itemType string) error {
|
||||||
lower := strings.ToLower(content)
|
lower := strings.ToLower(content)
|
||||||
for _, d := range dangerousChars {
|
for _, d := range dangerousChars {
|
||||||
if strings.Contains(lower, d) {
|
if strings.Contains(lower, d) {
|
||||||
return fmt.Errorf("第%d个%s包含不允许的内容", index, itemType)
|
return fmt.Errorf("第%d个%s包含不允许的内容", index, itemType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getJSONList(jsonStr string) []map[string]interface{} {
|
func getJSONList(jsonStr string) []map[string]interface{} {
|
||||||
if jsonStr == "" {
|
if jsonStr == "" {
|
||||||
return []map[string]interface{}{}
|
return []map[string]interface{}{}
|
||||||
}
|
}
|
||||||
var list []map[string]interface{}
|
var list []map[string]interface{}
|
||||||
json.Unmarshal([]byte(jsonStr), &list)
|
json.Unmarshal([]byte(jsonStr), &list)
|
||||||
return list
|
return list
|
||||||
}
|
}
|
||||||
|
|
||||||
func ValidateConsoleSettings(settingsStr string, settingType string) error {
|
func ValidateConsoleSettings(settingsStr string, settingType string) error {
|
||||||
if settingsStr == "" {
|
if settingsStr == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
switch settingType {
|
switch settingType {
|
||||||
case "ApiInfo":
|
case "ApiInfo":
|
||||||
return validateApiInfo(settingsStr)
|
return validateApiInfo(settingsStr)
|
||||||
case "Announcements":
|
case "Announcements":
|
||||||
return validateAnnouncements(settingsStr)
|
return validateAnnouncements(settingsStr)
|
||||||
case "FAQ":
|
case "FAQ":
|
||||||
return validateFAQ(settingsStr)
|
return validateFAQ(settingsStr)
|
||||||
case "UptimeKumaGroups":
|
case "UptimeKumaGroups":
|
||||||
return validateUptimeKumaGroups(settingsStr)
|
return validateUptimeKumaGroups(settingsStr)
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("未知的设置类型:%s", settingType)
|
return fmt.Errorf("未知的设置类型:%s", settingType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateApiInfo(apiInfoStr string) error {
|
func validateApiInfo(apiInfoStr string) error {
|
||||||
apiInfoList, err := parseJSONArray(apiInfoStr, "API信息")
|
apiInfoList, err := parseJSONArray(apiInfoStr, "API信息")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(apiInfoList) > 50 {
|
if len(apiInfoList) > 50 {
|
||||||
return fmt.Errorf("API信息数量不能超过50个")
|
return fmt.Errorf("API信息数量不能超过50个")
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, apiInfo := range apiInfoList {
|
for i, apiInfo := range apiInfoList {
|
||||||
urlStr, ok := apiInfo["url"].(string)
|
urlStr, ok := apiInfo["url"].(string)
|
||||||
if !ok || urlStr == "" {
|
if !ok || urlStr == "" {
|
||||||
return fmt.Errorf("第%d个API信息缺少URL字段", i+1)
|
return fmt.Errorf("第%d个API信息缺少URL字段", i+1)
|
||||||
}
|
}
|
||||||
route, ok := apiInfo["route"].(string)
|
route, ok := apiInfo["route"].(string)
|
||||||
if !ok || route == "" {
|
if !ok || route == "" {
|
||||||
return fmt.Errorf("第%d个API信息缺少线路描述字段", i+1)
|
return fmt.Errorf("第%d个API信息缺少线路描述字段", i+1)
|
||||||
}
|
}
|
||||||
description, ok := apiInfo["description"].(string)
|
description, ok := apiInfo["description"].(string)
|
||||||
if !ok || description == "" {
|
if !ok || description == "" {
|
||||||
return fmt.Errorf("第%d个API信息缺少说明字段", i+1)
|
return fmt.Errorf("第%d个API信息缺少说明字段", i+1)
|
||||||
}
|
}
|
||||||
color, ok := apiInfo["color"].(string)
|
color, ok := apiInfo["color"].(string)
|
||||||
if !ok || color == "" {
|
if !ok || color == "" {
|
||||||
return fmt.Errorf("第%d个API信息缺少颜色字段", i+1)
|
return fmt.Errorf("第%d个API信息缺少颜色字段", i+1)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := validateURL(urlStr, i+1, "API信息"); err != nil {
|
if err := validateURL(urlStr, i+1, "API信息"); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(urlStr) > 500 {
|
if len(urlStr) > 500 {
|
||||||
return fmt.Errorf("第%d个API信息的URL长度不能超过500字符", i+1)
|
return fmt.Errorf("第%d个API信息的URL长度不能超过500字符", i+1)
|
||||||
}
|
}
|
||||||
if len(route) > 100 {
|
if len(route) > 100 {
|
||||||
return fmt.Errorf("第%d个API信息的线路描述长度不能超过100字符", i+1)
|
return fmt.Errorf("第%d个API信息的线路描述长度不能超过100字符", i+1)
|
||||||
}
|
}
|
||||||
if len(description) > 200 {
|
if len(description) > 200 {
|
||||||
return fmt.Errorf("第%d个API信息的说明长度不能超过200字符", i+1)
|
return fmt.Errorf("第%d个API信息的说明长度不能超过200字符", i+1)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !validColors[color] {
|
if !validColors[color] {
|
||||||
return fmt.Errorf("第%d个API信息的颜色值不合法", i+1)
|
return fmt.Errorf("第%d个API信息的颜色值不合法", i+1)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := checkDangerousContent(description, i+1, "API信息"); err != nil {
|
if err := checkDangerousContent(description, i+1, "API信息"); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := checkDangerousContent(route, i+1, "API信息"); err != nil {
|
if err := checkDangerousContent(route, i+1, "API信息"); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetApiInfo() []map[string]interface{} {
|
func GetApiInfo() []map[string]interface{} {
|
||||||
return getJSONList(GetConsoleSetting().ApiInfo)
|
return getJSONList(GetConsoleSetting().ApiInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateAnnouncements(announcementsStr string) error {
|
func validateAnnouncements(announcementsStr string) error {
|
||||||
list, err := parseJSONArray(announcementsStr, "系统公告")
|
list, err := parseJSONArray(announcementsStr, "系统公告")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if len(list) > 100 {
|
if len(list) > 100 {
|
||||||
return fmt.Errorf("系统公告数量不能超过100个")
|
return fmt.Errorf("系统公告数量不能超过100个")
|
||||||
}
|
}
|
||||||
validTypes := map[string]bool{
|
validTypes := map[string]bool{
|
||||||
"default": true, "ongoing": true, "success": true, "warning": true, "error": true,
|
"default": true, "ongoing": true, "success": true, "warning": true, "error": true,
|
||||||
}
|
}
|
||||||
for i, ann := range list {
|
for i, ann := range list {
|
||||||
content, ok := ann["content"].(string)
|
content, ok := ann["content"].(string)
|
||||||
if !ok || content == "" {
|
if !ok || content == "" {
|
||||||
return fmt.Errorf("第%d个公告缺少内容字段", i+1)
|
return fmt.Errorf("第%d个公告缺少内容字段", i+1)
|
||||||
}
|
}
|
||||||
publishDateAny, exists := ann["publishDate"]
|
publishDateAny, exists := ann["publishDate"]
|
||||||
if !exists {
|
if !exists {
|
||||||
return fmt.Errorf("第%d个公告缺少发布日期字段", i+1)
|
return fmt.Errorf("第%d个公告缺少发布日期字段", i+1)
|
||||||
}
|
}
|
||||||
publishDateStr, ok := publishDateAny.(string)
|
publishDateStr, ok := publishDateAny.(string)
|
||||||
if !ok || publishDateStr == "" {
|
if !ok || publishDateStr == "" {
|
||||||
return fmt.Errorf("第%d个公告的发布日期不能为空", i+1)
|
return fmt.Errorf("第%d个公告的发布日期不能为空", i+1)
|
||||||
}
|
}
|
||||||
if _, err := time.Parse(time.RFC3339, publishDateStr); err != nil {
|
if _, err := time.Parse(time.RFC3339, publishDateStr); err != nil {
|
||||||
return fmt.Errorf("第%d个公告的发布日期格式错误", i+1)
|
return fmt.Errorf("第%d个公告的发布日期格式错误", i+1)
|
||||||
}
|
}
|
||||||
if t, exists := ann["type"]; exists {
|
if t, exists := ann["type"]; exists {
|
||||||
if typeStr, ok := t.(string); ok {
|
if typeStr, ok := t.(string); ok {
|
||||||
if !validTypes[typeStr] {
|
if !validTypes[typeStr] {
|
||||||
return fmt.Errorf("第%d个公告的类型值不合法", i+1)
|
return fmt.Errorf("第%d个公告的类型值不合法", i+1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(content) > 500 {
|
if len(content) > 500 {
|
||||||
return fmt.Errorf("第%d个公告的内容长度不能超过500字符", i+1)
|
return fmt.Errorf("第%d个公告的内容长度不能超过500字符", i+1)
|
||||||
}
|
}
|
||||||
if extra, exists := ann["extra"]; exists {
|
if extra, exists := ann["extra"]; exists {
|
||||||
if extraStr, ok := extra.(string); ok && len(extraStr) > 200 {
|
if extraStr, ok := extra.(string); ok && len(extraStr) > 200 {
|
||||||
return fmt.Errorf("第%d个公告的说明长度不能超过200字符", i+1)
|
return fmt.Errorf("第%d个公告的说明长度不能超过200字符", i+1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateFAQ(faqStr string) error {
|
func validateFAQ(faqStr string) error {
|
||||||
list, err := parseJSONArray(faqStr, "FAQ信息")
|
list, err := parseJSONArray(faqStr, "FAQ信息")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if len(list) > 100 {
|
if len(list) > 100 {
|
||||||
return fmt.Errorf("FAQ数量不能超过100个")
|
return fmt.Errorf("FAQ数量不能超过100个")
|
||||||
}
|
}
|
||||||
for i, faq := range list {
|
for i, faq := range list {
|
||||||
question, ok := faq["question"].(string)
|
question, ok := faq["question"].(string)
|
||||||
if !ok || question == "" {
|
if !ok || question == "" {
|
||||||
return fmt.Errorf("第%d个FAQ缺少问题字段", i+1)
|
return fmt.Errorf("第%d个FAQ缺少问题字段", i+1)
|
||||||
}
|
}
|
||||||
answer, ok := faq["answer"].(string)
|
answer, ok := faq["answer"].(string)
|
||||||
if !ok || answer == "" {
|
if !ok || answer == "" {
|
||||||
return fmt.Errorf("第%d个FAQ缺少答案字段", i+1)
|
return fmt.Errorf("第%d个FAQ缺少答案字段", i+1)
|
||||||
}
|
}
|
||||||
if len(question) > 200 {
|
if len(question) > 200 {
|
||||||
return fmt.Errorf("第%d个FAQ的问题长度不能超过200字符", i+1)
|
return fmt.Errorf("第%d个FAQ的问题长度不能超过200字符", i+1)
|
||||||
}
|
}
|
||||||
if len(answer) > 1000 {
|
if len(answer) > 1000 {
|
||||||
return fmt.Errorf("第%d个FAQ的答案长度不能超过1000字符", i+1)
|
return fmt.Errorf("第%d个FAQ的答案长度不能超过1000字符", i+1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPublishTime(item map[string]interface{}) time.Time {
|
func getPublishTime(item map[string]interface{}) time.Time {
|
||||||
if v, ok := item["publishDate"]; ok {
|
if v, ok := item["publishDate"]; ok {
|
||||||
if s, ok2 := v.(string); ok2 {
|
if s, ok2 := v.(string); ok2 {
|
||||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return time.Time{}
|
return time.Time{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAnnouncements() []map[string]interface{} {
|
func GetAnnouncements() []map[string]interface{} {
|
||||||
list := getJSONList(GetConsoleSetting().Announcements)
|
list := getJSONList(GetConsoleSetting().Announcements)
|
||||||
sort.SliceStable(list, func(i, j int) bool {
|
sort.SliceStable(list, func(i, j int) bool {
|
||||||
return getPublishTime(list[i]).After(getPublishTime(list[j]))
|
return getPublishTime(list[i]).After(getPublishTime(list[j]))
|
||||||
})
|
})
|
||||||
return list
|
return list
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetFAQ() []map[string]interface{} {
|
func GetFAQ() []map[string]interface{} {
|
||||||
return getJSONList(GetConsoleSetting().FAQ)
|
return getJSONList(GetConsoleSetting().FAQ)
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateUptimeKumaGroups(groupsStr string) error {
|
func validateUptimeKumaGroups(groupsStr string) error {
|
||||||
groups, err := parseJSONArray(groupsStr, "Uptime Kuma分组配置")
|
groups, err := parseJSONArray(groupsStr, "Uptime Kuma分组配置")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(groups) > 20 {
|
if len(groups) > 20 {
|
||||||
return fmt.Errorf("Uptime Kuma分组数量不能超过20个")
|
return fmt.Errorf("Uptime Kuma分组数量不能超过20个")
|
||||||
}
|
}
|
||||||
|
|
||||||
nameSet := make(map[string]bool)
|
nameSet := make(map[string]bool)
|
||||||
|
|
||||||
for i, group := range groups {
|
for i, group := range groups {
|
||||||
categoryName, ok := group["categoryName"].(string)
|
categoryName, ok := group["categoryName"].(string)
|
||||||
if !ok || categoryName == "" {
|
if !ok || categoryName == "" {
|
||||||
return fmt.Errorf("第%d个分组缺少分类名称字段", i+1)
|
return fmt.Errorf("第%d个分组缺少分类名称字段", i+1)
|
||||||
}
|
}
|
||||||
if nameSet[categoryName] {
|
if nameSet[categoryName] {
|
||||||
return fmt.Errorf("第%d个分组的分类名称与其他分组重复", i+1)
|
return fmt.Errorf("第%d个分组的分类名称与其他分组重复", i+1)
|
||||||
}
|
}
|
||||||
nameSet[categoryName] = true
|
nameSet[categoryName] = true
|
||||||
urlStr, ok := group["url"].(string)
|
urlStr, ok := group["url"].(string)
|
||||||
if !ok || urlStr == "" {
|
if !ok || urlStr == "" {
|
||||||
return fmt.Errorf("第%d个分组缺少URL字段", i+1)
|
return fmt.Errorf("第%d个分组缺少URL字段", i+1)
|
||||||
}
|
}
|
||||||
slug, ok := group["slug"].(string)
|
slug, ok := group["slug"].(string)
|
||||||
if !ok || slug == "" {
|
if !ok || slug == "" {
|
||||||
return fmt.Errorf("第%d个分组缺少Slug字段", i+1)
|
return fmt.Errorf("第%d个分组缺少Slug字段", i+1)
|
||||||
}
|
}
|
||||||
description, ok := group["description"].(string)
|
description, ok := group["description"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
description = ""
|
description = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := validateURL(urlStr, i+1, "分组"); err != nil {
|
if err := validateURL(urlStr, i+1, "分组"); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(categoryName) > 50 {
|
if len(categoryName) > 50 {
|
||||||
return fmt.Errorf("第%d个分组的分类名称长度不能超过50字符", i+1)
|
return fmt.Errorf("第%d个分组的分类名称长度不能超过50字符", i+1)
|
||||||
}
|
}
|
||||||
if len(urlStr) > 500 {
|
if len(urlStr) > 500 {
|
||||||
return fmt.Errorf("第%d个分组的URL长度不能超过500字符", i+1)
|
return fmt.Errorf("第%d个分组的URL长度不能超过500字符", i+1)
|
||||||
}
|
}
|
||||||
if len(slug) > 100 {
|
if len(slug) > 100 {
|
||||||
return fmt.Errorf("第%d个分组的Slug长度不能超过100字符", i+1)
|
return fmt.Errorf("第%d个分组的Slug长度不能超过100字符", i+1)
|
||||||
}
|
}
|
||||||
if len(description) > 200 {
|
if len(description) > 200 {
|
||||||
return fmt.Errorf("第%d个分组的描述长度不能超过200字符", i+1)
|
return fmt.Errorf("第%d个分组的描述长度不能超过200字符", i+1)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !slugRegex.MatchString(slug) {
|
if !slugRegex.MatchString(slug) {
|
||||||
return fmt.Errorf("第%d个分组的Slug只能包含字母、数字、下划线和连字符", i+1)
|
return fmt.Errorf("第%d个分组的Slug只能包含字母、数字、下划线和连字符", i+1)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := checkDangerousContent(description, i+1, "分组"); err != nil {
|
if err := checkDangerousContent(description, i+1, "分组"); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := checkDangerousContent(categoryName, i+1, "分组"); err != nil {
|
if err := checkDangerousContent(categoryName, i+1, "分组"); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUptimeKumaGroups() []map[string]interface{} {
|
func GetUptimeKumaGroups() []map[string]interface{} {
|
||||||
return getJSONList(GetConsoleSetting().UptimeKumaGroups)
|
return getJSONList(GetConsoleSetting().UptimeKumaGroups)
|
||||||
}
|
}
|
||||||
|
|||||||
34
setting/operation_setting/monitor_setting.go
Normal file
34
setting/operation_setting/monitor_setting.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package operation_setting
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/setting/config"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MonitorSetting struct {
|
||||||
|
AutoTestChannelEnabled bool `json:"auto_test_channel_enabled"`
|
||||||
|
AutoTestChannelMinutes int `json:"auto_test_channel_minutes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// 默认配置
|
||||||
|
var monitorSetting = MonitorSetting{
|
||||||
|
AutoTestChannelEnabled: false,
|
||||||
|
AutoTestChannelMinutes: 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// 注册到全局配置管理器
|
||||||
|
config.GlobalConfig.Register("monitor_setting", &monitorSetting)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetMonitorSetting() *MonitorSetting {
|
||||||
|
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
|
||||||
|
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
|
||||||
|
if err == nil && frequency > 0 {
|
||||||
|
monitorSetting.AutoTestChannelEnabled = true
|
||||||
|
monitorSetting.AutoTestChannelMinutes = frequency
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &monitorSetting
|
||||||
|
}
|
||||||
23
setting/operation_setting/payment_setting.go
Normal file
23
setting/operation_setting/payment_setting.go
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
package operation_setting
|
||||||
|
|
||||||
|
import "one-api/setting/config"
|
||||||
|
|
||||||
|
type PaymentSetting struct {
|
||||||
|
AmountOptions []int `json:"amount_options"`
|
||||||
|
AmountDiscount map[int]float64 `json:"amount_discount"` // 充值金额对应的折扣,例如 100 元 0.9 表示 100 元充值享受 9 折优惠
|
||||||
|
}
|
||||||
|
|
||||||
|
// 默认配置
|
||||||
|
var paymentSetting = PaymentSetting{
|
||||||
|
AmountOptions: []int{10, 20, 50, 100, 200, 500},
|
||||||
|
AmountDiscount: map[int]float64{},
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// 注册到全局配置管理器
|
||||||
|
config.GlobalConfig.Register("payment_setting", &paymentSetting)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetPaymentSetting() *PaymentSetting {
|
||||||
|
return &paymentSetting
|
||||||
|
}
|
||||||
@ -1,6 +1,13 @@
|
|||||||
package setting
|
/**
|
||||||
|
此文件为旧版支付设置文件,如需增加新的参数、变量等,请在 payment_setting.go 中添加
|
||||||
|
This file is the old version of the payment settings file. If you need to add new parameters, variables, etc., please add them in payment_setting.go
|
||||||
|
*/
|
||||||
|
|
||||||
import "encoding/json"
|
package operation_setting
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/common"
|
||||||
|
)
|
||||||
|
|
||||||
var PayAddress = ""
|
var PayAddress = ""
|
||||||
var CustomCallbackAddress = ""
|
var CustomCallbackAddress = ""
|
||||||
@ -21,15 +28,21 @@ var PayMethods = []map[string]string{
|
|||||||
"color": "rgba(var(--semi-green-5), 1)",
|
"color": "rgba(var(--semi-green-5), 1)",
|
||||||
"type": "wxpay",
|
"type": "wxpay",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "自定义1",
|
||||||
|
"color": "black",
|
||||||
|
"type": "custom1",
|
||||||
|
"min_topup": "50",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdatePayMethodsByJsonString(jsonString string) error {
|
func UpdatePayMethodsByJsonString(jsonString string) error {
|
||||||
PayMethods = make([]map[string]string, 0)
|
PayMethods = make([]map[string]string, 0)
|
||||||
return json.Unmarshal([]byte(jsonString), &PayMethods)
|
return common.Unmarshal([]byte(jsonString), &PayMethods)
|
||||||
}
|
}
|
||||||
|
|
||||||
func PayMethods2JsonString() string {
|
func PayMethods2JsonString() string {
|
||||||
jsonBytes, err := json.Marshal(PayMethods)
|
jsonBytes, err := common.Marshal(PayMethods)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "[]"
|
return "[]"
|
||||||
}
|
}
|
||||||
@ -5,13 +5,13 @@ import "sync/atomic"
|
|||||||
var exposeRatioEnabled atomic.Bool
|
var exposeRatioEnabled atomic.Bool
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
exposeRatioEnabled.Store(false)
|
exposeRatioEnabled.Store(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetExposeRatioEnabled(enabled bool) {
|
func SetExposeRatioEnabled(enabled bool) {
|
||||||
exposeRatioEnabled.Store(enabled)
|
exposeRatioEnabled.Store(enabled)
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsExposeRatioEnabled() bool {
|
func IsExposeRatioEnabled() bool {
|
||||||
return exposeRatioEnabled.Load()
|
return exposeRatioEnabled.Load()
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,55 +1,55 @@
|
|||||||
package ratio_setting
|
package ratio_setting
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
const exposedDataTTL = 30 * time.Second
|
const exposedDataTTL = 30 * time.Second
|
||||||
|
|
||||||
type exposedCache struct {
|
type exposedCache struct {
|
||||||
data gin.H
|
data gin.H
|
||||||
expiresAt time.Time
|
expiresAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
exposedData atomic.Value
|
exposedData atomic.Value
|
||||||
rebuildMu sync.Mutex
|
rebuildMu sync.Mutex
|
||||||
)
|
)
|
||||||
|
|
||||||
func InvalidateExposedDataCache() {
|
func InvalidateExposedDataCache() {
|
||||||
exposedData.Store((*exposedCache)(nil))
|
exposedData.Store((*exposedCache)(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func cloneGinH(src gin.H) gin.H {
|
func cloneGinH(src gin.H) gin.H {
|
||||||
dst := make(gin.H, len(src))
|
dst := make(gin.H, len(src))
|
||||||
for k, v := range src {
|
for k, v := range src {
|
||||||
dst[k] = v
|
dst[k] = v
|
||||||
}
|
}
|
||||||
return dst
|
return dst
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetExposedData() gin.H {
|
func GetExposedData() gin.H {
|
||||||
if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) {
|
if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) {
|
||||||
return cloneGinH(c.data)
|
return cloneGinH(c.data)
|
||||||
}
|
}
|
||||||
rebuildMu.Lock()
|
rebuildMu.Lock()
|
||||||
defer rebuildMu.Unlock()
|
defer rebuildMu.Unlock()
|
||||||
if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) {
|
if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) {
|
||||||
return cloneGinH(c.data)
|
return cloneGinH(c.data)
|
||||||
}
|
}
|
||||||
newData := gin.H{
|
newData := gin.H{
|
||||||
"model_ratio": GetModelRatioCopy(),
|
"model_ratio": GetModelRatioCopy(),
|
||||||
"completion_ratio": GetCompletionRatioCopy(),
|
"completion_ratio": GetCompletionRatioCopy(),
|
||||||
"cache_ratio": GetCacheRatioCopy(),
|
"cache_ratio": GetCacheRatioCopy(),
|
||||||
"model_price": GetModelPriceCopy(),
|
"model_price": GetModelPriceCopy(),
|
||||||
}
|
}
|
||||||
exposedData.Store(&exposedCache{
|
exposedData.Store(&exposedCache{
|
||||||
data: newData,
|
data: newData,
|
||||||
expiresAt: time.Now().Add(exposedDataTTL),
|
expiresAt: time.Now().Add(exposedDataTTL),
|
||||||
})
|
})
|
||||||
return cloneGinH(newData)
|
return cloneGinH(newData)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
package setting
|
package system_setting
|
||||||
|
|
||||||
var ServerAddress = "http://localhost:3000"
|
var ServerAddress = "http://localhost:3000"
|
||||||
var WorkerUrl = ""
|
var WorkerUrl = ""
|
||||||
@ -185,6 +185,14 @@ func (e *NewAPIError) ToClaudeError() ClaudeError {
|
|||||||
type NewAPIErrorOptions func(*NewAPIError)
|
type NewAPIErrorOptions func(*NewAPIError)
|
||||||
|
|
||||||
func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPIError {
|
func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPIError {
|
||||||
|
var newErr *NewAPIError
|
||||||
|
// 保留深层传递的 new err
|
||||||
|
if errors.As(err, &newErr) {
|
||||||
|
for _, op := range ops {
|
||||||
|
op(newErr)
|
||||||
|
}
|
||||||
|
return newErr
|
||||||
|
}
|
||||||
e := &NewAPIError{
|
e := &NewAPIError{
|
||||||
Err: err,
|
Err: err,
|
||||||
RelayError: nil,
|
RelayError: nil,
|
||||||
@ -199,8 +207,21 @@ func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPI
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
|
func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
|
||||||
if errorCode == ErrorCodeDoRequestFailed {
|
var newErr *NewAPIError
|
||||||
err = errors.New("upstream error: do request failed")
|
// 保留深层传递的 new err
|
||||||
|
if errors.As(err, &newErr) {
|
||||||
|
if newErr.RelayError == nil {
|
||||||
|
openaiError := OpenAIError{
|
||||||
|
Message: newErr.Error(),
|
||||||
|
Type: string(errorCode),
|
||||||
|
Code: errorCode,
|
||||||
|
}
|
||||||
|
newErr.RelayError = openaiError
|
||||||
|
}
|
||||||
|
for _, op := range ops {
|
||||||
|
op(newErr)
|
||||||
|
}
|
||||||
|
return newErr
|
||||||
}
|
}
|
||||||
openaiError := OpenAIError{
|
openaiError := OpenAIError{
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
@ -305,6 +326,15 @@ func ErrOptionWithNoRecordErrorLog() NewAPIErrorOptions {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ErrOptionWithHideErrMsg(replaceStr string) NewAPIErrorOptions {
|
||||||
|
return func(e *NewAPIError) {
|
||||||
|
if common.DebugEnabled {
|
||||||
|
fmt.Printf("ErrOptionWithHideErrMsg: %s, origin error: %s", replaceStr, e.Err)
|
||||||
|
}
|
||||||
|
e.Err = errors.New(replaceStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func IsRecordErrorLog(e *NewAPIError) bool {
|
func IsRecordErrorLog(e *NewAPIError) bool {
|
||||||
if e == nil {
|
if e == nil {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@ -17,7 +17,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|||||||
For commercial licensing, please contact support@quantumnous.com
|
For commercial licensing, please contact support@quantumnous.com
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import React, { lazy, Suspense } from 'react';
|
import React, { lazy, Suspense, useContext, useMemo } from 'react';
|
||||||
import { Route, Routes, useLocation } from 'react-router-dom';
|
import { Route, Routes, useLocation } from 'react-router-dom';
|
||||||
import Loading from './components/common/ui/Loading';
|
import Loading from './components/common/ui/Loading';
|
||||||
import User from './pages/User';
|
import User from './pages/User';
|
||||||
@ -27,6 +27,7 @@ import LoginForm from './components/auth/LoginForm';
|
|||||||
import NotFound from './pages/NotFound';
|
import NotFound from './pages/NotFound';
|
||||||
import Forbidden from './pages/Forbidden';
|
import Forbidden from './pages/Forbidden';
|
||||||
import Setting from './pages/Setting';
|
import Setting from './pages/Setting';
|
||||||
|
import { StatusContext } from './context/Status';
|
||||||
|
|
||||||
import PasswordResetForm from './components/auth/PasswordResetForm';
|
import PasswordResetForm from './components/auth/PasswordResetForm';
|
||||||
import PasswordResetConfirm from './components/auth/PasswordResetConfirm';
|
import PasswordResetConfirm from './components/auth/PasswordResetConfirm';
|
||||||
@ -53,6 +54,29 @@ const About = lazy(() => import('./pages/About'));
|
|||||||
|
|
||||||
function App() {
|
function App() {
|
||||||
const location = useLocation();
|
const location = useLocation();
|
||||||
|
const [statusState] = useContext(StatusContext);
|
||||||
|
|
||||||
|
// 获取模型广场权限配置
|
||||||
|
const pricingRequireAuth = useMemo(() => {
|
||||||
|
const headerNavModulesConfig = statusState?.status?.HeaderNavModules;
|
||||||
|
if (headerNavModulesConfig) {
|
||||||
|
try {
|
||||||
|
const modules = JSON.parse(headerNavModulesConfig);
|
||||||
|
|
||||||
|
// 处理向后兼容性:如果pricing是boolean,默认不需要登录
|
||||||
|
if (typeof modules.pricing === 'boolean') {
|
||||||
|
return false; // 默认不需要登录鉴权
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果是对象格式,使用requireAuth配置
|
||||||
|
return modules.pricing?.requireAuth === true;
|
||||||
|
} catch (error) {
|
||||||
|
console.error('解析顶栏模块配置失败:', error);
|
||||||
|
return false; // 默认不需要登录
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false; // 默认不需要登录
|
||||||
|
}, [statusState?.status?.HeaderNavModules]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<SetupCheck>
|
<SetupCheck>
|
||||||
@ -253,9 +277,20 @@ function App() {
|
|||||||
<Route
|
<Route
|
||||||
path='/pricing'
|
path='/pricing'
|
||||||
element={
|
element={
|
||||||
<Suspense fallback={<Loading></Loading>} key={location.pathname}>
|
pricingRequireAuth ? (
|
||||||
<Pricing />
|
<PrivateRoute>
|
||||||
</Suspense>
|
<Suspense
|
||||||
|
fallback={<Loading></Loading>}
|
||||||
|
key={location.pathname}
|
||||||
|
>
|
||||||
|
<Pricing />
|
||||||
|
</Suspense>
|
||||||
|
</PrivateRoute>
|
||||||
|
) : (
|
||||||
|
<Suspense fallback={<Loading></Loading>} key={location.pathname}>
|
||||||
|
<Pricing />
|
||||||
|
</Suspense>
|
||||||
|
)
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
<Route
|
<Route
|
||||||
|
|||||||
@ -135,7 +135,7 @@ const TwoFactorAuthModal = ({
|
|||||||
autoFocus
|
autoFocus
|
||||||
/>
|
/>
|
||||||
<Typography.Text type='tertiary' size='small' className='mt-2 block'>
|
<Typography.Text type='tertiary' size='small' className='mt-2 block'>
|
||||||
{t('支持6位TOTP验证码或8位备用码')}
|
{t('支持6位TOTP验证码或8位备用码,可到`个人设置-安全设置-两步验证设置`配置或查看。')}
|
||||||
</Typography.Text>
|
</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@ -443,7 +443,7 @@ const JSONEditor = ({
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<Row key={pair.id} gutter={8} align='middle'>
|
<Row key={pair.id} gutter={8} align='middle'>
|
||||||
<Col span={6}>
|
<Col span={10}>
|
||||||
<div className='relative'>
|
<div className='relative'>
|
||||||
<Input
|
<Input
|
||||||
placeholder={t('键名')}
|
placeholder={t('键名')}
|
||||||
@ -470,7 +470,7 @@ const JSONEditor = ({
|
|||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
</Col>
|
</Col>
|
||||||
<Col span={16}>{renderValueInput(pair.id, pair.value)}</Col>
|
<Col span={12}>{renderValueInput(pair.id, pair.value)}</Col>
|
||||||
<Col span={2}>
|
<Col span={2}>
|
||||||
<Button
|
<Button
|
||||||
icon={<IconDelete />}
|
icon={<IconDelete />}
|
||||||
|
|||||||
@ -100,7 +100,7 @@ const ApiInfoPanel = ({
|
|||||||
</React.Fragment>
|
</React.Fragment>
|
||||||
))
|
))
|
||||||
) : (
|
) : (
|
||||||
<div className='flex justify-center items-center py-8'>
|
<div className='flex justify-center items-center min-h-[20rem] w-full'>
|
||||||
<Empty
|
<Empty
|
||||||
image={<IllustrationConstruction style={ILLUSTRATION_SIZE} />}
|
image={<IllustrationConstruction style={ILLUSTRATION_SIZE} />}
|
||||||
darkModeImage={
|
darkModeImage={
|
||||||
|
|||||||
@ -20,11 +20,6 @@ For commercial licensing, please contact support@quantumnous.com
|
|||||||
import React from 'react';
|
import React from 'react';
|
||||||
import { Card, Tabs, TabPane } from '@douyinfe/semi-ui';
|
import { Card, Tabs, TabPane } from '@douyinfe/semi-ui';
|
||||||
import { PieChart } from 'lucide-react';
|
import { PieChart } from 'lucide-react';
|
||||||
import {
|
|
||||||
IconHistogram,
|
|
||||||
IconPulse,
|
|
||||||
IconPieChart2Stroked,
|
|
||||||
} from '@douyinfe/semi-icons';
|
|
||||||
import { VChart } from '@visactor/react-vchart';
|
import { VChart } from '@visactor/react-vchart';
|
||||||
|
|
||||||
const ChartsPanel = ({
|
const ChartsPanel = ({
|
||||||
@ -51,46 +46,14 @@ const ChartsPanel = ({
|
|||||||
{t('模型数据分析')}
|
{t('模型数据分析')}
|
||||||
</div>
|
</div>
|
||||||
<Tabs
|
<Tabs
|
||||||
type='button'
|
type='slash'
|
||||||
activeKey={activeChartTab}
|
activeKey={activeChartTab}
|
||||||
onChange={setActiveChartTab}
|
onChange={setActiveChartTab}
|
||||||
>
|
>
|
||||||
<TabPane
|
<TabPane tab={<span>{t('消耗分布')}</span>} itemKey='1' />
|
||||||
tab={
|
<TabPane tab={<span>{t('消耗趋势')}</span>} itemKey='2' />
|
||||||
<span>
|
<TabPane tab={<span>{t('调用次数分布')}</span>} itemKey='3' />
|
||||||
<IconHistogram />
|
<TabPane tab={<span>{t('调用次数排行')}</span>} itemKey='4' />
|
||||||
{t('消耗分布')}
|
|
||||||
</span>
|
|
||||||
}
|
|
||||||
itemKey='1'
|
|
||||||
/>
|
|
||||||
<TabPane
|
|
||||||
tab={
|
|
||||||
<span>
|
|
||||||
<IconPulse />
|
|
||||||
{t('消耗趋势')}
|
|
||||||
</span>
|
|
||||||
}
|
|
||||||
itemKey='2'
|
|
||||||
/>
|
|
||||||
<TabPane
|
|
||||||
tab={
|
|
||||||
<span>
|
|
||||||
<IconPieChart2Stroked />
|
|
||||||
{t('调用次数分布')}
|
|
||||||
</span>
|
|
||||||
}
|
|
||||||
itemKey='3'
|
|
||||||
/>
|
|
||||||
<TabPane
|
|
||||||
tab={
|
|
||||||
<span>
|
|
||||||
<IconHistogram />
|
|
||||||
{t('调用次数排行')}
|
|
||||||
</span>
|
|
||||||
}
|
|
||||||
itemKey='4'
|
|
||||||
/>
|
|
||||||
</Tabs>
|
</Tabs>
|
||||||
</div>
|
</div>
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,20 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (C) 2025 QuantumNous
|
|
||||||
|
|
||||||
This program is free software: you can redistribute it and/or modify
|
|
||||||
it under the terms of the GNU Affero General Public License as
|
|
||||||
published by the Free Software Foundation, either version 3 of the
|
|
||||||
License, or (at your option) any later version.
|
|
||||||
|
|
||||||
This program is distributed in the hope that it will be useful,
|
|
||||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
GNU Affero General Public License for more details.
|
|
||||||
|
|
||||||
You should have received a copy of the GNU Affero General Public License
|
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
||||||
|
|
||||||
For commercial licensing, please contact support@quantumnous.com
|
|
||||||
*/
|
|
||||||
|
|
||||||
export { default } from './HeaderBar/index';
|
|
||||||
@ -1,148 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright (C) 2025 QuantumNous
|
|
||||||
|
|
||||||
This program is free software: you can redistribute it and/or modify
|
|
||||||
it under the terms of the GNU Affero General Public License as
|
|
||||||
published by the Free Software Foundation, either version 3 of the
|
|
||||||
License, or (at your option) any later version.
|
|
||||||
|
|
||||||
This program is distributed in the hope that it will be useful,
|
|
||||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
GNU Affero General Public License for more details.
|
|
||||||
|
|
||||||
You should have received a copy of the GNU Affero General Public License
|
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
||||||
|
|
||||||
For commercial licensing, please contact support@quantumnous.com
|
|
||||||
*/
|
|
||||||
|
|
||||||
import React from 'react';
|
|
||||||
import { Skeleton } from '@douyinfe/semi-ui';
|
|
||||||
|
|
||||||
const SkeletonWrapper = ({
|
|
||||||
loading = false,
|
|
||||||
type = 'text',
|
|
||||||
count = 1,
|
|
||||||
width = 60,
|
|
||||||
height = 16,
|
|
||||||
isMobile = false,
|
|
||||||
className = '',
|
|
||||||
children,
|
|
||||||
...props
|
|
||||||
}) => {
|
|
||||||
if (!loading) {
|
|
||||||
return children;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 导航链接骨架屏
|
|
||||||
const renderNavigationSkeleton = () => {
|
|
||||||
const skeletonLinkClasses = isMobile
|
|
||||||
? 'flex items-center gap-1 p-1 w-full rounded-md'
|
|
||||||
: 'flex items-center gap-1 p-2 rounded-md';
|
|
||||||
|
|
||||||
return Array(count)
|
|
||||||
.fill(null)
|
|
||||||
.map((_, index) => (
|
|
||||||
<div key={index} className={skeletonLinkClasses}>
|
|
||||||
<Skeleton
|
|
||||||
loading={true}
|
|
||||||
active
|
|
||||||
placeholder={
|
|
||||||
<Skeleton.Title
|
|
||||||
active
|
|
||||||
style={{ width: isMobile ? 40 : width, height }}
|
|
||||||
/>
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
));
|
|
||||||
};
|
|
||||||
|
|
||||||
// 用户区域骨架屏 (头像 + 文本)
|
|
||||||
const renderUserAreaSkeleton = () => {
|
|
||||||
return (
|
|
||||||
<div
|
|
||||||
className={`flex items-center p-1 rounded-full bg-semi-color-fill-0 dark:bg-semi-color-fill-1 ${className}`}
|
|
||||||
>
|
|
||||||
<Skeleton
|
|
||||||
loading={true}
|
|
||||||
active
|
|
||||||
placeholder={
|
|
||||||
<Skeleton.Avatar active size='extra-small' className='shadow-sm' />
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
<div className='ml-1.5 mr-1'>
|
|
||||||
<Skeleton
|
|
||||||
loading={true}
|
|
||||||
active
|
|
||||||
placeholder={
|
|
||||||
<Skeleton.Title
|
|
||||||
active
|
|
||||||
style={{ width: isMobile ? 15 : width, height: 12 }}
|
|
||||||
/>
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Logo图片骨架屏
|
|
||||||
const renderImageSkeleton = () => {
|
|
||||||
return (
|
|
||||||
<Skeleton
|
|
||||||
loading={true}
|
|
||||||
active
|
|
||||||
placeholder={
|
|
||||||
<Skeleton.Image
|
|
||||||
active
|
|
||||||
className={`absolute inset-0 !rounded-full ${className}`}
|
|
||||||
style={{ width: '100%', height: '100%' }}
|
|
||||||
/>
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
// 系统名称骨架屏
|
|
||||||
const renderTitleSkeleton = () => {
|
|
||||||
return (
|
|
||||||
<Skeleton
|
|
||||||
loading={true}
|
|
||||||
active
|
|
||||||
placeholder={<Skeleton.Title active style={{ width, height: 24 }} />}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
// 通用文本骨架屏
|
|
||||||
const renderTextSkeleton = () => {
|
|
||||||
return (
|
|
||||||
<div className={className}>
|
|
||||||
<Skeleton
|
|
||||||
loading={true}
|
|
||||||
active
|
|
||||||
placeholder={<Skeleton.Title active style={{ width, height }} />}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
// 根据类型渲染不同的骨架屏
|
|
||||||
switch (type) {
|
|
||||||
case 'navigation':
|
|
||||||
return renderNavigationSkeleton();
|
|
||||||
case 'userArea':
|
|
||||||
return renderUserAreaSkeleton();
|
|
||||||
case 'image':
|
|
||||||
return renderImageSkeleton();
|
|
||||||
case 'title':
|
|
||||||
return renderTitleSkeleton();
|
|
||||||
case 'text':
|
|
||||||
default:
|
|
||||||
return renderTextSkeleton();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
export default SkeletonWrapper;
|
|
||||||
@ -17,7 +17,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|||||||
For commercial licensing, please contact support@quantumnous.com
|
For commercial licensing, please contact support@quantumnous.com
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import HeaderBar from './HeaderBar';
|
import HeaderBar from './headerbar';
|
||||||
import { Layout } from '@douyinfe/semi-ui';
|
import { Layout } from '@douyinfe/semi-ui';
|
||||||
import SiderBar from './SiderBar';
|
import SiderBar from './SiderBar';
|
||||||
import App from '../../App';
|
import App from '../../App';
|
||||||
|
|||||||
@ -23,7 +23,10 @@ import { useTranslation } from 'react-i18next';
|
|||||||
import { getLucideIcon } from '../../helpers/render';
|
import { getLucideIcon } from '../../helpers/render';
|
||||||
import { ChevronLeft } from 'lucide-react';
|
import { ChevronLeft } from 'lucide-react';
|
||||||
import { useSidebarCollapsed } from '../../hooks/common/useSidebarCollapsed';
|
import { useSidebarCollapsed } from '../../hooks/common/useSidebarCollapsed';
|
||||||
|
import { useSidebar } from '../../hooks/common/useSidebar';
|
||||||
|
import { useMinimumLoadingTime } from '../../hooks/common/useMinimumLoadingTime';
|
||||||
import { isAdmin, isRoot, showError } from '../../helpers';
|
import { isAdmin, isRoot, showError } from '../../helpers';
|
||||||
|
import SkeletonWrapper from './components/SkeletonWrapper';
|
||||||
|
|
||||||
import { Nav, Divider, Button } from '@douyinfe/semi-ui';
|
import { Nav, Divider, Button } from '@douyinfe/semi-ui';
|
||||||
|
|
||||||
@ -49,6 +52,13 @@ const routerMap = {
|
|||||||
const SiderBar = ({ onNavigate = () => {} }) => {
|
const SiderBar = ({ onNavigate = () => {} }) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const [collapsed, toggleCollapsed] = useSidebarCollapsed();
|
const [collapsed, toggleCollapsed] = useSidebarCollapsed();
|
||||||
|
const {
|
||||||
|
isModuleVisible,
|
||||||
|
hasSectionVisibleModules,
|
||||||
|
loading: sidebarLoading,
|
||||||
|
} = useSidebar();
|
||||||
|
|
||||||
|
const showSkeleton = useMinimumLoadingTime(sidebarLoading);
|
||||||
|
|
||||||
const [selectedKeys, setSelectedKeys] = useState(['home']);
|
const [selectedKeys, setSelectedKeys] = useState(['home']);
|
||||||
const [chatItems, setChatItems] = useState([]);
|
const [chatItems, setChatItems] = useState([]);
|
||||||
@ -56,8 +66,8 @@ const SiderBar = ({ onNavigate = () => {} }) => {
|
|||||||
const location = useLocation();
|
const location = useLocation();
|
||||||
const [routerMapState, setRouterMapState] = useState(routerMap);
|
const [routerMapState, setRouterMapState] = useState(routerMap);
|
||||||
|
|
||||||
const workspaceItems = useMemo(
|
const workspaceItems = useMemo(() => {
|
||||||
() => [
|
const items = [
|
||||||
{
|
{
|
||||||
text: t('数据看板'),
|
text: t('数据看板'),
|
||||||
itemKey: 'detail',
|
itemKey: 'detail',
|
||||||
@ -93,17 +103,25 @@ const SiderBar = ({ onNavigate = () => {} }) => {
|
|||||||
className:
|
className:
|
||||||
localStorage.getItem('enable_task') === 'true' ? '' : 'tableHiddle',
|
localStorage.getItem('enable_task') === 'true' ? '' : 'tableHiddle',
|
||||||
},
|
},
|
||||||
],
|
];
|
||||||
[
|
|
||||||
localStorage.getItem('enable_data_export'),
|
|
||||||
localStorage.getItem('enable_drawing'),
|
|
||||||
localStorage.getItem('enable_task'),
|
|
||||||
t,
|
|
||||||
],
|
|
||||||
);
|
|
||||||
|
|
||||||
const financeItems = useMemo(
|
// 根据配置过滤项目
|
||||||
() => [
|
const filteredItems = items.filter((item) => {
|
||||||
|
const configVisible = isModuleVisible('console', item.itemKey);
|
||||||
|
return configVisible;
|
||||||
|
});
|
||||||
|
|
||||||
|
return filteredItems;
|
||||||
|
}, [
|
||||||
|
localStorage.getItem('enable_data_export'),
|
||||||
|
localStorage.getItem('enable_drawing'),
|
||||||
|
localStorage.getItem('enable_task'),
|
||||||
|
t,
|
||||||
|
isModuleVisible,
|
||||||
|
]);
|
||||||
|
|
||||||
|
const financeItems = useMemo(() => {
|
||||||
|
const items = [
|
||||||
{
|
{
|
||||||
text: t('钱包管理'),
|
text: t('钱包管理'),
|
||||||
itemKey: 'topup',
|
itemKey: 'topup',
|
||||||
@ -114,12 +132,19 @@ const SiderBar = ({ onNavigate = () => {} }) => {
|
|||||||
itemKey: 'personal',
|
itemKey: 'personal',
|
||||||
to: '/personal',
|
to: '/personal',
|
||||||
},
|
},
|
||||||
],
|
];
|
||||||
[t],
|
|
||||||
);
|
|
||||||
|
|
||||||
const adminItems = useMemo(
|
// 根据配置过滤项目
|
||||||
() => [
|
const filteredItems = items.filter((item) => {
|
||||||
|
const configVisible = isModuleVisible('personal', item.itemKey);
|
||||||
|
return configVisible;
|
||||||
|
});
|
||||||
|
|
||||||
|
return filteredItems;
|
||||||
|
}, [t, isModuleVisible]);
|
||||||
|
|
||||||
|
const adminItems = useMemo(() => {
|
||||||
|
const items = [
|
||||||
{
|
{
|
||||||
text: t('渠道管理'),
|
text: t('渠道管理'),
|
||||||
itemKey: 'channel',
|
itemKey: 'channel',
|
||||||
@ -150,12 +175,19 @@ const SiderBar = ({ onNavigate = () => {} }) => {
|
|||||||
to: '/setting',
|
to: '/setting',
|
||||||
className: isRoot() ? '' : 'tableHiddle',
|
className: isRoot() ? '' : 'tableHiddle',
|
||||||
},
|
},
|
||||||
],
|
];
|
||||||
[isAdmin(), isRoot(), t],
|
|
||||||
);
|
|
||||||
|
|
||||||
const chatMenuItems = useMemo(
|
// 根据配置过滤项目
|
||||||
() => [
|
const filteredItems = items.filter((item) => {
|
||||||
|
const configVisible = isModuleVisible('admin', item.itemKey);
|
||||||
|
return configVisible;
|
||||||
|
});
|
||||||
|
|
||||||
|
return filteredItems;
|
||||||
|
}, [isAdmin(), isRoot(), t, isModuleVisible]);
|
||||||
|
|
||||||
|
const chatMenuItems = useMemo(() => {
|
||||||
|
const items = [
|
||||||
{
|
{
|
||||||
text: t('操练场'),
|
text: t('操练场'),
|
||||||
itemKey: 'playground',
|
itemKey: 'playground',
|
||||||
@ -166,9 +198,16 @@ const SiderBar = ({ onNavigate = () => {} }) => {
|
|||||||
itemKey: 'chat',
|
itemKey: 'chat',
|
||||||
items: chatItems,
|
items: chatItems,
|
||||||
},
|
},
|
||||||
],
|
];
|
||||||
[chatItems, t],
|
|
||||||
);
|
// 根据配置过滤项目
|
||||||
|
const filteredItems = items.filter((item) => {
|
||||||
|
const configVisible = isModuleVisible('chat', item.itemKey);
|
||||||
|
return configVisible;
|
||||||
|
});
|
||||||
|
|
||||||
|
return filteredItems;
|
||||||
|
}, [chatItems, t, isModuleVisible]);
|
||||||
|
|
||||||
// 更新路由映射,添加聊天路由
|
// 更新路由映射,添加聊天路由
|
||||||
const updateRouterMapWithChats = (chats) => {
|
const updateRouterMapWithChats = (chats) => {
|
||||||
@ -213,7 +252,6 @@ const SiderBar = ({ onNavigate = () => {} }) => {
|
|||||||
updateRouterMapWithChats(chats);
|
updateRouterMapWithChats(chats);
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.error(e);
|
|
||||||
showError('聊天数据解析失败');
|
showError('聊天数据解析失败');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -267,14 +305,12 @@ const SiderBar = ({ onNavigate = () => {} }) => {
|
|||||||
key={item.itemKey}
|
key={item.itemKey}
|
||||||
itemKey={item.itemKey}
|
itemKey={item.itemKey}
|
||||||
text={
|
text={
|
||||||
<div className='flex items-center'>
|
<span
|
||||||
<span
|
className='truncate font-medium text-sm'
|
||||||
className='truncate font-medium text-sm'
|
style={{ color: textColor }}
|
||||||
style={{ color: textColor }}
|
>
|
||||||
>
|
{item.text}
|
||||||
{item.text}
|
</span>
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
}
|
}
|
||||||
icon={
|
icon={
|
||||||
<div className='sidebar-icon-container flex-shrink-0'>
|
<div className='sidebar-icon-container flex-shrink-0'>
|
||||||
@ -297,14 +333,12 @@ const SiderBar = ({ onNavigate = () => {} }) => {
|
|||||||
key={item.itemKey}
|
key={item.itemKey}
|
||||||
itemKey={item.itemKey}
|
itemKey={item.itemKey}
|
||||||
text={
|
text={
|
||||||
<div className='flex items-center'>
|
<span
|
||||||
<span
|
className='truncate font-medium text-sm'
|
||||||
className='truncate font-medium text-sm'
|
style={{ color: textColor }}
|
||||||
style={{ color: textColor }}
|
>
|
||||||
>
|
{item.text}
|
||||||
{item.text}
|
</span>
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
}
|
}
|
||||||
icon={
|
icon={
|
||||||
<div className='sidebar-icon-container flex-shrink-0'>
|
<div className='sidebar-icon-container flex-shrink-0'>
|
||||||
@ -341,110 +375,142 @@ const SiderBar = ({ onNavigate = () => {} }) => {
|
|||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className='sidebar-container'
|
className='sidebar-container'
|
||||||
style={{ width: 'var(--sidebar-current-width)' }}
|
style={{
|
||||||
|
width: 'var(--sidebar-current-width)',
|
||||||
|
background: 'var(--semi-color-bg-0)',
|
||||||
|
}}
|
||||||
>
|
>
|
||||||
<Nav
|
<SkeletonWrapper
|
||||||
className='sidebar-nav'
|
loading={showSkeleton}
|
||||||
defaultIsCollapsed={collapsed}
|
type='sidebar'
|
||||||
isCollapsed={collapsed}
|
className=''
|
||||||
onCollapseChange={toggleCollapsed}
|
collapsed={collapsed}
|
||||||
selectedKeys={selectedKeys}
|
showAdmin={isAdmin()}
|
||||||
itemStyle='sidebar-nav-item'
|
|
||||||
hoverStyle='sidebar-nav-item:hover'
|
|
||||||
selectedStyle='sidebar-nav-item-selected'
|
|
||||||
renderWrapper={({ itemElement, props }) => {
|
|
||||||
const to = routerMapState[props.itemKey] || routerMap[props.itemKey];
|
|
||||||
|
|
||||||
// 如果没有路由,直接返回元素
|
|
||||||
if (!to) return itemElement;
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Link
|
|
||||||
style={{ textDecoration: 'none' }}
|
|
||||||
to={to}
|
|
||||||
onClick={onNavigate}
|
|
||||||
>
|
|
||||||
{itemElement}
|
|
||||||
</Link>
|
|
||||||
);
|
|
||||||
}}
|
|
||||||
onSelect={(key) => {
|
|
||||||
// 如果点击的是已经展开的子菜单的父项,则收起子菜单
|
|
||||||
if (openedKeys.includes(key.itemKey)) {
|
|
||||||
setOpenedKeys(openedKeys.filter((k) => k !== key.itemKey));
|
|
||||||
}
|
|
||||||
|
|
||||||
setSelectedKeys([key.itemKey]);
|
|
||||||
}}
|
|
||||||
openKeys={openedKeys}
|
|
||||||
onOpenChange={(data) => {
|
|
||||||
setOpenedKeys(data.openKeys);
|
|
||||||
}}
|
|
||||||
>
|
>
|
||||||
{/* 聊天区域 */}
|
<Nav
|
||||||
<div className='sidebar-section'>
|
className='sidebar-nav'
|
||||||
{!collapsed && <div className='sidebar-group-label'>{t('聊天')}</div>}
|
defaultIsCollapsed={collapsed}
|
||||||
{chatMenuItems.map((item) => renderSubItem(item))}
|
isCollapsed={collapsed}
|
||||||
</div>
|
onCollapseChange={toggleCollapsed}
|
||||||
|
selectedKeys={selectedKeys}
|
||||||
|
itemStyle='sidebar-nav-item'
|
||||||
|
hoverStyle='sidebar-nav-item:hover'
|
||||||
|
selectedStyle='sidebar-nav-item-selected'
|
||||||
|
renderWrapper={({ itemElement, props }) => {
|
||||||
|
const to =
|
||||||
|
routerMapState[props.itemKey] || routerMap[props.itemKey];
|
||||||
|
|
||||||
{/* 控制台区域 */}
|
// 如果没有路由,直接返回元素
|
||||||
<Divider className='sidebar-divider' />
|
if (!to) return itemElement;
|
||||||
<div>
|
|
||||||
{!collapsed && (
|
|
||||||
<div className='sidebar-group-label'>{t('控制台')}</div>
|
|
||||||
)}
|
|
||||||
{workspaceItems.map((item) => renderNavItem(item))}
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* 个人中心区域 */}
|
return (
|
||||||
<Divider className='sidebar-divider' />
|
<Link
|
||||||
<div>
|
style={{ textDecoration: 'none' }}
|
||||||
{!collapsed && (
|
to={to}
|
||||||
<div className='sidebar-group-label'>{t('个人中心')}</div>
|
onClick={onNavigate}
|
||||||
)}
|
>
|
||||||
{financeItems.map((item) => renderNavItem(item))}
|
{itemElement}
|
||||||
</div>
|
</Link>
|
||||||
|
);
|
||||||
|
}}
|
||||||
|
onSelect={(key) => {
|
||||||
|
// 如果点击的是已经展开的子菜单的父项,则收起子菜单
|
||||||
|
if (openedKeys.includes(key.itemKey)) {
|
||||||
|
setOpenedKeys(openedKeys.filter((k) => k !== key.itemKey));
|
||||||
|
}
|
||||||
|
|
||||||
{/* 管理员区域 - 只在管理员时显示 */}
|
setSelectedKeys([key.itemKey]);
|
||||||
{isAdmin() && (
|
}}
|
||||||
<>
|
openKeys={openedKeys}
|
||||||
<Divider className='sidebar-divider' />
|
onOpenChange={(data) => {
|
||||||
<div>
|
setOpenedKeys(data.openKeys);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{/* 聊天区域 */}
|
||||||
|
{hasSectionVisibleModules('chat') && (
|
||||||
|
<div className='sidebar-section'>
|
||||||
{!collapsed && (
|
{!collapsed && (
|
||||||
<div className='sidebar-group-label'>{t('管理员')}</div>
|
<div className='sidebar-group-label'>{t('聊天')}</div>
|
||||||
)}
|
)}
|
||||||
{adminItems.map((item) => renderNavItem(item))}
|
{chatMenuItems.map((item) => renderSubItem(item))}
|
||||||
</div>
|
</div>
|
||||||
</>
|
)}
|
||||||
)}
|
|
||||||
</Nav>
|
{/* 控制台区域 */}
|
||||||
|
{hasSectionVisibleModules('console') && (
|
||||||
|
<>
|
||||||
|
<Divider className='sidebar-divider' />
|
||||||
|
<div>
|
||||||
|
{!collapsed && (
|
||||||
|
<div className='sidebar-group-label'>{t('控制台')}</div>
|
||||||
|
)}
|
||||||
|
{workspaceItems.map((item) => renderNavItem(item))}
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* 个人中心区域 */}
|
||||||
|
{hasSectionVisibleModules('personal') && (
|
||||||
|
<>
|
||||||
|
<Divider className='sidebar-divider' />
|
||||||
|
<div>
|
||||||
|
{!collapsed && (
|
||||||
|
<div className='sidebar-group-label'>{t('个人中心')}</div>
|
||||||
|
)}
|
||||||
|
{financeItems.map((item) => renderNavItem(item))}
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* 管理员区域 - 只在管理员时显示且配置允许时显示 */}
|
||||||
|
{isAdmin() && hasSectionVisibleModules('admin') && (
|
||||||
|
<>
|
||||||
|
<Divider className='sidebar-divider' />
|
||||||
|
<div>
|
||||||
|
{!collapsed && (
|
||||||
|
<div className='sidebar-group-label'>{t('管理员')}</div>
|
||||||
|
)}
|
||||||
|
{adminItems.map((item) => renderNavItem(item))}
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</Nav>
|
||||||
|
</SkeletonWrapper>
|
||||||
|
|
||||||
{/* 底部折叠按钮 */}
|
{/* 底部折叠按钮 */}
|
||||||
<div className='sidebar-collapse-button'>
|
<div className='sidebar-collapse-button'>
|
||||||
<Button
|
<SkeletonWrapper
|
||||||
theme='outline'
|
loading={showSkeleton}
|
||||||
type='tertiary'
|
type='button'
|
||||||
size='small'
|
width={collapsed ? 36 : 156}
|
||||||
icon={
|
height={24}
|
||||||
<ChevronLeft
|
className='w-full'
|
||||||
size={16}
|
|
||||||
strokeWidth={2.5}
|
|
||||||
color='var(--semi-color-text-2)'
|
|
||||||
style={{
|
|
||||||
transform: collapsed ? 'rotate(180deg)' : 'rotate(0deg)',
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
}
|
|
||||||
onClick={toggleCollapsed}
|
|
||||||
icononly={collapsed}
|
|
||||||
style={
|
|
||||||
collapsed
|
|
||||||
? { padding: '4px', width: '100%' }
|
|
||||||
: { padding: '4px 12px', width: '100%' }
|
|
||||||
}
|
|
||||||
>
|
>
|
||||||
{!collapsed ? t('收起侧边栏') : null}
|
<Button
|
||||||
</Button>
|
theme='outline'
|
||||||
|
type='tertiary'
|
||||||
|
size='small'
|
||||||
|
icon={
|
||||||
|
<ChevronLeft
|
||||||
|
size={16}
|
||||||
|
strokeWidth={2.5}
|
||||||
|
color='var(--semi-color-text-2)'
|
||||||
|
style={{
|
||||||
|
transform: collapsed ? 'rotate(180deg)' : 'rotate(0deg)',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
onClick={toggleCollapsed}
|
||||||
|
icononly={collapsed}
|
||||||
|
style={
|
||||||
|
collapsed
|
||||||
|
? { width: 36, height: 24, padding: 0 }
|
||||||
|
: { padding: '4px 12px', width: '100%' }
|
||||||
|
}
|
||||||
|
>
|
||||||
|
{!collapsed ? t('收起侧边栏') : null}
|
||||||
|
</Button>
|
||||||
|
</SkeletonWrapper>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|||||||
394
web/src/components/layout/components/SkeletonWrapper.jsx
Normal file
394
web/src/components/layout/components/SkeletonWrapper.jsx
Normal file
@ -0,0 +1,394 @@
|
|||||||
|
/*
|
||||||
|
Copyright (C) 2025 QuantumNous
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Affero General Public License as
|
||||||
|
published by the Free Software Foundation, either version 3 of the
|
||||||
|
License, or (at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Affero General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
For commercial licensing, please contact support@quantumnous.com
|
||||||
|
*/
|
||||||
|
|
||||||
|
import React from 'react';
|
||||||
|
import { Skeleton } from '@douyinfe/semi-ui';
|
||||||
|
|
||||||
|
const SkeletonWrapper = ({
|
||||||
|
loading = false,
|
||||||
|
type = 'text',
|
||||||
|
count = 1,
|
||||||
|
width = 60,
|
||||||
|
height = 16,
|
||||||
|
isMobile = false,
|
||||||
|
className = '',
|
||||||
|
collapsed = false,
|
||||||
|
showAdmin = true,
|
||||||
|
children,
|
||||||
|
...props
|
||||||
|
}) => {
|
||||||
|
if (!loading) {
|
||||||
|
return children;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 导航链接骨架屏
|
||||||
|
const renderNavigationSkeleton = () => {
|
||||||
|
const skeletonLinkClasses = isMobile
|
||||||
|
? 'flex items-center gap-1 p-1 w-full rounded-md'
|
||||||
|
: 'flex items-center gap-1 p-2 rounded-md';
|
||||||
|
|
||||||
|
return Array(count)
|
||||||
|
.fill(null)
|
||||||
|
.map((_, index) => (
|
||||||
|
<div key={index} className={skeletonLinkClasses}>
|
||||||
|
<Skeleton
|
||||||
|
loading={true}
|
||||||
|
active
|
||||||
|
placeholder={
|
||||||
|
<Skeleton.Title
|
||||||
|
active
|
||||||
|
style={{ width: isMobile ? 40 : width, height }}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
// 用户区域骨架屏 (头像 + 文本)
|
||||||
|
const renderUserAreaSkeleton = () => {
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className={`flex items-center p-1 rounded-full bg-semi-color-fill-0 dark:bg-semi-color-fill-1 ${className}`}
|
||||||
|
>
|
||||||
|
<Skeleton
|
||||||
|
loading={true}
|
||||||
|
active
|
||||||
|
placeholder={
|
||||||
|
<Skeleton.Avatar active size='extra-small' className='shadow-sm' />
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
<div className='ml-1.5 mr-1'>
|
||||||
|
<Skeleton
|
||||||
|
loading={true}
|
||||||
|
active
|
||||||
|
placeholder={
|
||||||
|
<Skeleton.Title
|
||||||
|
active
|
||||||
|
style={{ width: isMobile ? 15 : width, height: 12 }}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Logo图片骨架屏
|
||||||
|
const renderImageSkeleton = () => {
|
||||||
|
return (
|
||||||
|
<Skeleton
|
||||||
|
loading={true}
|
||||||
|
active
|
||||||
|
placeholder={
|
||||||
|
<Skeleton.Image
|
||||||
|
active
|
||||||
|
className={`absolute inset-0 !rounded-full ${className}`}
|
||||||
|
style={{ width: '100%', height: '100%' }}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
// 系统名称骨架屏
|
||||||
|
const renderTitleSkeleton = () => {
|
||||||
|
return (
|
||||||
|
<Skeleton
|
||||||
|
loading={true}
|
||||||
|
active
|
||||||
|
placeholder={<Skeleton.Title active style={{ width, height: 24 }} />}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
// 通用文本骨架屏
|
||||||
|
const renderTextSkeleton = () => {
|
||||||
|
return (
|
||||||
|
<div className={className}>
|
||||||
|
<Skeleton
|
||||||
|
loading={true}
|
||||||
|
active
|
||||||
|
placeholder={<Skeleton.Title active style={{ width, height }} />}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
// 按钮骨架屏(支持圆角)
|
||||||
|
const renderButtonSkeleton = () => {
|
||||||
|
return (
|
||||||
|
<div className={className}>
|
||||||
|
<Skeleton
|
||||||
|
loading={true}
|
||||||
|
active
|
||||||
|
placeholder={
|
||||||
|
<Skeleton.Title
|
||||||
|
active
|
||||||
|
style={{ width, height, borderRadius: 9999 }}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
// 侧边栏导航项骨架屏 (图标 + 文本)
|
||||||
|
const renderSidebarNavItemSkeleton = () => {
|
||||||
|
return Array(count)
|
||||||
|
.fill(null)
|
||||||
|
.map((_, index) => (
|
||||||
|
<div
|
||||||
|
key={index}
|
||||||
|
className={`flex items-center p-2 mb-1 rounded-md ${className}`}
|
||||||
|
>
|
||||||
|
{/* 图标骨架屏 */}
|
||||||
|
<div className='sidebar-icon-container flex-shrink-0'>
|
||||||
|
<Skeleton
|
||||||
|
loading={true}
|
||||||
|
active
|
||||||
|
placeholder={
|
||||||
|
<Skeleton.Avatar active size='extra-small' shape='square' />
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
{/* 文本骨架屏 */}
|
||||||
|
<Skeleton
|
||||||
|
loading={true}
|
||||||
|
active
|
||||||
|
placeholder={
|
||||||
|
<Skeleton.Title
|
||||||
|
active
|
||||||
|
style={{ width: width || 80, height: height || 14 }}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
// 侧边栏组标题骨架屏
|
||||||
|
const renderSidebarGroupTitleSkeleton = () => {
|
||||||
|
return (
|
||||||
|
<div className={`mb-2 ${className}`}>
|
||||||
|
<Skeleton
|
||||||
|
loading={true}
|
||||||
|
active
|
||||||
|
placeholder={
|
||||||
|
<Skeleton.Title
|
||||||
|
active
|
||||||
|
style={{ width: width || 60, height: height || 12 }}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
// 完整侧边栏骨架屏 - 1:1 还原,去重实现
|
||||||
|
const renderSidebarSkeleton = () => {
|
||||||
|
const NAV_WIDTH = 164;
|
||||||
|
const NAV_HEIGHT = 30;
|
||||||
|
const COLLAPSED_WIDTH = 44;
|
||||||
|
const COLLAPSED_HEIGHT = 44;
|
||||||
|
const ICON_SIZE = 16;
|
||||||
|
const TITLE_HEIGHT = 12;
|
||||||
|
const TEXT_HEIGHT = 16;
|
||||||
|
|
||||||
|
const renderIcon = () => (
|
||||||
|
<Skeleton
|
||||||
|
loading={true}
|
||||||
|
active
|
||||||
|
placeholder={
|
||||||
|
<Skeleton.Avatar
|
||||||
|
active
|
||||||
|
shape='square'
|
||||||
|
style={{ width: ICON_SIZE, height: ICON_SIZE }}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
|
||||||
|
const renderLabel = (labelWidth) => (
|
||||||
|
<Skeleton
|
||||||
|
loading={true}
|
||||||
|
active
|
||||||
|
placeholder={
|
||||||
|
<Skeleton.Title
|
||||||
|
active
|
||||||
|
style={{ width: labelWidth, height: TEXT_HEIGHT }}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
|
||||||
|
const NavRow = ({ labelWidth }) => (
|
||||||
|
<div
|
||||||
|
className='flex items-center p-2 mb-1 rounded-md'
|
||||||
|
style={{
|
||||||
|
width: `${NAV_WIDTH}px`,
|
||||||
|
height: `${NAV_HEIGHT}px`,
|
||||||
|
margin: '3px 8px',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<div className='sidebar-icon-container flex-shrink-0'>
|
||||||
|
{renderIcon()}
|
||||||
|
</div>
|
||||||
|
{renderLabel(labelWidth)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
|
||||||
|
const CollapsedRow = ({ keyPrefix, index }) => (
|
||||||
|
<div
|
||||||
|
key={`${keyPrefix}-${index}`}
|
||||||
|
className='flex items-center justify-center'
|
||||||
|
style={{
|
||||||
|
width: `${COLLAPSED_WIDTH}px`,
|
||||||
|
height: `${COLLAPSED_HEIGHT}px`,
|
||||||
|
margin: '0 8px 4px 8px',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Skeleton
|
||||||
|
loading={true}
|
||||||
|
active
|
||||||
|
placeholder={
|
||||||
|
<Skeleton.Avatar
|
||||||
|
active
|
||||||
|
shape='square'
|
||||||
|
style={{ width: ICON_SIZE, height: ICON_SIZE }}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
|
||||||
|
if (collapsed) {
|
||||||
|
return (
|
||||||
|
<div className={`w-full ${className}`} style={{ paddingTop: '12px' }}>
|
||||||
|
{Array(2)
|
||||||
|
.fill(null)
|
||||||
|
.map((_, i) => (
|
||||||
|
<CollapsedRow keyPrefix='c-chat' index={i} />
|
||||||
|
))}
|
||||||
|
{Array(5)
|
||||||
|
.fill(null)
|
||||||
|
.map((_, i) => (
|
||||||
|
<CollapsedRow keyPrefix='c-console' index={i} />
|
||||||
|
))}
|
||||||
|
{Array(2)
|
||||||
|
.fill(null)
|
||||||
|
.map((_, i) => (
|
||||||
|
<CollapsedRow keyPrefix='c-personal' index={i} />
|
||||||
|
))}
|
||||||
|
{Array(5)
|
||||||
|
.fill(null)
|
||||||
|
.map((_, i) => (
|
||||||
|
<CollapsedRow keyPrefix='c-admin' index={i} />
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const sections = [
|
||||||
|
{ key: 'chat', titleWidth: 32, itemWidths: [54, 32], wrapper: 'section' },
|
||||||
|
{ key: 'console', titleWidth: 48, itemWidths: [64, 64, 64, 64, 64] },
|
||||||
|
{ key: 'personal', titleWidth: 64, itemWidths: [64, 64] },
|
||||||
|
...(showAdmin
|
||||||
|
? [{ key: 'admin', titleWidth: 48, itemWidths: [64, 64, 80, 64, 64] }]
|
||||||
|
: []),
|
||||||
|
];
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={`w-full ${className}`} style={{ paddingTop: '12px' }}>
|
||||||
|
{sections.map((sec, idx) => (
|
||||||
|
<React.Fragment key={sec.key}>
|
||||||
|
{sec.wrapper === 'section' ? (
|
||||||
|
<div className='sidebar-section'>
|
||||||
|
<div
|
||||||
|
className='sidebar-group-label'
|
||||||
|
style={{ padding: '4px 15px 8px' }}
|
||||||
|
>
|
||||||
|
<Skeleton
|
||||||
|
loading={true}
|
||||||
|
active
|
||||||
|
placeholder={
|
||||||
|
<Skeleton.Title
|
||||||
|
active
|
||||||
|
style={{ width: sec.titleWidth, height: TITLE_HEIGHT }}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
{sec.itemWidths.map((w, i) => (
|
||||||
|
<NavRow key={`${sec.key}-${i}`} labelWidth={w} />
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<div>
|
||||||
|
<div
|
||||||
|
className='sidebar-group-label'
|
||||||
|
style={{ padding: '4px 15px 8px' }}
|
||||||
|
>
|
||||||
|
<Skeleton
|
||||||
|
loading={true}
|
||||||
|
active
|
||||||
|
placeholder={
|
||||||
|
<Skeleton.Title
|
||||||
|
active
|
||||||
|
style={{ width: sec.titleWidth, height: TITLE_HEIGHT }}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
{sec.itemWidths.map((w, i) => (
|
||||||
|
<NavRow key={`${sec.key}-${i}`} labelWidth={w} />
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</React.Fragment>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
// 根据类型渲染不同的骨架屏
|
||||||
|
switch (type) {
|
||||||
|
case 'navigation':
|
||||||
|
return renderNavigationSkeleton();
|
||||||
|
case 'userArea':
|
||||||
|
return renderUserAreaSkeleton();
|
||||||
|
case 'image':
|
||||||
|
return renderImageSkeleton();
|
||||||
|
case 'title':
|
||||||
|
return renderTitleSkeleton();
|
||||||
|
case 'sidebarNavItem':
|
||||||
|
return renderSidebarNavItemSkeleton();
|
||||||
|
case 'sidebarGroupTitle':
|
||||||
|
return renderSidebarGroupTitleSkeleton();
|
||||||
|
case 'sidebar':
|
||||||
|
return renderSidebarSkeleton();
|
||||||
|
case 'button':
|
||||||
|
return renderButtonSkeleton();
|
||||||
|
case 'text':
|
||||||
|
default:
|
||||||
|
return renderTextSkeleton();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
export default SkeletonWrapper;
|
||||||
@ -20,7 +20,7 @@ For commercial licensing, please contact support@quantumnous.com
|
|||||||
import React from 'react';
|
import React from 'react';
|
||||||
import { Link } from 'react-router-dom';
|
import { Link } from 'react-router-dom';
|
||||||
import { Typography, Tag } from '@douyinfe/semi-ui';
|
import { Typography, Tag } from '@douyinfe/semi-ui';
|
||||||
import SkeletonWrapper from './SkeletonWrapper';
|
import SkeletonWrapper from '../components/SkeletonWrapper';
|
||||||
|
|
||||||
const HeaderLogo = ({
|
const HeaderLogo = ({
|
||||||
isMobile,
|
isMobile,
|
||||||
@ -19,9 +19,15 @@ For commercial licensing, please contact support@quantumnous.com
|
|||||||
|
|
||||||
import React from 'react';
|
import React from 'react';
|
||||||
import { Link } from 'react-router-dom';
|
import { Link } from 'react-router-dom';
|
||||||
import SkeletonWrapper from './SkeletonWrapper';
|
import SkeletonWrapper from '../components/SkeletonWrapper';
|
||||||
|
|
||||||
const Navigation = ({ mainNavLinks, isMobile, isLoading, userState }) => {
|
const Navigation = ({
|
||||||
|
mainNavLinks,
|
||||||
|
isMobile,
|
||||||
|
isLoading,
|
||||||
|
userState,
|
||||||
|
pricingRequireAuth,
|
||||||
|
}) => {
|
||||||
const renderNavLinks = () => {
|
const renderNavLinks = () => {
|
||||||
const baseClasses =
|
const baseClasses =
|
||||||
'flex-shrink-0 flex items-center gap-1 font-semibold rounded-md transition-all duration-200 ease-in-out';
|
'flex-shrink-0 flex items-center gap-1 font-semibold rounded-md transition-all duration-200 ease-in-out';
|
||||||
@ -51,6 +57,9 @@ const Navigation = ({ mainNavLinks, isMobile, isLoading, userState }) => {
|
|||||||
if (link.itemKey === 'console' && !userState.user) {
|
if (link.itemKey === 'console' && !userState.user) {
|
||||||
targetPath = '/login';
|
targetPath = '/login';
|
||||||
}
|
}
|
||||||
|
if (link.itemKey === 'pricing' && pricingRequireAuth && !userState.user) {
|
||||||
|
targetPath = '/login';
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Link key={link.itemKey} to={targetPath} className={commonLinkClasses}>
|
<Link key={link.itemKey} to={targetPath} className={commonLinkClasses}>
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user