fix: resolve model owned_by from active channels (#4416)
* fix: resolve model owned_by from active channels * fix: respect token group when resolving model owners
This commit is contained in:
parent
6f11d19877
commit
006e801652
@ -3,6 +3,7 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/QuantumNous/new-api/common"
|
"github.com/QuantumNous/new-api/common"
|
||||||
@ -109,9 +110,102 @@ func init() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListModels(c *gin.Context, modelType int) {
|
func channelOwnerName(channelType int) string {
|
||||||
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
apiType, success := common.ChannelType2APIType(channelType)
|
||||||
|
if !success {
|
||||||
|
return strings.ToLower(constant.GetChannelTypeName(channelType))
|
||||||
|
}
|
||||||
|
adaptor := relay.GetAdaptor(apiType)
|
||||||
|
if adaptor == nil {
|
||||||
|
return strings.ToLower(constant.GetChannelTypeName(channelType))
|
||||||
|
}
|
||||||
|
adaptor.Init(&relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{
|
||||||
|
ChannelType: channelType,
|
||||||
|
}})
|
||||||
|
if name := strings.TrimSpace(adaptor.GetChannelName()); name != "" {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
return strings.ToLower(constant.GetChannelTypeName(channelType))
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPreferredModelOwners(modelNames []string, groups []string) map[string]string {
|
||||||
|
channelTypes, err := model.GetPreferredModelOwnerChannelTypes(modelNames, groups)
|
||||||
|
if err != nil {
|
||||||
|
common.SysLog(fmt.Sprintf("GetPreferredModelOwnerChannelTypes error: %v", err))
|
||||||
|
return map[string]string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
ownerByChannelType := make(map[int]string)
|
||||||
|
owners := make(map[string]string, len(channelTypes))
|
||||||
|
for modelName, channelType := range channelTypes {
|
||||||
|
owner, ok := ownerByChannelType[channelType]
|
||||||
|
if !ok {
|
||||||
|
owner = channelOwnerName(channelType)
|
||||||
|
ownerByChannelType[channelType] = owner
|
||||||
|
}
|
||||||
|
if owner != "" {
|
||||||
|
owners[modelName] = owner
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return owners
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildOpenAIModel(modelName string, ownerByModel map[string]string) dto.OpenAIModels {
|
||||||
|
var oaiModel dto.OpenAIModels
|
||||||
|
if staticModel, ok := openAIModelsMap[modelName]; ok {
|
||||||
|
oaiModel = staticModel
|
||||||
|
} else {
|
||||||
|
oaiModel = dto.OpenAIModels{
|
||||||
|
Id: modelName,
|
||||||
|
Object: "model",
|
||||||
|
Created: 1626777600,
|
||||||
|
OwnedBy: "custom",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if owner, ok := ownerByModel[modelName]; ok && owner != "" {
|
||||||
|
oaiModel.OwnedBy = owner
|
||||||
|
}
|
||||||
|
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
|
||||||
|
return oaiModel
|
||||||
|
}
|
||||||
|
|
||||||
|
type modelListGroups struct {
|
||||||
|
userGroup string
|
||||||
|
tokenGroup string
|
||||||
|
ownerGroups []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func getModelListGroups(c *gin.Context) (modelListGroups, error) {
|
||||||
|
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||||
|
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||||
|
if userGroup == "" && (tokenGroup == "" || tokenGroup == "auto") {
|
||||||
|
var err error
|
||||||
|
userGroup, err = model.GetUserGroup(c.GetInt("id"), false)
|
||||||
|
if err != nil {
|
||||||
|
return modelListGroups{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokenGroup == "auto" {
|
||||||
|
return modelListGroups{
|
||||||
|
userGroup: userGroup,
|
||||||
|
tokenGroup: tokenGroup,
|
||||||
|
ownerGroups: service.GetUserAutoGroup(userGroup),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
group := userGroup
|
||||||
|
if tokenGroup != "" {
|
||||||
|
group = tokenGroup
|
||||||
|
}
|
||||||
|
return modelListGroups{
|
||||||
|
userGroup: userGroup,
|
||||||
|
tokenGroup: tokenGroup,
|
||||||
|
ownerGroups: []string{group},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListModels(c *gin.Context, modelType int) {
|
||||||
acceptUnsetRatioModel := operation_setting.SelfUseModeEnabled
|
acceptUnsetRatioModel := operation_setting.SelfUseModeEnabled
|
||||||
if !acceptUnsetRatioModel {
|
if !acceptUnsetRatioModel {
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
@ -123,6 +217,16 @@ func ListModels(c *gin.Context, modelType int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
userModelNames := make([]string, 0)
|
||||||
|
groups, err := getModelListGroups(c)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "get user group failed",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ownerGroups := groups.ownerGroups
|
||||||
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||||
if modelLimitEnable {
|
if modelLimitEnable {
|
||||||
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
||||||
@ -138,37 +242,12 @@ func ListModels(c *gin.Context, modelType int) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if oaiModel, ok := openAIModelsMap[allowModel]; ok {
|
userModelNames = append(userModelNames, allowModel)
|
||||||
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
|
|
||||||
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
|
||||||
} else {
|
|
||||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
|
||||||
Id: allowModel,
|
|
||||||
Object: "model",
|
|
||||||
Created: 1626777600,
|
|
||||||
OwnedBy: "custom",
|
|
||||||
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
userId := c.GetInt("id")
|
|
||||||
userGroup, err := model.GetUserGroup(userId, false)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": false,
|
|
||||||
"message": "get user group failed",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
group := userGroup
|
|
||||||
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
|
||||||
if tokenGroup != "" {
|
|
||||||
group = tokenGroup
|
|
||||||
}
|
|
||||||
var models []string
|
var models []string
|
||||||
if tokenGroup == "auto" {
|
if groups.tokenGroup == "auto" {
|
||||||
for _, autoGroup := range service.GetUserAutoGroup(userGroup) {
|
for _, autoGroup := range ownerGroups {
|
||||||
groupModels := model.GetGroupEnabledModels(autoGroup)
|
groupModels := model.GetGroupEnabledModels(autoGroup)
|
||||||
for _, g := range groupModels {
|
for _, g := range groupModels {
|
||||||
if !common.StringsContains(models, g) {
|
if !common.StringsContains(models, g) {
|
||||||
@ -177,7 +256,7 @@ func ListModels(c *gin.Context, modelType int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
models = model.GetGroupEnabledModels(group)
|
models = model.GetGroupEnabledModels(ownerGroups[0])
|
||||||
}
|
}
|
||||||
for _, modelName := range models {
|
for _, modelName := range models {
|
||||||
if !acceptUnsetRatioModel {
|
if !acceptUnsetRatioModel {
|
||||||
@ -185,21 +264,19 @@ func ListModels(c *gin.Context, modelType int) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if oaiModel, ok := openAIModelsMap[modelName]; ok {
|
userModelNames = append(userModelNames, modelName)
|
||||||
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
|
|
||||||
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
|
||||||
} else {
|
|
||||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
|
||||||
Id: modelName,
|
|
||||||
Object: "model",
|
|
||||||
Created: 1626777600,
|
|
||||||
OwnedBy: "custom",
|
|
||||||
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ownerByModel := map[string]string{}
|
||||||
|
if len(ownerGroups) > 0 {
|
||||||
|
ownerByModel = getPreferredModelOwners(userModelNames, ownerGroups)
|
||||||
|
}
|
||||||
|
userOpenAiModels := make([]dto.OpenAIModels, 0, len(userModelNames))
|
||||||
|
for _, modelName := range userModelNames {
|
||||||
|
userOpenAiModels = append(userOpenAiModels, buildOpenAIModel(modelName, ownerByModel))
|
||||||
|
}
|
||||||
|
|
||||||
switch modelType {
|
switch modelType {
|
||||||
case constant.ChannelTypeAnthropic:
|
case constant.ChannelTypeAnthropic:
|
||||||
useranthropicModels := make([]dto.AnthropicModel, len(userOpenAiModels))
|
useranthropicModels := make([]dto.AnthropicModel, len(userOpenAiModels))
|
||||||
|
|||||||
85
controller/model_owned_by_test.go
Normal file
85
controller/model_owned_by_test.go
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/QuantumNous/new-api/common"
|
||||||
|
"github.com/QuantumNous/new-api/constant"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestChannelOwnerNameUsesAdaptorChannelName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
channelType int
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "openai",
|
||||||
|
channelType: constant.ChannelTypeOpenAI,
|
||||||
|
expected: "openai",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "codex",
|
||||||
|
channelType: constant.ChannelTypeCodex,
|
||||||
|
expected: "codex",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "openrouter",
|
||||||
|
channelType: constant.ChannelTypeOpenRouter,
|
||||||
|
expected: "openrouter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "azure fallback",
|
||||||
|
channelType: constant.ChannelTypeAzure,
|
||||||
|
expected: "azure",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
require.Equal(t, tt.expected, channelOwnerName(tt.channelType))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildOpenAIModelOverridesOwnedBy(t *testing.T) {
|
||||||
|
modelItem := buildOpenAIModel("gpt-5.4", map[string]string{"gpt-5.4": "openai"})
|
||||||
|
require.Equal(t, "gpt-5.4", modelItem.Id)
|
||||||
|
require.Equal(t, "openai", modelItem.OwnedBy)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildOpenAIModelFallsBackToCustomForUnknownModels(t *testing.T) {
|
||||||
|
modelItem := buildOpenAIModel("custom-test-model", nil)
|
||||||
|
require.Equal(t, "custom-test-model", modelItem.Id)
|
||||||
|
require.Equal(t, "custom", modelItem.OwnedBy)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetModelListGroupsUsesUserGroupWhenTokenGroupIsEmpty(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||||
|
common.SetContextKey(ctx, constant.ContextKeyUserGroup, "default")
|
||||||
|
|
||||||
|
groups, err := getModelListGroups(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, "default", groups.userGroup)
|
||||||
|
require.Empty(t, groups.tokenGroup)
|
||||||
|
require.Equal(t, []string{"default"}, groups.ownerGroups)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetModelListGroupsUsesExplicitTokenGroup(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||||
|
common.SetContextKey(ctx, constant.ContextKeyUserGroup, "default")
|
||||||
|
common.SetContextKey(ctx, constant.ContextKeyTokenGroup, "vip")
|
||||||
|
|
||||||
|
groups, err := getModelListGroups(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, "default", groups.userGroup)
|
||||||
|
require.Equal(t, "vip", groups.tokenGroup)
|
||||||
|
require.Equal(t, []string{"vip"}, groups.ownerGroups)
|
||||||
|
}
|
||||||
@ -2,6 +2,7 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/QuantumNous/new-api/common"
|
"github.com/QuantumNous/new-api/common"
|
||||||
|
|
||||||
@ -135,6 +136,62 @@ func GetBoundChannelsByModelsMap(modelNames []string) (map[string][]BoundChannel
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeLookupValues(values []string) []string {
|
||||||
|
seen := make(map[string]struct{}, len(values))
|
||||||
|
normalized := make([]string, 0, len(values))
|
||||||
|
for _, value := range values {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
if value == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[value]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[value] = struct{}{}
|
||||||
|
normalized = append(normalized, value)
|
||||||
|
}
|
||||||
|
return normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetPreferredModelOwnerChannelTypes(modelNames []string, groups []string) (map[string]int, error) {
|
||||||
|
result := make(map[string]int)
|
||||||
|
modelNames = normalizeLookupValues(modelNames)
|
||||||
|
if len(modelNames) == 0 {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type row struct {
|
||||||
|
Model string
|
||||||
|
ChannelType int
|
||||||
|
}
|
||||||
|
var rows []row
|
||||||
|
|
||||||
|
query := DB.Table("abilities").
|
||||||
|
Select("abilities.model as model, channels.type as channel_type").
|
||||||
|
Joins("JOIN channels ON abilities.channel_id = channels.id").
|
||||||
|
Where("abilities.model IN ? AND abilities.enabled = ? AND channels.status = ?", modelNames, true, common.ChannelStatusEnabled).
|
||||||
|
Order("COALESCE(abilities.priority, 0) DESC").
|
||||||
|
Order("abilities.weight DESC").
|
||||||
|
Order("abilities.channel_id ASC")
|
||||||
|
|
||||||
|
groups = normalizeLookupValues(groups)
|
||||||
|
if len(groups) > 0 {
|
||||||
|
query = query.Where("abilities."+commonGroupCol+" IN ?", groups)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := query.Scan(&rows).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range rows {
|
||||||
|
if _, ok := result[r.Model]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result[r.Model] = r.ChannelType
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
|
func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
|
||||||
var models []*Model
|
var models []*Model
|
||||||
db := DB.Model(&Model{})
|
db := DB.Model(&Model{})
|
||||||
|
|||||||
141
model/model_owner_test.go
Normal file
141
model/model_owner_test.go
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/QuantumNous/new-api/common"
|
||||||
|
"github.com/QuantumNous/new-api/constant"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func clearPreferredOwnerTables(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
require.NoError(t, DB.Exec("DELETE FROM abilities").Error)
|
||||||
|
require.NoError(t, DB.Exec("DELETE FROM channels").Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func insertPreferredOwnerCandidate(
|
||||||
|
t *testing.T,
|
||||||
|
channelID int,
|
||||||
|
modelName string,
|
||||||
|
group string,
|
||||||
|
channelType int,
|
||||||
|
priority int64,
|
||||||
|
weight uint,
|
||||||
|
channelStatus int,
|
||||||
|
abilityEnabled bool,
|
||||||
|
) {
|
||||||
|
t.Helper()
|
||||||
|
require.NoError(t, DB.Create(&Channel{
|
||||||
|
Id: channelID,
|
||||||
|
Type: channelType,
|
||||||
|
Key: fmt.Sprintf("key-%d", channelID),
|
||||||
|
Status: channelStatus,
|
||||||
|
Name: fmt.Sprintf("channel-%d", channelID),
|
||||||
|
}).Error)
|
||||||
|
require.NoError(t, DB.Create(&Ability{
|
||||||
|
Group: group,
|
||||||
|
Model: modelName,
|
||||||
|
ChannelId: channelID,
|
||||||
|
Enabled: abilityEnabled,
|
||||||
|
Priority: &priority,
|
||||||
|
Weight: weight,
|
||||||
|
}).Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPreferredModelOwnerChannelTypes(t *testing.T) {
|
||||||
|
const modelName = "gpt-5.4"
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setup func(t *testing.T)
|
||||||
|
groups []string
|
||||||
|
expected int
|
||||||
|
found bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "openai only",
|
||||||
|
setup: func(t *testing.T) {
|
||||||
|
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 0, 0, common.ChannelStatusEnabled, true)
|
||||||
|
},
|
||||||
|
groups: []string{"default"},
|
||||||
|
expected: constant.ChannelTypeOpenAI,
|
||||||
|
found: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "codex only",
|
||||||
|
setup: func(t *testing.T) {
|
||||||
|
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeCodex, 0, 0, common.ChannelStatusEnabled, true)
|
||||||
|
},
|
||||||
|
groups: []string{"default"},
|
||||||
|
expected: constant.ChannelTypeCodex,
|
||||||
|
found: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "priority wins",
|
||||||
|
setup: func(t *testing.T) {
|
||||||
|
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 100, common.ChannelStatusEnabled, true)
|
||||||
|
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 2, 0, common.ChannelStatusEnabled, true)
|
||||||
|
},
|
||||||
|
groups: []string{"default"},
|
||||||
|
expected: constant.ChannelTypeCodex,
|
||||||
|
found: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "weight wins when priority is equal",
|
||||||
|
setup: func(t *testing.T) {
|
||||||
|
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 10, common.ChannelStatusEnabled, true)
|
||||||
|
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 1, 20, common.ChannelStatusEnabled, true)
|
||||||
|
},
|
||||||
|
groups: []string{"default"},
|
||||||
|
expected: constant.ChannelTypeCodex,
|
||||||
|
found: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "channel id stabilizes exact ties",
|
||||||
|
setup: func(t *testing.T) {
|
||||||
|
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeCodex, 1, 10, common.ChannelStatusEnabled, true)
|
||||||
|
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeOpenAI, 1, 10, common.ChannelStatusEnabled, true)
|
||||||
|
},
|
||||||
|
groups: []string{"default"},
|
||||||
|
expected: constant.ChannelTypeOpenAI,
|
||||||
|
found: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "group filter excludes other groups",
|
||||||
|
setup: func(t *testing.T) {
|
||||||
|
insertPreferredOwnerCandidate(t, 1, modelName, "vip", constant.ChannelTypeCodex, 10, 100, common.ChannelStatusEnabled, true)
|
||||||
|
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeOpenAI, 1, 0, common.ChannelStatusEnabled, true)
|
||||||
|
},
|
||||||
|
groups: []string{"default"},
|
||||||
|
expected: constant.ChannelTypeOpenAI,
|
||||||
|
found: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disabled candidates are ignored",
|
||||||
|
setup: func(t *testing.T) {
|
||||||
|
insertPreferredOwnerCandidate(t, 1, modelName, "default", constant.ChannelTypeCodex, 10, 100, common.ChannelStatusEnabled, false)
|
||||||
|
insertPreferredOwnerCandidate(t, 2, modelName, "default", constant.ChannelTypeOpenAI, 1, 0, common.ChannelStatusManuallyDisabled, true)
|
||||||
|
},
|
||||||
|
groups: []string{"default"},
|
||||||
|
found: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
clearPreferredOwnerTables(t)
|
||||||
|
tt.setup(t)
|
||||||
|
|
||||||
|
owners, err := GetPreferredModelOwnerChannelTypes([]string{modelName}, tt.groups)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, ok := owners[modelName]
|
||||||
|
require.Equal(t, tt.found, ok)
|
||||||
|
if tt.found {
|
||||||
|
require.Equal(t, tt.expected, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -40,6 +40,7 @@ func TestMain(m *testing.M) {
|
|||||||
&Token{},
|
&Token{},
|
||||||
&Log{},
|
&Log{},
|
||||||
&Channel{},
|
&Channel{},
|
||||||
|
&Ability{},
|
||||||
&TopUp{},
|
&TopUp{},
|
||||||
&SubscriptionPlan{},
|
&SubscriptionPlan{},
|
||||||
&SubscriptionOrder{},
|
&SubscriptionOrder{},
|
||||||
@ -60,6 +61,7 @@ func truncateTables(t *testing.T) {
|
|||||||
DB.Exec("DELETE FROM tokens")
|
DB.Exec("DELETE FROM tokens")
|
||||||
DB.Exec("DELETE FROM logs")
|
DB.Exec("DELETE FROM logs")
|
||||||
DB.Exec("DELETE FROM channels")
|
DB.Exec("DELETE FROM channels")
|
||||||
|
DB.Exec("DELETE FROM abilities")
|
||||||
DB.Exec("DELETE FROM top_ups")
|
DB.Exec("DELETE FROM top_ups")
|
||||||
DB.Exec("DELETE FROM subscription_orders")
|
DB.Exec("DELETE FROM subscription_orders")
|
||||||
DB.Exec("DELETE FROM subscription_plans")
|
DB.Exec("DELETE FROM subscription_plans")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user