diff --git a/controller/model.go b/controller/model.go index 4dbd4583..cc2b1eff 100644 --- a/controller/model.go +++ b/controller/model.go @@ -3,6 +3,7 @@ package controller import ( "fmt" "net/http" + "strings" "time" "github.com/QuantumNous/new-api/common" @@ -109,9 +110,102 @@ func init() { }) } -func ListModels(c *gin.Context, modelType int) { - userOpenAiModels := make([]dto.OpenAIModels, 0) +func channelOwnerName(channelType int) string { + apiType, success := common.ChannelType2APIType(channelType) + if !success { + return strings.ToLower(constant.GetChannelTypeName(channelType)) + } + adaptor := relay.GetAdaptor(apiType) + if adaptor == nil { + return strings.ToLower(constant.GetChannelTypeName(channelType)) + } + adaptor.Init(&relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{ + ChannelType: channelType, + }}) + if name := strings.TrimSpace(adaptor.GetChannelName()); name != "" { + return name + } + return strings.ToLower(constant.GetChannelTypeName(channelType)) +} +func getPreferredModelOwners(modelNames []string, groups []string) map[string]string { + channelTypes, err := model.GetPreferredModelOwnerChannelTypes(modelNames, groups) + if err != nil { + common.SysLog(fmt.Sprintf("GetPreferredModelOwnerChannelTypes error: %v", err)) + return map[string]string{} + } + + ownerByChannelType := make(map[int]string) + owners := make(map[string]string, len(channelTypes)) + for modelName, channelType := range channelTypes { + owner, ok := ownerByChannelType[channelType] + if !ok { + owner = channelOwnerName(channelType) + ownerByChannelType[channelType] = owner + } + if owner != "" { + owners[modelName] = owner + } + } + return owners +} + +func buildOpenAIModel(modelName string, ownerByModel map[string]string) dto.OpenAIModels { + var oaiModel dto.OpenAIModels + if staticModel, ok := openAIModelsMap[modelName]; ok { + oaiModel = staticModel + } else { + oaiModel = dto.OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: "custom", + } + } + if owner, ok := ownerByModel[modelName]; ok && owner != "" { + oaiModel.OwnedBy = owner + } + oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName) + return oaiModel +} + +type modelListGroups struct { + userGroup string + tokenGroup string + ownerGroups []string +} + +func getModelListGroups(c *gin.Context) (modelListGroups, error) { + tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup) + userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup) + if userGroup == "" && (tokenGroup == "" || tokenGroup == "auto") { + var err error + userGroup, err = model.GetUserGroup(c.GetInt("id"), false) + if err != nil { + return modelListGroups{}, err + } + } + + if tokenGroup == "auto" { + return modelListGroups{ + userGroup: userGroup, + tokenGroup: tokenGroup, + ownerGroups: service.GetUserAutoGroup(userGroup), + }, nil + } + + group := userGroup + if tokenGroup != "" { + group = tokenGroup + } + return modelListGroups{ + userGroup: userGroup, + tokenGroup: tokenGroup, + ownerGroups: []string{group}, + }, nil +} + +func ListModels(c *gin.Context, modelType int) { acceptUnsetRatioModel := operation_setting.SelfUseModeEnabled if !acceptUnsetRatioModel { userId := c.GetInt("id") @@ -123,6 +217,16 @@ func ListModels(c *gin.Context, modelType int) { } } + userModelNames := make([]string, 0) + groups, err := getModelListGroups(c) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "get user group failed", + }) + return + } + ownerGroups := groups.ownerGroups modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled) if modelLimitEnable { s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit) @@ -138,37 +242,12 @@ func ListModels(c *gin.Context, modelType int) { continue } } - if oaiModel, ok := openAIModelsMap[allowModel]; ok { - oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel) - userOpenAiModels = append(userOpenAiModels, oaiModel) - } else { - userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{ - Id: allowModel, - Object: "model", - Created: 1626777600, - OwnedBy: "custom", - SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel), - }) - } + userModelNames = append(userModelNames, allowModel) } } else { - userId := c.GetInt("id") - userGroup, err := model.GetUserGroup(userId, false) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "get user group failed", - }) - return - } - group := userGroup - tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup) - if tokenGroup != "" { - group = tokenGroup - } var models []string - if tokenGroup == "auto" { - for _, autoGroup := range service.GetUserAutoGroup(userGroup) { + if groups.tokenGroup == "auto" { + for _, autoGroup := range ownerGroups { groupModels := model.GetGroupEnabledModels(autoGroup) for _, g := range groupModels { if !common.StringsContains(models, g) { @@ -177,7 +256,7 @@ func ListModels(c *gin.Context, modelType int) { } } } else { - models = model.GetGroupEnabledModels(group) + models = model.GetGroupEnabledModels(ownerGroups[0]) } for _, modelName := range models { if !acceptUnsetRatioModel { @@ -185,21 +264,19 @@ func ListModels(c *gin.Context, modelType int) { continue } } - if oaiModel, ok := openAIModelsMap[modelName]; ok { - oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName) - userOpenAiModels = append(userOpenAiModels, oaiModel) - } else { - userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: "custom", - SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName), - }) - } + userModelNames = append(userModelNames, modelName) } } + ownerByModel := map[string]string{} + if len(ownerGroups) > 0 { + ownerByModel = getPreferredModelOwners(userModelNames, ownerGroups) + } + userOpenAiModels := make([]dto.OpenAIModels, 0, len(userModelNames)) + for _, modelName := range userModelNames { + userOpenAiModels = append(userOpenAiModels, buildOpenAIModel(modelName, ownerByModel)) + } + switch modelType { case constant.ChannelTypeAnthropic: useranthropicModels := make([]dto.AnthropicModel, len(userOpenAiModels)) diff --git a/controller/model_owned_by_test.go b/controller/model_owned_by_test.go new file mode 100644 index 00000000..bc2ef32f --- /dev/null +++ b/controller/model_owned_by_test.go @@ -0,0 +1,85 @@ +package controller + +import ( + "net/http/httptest" + "testing" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/constant" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestChannelOwnerNameUsesAdaptorChannelName(t *testing.T) { + tests := []struct { + name string + channelType int + expected string + }{ + { + name: "openai", + channelType: constant.ChannelTypeOpenAI, + expected: "openai", + }, + { + name: "codex", + channelType: constant.ChannelTypeCodex, + expected: "codex", + }, + { + name: "openrouter", + channelType: constant.ChannelTypeOpenRouter, + expected: "openrouter", + }, + { + name: "azure fallback", + channelType: constant.ChannelTypeAzure, + expected: "azure", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, channelOwnerName(tt.channelType)) + }) + } +} + +func TestBuildOpenAIModelOverridesOwnedBy(t *testing.T) { + modelItem := buildOpenAIModel("gpt-5.4", map[string]string{"gpt-5.4": "openai"}) + require.Equal(t, "gpt-5.4", modelItem.Id) + require.Equal(t, "openai", modelItem.OwnedBy) +} + +func TestBuildOpenAIModelFallsBackToCustomForUnknownModels(t *testing.T) { + modelItem := buildOpenAIModel("custom-test-model", nil) + require.Equal(t, "custom-test-model", modelItem.Id) + require.Equal(t, "custom", modelItem.OwnedBy) +} + +func TestGetModelListGroupsUsesUserGroupWhenTokenGroupIsEmpty(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, _ := gin.CreateTestContext(httptest.NewRecorder()) + common.SetContextKey(ctx, constant.ContextKeyUserGroup, "default") + + groups, err := getModelListGroups(ctx) + require.NoError(t, err) + + require.Equal(t, "default", groups.userGroup) + require.Empty(t, groups.tokenGroup) + require.Equal(t, []string{"default"}, groups.ownerGroups) +} + +func TestGetModelListGroupsUsesExplicitTokenGroup(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, _ := gin.CreateTestContext(httptest.NewRecorder()) + common.SetContextKey(ctx, constant.ContextKeyUserGroup, "default") + common.SetContextKey(ctx, constant.ContextKeyTokenGroup, "vip") + + groups, err := getModelListGroups(ctx) + require.NoError(t, err) + + require.Equal(t, "default", groups.userGroup) + require.Equal(t, "vip", groups.tokenGroup) + require.Equal(t, []string{"vip"}, groups.ownerGroups) +} diff --git a/model/model_meta.go b/model/model_meta.go index 860b9602..86421277 100644 --- a/model/model_meta.go +++ b/model/model_meta.go @@ -2,6 +2,7 @@ package model import ( "strconv" + "strings" "github.com/QuantumNous/new-api/common" @@ -135,6 +136,62 @@ func GetBoundChannelsByModelsMap(modelNames []string) (map[string][]BoundChannel return result, nil } +func normalizeLookupValues(values []string) []string { + seen := make(map[string]struct{}, len(values)) + normalized := make([]string, 0, len(values)) + for _, value := range values { + value = strings.TrimSpace(value) + if value == "" { + continue + } + if _, ok := seen[value]; ok { + continue + } + seen[value] = struct{}{} + normalized = append(normalized, value) + } + return normalized +} + +func GetPreferredModelOwnerChannelTypes(modelNames []string, groups []string) (map[string]int, error) { + result := make(map[string]int) + modelNames = normalizeLookupValues(modelNames) + if len(modelNames) == 0 { + return result, nil + } + + type row struct { + Model string + ChannelType int + } + var rows []row + + query := DB.Table("abilities"). + Select("abilities.model as model, channels.type as channel_type"). + Joins("JOIN channels ON abilities.channel_id = channels.id"). + Where("abilities.model IN ? AND abilities.enabled = ? AND channels.status = ?", modelNames, true, common.ChannelStatusEnabled). + Order("COALESCE(abilities.priority, 0) DESC"). + Order("abilities.weight DESC"). + Order("abilities.channel_id ASC") + + groups = normalizeLookupValues(groups) + if len(groups) > 0 { + query = query.Where("abilities."+commonGroupCol+" IN ?", groups) + } + + if err := query.Scan(&rows).Error; err != nil { + return nil, err + } + + for _, r := range rows { + if _, ok := result[r.Model]; ok { + continue + } + result[r.Model] = r.ChannelType + } + return result, nil +} + func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) { var models []*Model db := DB.Model(&Model{}) diff --git a/model/model_owner_test.go b/model/model_owner_test.go new file mode 100644 index 00000000..88766317 --- /dev/null +++ b/model/model_owner_test.go @@ -0,0 +1,141 @@ +package model + +import ( + "fmt" + "testing" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/constant" + "github.com/stretchr/testify/require" +) + +func clearPreferredOwnerTables(t *testing.T) { + t.Helper() + require.NoError(t, DB.Exec("DELETE FROM abilities").Error) + require.NoError(t, DB.Exec("DELETE FROM channels").Error) +} + +func insertPreferredOwnerCandidate( + t *testing.T, + channelID int, + modelName string, + group string, + channelType int, + priority int64, + weight uint, + channelStatus int, + abilityEnabled bool, +) { + t.Helper() + require.NoError(t, DB.Create(&Channel{ + Id: channelID, + Type: channelType, + Key: fmt.Sprintf("key-%d", channelID), + Status: channelStatus, + Name: fmt.Sprintf("channel-%d", channelID), + }).Error) + require.NoError(t, DB.Create(&Ability{ + Group: group, + Model: modelName, + ChannelId: channelID, + Enabled: abilityEnabled, + Priority: &priority, + Weight: weight, + }).Error) +} + +func TestGetPreferredModelOwnerChannelTypes(t *testing.T) { + const modelName = "gpt-5.4" + + tests := []struct { + name string + setup func(t *testing.T) + groups []string + expected int + found bool + }{ + { + name: "openai only", + setup: func(t *testing.T) { + insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 0, 0, common.ChannelStatusEnabled, true) + }, + groups: []string{"default"}, + expected: constant.ChannelTypeOpenAI, + found: true, + }, + { + name: "codex only", + setup: func(t *testing.T) { + insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeCodex, 0, 0, common.ChannelStatusEnabled, true) + }, + groups: []string{"default"}, + expected: constant.ChannelTypeCodex, + found: true, + }, + { + name: "priority wins", + setup: func(t *testing.T) { + insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 100, common.ChannelStatusEnabled, true) + insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 2, 0, common.ChannelStatusEnabled, true) + }, + groups: []string{"default"}, + expected: constant.ChannelTypeCodex, + found: true, + }, + { + name: "weight wins when priority is equal", + setup: func(t *testing.T) { + insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 10, common.ChannelStatusEnabled, true) + insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 1, 20, common.ChannelStatusEnabled, true) + }, + groups: []string{"default"}, + expected: constant.ChannelTypeCodex, + found: true, + }, + { + name: "channel id stabilizes exact ties", + setup: func(t *testing.T) { + insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 1, 10, common.ChannelStatusEnabled, true) + insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 10, common.ChannelStatusEnabled, true) + }, + groups: []string{"default"}, + expected: constant.ChannelTypeOpenAI, + found: true, + }, + { + name: "group filter excludes other groups", + setup: func(t *testing.T) { + insertPreferredOwnerCandidate(t, 1, modelName, "vip", constant.ChannelTypeCodex, 10, 100, common.ChannelStatusEnabled, true) + insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeOpenAI, 1, 0, common.ChannelStatusEnabled, true) + }, + groups: []string{"default"}, + expected: constant.ChannelTypeOpenAI, + found: true, + }, + { + name: "disabled candidates are ignored", + setup: func(t *testing.T) { + insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeCodex, 10, 100, common.ChannelStatusEnabled, false) + insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeOpenAI, 1, 0, common.ChannelStatusManuallyDisabled, true) + }, + groups: []string{"default"}, + found: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clearPreferredOwnerTables(t) + tt.setup(t) + + owners, err := GetPreferredModelOwnerChannelTypes([]string{modelName}, tt.groups) + require.NoError(t, err) + + got, ok := owners[modelName] + require.Equal(t, tt.found, ok) + if tt.found { + require.Equal(t, tt.expected, got) + } + }) + } +} diff --git a/model/task_cas_test.go b/model/task_cas_test.go index 29455e3a..052bf638 100644 --- a/model/task_cas_test.go +++ b/model/task_cas_test.go @@ -40,6 +40,7 @@ func TestMain(m *testing.M) { &Token{}, &Log{}, &Channel{}, + &Ability{}, &TopUp{}, &SubscriptionPlan{}, &SubscriptionOrder{}, @@ -60,6 +61,7 @@ func truncateTables(t *testing.T) { DB.Exec("DELETE FROM tokens") DB.Exec("DELETE FROM logs") DB.Exec("DELETE FROM channels") + DB.Exec("DELETE FROM abilities") DB.Exec("DELETE FROM top_ups") DB.Exec("DELETE FROM subscription_orders") DB.Exec("DELETE FROM subscription_plans")