From 6f8668e4c3a005117ed4e2c854a10bd9c1b96e2a Mon Sep 17 00:00:00 2001 From: yyhhyyyyyy Date: Sat, 16 May 2026 14:54:47 +0800 Subject: [PATCH] fix: enforce header nav access control for public modules (#4889) --- controller/rankings.go | 40 ---- middleware/header_nav.go | 135 ++++++++++++++ middleware/header_nav_test.go | 167 +++++++++++++++++ router/api-router.go | 6 +- .../layout/components/public-header.tsx | 172 ++++++++++++++++-- web/default/src/components/layout/types.ts | 1 + web/default/src/hooks/use-top-nav-links.ts | 79 +------- web/default/src/i18n/locales/en.json | 4 + web/default/src/i18n/locales/fr.json | 4 + web/default/src/i18n/locales/ja.json | 4 + web/default/src/i18n/locales/ru.json | 4 + web/default/src/i18n/locales/vi.json | 4 + web/default/src/i18n/locales/zh.json | 4 + web/default/src/lib/nav-modules.ts | 165 +++++++++++++++-- .../src/routes/pricing/$modelId/index.tsx | 19 +- web/default/src/routes/pricing/index.tsx | 19 +- web/default/src/routes/rankings/index.tsx | 13 +- 17 files changed, 689 insertions(+), 151 deletions(-) create mode 100644 middleware/header_nav.go create mode 100644 middleware/header_nav_test.go diff --git a/controller/rankings.go b/controller/rankings.go index a3fdf2b5..5a7fdaae 100644 --- a/controller/rankings.go +++ b/controller/rankings.go @@ -3,51 +3,11 @@ package controller import ( "net/http" - "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/service" "github.com/gin-gonic/gin" ) -func isRankingsEnabled() bool { - common.OptionMapRWMutex.RLock() - raw := common.OptionMap["HeaderNavModules"] - common.OptionMapRWMutex.RUnlock() - - if raw == "" { - return true - } - - var parsed map[string]interface{} - if err := common.Unmarshal([]byte(raw), &parsed); err != nil { - return true - } - rankings, ok := parsed["rankings"] - if !ok { - return true - } - switch v := rankings.(type) { - case bool: - return v - case map[string]interface{}: - if enabled, ok := v["enabled"]; ok { - if b, ok := enabled.(bool); ok { - return b - } - } - return true - } - return true -} - func GetRankings(c *gin.Context) { - if !isRankingsEnabled() { - c.JSON(http.StatusForbidden, gin.H{ - "success": false, - "message": "rankings is disabled", - }) - return - } - result, err := service.GetRankingsSnapshot(c.DefaultQuery("period", "week")) if err != nil { c.JSON(http.StatusBadRequest, gin.H{ diff --git a/middleware/header_nav.go b/middleware/header_nav.go new file mode 100644 index 00000000..70aa1869 --- /dev/null +++ b/middleware/header_nav.go @@ -0,0 +1,135 @@ +package middleware + +import ( + "fmt" + "net/http" + "strings" + + "github.com/QuantumNous/new-api/common" + "github.com/gin-gonic/gin" +) + +type headerNavAccess struct { + Enabled bool + RequireAuth bool +} + +func getHeaderNavAccess(module string) headerNavAccess { + fallback := headerNavAccess{ + Enabled: true, + RequireAuth: false, + } + + common.OptionMapRWMutex.RLock() + raw := common.OptionMap["HeaderNavModules"] + common.OptionMapRWMutex.RUnlock() + + if strings.TrimSpace(raw) == "" { + return fallback + } + + var parsed map[string]any + if err := common.Unmarshal([]byte(raw), &parsed); err != nil { + return fallback + } + + return parseHeaderNavAccess(parsed[module], fallback) +} + +func parseHeaderNavAccess(raw any, fallback headerNavAccess) headerNavAccess { + switch value := raw.(type) { + case bool: + return headerNavAccess{ + Enabled: value, + RequireAuth: fallback.RequireAuth, + } + case string: + return headerNavAccess{ + Enabled: parseHeaderNavBool(value, fallback.Enabled), + RequireAuth: fallback.RequireAuth, + } + case float64: + return headerNavAccess{ + Enabled: parseHeaderNavBool(value, fallback.Enabled), + RequireAuth: fallback.RequireAuth, + } + case map[string]any: + access := fallback + if enabled, ok := value["enabled"]; ok { + access.Enabled = parseHeaderNavBool(enabled, fallback.Enabled) + } + if requireAuth, ok := value["requireAuth"]; ok { + access.RequireAuth = parseHeaderNavBool(requireAuth, fallback.RequireAuth) + } + return access + default: + return fallback + } +} + +func parseHeaderNavBool(value any, fallback bool) bool { + switch v := value.(type) { + case bool: + return v + case string: + switch strings.ToLower(strings.TrimSpace(v)) { + case "true", "1": + return true + case "false", "0": + return false + default: + return fallback + } + case float64: + if v == 1 { + return true + } + if v == 0 { + return false + } + return fallback + case int: + if v == 1 { + return true + } + if v == 0 { + return false + } + return fallback + default: + return fallback + } +} + +func HeaderNavModuleAuth(module string) gin.HandlerFunc { + return func(c *gin.Context) { + access := getHeaderNavAccess(module) + if !access.Enabled { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": fmt.Sprintf("%s is disabled", module), + }) + c.Abort() + return + } + + if access.RequireAuth { + UserAuth()(c) + return + } + + TryUserAuth()(c) + } +} + +func HeaderNavModulePublicOrUserAuth(module string) gin.HandlerFunc { + return func(c *gin.Context) { + access := getHeaderNavAccess(module) + if !access.Enabled || access.RequireAuth { + UserAuth()(c) + return + } + + TryUserAuth()(c) + } +} diff --git a/middleware/header_nav_test.go b/middleware/header_nav_test.go new file mode 100644 index 00000000..d4c9c221 --- /dev/null +++ b/middleware/header_nav_test.go @@ -0,0 +1,167 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/QuantumNous/new-api/common" + "github.com/gin-contrib/sessions" + "github.com/gin-contrib/sessions/cookie" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func withHeaderNavModules(t *testing.T, raw string) { + t.Helper() + + common.OptionMapRWMutex.Lock() + if common.OptionMap == nil { + common.OptionMap = map[string]string{} + } + previous, hadPrevious := common.OptionMap["HeaderNavModules"] + common.OptionMap["HeaderNavModules"] = raw + common.OptionMapRWMutex.Unlock() + + t.Cleanup(func() { + common.OptionMapRWMutex.Lock() + defer common.OptionMapRWMutex.Unlock() + if hadPrevious { + common.OptionMap["HeaderNavModules"] = previous + return + } + delete(common.OptionMap, "HeaderNavModules") + }) +} + +func performHeaderNavRequest(t *testing.T, handler gin.HandlerFunc, authenticated bool) *httptest.ResponseRecorder { + t.Helper() + + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(sessions.Sessions("session", cookie.NewStore([]byte("header-nav-test")))) + router.GET("/login", func(c *gin.Context) { + session := sessions.Default(c) + session.Set("username", "tester") + session.Set("role", common.RoleCommonUser) + session.Set("id", 1) + session.Set("status", common.UserStatusEnabled) + session.Set("group", "default") + if err := session.Save(); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"success": false}) + return + } + c.Status(http.StatusNoContent) + }) + router.GET("/api/test", handler, func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"success": true}) + }) + + var cookies []*http.Cookie + if authenticated { + loginRecorder := httptest.NewRecorder() + loginRequest := httptest.NewRequest(http.MethodGet, "/login", nil) + router.ServeHTTP(loginRecorder, loginRequest) + require.Equal(t, http.StatusNoContent, loginRecorder.Code) + cookies = loginRecorder.Result().Cookies() + } + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "/api/test", nil) + if authenticated { + request.Header.Set("New-Api-User", "1") + for _, cookie := range cookies { + request.AddCookie(cookie) + } + } + router.ServeHTTP(recorder, request) + return recorder +} + +func TestHeaderNavModuleAuthAllowsDefaultPublicAccess(t *testing.T) { + withHeaderNavModules(t, "") + + recorder := performHeaderNavRequest(t, HeaderNavModuleAuth("pricing"), false) + + require.Equal(t, http.StatusOK, recorder.Code) +} + +func TestHeaderNavModuleAuthRejectsDisabledPricing(t *testing.T) { + raw := `{"pricing":{"enabled":false,"requireAuth":false}}` + withHeaderNavModules(t, raw) + + recorder := performHeaderNavRequest(t, HeaderNavModuleAuth("pricing"), false) + + require.Equal(t, http.StatusForbidden, recorder.Code) +} + +func TestHeaderNavModuleAuthRequiresLoginForPricing(t *testing.T) { + raw := `{"pricing":{"enabled":true,"requireAuth":true}}` + withHeaderNavModules(t, raw) + + recorder := performHeaderNavRequest(t, HeaderNavModuleAuth("pricing"), false) + + require.Equal(t, http.StatusUnauthorized, recorder.Code) +} + +func TestHeaderNavModuleAuthRequiresLoginForRankings(t *testing.T) { + raw := `{"rankings":{"enabled":true,"requireAuth":true}}` + withHeaderNavModules(t, raw) + + recorder := performHeaderNavRequest(t, HeaderNavModuleAuth("rankings"), false) + + require.Equal(t, http.StatusUnauthorized, recorder.Code) +} + +func TestHeaderNavModuleAuthRejectsLegacyDisabledModule(t *testing.T) { + raw := `{"rankings":false}` + withHeaderNavModules(t, raw) + + recorder := performHeaderNavRequest(t, HeaderNavModuleAuth("rankings"), false) + + require.Equal(t, http.StatusForbidden, recorder.Code) +} + +func TestHeaderNavModulePublicOrUserAuthAllowsDefaultPublicAccess(t *testing.T) { + withHeaderNavModules(t, "") + + recorder := performHeaderNavRequest(t, HeaderNavModulePublicOrUserAuth("pricing"), false) + + require.Equal(t, http.StatusOK, recorder.Code) +} + +func TestHeaderNavModulePublicOrUserAuthRequiresLoginWhenDisabled(t *testing.T) { + raw := `{"pricing":{"enabled":false,"requireAuth":false}}` + withHeaderNavModules(t, raw) + + recorder := performHeaderNavRequest(t, HeaderNavModulePublicOrUserAuth("pricing"), false) + + require.Equal(t, http.StatusUnauthorized, recorder.Code) +} + +func TestHeaderNavModulePublicOrUserAuthAllowsLoggedInWhenDisabled(t *testing.T) { + raw := `{"pricing":{"enabled":false,"requireAuth":false}}` + withHeaderNavModules(t, raw) + + recorder := performHeaderNavRequest(t, HeaderNavModulePublicOrUserAuth("pricing"), true) + + require.Equal(t, http.StatusOK, recorder.Code) +} + +func TestHeaderNavModulePublicOrUserAuthRequiresLoginWhenRequireAuth(t *testing.T) { + raw := `{"pricing":{"enabled":true,"requireAuth":true}}` + withHeaderNavModules(t, raw) + + recorder := performHeaderNavRequest(t, HeaderNavModulePublicOrUserAuth("pricing"), false) + + require.Equal(t, http.StatusUnauthorized, recorder.Code) +} + +func TestHeaderNavModulePublicOrUserAuthRequiresLoginForLegacyDisabledModule(t *testing.T) { + raw := `{"pricing":false}` + withHeaderNavModules(t, raw) + + recorder := performHeaderNavRequest(t, HeaderNavModulePublicOrUserAuth("pricing"), false) + + require.Equal(t, http.StatusUnauthorized, recorder.Code) +} diff --git a/router/api-router.go b/router/api-router.go index 64ccbe15..da026ed9 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -30,14 +30,14 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/about", controller.GetAbout) //apiRouter.GET("/midjourney", controller.GetMidjourney) apiRouter.GET("/home_page_content", controller.GetHomePageContent) - apiRouter.GET("/pricing", middleware.TryUserAuth(), controller.GetPricing) + apiRouter.GET("/pricing", middleware.HeaderNavModuleAuth("pricing"), controller.GetPricing) perfMetricsRoute := apiRouter.Group("/perf-metrics") - perfMetricsRoute.Use(middleware.TryUserAuth()) + perfMetricsRoute.Use(middleware.HeaderNavModulePublicOrUserAuth("pricing")) { perfMetricsRoute.GET("/summary", controller.GetPerfMetricsSummary) perfMetricsRoute.GET("", controller.GetPerfMetrics) } - apiRouter.GET("/rankings", controller.GetRankings) + apiRouter.GET("/rankings", middleware.HeaderNavModuleAuth("rankings"), controller.GetRankings) apiRouter.GET("/verification", middleware.EmailVerificationRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification) apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) diff --git a/web/default/src/components/layout/components/public-header.tsx b/web/default/src/components/layout/components/public-header.tsx index 1854ed05..4aca8365 100644 --- a/web/default/src/components/layout/components/public-header.tsx +++ b/web/default/src/components/layout/components/public-header.tsx @@ -16,8 +16,8 @@ along with this program. If not, see . For commercial licensing, please contact support@quantumnous.com */ -import { useState, useEffect } from 'react' -import { Link, useRouterState } from '@tanstack/react-router' +import { useCallback, useEffect, useState } from 'react' +import { Link, useNavigate, useRouterState } from '@tanstack/react-router' import { useTranslation } from 'react-i18next' import { useAuthStore } from '@/stores/auth-store' import { cn } from '@/lib/utils' @@ -25,6 +25,14 @@ import { useNotifications } from '@/hooks/use-notifications' import { useSystemConfig } from '@/hooks/use-system-config' import { useTopNavLinks } from '@/hooks/use-top-nav-links' import { Button } from '@/components/ui/button' +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, +} from '@/components/ui/dialog' import { Skeleton } from '@/components/ui/skeleton' import { LanguageSwitcher } from '@/components/language-switcher' import { NotificationButton } from '@/components/notification-button' @@ -35,6 +43,13 @@ import { defaultTopNavLinks } from '../config/top-nav.config' import type { TopNavLink } from '../types' import { HeaderLogo } from './header-logo' +const AUTH_PROMPT_SECONDS = 5 + +type AuthPromptTarget = { + title: string + href: string +} + export interface PublicHeaderProps { navLinks?: TopNavLink[] mobileLinks?: TopNavLink[] @@ -65,8 +80,13 @@ export function PublicHeader(props: PublicHeaderProps) { } = props const { t } = useTranslation() + const navigate = useNavigate() const [scrolled, setScrolled] = useState(false) const [mobileOpen, setMobileOpen] = useState(false) + const [authPromptTarget, setAuthPromptTarget] = + useState(null) + const [authPromptSecondsLeft, setAuthPromptSecondsLeft] = + useState(AUTH_PROMPT_SECONDS) const { auth } = useAuthStore() const { systemName, @@ -98,6 +118,67 @@ export function PublicHeader(props: PublicHeaderProps) { } }, [mobileOpen]) + useEffect(() => { + if (!authPromptTarget) return + + const intervalId = window.setInterval(() => { + setAuthPromptSecondsLeft((seconds) => Math.max(seconds - 1, 0)) + }, 1000) + + const timeoutId = window.setTimeout(() => { + const redirect = authPromptTarget.href + setAuthPromptTarget(null) + navigate({ to: '/sign-in', search: { redirect } }) + }, AUTH_PROMPT_SECONDS * 1000) + + return () => { + window.clearInterval(intervalId) + window.clearTimeout(timeoutId) + } + }, [authPromptTarget, navigate]) + + const closeAuthPrompt = useCallback(() => { + setAuthPromptTarget(null) + setAuthPromptSecondsLeft(AUTH_PROMPT_SECONDS) + }, []) + + const navigateToSignIn = useCallback(() => { + const redirect = authPromptTarget?.href || '/' + setAuthPromptTarget(null) + navigate({ to: '/sign-in', search: { redirect } }) + }, [authPromptTarget?.href, navigate]) + + const handleNavLinkClick = useCallback( + ( + event: React.MouseEvent, + link: TopNavLink, + closeMobile = false + ) => { + if (link.disabled) { + event.preventDefault() + return + } + + if (link.requiresAuth) { + event.preventDefault() + if (closeMobile) { + setMobileOpen(false) + } + setAuthPromptSecondsLeft(AUTH_PROMPT_SECONDS) + setAuthPromptTarget({ + title: t(link.title), + href: link.href, + }) + return + } + + if (closeMobile) { + setMobileOpen(false) + } + }, + [t] + ) + return ( <>
@@ -150,7 +231,13 @@ export function PublicHeader(props: PublicHeaderProps) { href={link.href} target='_blank' rel='noopener noreferrer' - className='text-muted-foreground hover:text-foreground rounded-lg px-3 py-1.5 text-[13px] font-medium transition-colors duration-200' + aria-disabled={link.disabled} + tabIndex={link.disabled ? -1 : undefined} + onClick={(event) => handleNavLinkClick(event, link)} + className={cn( + 'text-muted-foreground hover:text-foreground rounded-lg px-3 py-1.5 text-[13px] font-medium transition-colors duration-200', + link.disabled && 'pointer-events-none opacity-50' + )} > {t(link.title)} @@ -160,11 +247,14 @@ export function PublicHeader(props: PublicHeaderProps) { handleNavLinkClick(event, link)} className={cn( 'rounded-lg px-3 py-1.5 text-[13px] font-medium transition-colors duration-200', isActive ? 'text-foreground' - : 'text-muted-foreground hover:text-foreground' + : 'text-muted-foreground hover:text-foreground', + link.disabled && 'pointer-events-none opacity-50' )} > {t(link.title)} @@ -260,21 +350,42 @@ export function PublicHeader(props: PublicHeaderProps) {