diff --git a/controller/perf_metrics.go b/controller/perf_metrics.go index 51e8d9ec..66d0787f 100644 --- a/controller/perf_metrics.go +++ b/controller/perf_metrics.go @@ -8,6 +8,7 @@ import ( "github.com/QuantumNous/new-api/setting/ratio_setting" "github.com/gin-gonic/gin" + "github.com/samber/lo" ) func GetPerfMetricsSummary(c *gin.Context) { @@ -18,7 +19,8 @@ func GetPerfMetricsSummary(c *gin.Context) { } } - result, err := perfmetrics.QuerySummaryAll(hours) + activeGroups := append(lo.Keys(ratio_setting.GetGroupRatioCopy()), "auto") + result, err := perfmetrics.QuerySummaryAll(hours, activeGroups) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "success": false, @@ -72,12 +74,9 @@ func GetPerfMetrics(c *gin.Context) { } func filterActiveGroups(groups []perfmetrics.GroupResult) []perfmetrics.GroupResult { - activeGroups := ratio_setting.GetGroupRatioCopy() - filtered := make([]perfmetrics.GroupResult, 0, len(groups)) - for _, g := range groups { - if _, ok := activeGroups[g.Group]; ok || g.Group == "auto" { - filtered = append(filtered, g) - } - } - return filtered + activeRatios := ratio_setting.GetGroupRatioCopy() + return lo.Filter(groups, func(g perfmetrics.GroupResult, _ int) bool { + _, ok := activeRatios[g.Group] + return ok || g.Group == "auto" + }) } diff --git a/model/perf_metric.go b/model/perf_metric.go index dcc4a8bc..1667adfb 100644 --- a/model/perf_metric.go +++ b/model/perf_metric.go @@ -68,11 +68,18 @@ type PerfMetricSummary struct { GenerationMs int64 `json:"generation_ms"` } -func GetPerfMetricsSummaryAll(startTs int64, endTs int64) ([]PerfMetricSummary, error) { +func GetPerfMetricsSummaryAll(startTs int64, endTs int64, groups []string) ([]PerfMetricSummary, error) { var summaries []PerfMetricSummary - err := DB.Model(&PerfMetric{}). + query := DB.Model(&PerfMetric{}). Select("model_name, SUM(request_count) as request_count, SUM(success_count) as success_count, SUM(total_latency_ms) as total_latency_ms, SUM(output_tokens) as output_tokens, SUM(generation_ms) as generation_ms"). - Where("bucket_ts >= ? AND bucket_ts <= ?", startTs, endTs). + Where("bucket_ts >= ? AND bucket_ts <= ?", startTs, endTs) + if groups != nil { + if len(groups) == 0 { + return summaries, nil + } + query = query.Where(commonGroupCol+" IN ?", groups) + } + err := query. Group("model_name"). Having("SUM(request_count) > 0"). Find(&summaries).Error diff --git a/model/task_cas_test.go b/model/task_cas_test.go index ba34a732..29455e3a 100644 --- a/model/task_cas_test.go +++ b/model/task_cas_test.go @@ -26,6 +26,7 @@ func TestMain(m *testing.M) { common.RedisEnabled = false common.BatchUpdateEnabled = false common.LogConsumeEnabled = true + initCol() sqlDB, err := db.DB() if err != nil { @@ -43,6 +44,7 @@ func TestMain(m *testing.M) { &SubscriptionPlan{}, &SubscriptionOrder{}, &UserSubscription{}, + &PerfMetric{}, ); err != nil { panic("failed to migrate: " + err.Error()) } @@ -62,6 +64,7 @@ func truncateTables(t *testing.T) { DB.Exec("DELETE FROM subscription_orders") DB.Exec("DELETE FROM subscription_plans") DB.Exec("DELETE FROM user_subscriptions") + DB.Exec("DELETE FROM perf_metrics") }) } diff --git a/pkg/perf_metrics/metrics.go b/pkg/perf_metrics/metrics.go index e258ae1a..ab505d0d 100644 --- a/pkg/perf_metrics/metrics.go +++ b/pkg/perf_metrics/metrics.go @@ -122,7 +122,7 @@ func Query(params QueryParams) (QueryResult, error) { return buildQueryResult(params.Model, merged), nil } -func QuerySummaryAll(hours int) (SummaryAllResult, error) { +func QuerySummaryAll(hours int, groups []string) (SummaryAllResult, error) { if hours <= 0 { hours = 24 } @@ -131,8 +131,9 @@ func QuerySummaryAll(hours int) (SummaryAllResult, error) { } endTs := time.Now().Unix() startTs := endTs - int64(hours)*3600 + allowedGroups := allowedGroupSet(groups) - rows, err := model.GetPerfMetricsSummaryAll(startTs, endTs) + rows, err := model.GetPerfMetricsSummaryAll(startTs, endTs, groups) if err != nil { return SummaryAllResult{}, err } @@ -153,6 +154,11 @@ func QuerySummaryAll(hours int) (SummaryAllResult, error) { if k.bucketTs < startTs || k.bucketTs > endTs { return true } + if allowedGroups != nil { + if _, ok := allowedGroups[k.group]; !ok { + return true + } + } snap := value.(*atomicBucket).snapshot() if snap.requestCount == 0 { return true @@ -193,6 +199,17 @@ func QuerySummaryAll(hours int) (SummaryAllResult, error) { return SummaryAllResult{Models: models}, nil } +func allowedGroupSet(groups []string) map[string]struct{} { + if groups == nil { + return nil + } + allowed := make(map[string]struct{}, len(groups)) + for _, group := range groups { + allowed[group] = struct{}{} + } + return allowed +} + func bucketStart(ts int64) int64 { bucketSeconds := perf_metrics_setting.GetBucketSeconds() if bucketSeconds <= 0 {