diff --git a/service/channel_affinity.go b/service/channel_affinity.go index e09cb01f..f16c350b 100644 --- a/service/channel_affinity.go +++ b/service/channel_affinity.go @@ -302,6 +302,11 @@ func extractChannelAffinityValue(c *gin.Context, src operation_setting.ChannelAf return "" } return strings.TrimSpace(c.GetString(src.Key)) + case "request_header": + if c == nil || c.Request == nil || src.Key == "" { + return "" + } + return strings.TrimSpace(c.Request.Header.Get(src.Key)) case "gjson": if src.Path == "" { return "" diff --git a/service/channel_affinity_template_test.go b/service/channel_affinity_template_test.go index 033cbd83..91844fc3 100644 --- a/service/channel_affinity_template_test.go +++ b/service/channel_affinity_template_test.go @@ -176,6 +176,66 @@ func TestShouldSkipRetryAfterChannelAffinityFailure(t *testing.T) { } } +func TestExtractChannelAffinityValue_RequestHeader(t *testing.T) { + rec := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rec) + ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + ctx.Request.Header.Set("X-Affinity-Key", " tenant-123 ") + + value := extractChannelAffinityValue(ctx, operation_setting.ChannelAffinityKeySource{ + Type: "request_header", + Key: "X-Affinity-Key", + }) + + require.Equal(t, "tenant-123", value) +} + +func TestGetPreferredChannelByAffinity_RequestHeaderKeySource(t *testing.T) { + gin.SetMode(gin.TestMode) + + rule := operation_setting.ChannelAffinityRule{ + Name: "header-affinity", + ModelRegex: []string{"^gpt-.*$"}, + PathRegex: []string{"/v1/responses"}, + KeySources: []operation_setting.ChannelAffinityKeySource{ + {Type: "request_header", Key: "X-Affinity-Key"}, + }, + IncludeRuleName: true, + IncludeModelName: true, + } + + affinityValue := fmt.Sprintf("header-hit-%d", time.Now().UnixNano()) + cacheKeySuffix := buildChannelAffinityCacheKeySuffix(rule, "gpt-5", "default", affinityValue) + + cache := getChannelAffinityCache() + require.NoError(t, cache.SetWithTTL(cacheKeySuffix, 9528, time.Minute)) + t.Cleanup(func() { + _, _ = cache.DeleteMany([]string{cacheKeySuffix}) + }) + + setting := operation_setting.GetChannelAffinitySetting() + originalRules := setting.Rules + setting.Rules = append([]operation_setting.ChannelAffinityRule{rule}, originalRules...) + t.Cleanup(func() { + setting.Rules = originalRules + }) + + rec := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rec) + ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + ctx.Request.Header.Set("X-Affinity-Key", affinityValue) + + channelID, found := GetPreferredChannelByAffinity(ctx, "gpt-5", "default") + require.True(t, found) + require.Equal(t, 9528, channelID) + + meta, ok := getChannelAffinityMeta(ctx) + require.True(t, ok) + require.Equal(t, "request_header", meta.KeySourceType) + require.Equal(t, "X-Affinity-Key", meta.KeySourceKey) + require.Equal(t, buildChannelAffinityKeyHint(affinityValue), meta.KeyHint) +} + func TestChannelAffinityHitCodexTemplatePassHeadersEffective(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/setting/operation_setting/channel_affinity_setting.go b/setting/operation_setting/channel_affinity_setting.go index 8cd605c0..bd573696 100644 --- a/setting/operation_setting/channel_affinity_setting.go +++ b/setting/operation_setting/channel_affinity_setting.go @@ -3,7 +3,7 @@ package operation_setting import "github.com/QuantumNous/new-api/setting/config" type ChannelAffinityKeySource struct { - Type string `json:"type"` // context_int, context_string, gjson + Type string `json:"type"` // context_int, context_string, request_header, gjson Key string `json:"key,omitempty"` Path string `json:"path,omitempty"` } diff --git a/web/classic/src/pages/Setting/Operation/SettingsChannelAffinity.jsx b/web/classic/src/pages/Setting/Operation/SettingsChannelAffinity.jsx index 76ad7aec..66b91443 100644 --- a/web/classic/src/pages/Setting/Operation/SettingsChannelAffinity.jsx +++ b/web/classic/src/pages/Setting/Operation/SettingsChannelAffinity.jsx @@ -69,6 +69,7 @@ const KEY_RULES = 'channel_affinity_setting.rules'; const KEY_SOURCE_TYPES = [ { label: 'context_int', value: 'context_int' }, { label: 'context_string', value: 'context_string' }, + { label: 'request_header', value: 'request_header' }, { label: 'gjson', value: 'gjson' }, ]; @@ -659,7 +660,11 @@ export default function SettingsChannelAffinity(props) { const xs = (keySources || []).map(normalizeKeySource).filter((x) => x.type); if (xs.length === 0) return { ok: false, message: 'Key 来源不能为空' }; for (const x of xs) { - if (x.type === 'context_int' || x.type === 'context_string') { + if ( + x.type === 'context_int' || + x.type === 'context_string' || + x.type === 'request_header' + ) { if (!x.key) return { ok: false, message: 'Key 不能为空' }; } else if (x.type === 'gjson') { if (!x.path) return { ok: false, message: 'Path 不能为空' }; @@ -1316,7 +1321,7 @@ export default function SettingsChannelAffinity(props) { {t( - 'context_int/context_string 从请求上下文读取;gjson 从入口请求的 JSON body 按 gjson path 读取。', + 'context_int/context_string 从请求上下文读取;request_header 从用户请求头读取;gjson 从入口请求的 JSON body 按 gjson path 读取。', )}
@@ -1358,7 +1363,7 @@ export default function SettingsChannelAffinity(props) { return ( . For commercial licensing, please contact support@quantumnous.com */ export interface KeySource { - type: 'context_int' | 'context_string' | 'gjson' + type: 'context_int' | 'context_string' | 'request_header' | 'gjson' key?: string path?: string }