From 2d968c3eab2f0504c658d996da9a1973dc6764f2 Mon Sep 17 00:00:00 2001 From: yyhhyyyyyy Date: Sun, 17 May 2026 11:44:07 +0800 Subject: [PATCH] fix: apply group filter to channel list queries (#4885) --- controller/channel.go | 117 +++++++++++++++++++++++------------------- model/channel.go | 92 ++++++++++++++++++++------------- 2 files changed, 119 insertions(+), 90 deletions(-) diff --git a/controller/channel.go b/controller/channel.go index 9eef9f0e..21735170 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -19,6 +19,7 @@ import ( "github.com/QuantumNous/new-api/service" "github.com/gin-gonic/gin" + "gorm.io/gorm" ) type OpenAIModel struct { @@ -68,12 +69,33 @@ func clearChannelInfo(channel *model.Channel) { } } +func applyChannelStatusFilter(query *gorm.DB, statusFilter int) *gorm.DB { + if statusFilter == common.ChannelStatusEnabled { + return query.Where("status = ?", common.ChannelStatusEnabled) + } + if statusFilter == 0 { + return query.Where("status != ?", common.ChannelStatusEnabled) + } + return query +} + +func buildChannelListQuery(group string, statusFilter int, typeFilter int) *gorm.DB { + query := model.DB.Model(&model.Channel{}) + query = model.ApplyChannelGroupFilter(query, group) + query = applyChannelStatusFilter(query, statusFilter) + if typeFilter >= 0 { + query = query.Where("type = ?", typeFilter) + } + return query +} + func GetAllChannels(c *gin.Context) { pageInfo := common.GetPageQuery(c) channelData := make([]*model.Channel, 0) idSort, _ := strconv.ParseBool(c.Query("id_sort")) sortOptions := model.NewChannelSortOptions(c.Query("sort_by"), c.Query("sort_order"), idSort) enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode")) + groupFilter := model.NormalizeChannelGroupFilter(c.Query("group")) statusParam := c.Query("status") // statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual) statusFilter := parseStatusFilter(statusParam) @@ -85,69 +107,49 @@ func GetAllChannels(c *gin.Context) { typeFilter = t } } - // group filter - groupFilter := c.Query("group") var total int64 if enableTagMode { - tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize()) + tags, err := model.GetPaginatedChannelTags(buildChannelListQuery(groupFilter, statusFilter, typeFilter), pageInfo.GetStartIdx(), pageInfo.GetPageSize()) if err != nil { common.SysError("failed to get paginated tags: " + err.Error()) c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签失败,请稍后重试"}) return } + total, err = model.CountChannelTags(buildChannelListQuery(groupFilter, statusFilter, typeFilter)) + if err != nil { + common.SysError("failed to count tags: " + err.Error()) + c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签数量失败,请稍后重试"}) + return + } for _, tag := range tags { if tag == nil || *tag == "" { continue } - tagChannels, err := model.GetChannelsByTag(*tag, idSort, false, sortOptions) + var tagChannels []*model.Channel + err := sortOptions.Apply(buildChannelListQuery(groupFilter, statusFilter, typeFilter).Where("tag = ?", *tag)). + Omit("key"). + Find(&tagChannels).Error if err != nil { - continue + common.SysError("failed to get channels by tag: " + err.Error()) + c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签渠道失败,请稍后重试"}) + return } - filtered := make([]*model.Channel, 0) - for _, ch := range tagChannels { - if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled { - continue - } - if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled { - continue - } - if typeFilter >= 0 && ch.Type != typeFilter { - continue - } - if groupFilter != "" && groupFilter != "null" { - if !strings.Contains(","+ch.Group+",", ","+groupFilter+",") { - continue - } - } - filtered = append(filtered, ch) - } - channelData = append(channelData, filtered...) + channelData = append(channelData, tagChannels...) } - total, _ = model.CountAllTags() } else { - baseQuery := model.DB.Model(&model.Channel{}) - if typeFilter >= 0 { - baseQuery = baseQuery.Where("type = ?", typeFilter) - } - if statusFilter == common.ChannelStatusEnabled { - baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled) - } else if statusFilter == 0 { - baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled) - } - if groupFilter != "" && groupFilter != "null" { - if common.UsingMySQL { - baseQuery = baseQuery.Where("CONCAT(',', `group`, ',') LIKE ?", "%,"+groupFilter+",%") - } else { - // SQLite, PostgreSQL - baseQuery = baseQuery.Where("(',' || \"group\" || ',') LIKE ?", "%,"+groupFilter+",%") - } + if err := buildChannelListQuery(groupFilter, statusFilter, typeFilter).Count(&total).Error; err != nil { + common.SysError("failed to count channels: " + err.Error()) + c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道数量失败,请稍后重试"}) + return } - baseQuery.Count(&total) - - err := sortOptions.Apply(baseQuery).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error + err := sortOptions.Apply(buildChannelListQuery(groupFilter, statusFilter, typeFilter)). + Limit(pageInfo.GetPageSize()). + Offset(pageInfo.GetStartIdx()). + Omit("key"). + Find(&channelData).Error if err != nil { common.SysError("failed to get channels: " + err.Error()) c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道列表失败,请稍后重试"}) @@ -159,17 +161,16 @@ func GetAllChannels(c *gin.Context) { clearChannelInfo(datum) } - countQuery := model.DB.Model(&model.Channel{}) - if statusFilter == common.ChannelStatusEnabled { - countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled) - } else if statusFilter == 0 { - countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled) - } + countQuery := buildChannelListQuery(groupFilter, statusFilter, -1) var results []struct { Type int64 Count int64 } - _ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error + if err := countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error; err != nil { + common.SysError("failed to count channel types: " + err.Error()) + c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道类型统计失败,请稍后重试"}) + return + } typeCounts := make(map[int64]int64) for _, r := range results { typeCounts[r.Type] = r.Count @@ -277,10 +278,18 @@ func SearchChannels(c *gin.Context) { } for _, tag := range tags { if tag != nil && *tag != "" { - tagChannel, err := model.GetChannelsByTag(*tag, idSort, false, sortOptions) - if err == nil { - channelData = append(channelData, tagChannel...) + var tagChannels []*model.Channel + err := sortOptions.Apply(buildChannelListQuery(group, -1, -1).Where("tag = ?", *tag)). + Omit("key"). + Find(&tagChannels).Error + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return } + channelData = append(channelData, tagChannels...) } } } else { diff --git a/model/channel.go b/model/channel.go index 76b0554d..370a99bf 100644 --- a/model/channel.go +++ b/model/channel.go @@ -128,6 +128,38 @@ func resolveChannelSortOptions(idSort bool, sortOptions []ChannelSortOptions) Ch return options } +func NormalizeChannelGroupFilter(group string) string { + group = strings.TrimSpace(group) + if group == "" || strings.EqualFold(group, "all") || strings.EqualFold(group, "null") { + return "" + } + return group +} + +func channelGroupFilterCondition() string { + if common.UsingMySQL { + return `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ? ESCAPE '!'` + } + return `(',' || ` + commonGroupCol + ` || ',') LIKE ? ESCAPE '!'` +} + +func channelGroupFilterPattern(group string) string { + group = strings.NewReplacer( + "!", "!!", + "%", "!%", + "_", "!_", + ).Replace(group) + return "%," + group + ",%" +} + +func ApplyChannelGroupFilter(query *gorm.DB, group string) *gorm.DB { + group = NormalizeChannelGroupFilter(group) + if group == "" { + return query + } + return query.Where(channelGroupFilterCondition(), channelGroupFilterPattern(group)) +} + // Value implements driver.Valuer interface func (c ChannelInfo) Value() (driver.Value, error) { return common.Marshal(&c) @@ -365,25 +397,12 @@ func SearchChannels(keyword string, group string, model string, idSort bool, sor baseQuery := DB.Model(&Channel{}).Omit("key") // 构造WHERE子句 - var whereClause string - var args []interface{} - if group != "" && group != "null" { - var groupCondition string - if common.UsingMySQL { - groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?` - } else { - // sqlite, PostgreSQL - groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?` - } - whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition - args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%") - } else { - whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?" - args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%") - } + whereClause := "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?" + args := []any{common.String2Int(keyword), "%" + keyword + "%", keyword, "%" + keyword + "%", "%" + model + "%"} + baseQuery = ApplyChannelGroupFilter(baseQuery.Where(whereClause, args...), group) // 执行查询 - err := order.Apply(baseQuery.Where(whereClause, args...)).Find(&channels).Error + err := order.Apply(baseQuery).Find(&channels).Error if err != nil { return nil, err } @@ -828,8 +847,18 @@ func DeleteDisabledChannel() (int64, error) { } func GetPaginatedTags(offset int, limit int) ([]*string, error) { + return GetPaginatedChannelTags(DB.Model(&Channel{}), offset, limit) +} + +func GetPaginatedChannelTags(query *gorm.DB, offset int, limit int) ([]*string, error) { var tags []*string - err := DB.Model(&Channel{}).Select("DISTINCT tag").Where("tag != ''").Offset(offset).Limit(limit).Find(&tags).Error + err := query. + Select("DISTINCT tag"). + Where("tag is not null AND tag != ''"). + Order(clause.OrderByColumn{Column: clause.Column{Name: "tag"}}). + Offset(offset). + Limit(limit). + Find(&tags).Error return tags, err } @@ -857,24 +886,11 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str baseQuery := DB.Model(&Channel{}).Omit("key") // 构造WHERE子句 - var whereClause string - var args []interface{} - if group != "" && group != "null" { - var groupCondition string - if common.UsingMySQL { - groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?` - } else { - // sqlite, PostgreSQL - groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?` - } - whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition - args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%") - } else { - whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?" - args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%") - } + whereClause := "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?" + args := []any{common.String2Int(keyword), "%" + keyword + "%", keyword, "%" + keyword + "%", "%" + model + "%"} + baseQuery = ApplyChannelGroupFilter(baseQuery.Where(whereClause, args...), group) - subQuery := baseQuery.Where(whereClause, args...). + subQuery := baseQuery. Select("tag"). Where("tag != ''"). Order(order) @@ -1015,8 +1031,12 @@ func CountAllChannels() (int64, error) { // CountAllTags returns number of non-empty distinct tags func CountAllTags() (int64, error) { + return CountChannelTags(DB.Model(&Channel{})) +} + +func CountChannelTags(query *gorm.DB) (int64, error) { var total int64 - err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error + err := query.Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error return total, err }