fix: resolve model owned_by from active channels (#4416)
* fix: resolve model owned_by from active channels * fix: respect token group when resolving model owners
This commit is contained in:
parent
6f11d19877
commit
006e801652
@ -3,6 +3,7 @@ package controller
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
@ -109,9 +110,102 @@ func init() {
|
||||
})
|
||||
}
|
||||
|
||||
func ListModels(c *gin.Context, modelType int) {
|
||||
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
||||
func channelOwnerName(channelType int) string {
|
||||
apiType, success := common.ChannelType2APIType(channelType)
|
||||
if !success {
|
||||
return strings.ToLower(constant.GetChannelTypeName(channelType))
|
||||
}
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
if adaptor == nil {
|
||||
return strings.ToLower(constant.GetChannelTypeName(channelType))
|
||||
}
|
||||
adaptor.Init(&relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{
|
||||
ChannelType: channelType,
|
||||
}})
|
||||
if name := strings.TrimSpace(adaptor.GetChannelName()); name != "" {
|
||||
return name
|
||||
}
|
||||
return strings.ToLower(constant.GetChannelTypeName(channelType))
|
||||
}
|
||||
|
||||
func getPreferredModelOwners(modelNames []string, groups []string) map[string]string {
|
||||
channelTypes, err := model.GetPreferredModelOwnerChannelTypes(modelNames, groups)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("GetPreferredModelOwnerChannelTypes error: %v", err))
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
ownerByChannelType := make(map[int]string)
|
||||
owners := make(map[string]string, len(channelTypes))
|
||||
for modelName, channelType := range channelTypes {
|
||||
owner, ok := ownerByChannelType[channelType]
|
||||
if !ok {
|
||||
owner = channelOwnerName(channelType)
|
||||
ownerByChannelType[channelType] = owner
|
||||
}
|
||||
if owner != "" {
|
||||
owners[modelName] = owner
|
||||
}
|
||||
}
|
||||
return owners
|
||||
}
|
||||
|
||||
func buildOpenAIModel(modelName string, ownerByModel map[string]string) dto.OpenAIModels {
|
||||
var oaiModel dto.OpenAIModels
|
||||
if staticModel, ok := openAIModelsMap[modelName]; ok {
|
||||
oaiModel = staticModel
|
||||
} else {
|
||||
oaiModel = dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
}
|
||||
}
|
||||
if owner, ok := ownerByModel[modelName]; ok && owner != "" {
|
||||
oaiModel.OwnedBy = owner
|
||||
}
|
||||
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
|
||||
return oaiModel
|
||||
}
|
||||
|
||||
type modelListGroups struct {
|
||||
userGroup string
|
||||
tokenGroup string
|
||||
ownerGroups []string
|
||||
}
|
||||
|
||||
func getModelListGroups(c *gin.Context) (modelListGroups, error) {
|
||||
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||
if userGroup == "" && (tokenGroup == "" || tokenGroup == "auto") {
|
||||
var err error
|
||||
userGroup, err = model.GetUserGroup(c.GetInt("id"), false)
|
||||
if err != nil {
|
||||
return modelListGroups{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if tokenGroup == "auto" {
|
||||
return modelListGroups{
|
||||
userGroup: userGroup,
|
||||
tokenGroup: tokenGroup,
|
||||
ownerGroups: service.GetUserAutoGroup(userGroup),
|
||||
}, nil
|
||||
}
|
||||
|
||||
group := userGroup
|
||||
if tokenGroup != "" {
|
||||
group = tokenGroup
|
||||
}
|
||||
return modelListGroups{
|
||||
userGroup: userGroup,
|
||||
tokenGroup: tokenGroup,
|
||||
ownerGroups: []string{group},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ListModels(c *gin.Context, modelType int) {
|
||||
acceptUnsetRatioModel := operation_setting.SelfUseModeEnabled
|
||||
if !acceptUnsetRatioModel {
|
||||
userId := c.GetInt("id")
|
||||
@ -123,6 +217,16 @@ func ListModels(c *gin.Context, modelType int) {
|
||||
}
|
||||
}
|
||||
|
||||
userModelNames := make([]string, 0)
|
||||
groups, err := getModelListGroups(c)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "get user group failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
ownerGroups := groups.ownerGroups
|
||||
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||
if modelLimitEnable {
|
||||
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
||||
@ -138,37 +242,12 @@ func ListModels(c *gin.Context, modelType int) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if oaiModel, ok := openAIModelsMap[allowModel]; ok {
|
||||
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
|
||||
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||||
} else {
|
||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||
Id: allowModel,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
|
||||
})
|
||||
}
|
||||
userModelNames = append(userModelNames, allowModel)
|
||||
}
|
||||
} else {
|
||||
userId := c.GetInt("id")
|
||||
userGroup, err := model.GetUserGroup(userId, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "get user group failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
group := userGroup
|
||||
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||
if tokenGroup != "" {
|
||||
group = tokenGroup
|
||||
}
|
||||
var models []string
|
||||
if tokenGroup == "auto" {
|
||||
for _, autoGroup := range service.GetUserAutoGroup(userGroup) {
|
||||
if groups.tokenGroup == "auto" {
|
||||
for _, autoGroup := range ownerGroups {
|
||||
groupModels := model.GetGroupEnabledModels(autoGroup)
|
||||
for _, g := range groupModels {
|
||||
if !common.StringsContains(models, g) {
|
||||
@ -177,7 +256,7 @@ func ListModels(c *gin.Context, modelType int) {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
models = model.GetGroupEnabledModels(group)
|
||||
models = model.GetGroupEnabledModels(ownerGroups[0])
|
||||
}
|
||||
for _, modelName := range models {
|
||||
if !acceptUnsetRatioModel {
|
||||
@ -185,21 +264,19 @@ func ListModels(c *gin.Context, modelType int) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if oaiModel, ok := openAIModelsMap[modelName]; ok {
|
||||
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
|
||||
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||||
} else {
|
||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
|
||||
})
|
||||
}
|
||||
userModelNames = append(userModelNames, modelName)
|
||||
}
|
||||
}
|
||||
|
||||
ownerByModel := map[string]string{}
|
||||
if len(ownerGroups) > 0 {
|
||||
ownerByModel = getPreferredModelOwners(userModelNames, ownerGroups)
|
||||
}
|
||||
userOpenAiModels := make([]dto.OpenAIModels, 0, len(userModelNames))
|
||||
for _, modelName := range userModelNames {
|
||||
userOpenAiModels = append(userOpenAiModels, buildOpenAIModel(modelName, ownerByModel))
|
||||
}
|
||||
|
||||
switch modelType {
|
||||
case constant.ChannelTypeAnthropic:
|
||||
useranthropicModels := make([]dto.AnthropicModel, len(userOpenAiModels))
|
||||
|
||||
85
controller/model_owned_by_test.go
Normal file
85
controller/model_owned_by_test.go
Normal file
@ -0,0 +1,85 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestChannelOwnerNameUsesAdaptorChannelName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
channelType int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "openai",
|
||||
channelType: constant.ChannelTypeOpenAI,
|
||||
expected: "openai",
|
||||
},
|
||||
{
|
||||
name: "codex",
|
||||
channelType: constant.ChannelTypeCodex,
|
||||
expected: "codex",
|
||||
},
|
||||
{
|
||||
name: "openrouter",
|
||||
channelType: constant.ChannelTypeOpenRouter,
|
||||
expected: "openrouter",
|
||||
},
|
||||
{
|
||||
name: "azure fallback",
|
||||
channelType: constant.ChannelTypeAzure,
|
||||
expected: "azure",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.expected, channelOwnerName(tt.channelType))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIModelOverridesOwnedBy(t *testing.T) {
|
||||
modelItem := buildOpenAIModel("gpt-5.4", map[string]string{"gpt-5.4": "openai"})
|
||||
require.Equal(t, "gpt-5.4", modelItem.Id)
|
||||
require.Equal(t, "openai", modelItem.OwnedBy)
|
||||
}
|
||||
|
||||
func TestBuildOpenAIModelFallsBackToCustomForUnknownModels(t *testing.T) {
|
||||
modelItem := buildOpenAIModel("custom-test-model", nil)
|
||||
require.Equal(t, "custom-test-model", modelItem.Id)
|
||||
require.Equal(t, "custom", modelItem.OwnedBy)
|
||||
}
|
||||
|
||||
func TestGetModelListGroupsUsesUserGroupWhenTokenGroupIsEmpty(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
common.SetContextKey(ctx, constant.ContextKeyUserGroup, "default")
|
||||
|
||||
groups, err := getModelListGroups(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "default", groups.userGroup)
|
||||
require.Empty(t, groups.tokenGroup)
|
||||
require.Equal(t, []string{"default"}, groups.ownerGroups)
|
||||
}
|
||||
|
||||
func TestGetModelListGroupsUsesExplicitTokenGroup(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
common.SetContextKey(ctx, constant.ContextKeyUserGroup, "default")
|
||||
common.SetContextKey(ctx, constant.ContextKeyTokenGroup, "vip")
|
||||
|
||||
groups, err := getModelListGroups(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "default", groups.userGroup)
|
||||
require.Equal(t, "vip", groups.tokenGroup)
|
||||
require.Equal(t, []string{"vip"}, groups.ownerGroups)
|
||||
}
|
||||
@ -2,6 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
|
||||
@ -135,6 +136,62 @@ func GetBoundChannelsByModelsMap(modelNames []string) (map[string][]BoundChannel
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func normalizeLookupValues(values []string) []string {
|
||||
seen := make(map[string]struct{}, len(values))
|
||||
normalized := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[value]; ok {
|
||||
continue
|
||||
}
|
||||
seen[value] = struct{}{}
|
||||
normalized = append(normalized, value)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func GetPreferredModelOwnerChannelTypes(modelNames []string, groups []string) (map[string]int, error) {
|
||||
result := make(map[string]int)
|
||||
modelNames = normalizeLookupValues(modelNames)
|
||||
if len(modelNames) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type row struct {
|
||||
Model string
|
||||
ChannelType int
|
||||
}
|
||||
var rows []row
|
||||
|
||||
query := DB.Table("abilities").
|
||||
Select("abilities.model as model, channels.type as channel_type").
|
||||
Joins("JOIN channels ON abilities.channel_id = channels.id").
|
||||
Where("abilities.model IN ? AND abilities.enabled = ? AND channels.status = ?", modelNames, true, common.ChannelStatusEnabled).
|
||||
Order("COALESCE(abilities.priority, 0) DESC").
|
||||
Order("abilities.weight DESC").
|
||||
Order("abilities.channel_id ASC")
|
||||
|
||||
groups = normalizeLookupValues(groups)
|
||||
if len(groups) > 0 {
|
||||
query = query.Where("abilities."+commonGroupCol+" IN ?", groups)
|
||||
}
|
||||
|
||||
if err := query.Scan(&rows).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, r := range rows {
|
||||
if _, ok := result[r.Model]; ok {
|
||||
continue
|
||||
}
|
||||
result[r.Model] = r.ChannelType
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
|
||||
var models []*Model
|
||||
db := DB.Model(&Model{})
|
||||
|
||||
141
model/model_owner_test.go
Normal file
141
model/model_owner_test.go
Normal file
@ -0,0 +1,141 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func clearPreferredOwnerTables(t *testing.T) {
|
||||
t.Helper()
|
||||
require.NoError(t, DB.Exec("DELETE FROM abilities").Error)
|
||||
require.NoError(t, DB.Exec("DELETE FROM channels").Error)
|
||||
}
|
||||
|
||||
func insertPreferredOwnerCandidate(
|
||||
t *testing.T,
|
||||
channelID int,
|
||||
modelName string,
|
||||
group string,
|
||||
channelType int,
|
||||
priority int64,
|
||||
weight uint,
|
||||
channelStatus int,
|
||||
abilityEnabled bool,
|
||||
) {
|
||||
t.Helper()
|
||||
require.NoError(t, DB.Create(&Channel{
|
||||
Id: channelID,
|
||||
Type: channelType,
|
||||
Key: fmt.Sprintf("key-%d", channelID),
|
||||
Status: channelStatus,
|
||||
Name: fmt.Sprintf("channel-%d", channelID),
|
||||
}).Error)
|
||||
require.NoError(t, DB.Create(&Ability{
|
||||
Group: group,
|
||||
Model: modelName,
|
||||
ChannelId: channelID,
|
||||
Enabled: abilityEnabled,
|
||||
Priority: &priority,
|
||||
Weight: weight,
|
||||
}).Error)
|
||||
}
|
||||
|
||||
func TestGetPreferredModelOwnerChannelTypes(t *testing.T) {
|
||||
const modelName = "gpt-5.4"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(t *testing.T)
|
||||
groups []string
|
||||
expected int
|
||||
found bool
|
||||
}{
|
||||
{
|
||||
name: "openai only",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 0, 0, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeOpenAI,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "codex only",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeCodex, 0, 0, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeCodex,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "priority wins",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 100, common.ChannelStatusEnabled, true)
|
||||
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 2, 0, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeCodex,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "weight wins when priority is equal",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 10, common.ChannelStatusEnabled, true)
|
||||
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 1, 20, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeCodex,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "channel id stabilizes exact ties",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 1, 10, common.ChannelStatusEnabled, true)
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 10, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeOpenAI,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "group filter excludes other groups",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "vip", constant.ChannelTypeCodex, 10, 100, common.ChannelStatusEnabled, true)
|
||||
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeOpenAI, 1, 0, common.ChannelStatusEnabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
expected: constant.ChannelTypeOpenAI,
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "disabled candidates are ignored",
|
||||
setup: func(t *testing.T) {
|
||||
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeCodex, 10, 100, common.ChannelStatusEnabled, false)
|
||||
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeOpenAI, 1, 0, common.ChannelStatusManuallyDisabled, true)
|
||||
},
|
||||
groups: []string{"default"},
|
||||
found: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
clearPreferredOwnerTables(t)
|
||||
tt.setup(t)
|
||||
|
||||
owners, err := GetPreferredModelOwnerChannelTypes([]string{modelName}, tt.groups)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, ok := owners[modelName]
|
||||
require.Equal(t, tt.found, ok)
|
||||
if tt.found {
|
||||
require.Equal(t, tt.expected, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -40,6 +40,7 @@ func TestMain(m *testing.M) {
|
||||
&Token{},
|
||||
&Log{},
|
||||
&Channel{},
|
||||
&Ability{},
|
||||
&TopUp{},
|
||||
&SubscriptionPlan{},
|
||||
&SubscriptionOrder{},
|
||||
@ -60,6 +61,7 @@ func truncateTables(t *testing.T) {
|
||||
DB.Exec("DELETE FROM tokens")
|
||||
DB.Exec("DELETE FROM logs")
|
||||
DB.Exec("DELETE FROM channels")
|
||||
DB.Exec("DELETE FROM abilities")
|
||||
DB.Exec("DELETE FROM top_ups")
|
||||
DB.Exec("DELETE FROM subscription_orders")
|
||||
DB.Exec("DELETE FROM subscription_plans")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user