fix: apply group filter to channel list queries (#4885)

This commit is contained in:
yyhhyyyyyy 2026-05-17 11:44:07 +08:00 committed by GitHub
parent cb7a61466e
commit 2d968c3eab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 119 additions and 90 deletions

View File

@ -19,6 +19,7 @@ import (
"github.com/QuantumNous/new-api/service"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type OpenAIModel struct {
@ -68,12 +69,33 @@ func clearChannelInfo(channel *model.Channel) {
}
}
func applyChannelStatusFilter(query *gorm.DB, statusFilter int) *gorm.DB {
if statusFilter == common.ChannelStatusEnabled {
return query.Where("status = ?", common.ChannelStatusEnabled)
}
if statusFilter == 0 {
return query.Where("status != ?", common.ChannelStatusEnabled)
}
return query
}
func buildChannelListQuery(group string, statusFilter int, typeFilter int) *gorm.DB {
query := model.DB.Model(&model.Channel{})
query = model.ApplyChannelGroupFilter(query, group)
query = applyChannelStatusFilter(query, statusFilter)
if typeFilter >= 0 {
query = query.Where("type = ?", typeFilter)
}
return query
}
func GetAllChannels(c *gin.Context) {
pageInfo := common.GetPageQuery(c)
channelData := make([]*model.Channel, 0)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
sortOptions := model.NewChannelSortOptions(c.Query("sort_by"), c.Query("sort_order"), idSort)
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
groupFilter := model.NormalizeChannelGroupFilter(c.Query("group"))
statusParam := c.Query("status")
// statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
statusFilter := parseStatusFilter(statusParam)
@ -85,69 +107,49 @@ func GetAllChannels(c *gin.Context) {
typeFilter = t
}
}
// group filter
groupFilter := c.Query("group")
var total int64
if enableTagMode {
tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
tags, err := model.GetPaginatedChannelTags(buildChannelListQuery(groupFilter, statusFilter, typeFilter), pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil {
common.SysError("failed to get paginated tags: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签失败,请稍后重试"})
return
}
total, err = model.CountChannelTags(buildChannelListQuery(groupFilter, statusFilter, typeFilter))
if err != nil {
common.SysError("failed to count tags: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签数量失败,请稍后重试"})
return
}
for _, tag := range tags {
if tag == nil || *tag == "" {
continue
}
tagChannels, err := model.GetChannelsByTag(*tag, idSort, false, sortOptions)
var tagChannels []*model.Channel
err := sortOptions.Apply(buildChannelListQuery(groupFilter, statusFilter, typeFilter).Where("tag = ?", *tag)).
Omit("key").
Find(&tagChannels).Error
if err != nil {
continue
common.SysError("failed to get channels by tag: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签渠道失败,请稍后重试"})
return
}
filtered := make([]*model.Channel, 0)
for _, ch := range tagChannels {
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
continue
}
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
continue
}
if typeFilter >= 0 && ch.Type != typeFilter {
continue
}
if groupFilter != "" && groupFilter != "null" {
if !strings.Contains(","+ch.Group+",", ","+groupFilter+",") {
continue
}
}
filtered = append(filtered, ch)
}
channelData = append(channelData, filtered...)
channelData = append(channelData, tagChannels...)
}
total, _ = model.CountAllTags()
} else {
baseQuery := model.DB.Model(&model.Channel{})
if typeFilter >= 0 {
baseQuery = baseQuery.Where("type = ?", typeFilter)
}
if statusFilter == common.ChannelStatusEnabled {
baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
} else if statusFilter == 0 {
baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
}
if groupFilter != "" && groupFilter != "null" {
if common.UsingMySQL {
baseQuery = baseQuery.Where("CONCAT(',', `group`, ',') LIKE ?", "%,"+groupFilter+",%")
} else {
// SQLite, PostgreSQL
baseQuery = baseQuery.Where("(',' || \"group\" || ',') LIKE ?", "%,"+groupFilter+",%")
}
if err := buildChannelListQuery(groupFilter, statusFilter, typeFilter).Count(&total).Error; err != nil {
common.SysError("failed to count channels: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道数量失败,请稍后重试"})
return
}
baseQuery.Count(&total)
err := sortOptions.Apply(baseQuery).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
err := sortOptions.Apply(buildChannelListQuery(groupFilter, statusFilter, typeFilter)).
Limit(pageInfo.GetPageSize()).
Offset(pageInfo.GetStartIdx()).
Omit("key").
Find(&channelData).Error
if err != nil {
common.SysError("failed to get channels: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道列表失败,请稍后重试"})
@ -159,17 +161,16 @@ func GetAllChannels(c *gin.Context) {
clearChannelInfo(datum)
}
countQuery := model.DB.Model(&model.Channel{})
if statusFilter == common.ChannelStatusEnabled {
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
} else if statusFilter == 0 {
countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
}
countQuery := buildChannelListQuery(groupFilter, statusFilter, -1)
var results []struct {
Type int64
Count int64
}
_ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
if err := countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error; err != nil {
common.SysError("failed to count channel types: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道类型统计失败,请稍后重试"})
return
}
typeCounts := make(map[int64]int64)
for _, r := range results {
typeCounts[r.Type] = r.Count
@ -277,10 +278,18 @@ func SearchChannels(c *gin.Context) {
}
for _, tag := range tags {
if tag != nil && *tag != "" {
tagChannel, err := model.GetChannelsByTag(*tag, idSort, false, sortOptions)
if err == nil {
channelData = append(channelData, tagChannel...)
var tagChannels []*model.Channel
err := sortOptions.Apply(buildChannelListQuery(group, -1, -1).Where("tag = ?", *tag)).
Omit("key").
Find(&tagChannels).Error
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channelData = append(channelData, tagChannels...)
}
}
} else {

View File

@ -128,6 +128,38 @@ func resolveChannelSortOptions(idSort bool, sortOptions []ChannelSortOptions) Ch
return options
}
func NormalizeChannelGroupFilter(group string) string {
group = strings.TrimSpace(group)
if group == "" || strings.EqualFold(group, "all") || strings.EqualFold(group, "null") {
return ""
}
return group
}
func channelGroupFilterCondition() string {
if common.UsingMySQL {
return `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ? ESCAPE '!'`
}
return `(',' || ` + commonGroupCol + ` || ',') LIKE ? ESCAPE '!'`
}
func channelGroupFilterPattern(group string) string {
group = strings.NewReplacer(
"!", "!!",
"%", "!%",
"_", "!_",
).Replace(group)
return "%," + group + ",%"
}
func ApplyChannelGroupFilter(query *gorm.DB, group string) *gorm.DB {
group = NormalizeChannelGroupFilter(group)
if group == "" {
return query
}
return query.Where(channelGroupFilterCondition(), channelGroupFilterPattern(group))
}
// Value implements driver.Valuer interface
func (c ChannelInfo) Value() (driver.Value, error) {
return common.Marshal(&c)
@ -365,25 +397,12 @@ func SearchChannels(keyword string, group string, model string, idSort bool, sor
baseQuery := DB.Model(&Channel{}).Omit("key")
// 构造WHERE子句
var whereClause string
var args []interface{}
if group != "" && group != "null" {
var groupCondition string
if common.UsingMySQL {
groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
} else {
// sqlite, PostgreSQL
groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
}
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
} else {
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
}
whereClause := "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
args := []any{common.String2Int(keyword), "%" + keyword + "%", keyword, "%" + keyword + "%", "%" + model + "%"}
baseQuery = ApplyChannelGroupFilter(baseQuery.Where(whereClause, args...), group)
// 执行查询
err := order.Apply(baseQuery.Where(whereClause, args...)).Find(&channels).Error
err := order.Apply(baseQuery).Find(&channels).Error
if err != nil {
return nil, err
}
@ -828,8 +847,18 @@ func DeleteDisabledChannel() (int64, error) {
}
func GetPaginatedTags(offset int, limit int) ([]*string, error) {
return GetPaginatedChannelTags(DB.Model(&Channel{}), offset, limit)
}
func GetPaginatedChannelTags(query *gorm.DB, offset int, limit int) ([]*string, error) {
var tags []*string
err := DB.Model(&Channel{}).Select("DISTINCT tag").Where("tag != ''").Offset(offset).Limit(limit).Find(&tags).Error
err := query.
Select("DISTINCT tag").
Where("tag is not null AND tag != ''").
Order(clause.OrderByColumn{Column: clause.Column{Name: "tag"}}).
Offset(offset).
Limit(limit).
Find(&tags).Error
return tags, err
}
@ -857,24 +886,11 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
baseQuery := DB.Model(&Channel{}).Omit("key")
// 构造WHERE子句
var whereClause string
var args []interface{}
if group != "" && group != "null" {
var groupCondition string
if common.UsingMySQL {
groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
} else {
// sqlite, PostgreSQL
groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
}
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
} else {
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
}
whereClause := "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
args := []any{common.String2Int(keyword), "%" + keyword + "%", keyword, "%" + keyword + "%", "%" + model + "%"}
baseQuery = ApplyChannelGroupFilter(baseQuery.Where(whereClause, args...), group)
subQuery := baseQuery.Where(whereClause, args...).
subQuery := baseQuery.
Select("tag").
Where("tag != ''").
Order(order)
@ -1015,8 +1031,12 @@ func CountAllChannels() (int64, error) {
// CountAllTags returns number of non-empty distinct tags
func CountAllTags() (int64, error) {
return CountChannelTags(DB.Model(&Channel{}))
}
func CountChannelTags(query *gorm.DB) (int64, error) {
var total int64
err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
err := query.Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
return total, err
}