fix: apply group filter to channel list queries (#4885)
This commit is contained in:
parent
cb7a61466e
commit
2d968c3eab
@ -19,6 +19,7 @@ import (
|
|||||||
"github.com/QuantumNous/new-api/service"
|
"github.com/QuantumNous/new-api/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OpenAIModel struct {
|
type OpenAIModel struct {
|
||||||
@ -68,12 +69,33 @@ func clearChannelInfo(channel *model.Channel) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func applyChannelStatusFilter(query *gorm.DB, statusFilter int) *gorm.DB {
|
||||||
|
if statusFilter == common.ChannelStatusEnabled {
|
||||||
|
return query.Where("status = ?", common.ChannelStatusEnabled)
|
||||||
|
}
|
||||||
|
if statusFilter == 0 {
|
||||||
|
return query.Where("status != ?", common.ChannelStatusEnabled)
|
||||||
|
}
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildChannelListQuery(group string, statusFilter int, typeFilter int) *gorm.DB {
|
||||||
|
query := model.DB.Model(&model.Channel{})
|
||||||
|
query = model.ApplyChannelGroupFilter(query, group)
|
||||||
|
query = applyChannelStatusFilter(query, statusFilter)
|
||||||
|
if typeFilter >= 0 {
|
||||||
|
query = query.Where("type = ?", typeFilter)
|
||||||
|
}
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
func GetAllChannels(c *gin.Context) {
|
func GetAllChannels(c *gin.Context) {
|
||||||
pageInfo := common.GetPageQuery(c)
|
pageInfo := common.GetPageQuery(c)
|
||||||
channelData := make([]*model.Channel, 0)
|
channelData := make([]*model.Channel, 0)
|
||||||
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
||||||
sortOptions := model.NewChannelSortOptions(c.Query("sort_by"), c.Query("sort_order"), idSort)
|
sortOptions := model.NewChannelSortOptions(c.Query("sort_by"), c.Query("sort_order"), idSort)
|
||||||
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
||||||
|
groupFilter := model.NormalizeChannelGroupFilter(c.Query("group"))
|
||||||
statusParam := c.Query("status")
|
statusParam := c.Query("status")
|
||||||
// statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
|
// statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
|
||||||
statusFilter := parseStatusFilter(statusParam)
|
statusFilter := parseStatusFilter(statusParam)
|
||||||
@ -85,69 +107,49 @@ func GetAllChannels(c *gin.Context) {
|
|||||||
typeFilter = t
|
typeFilter = t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// group filter
|
|
||||||
groupFilter := c.Query("group")
|
|
||||||
|
|
||||||
var total int64
|
var total int64
|
||||||
|
|
||||||
if enableTagMode {
|
if enableTagMode {
|
||||||
tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
tags, err := model.GetPaginatedChannelTags(buildChannelListQuery(groupFilter, statusFilter, typeFilter), pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to get paginated tags: " + err.Error())
|
common.SysError("failed to get paginated tags: " + err.Error())
|
||||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签失败,请稍后重试"})
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签失败,请稍后重试"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
total, err = model.CountChannelTags(buildChannelListQuery(groupFilter, statusFilter, typeFilter))
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to count tags: " + err.Error())
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签数量失败,请稍后重试"})
|
||||||
|
return
|
||||||
|
}
|
||||||
for _, tag := range tags {
|
for _, tag := range tags {
|
||||||
if tag == nil || *tag == "" {
|
if tag == nil || *tag == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
tagChannels, err := model.GetChannelsByTag(*tag, idSort, false, sortOptions)
|
var tagChannels []*model.Channel
|
||||||
|
err := sortOptions.Apply(buildChannelListQuery(groupFilter, statusFilter, typeFilter).Where("tag = ?", *tag)).
|
||||||
|
Omit("key").
|
||||||
|
Find(&tagChannels).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
common.SysError("failed to get channels by tag: " + err.Error())
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签渠道失败,请稍后重试"})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
filtered := make([]*model.Channel, 0)
|
channelData = append(channelData, tagChannels...)
|
||||||
for _, ch := range tagChannels {
|
|
||||||
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if typeFilter >= 0 && ch.Type != typeFilter {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if groupFilter != "" && groupFilter != "null" {
|
|
||||||
if !strings.Contains(","+ch.Group+",", ","+groupFilter+",") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
filtered = append(filtered, ch)
|
|
||||||
}
|
|
||||||
channelData = append(channelData, filtered...)
|
|
||||||
}
|
}
|
||||||
total, _ = model.CountAllTags()
|
|
||||||
} else {
|
} else {
|
||||||
baseQuery := model.DB.Model(&model.Channel{})
|
if err := buildChannelListQuery(groupFilter, statusFilter, typeFilter).Count(&total).Error; err != nil {
|
||||||
if typeFilter >= 0 {
|
common.SysError("failed to count channels: " + err.Error())
|
||||||
baseQuery = baseQuery.Where("type = ?", typeFilter)
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道数量失败,请稍后重试"})
|
||||||
}
|
return
|
||||||
if statusFilter == common.ChannelStatusEnabled {
|
|
||||||
baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
|
|
||||||
} else if statusFilter == 0 {
|
|
||||||
baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
|
|
||||||
}
|
|
||||||
if groupFilter != "" && groupFilter != "null" {
|
|
||||||
if common.UsingMySQL {
|
|
||||||
baseQuery = baseQuery.Where("CONCAT(',', `group`, ',') LIKE ?", "%,"+groupFilter+",%")
|
|
||||||
} else {
|
|
||||||
// SQLite, PostgreSQL
|
|
||||||
baseQuery = baseQuery.Where("(',' || \"group\" || ',') LIKE ?", "%,"+groupFilter+",%")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
baseQuery.Count(&total)
|
err := sortOptions.Apply(buildChannelListQuery(groupFilter, statusFilter, typeFilter)).
|
||||||
|
Limit(pageInfo.GetPageSize()).
|
||||||
err := sortOptions.Apply(baseQuery).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
|
Offset(pageInfo.GetStartIdx()).
|
||||||
|
Omit("key").
|
||||||
|
Find(&channelData).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to get channels: " + err.Error())
|
common.SysError("failed to get channels: " + err.Error())
|
||||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道列表失败,请稍后重试"})
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道列表失败,请稍后重试"})
|
||||||
@ -159,17 +161,16 @@ func GetAllChannels(c *gin.Context) {
|
|||||||
clearChannelInfo(datum)
|
clearChannelInfo(datum)
|
||||||
}
|
}
|
||||||
|
|
||||||
countQuery := model.DB.Model(&model.Channel{})
|
countQuery := buildChannelListQuery(groupFilter, statusFilter, -1)
|
||||||
if statusFilter == common.ChannelStatusEnabled {
|
|
||||||
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
|
|
||||||
} else if statusFilter == 0 {
|
|
||||||
countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
|
|
||||||
}
|
|
||||||
var results []struct {
|
var results []struct {
|
||||||
Type int64
|
Type int64
|
||||||
Count int64
|
Count int64
|
||||||
}
|
}
|
||||||
_ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
|
if err := countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error; err != nil {
|
||||||
|
common.SysError("failed to count channel types: " + err.Error())
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道类型统计失败,请稍后重试"})
|
||||||
|
return
|
||||||
|
}
|
||||||
typeCounts := make(map[int64]int64)
|
typeCounts := make(map[int64]int64)
|
||||||
for _, r := range results {
|
for _, r := range results {
|
||||||
typeCounts[r.Type] = r.Count
|
typeCounts[r.Type] = r.Count
|
||||||
@ -277,10 +278,18 @@ func SearchChannels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
for _, tag := range tags {
|
for _, tag := range tags {
|
||||||
if tag != nil && *tag != "" {
|
if tag != nil && *tag != "" {
|
||||||
tagChannel, err := model.GetChannelsByTag(*tag, idSort, false, sortOptions)
|
var tagChannels []*model.Channel
|
||||||
if err == nil {
|
err := sortOptions.Apply(buildChannelListQuery(group, -1, -1).Where("tag = ?", *tag)).
|
||||||
channelData = append(channelData, tagChannel...)
|
Omit("key").
|
||||||
|
Find(&tagChannels).Error
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
channelData = append(channelData, tagChannels...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -128,6 +128,38 @@ func resolveChannelSortOptions(idSort bool, sortOptions []ChannelSortOptions) Ch
|
|||||||
return options
|
return options
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NormalizeChannelGroupFilter(group string) string {
|
||||||
|
group = strings.TrimSpace(group)
|
||||||
|
if group == "" || strings.EqualFold(group, "all") || strings.EqualFold(group, "null") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return group
|
||||||
|
}
|
||||||
|
|
||||||
|
func channelGroupFilterCondition() string {
|
||||||
|
if common.UsingMySQL {
|
||||||
|
return `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ? ESCAPE '!'`
|
||||||
|
}
|
||||||
|
return `(',' || ` + commonGroupCol + ` || ',') LIKE ? ESCAPE '!'`
|
||||||
|
}
|
||||||
|
|
||||||
|
func channelGroupFilterPattern(group string) string {
|
||||||
|
group = strings.NewReplacer(
|
||||||
|
"!", "!!",
|
||||||
|
"%", "!%",
|
||||||
|
"_", "!_",
|
||||||
|
).Replace(group)
|
||||||
|
return "%," + group + ",%"
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplyChannelGroupFilter(query *gorm.DB, group string) *gorm.DB {
|
||||||
|
group = NormalizeChannelGroupFilter(group)
|
||||||
|
if group == "" {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
return query.Where(channelGroupFilterCondition(), channelGroupFilterPattern(group))
|
||||||
|
}
|
||||||
|
|
||||||
// Value implements driver.Valuer interface
|
// Value implements driver.Valuer interface
|
||||||
func (c ChannelInfo) Value() (driver.Value, error) {
|
func (c ChannelInfo) Value() (driver.Value, error) {
|
||||||
return common.Marshal(&c)
|
return common.Marshal(&c)
|
||||||
@ -365,25 +397,12 @@ func SearchChannels(keyword string, group string, model string, idSort bool, sor
|
|||||||
baseQuery := DB.Model(&Channel{}).Omit("key")
|
baseQuery := DB.Model(&Channel{}).Omit("key")
|
||||||
|
|
||||||
// 构造WHERE子句
|
// 构造WHERE子句
|
||||||
var whereClause string
|
whereClause := "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
||||||
var args []interface{}
|
args := []any{common.String2Int(keyword), "%" + keyword + "%", keyword, "%" + keyword + "%", "%" + model + "%"}
|
||||||
if group != "" && group != "null" {
|
baseQuery = ApplyChannelGroupFilter(baseQuery.Where(whereClause, args...), group)
|
||||||
var groupCondition string
|
|
||||||
if common.UsingMySQL {
|
|
||||||
groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
|
|
||||||
} else {
|
|
||||||
// sqlite, PostgreSQL
|
|
||||||
groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
|
|
||||||
}
|
|
||||||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
|
|
||||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
|
|
||||||
} else {
|
|
||||||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
|
||||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 执行查询
|
// 执行查询
|
||||||
err := order.Apply(baseQuery.Where(whereClause, args...)).Find(&channels).Error
|
err := order.Apply(baseQuery).Find(&channels).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -828,8 +847,18 @@ func DeleteDisabledChannel() (int64, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetPaginatedTags(offset int, limit int) ([]*string, error) {
|
func GetPaginatedTags(offset int, limit int) ([]*string, error) {
|
||||||
|
return GetPaginatedChannelTags(DB.Model(&Channel{}), offset, limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetPaginatedChannelTags(query *gorm.DB, offset int, limit int) ([]*string, error) {
|
||||||
var tags []*string
|
var tags []*string
|
||||||
err := DB.Model(&Channel{}).Select("DISTINCT tag").Where("tag != ''").Offset(offset).Limit(limit).Find(&tags).Error
|
err := query.
|
||||||
|
Select("DISTINCT tag").
|
||||||
|
Where("tag is not null AND tag != ''").
|
||||||
|
Order(clause.OrderByColumn{Column: clause.Column{Name: "tag"}}).
|
||||||
|
Offset(offset).
|
||||||
|
Limit(limit).
|
||||||
|
Find(&tags).Error
|
||||||
return tags, err
|
return tags, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -857,24 +886,11 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
|
|||||||
baseQuery := DB.Model(&Channel{}).Omit("key")
|
baseQuery := DB.Model(&Channel{}).Omit("key")
|
||||||
|
|
||||||
// 构造WHERE子句
|
// 构造WHERE子句
|
||||||
var whereClause string
|
whereClause := "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
||||||
var args []interface{}
|
args := []any{common.String2Int(keyword), "%" + keyword + "%", keyword, "%" + keyword + "%", "%" + model + "%"}
|
||||||
if group != "" && group != "null" {
|
baseQuery = ApplyChannelGroupFilter(baseQuery.Where(whereClause, args...), group)
|
||||||
var groupCondition string
|
|
||||||
if common.UsingMySQL {
|
|
||||||
groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
|
|
||||||
} else {
|
|
||||||
// sqlite, PostgreSQL
|
|
||||||
groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
|
|
||||||
}
|
|
||||||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
|
|
||||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
|
|
||||||
} else {
|
|
||||||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
|
||||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
|
|
||||||
}
|
|
||||||
|
|
||||||
subQuery := baseQuery.Where(whereClause, args...).
|
subQuery := baseQuery.
|
||||||
Select("tag").
|
Select("tag").
|
||||||
Where("tag != ''").
|
Where("tag != ''").
|
||||||
Order(order)
|
Order(order)
|
||||||
@ -1015,8 +1031,12 @@ func CountAllChannels() (int64, error) {
|
|||||||
|
|
||||||
// CountAllTags returns number of non-empty distinct tags
|
// CountAllTags returns number of non-empty distinct tags
|
||||||
func CountAllTags() (int64, error) {
|
func CountAllTags() (int64, error) {
|
||||||
|
return CountChannelTags(DB.Model(&Channel{}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func CountChannelTags(query *gorm.DB) (int64, error) {
|
||||||
var total int64
|
var total int64
|
||||||
err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
|
err := query.Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
|
||||||
return total, err
|
return total, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user