diff --git a/model/log.go b/model/log.go index 7edf24b4..8ec7807e 100644 --- a/model/log.go +++ b/model/log.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "strings" "time" "github.com/QuantumNous/new-api/common" @@ -308,15 +309,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName tx = LOG_DB.Where("logs.type = ?", logType) } - if modelName != "" { - tx = tx.Where("logs.model_name like ?", modelName) - } - if username != "" { - tx = tx.Where("logs.username = ?", username) - } - if tokenName != "" { - tx = tx.Where("logs.token_name = ?", tokenName) - } + tx = applyLogContainsFilter(tx, "logs.model_name", modelName) + tx = applyLogContainsFilter(tx, "logs.username", username) + tx = applyLogContainsFilter(tx, "logs.token_name", tokenName) if requestId != "" { tx = tx.Where("logs.request_id = ?", requestId) } @@ -397,16 +392,8 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int tx = LOG_DB.Where("logs.user_id = ? and logs.type = ?", userId, logType) } - if modelName != "" { - modelNamePattern, err := sanitizeLikePattern(modelName) - if err != nil { - return nil, 0, err - } - tx = tx.Where("logs.model_name LIKE ? ESCAPE '!'", modelNamePattern) - } - if tokenName != "" { - tx = tx.Where("logs.token_name = ?", tokenName) - } + tx = applyLogContainsFilter(tx, "logs.model_name", modelName) + tx = applyLogContainsFilter(tx, "logs.token_name", tokenName) if requestId != "" { tx = tx.Where("logs.request_id = ?", requestId) } @@ -443,34 +430,42 @@ type Stat struct { Tpm int `json:"tpm"` } +func logContainsPattern(input string) (string, bool) { + input = strings.TrimSpace(input) + if input == "" { + return "", false + } + + replacer := strings.NewReplacer("!", "!!", "%", "!%", "_", "!_") + return "%" + replacer.Replace(input) + "%", true +} + +func applyLogContainsFilter(tx *gorm.DB, column string, value string) *gorm.DB { + pattern, ok := logContainsPattern(value) + if !ok { + return tx + } + return tx.Where(column+" LIKE ? ESCAPE '!'", pattern) +} + func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat, err error) { tx := LOG_DB.Table("logs").Select("sum(quota) quota") // 为rpm和tpm创建单独的查询 rpmTpmQuery := LOG_DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm") - if username != "" { - tx = tx.Where("username = ?", username) - rpmTpmQuery = rpmTpmQuery.Where("username = ?", username) - } - if tokenName != "" { - tx = tx.Where("token_name = ?", tokenName) - rpmTpmQuery = rpmTpmQuery.Where("token_name = ?", tokenName) - } + tx = applyLogContainsFilter(tx, "username", username) + rpmTpmQuery = applyLogContainsFilter(rpmTpmQuery, "username", username) + tx = applyLogContainsFilter(tx, "token_name", tokenName) + rpmTpmQuery = applyLogContainsFilter(rpmTpmQuery, "token_name", tokenName) if startTimestamp != 0 { tx = tx.Where("created_at >= ?", startTimestamp) } if endTimestamp != 0 { tx = tx.Where("created_at <= ?", endTimestamp) } - if modelName != "" { - modelNamePattern, err := sanitizeLikePattern(modelName) - if err != nil { - return stat, err - } - tx = tx.Where("model_name LIKE ? ESCAPE '!'", modelNamePattern) - rpmTpmQuery = rpmTpmQuery.Where("model_name LIKE ? ESCAPE '!'", modelNamePattern) - } + tx = applyLogContainsFilter(tx, "model_name", modelName) + rpmTpmQuery = applyLogContainsFilter(rpmTpmQuery, "model_name", modelName) if channel != 0 { tx = tx.Where("channel_id = ?", channel) rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel) diff --git a/web/default/src/features/usage-logs/components/usage-logs-table.tsx b/web/default/src/features/usage-logs/components/usage-logs-table.tsx index 333bfb08..996e6353 100644 --- a/web/default/src/features/usage-logs/components/usage-logs-table.tsx +++ b/web/default/src/features/usage-logs/components/usage-logs-table.tsx @@ -150,6 +150,7 @@ export function UsageLogsTable({ logCategory }: UsageLogsTableProps) { getFacetedRowModel: getFacetedRowModel(), getFacetedUniqueValues: getFacetedUniqueValues(), manualPagination: true, + manualFiltering: true, pageCount: Math.ceil((data?.total || 0) / pagination.pageSize), })