2025-02-24 16:20:55 +08:00
package middleware
import (
"context"
"fmt"
"net/http"
"one-api/common"
2025-04-16 10:33:43 +08:00
"one-api/common/limiter"
2025-02-24 16:20:55 +08:00
"one-api/setting"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
)
const (
ModelRequestRateLimitCountMark = "MRRL"
ModelRequestRateLimitSuccessCountMark = "MRRLS"
)
func checkRedisRateLimit ( ctx context . Context , rdb * redis . Client , key string , maxCount int , duration int64 ) ( bool , error ) {
if maxCount == 0 {
return true , nil
}
length , err := rdb . LLen ( ctx , key ) . Result ( )
if err != nil {
return false , err
}
if length < int64 ( maxCount ) {
return true , nil
}
oldTimeStr , _ := rdb . LIndex ( ctx , key , - 1 ) . Result ( )
oldTime , err := time . Parse ( timeFormat , oldTimeStr )
if err != nil {
return false , err
}
nowTimeStr := time . Now ( ) . Format ( timeFormat )
nowTime , err := time . Parse ( timeFormat , nowTimeStr )
if err != nil {
return false , err
}
subTime := nowTime . Sub ( oldTime ) . Seconds ( )
if int64 ( subTime ) < duration {
2025-03-02 17:34:39 +08:00
rdb . Expire ( ctx , key , time . Duration ( setting . ModelRequestRateLimitDurationMinutes ) * time . Minute )
2025-02-24 16:20:55 +08:00
return false , nil
}
return true , nil
}
func recordRedisRequest ( ctx context . Context , rdb * redis . Client , key string , maxCount int ) {
if maxCount == 0 {
return
}
now := time . Now ( ) . Format ( timeFormat )
rdb . LPush ( ctx , key , now )
rdb . LTrim ( ctx , key , 0 , int64 ( maxCount - 1 ) )
2025-03-02 17:34:39 +08:00
rdb . Expire ( ctx , key , time . Duration ( setting . ModelRequestRateLimitDurationMinutes ) * time . Minute )
2025-02-24 16:20:55 +08:00
}
func redisRateLimitHandler ( duration int64 , totalMaxCount , successMaxCount int ) gin . HandlerFunc {
return func ( c * gin . Context ) {
userId := strconv . Itoa ( c . GetInt ( "id" ) )
ctx := context . Background ( )
rdb := common . RDB
2025-04-16 10:33:43 +08:00
successKey := fmt . Sprintf ( "rateLimit:%s:%s" , ModelRequestRateLimitSuccessCountMark , userId )
allowed , err := checkRedisRateLimit ( ctx , rdb , successKey , successMaxCount , duration )
2025-02-24 16:20:55 +08:00
if err != nil {
2025-04-16 10:33:43 +08:00
fmt . Println ( "检查成功请求数限制失败:" , err . Error ( ) )
2025-02-24 16:20:55 +08:00
abortWithOpenAiMessage ( c , http . StatusInternalServerError , "rate_limit_check_failed" )
return
}
if ! allowed {
2025-04-16 10:33:43 +08:00
abortWithOpenAiMessage ( c , http . StatusTooManyRequests , fmt . Sprintf ( "您已达到请求数限制:%d分钟内最多请求%d次" , setting . ModelRequestRateLimitDurationMinutes , successMaxCount ) )
return
2025-02-24 16:20:55 +08:00
}
2025-04-16 16:36:07 +08:00
2025-04-16 10:33:43 +08:00
totalKey := fmt . Sprintf ( "rateLimit:%s" , userId )
tb := limiter . New ( ctx , rdb )
allowed , err = tb . Allow (
ctx ,
totalKey ,
limiter . WithCapacity ( int64 ( totalMaxCount ) * duration ) ,
limiter . WithRate ( int64 ( totalMaxCount ) ) ,
limiter . WithRequested ( duration ) ,
)
2025-02-24 16:20:55 +08:00
if err != nil {
2025-04-16 10:33:43 +08:00
fmt . Println ( "检查总请求数限制失败:" , err . Error ( ) )
2025-02-24 16:20:55 +08:00
abortWithOpenAiMessage ( c , http . StatusInternalServerError , "rate_limit_check_failed" )
return
}
2025-04-16 10:33:43 +08:00
2025-02-24 16:20:55 +08:00
if ! allowed {
2025-04-16 10:33:43 +08:00
abortWithOpenAiMessage ( c , http . StatusTooManyRequests , fmt . Sprintf ( "您已达到总请求数限制:%d分钟内最多请求%d次, 包括失败次数, 请检查您的请求是否正确" , setting . ModelRequestRateLimitDurationMinutes , totalMaxCount ) )
2025-02-24 16:20:55 +08:00
}
c . Next ( )
if c . Writer . Status ( ) < 400 {
recordRedisRequest ( ctx , rdb , successKey , successMaxCount )
}
}
}
func memoryRateLimitHandler ( duration int64 , totalMaxCount , successMaxCount int ) gin . HandlerFunc {
2025-03-02 17:34:39 +08:00
inMemoryRateLimiter . Init ( time . Duration ( setting . ModelRequestRateLimitDurationMinutes ) * time . Minute )
2025-02-24 16:20:55 +08:00
return func ( c * gin . Context ) {
userId := strconv . Itoa ( c . GetInt ( "id" ) )
totalKey := ModelRequestRateLimitCountMark + userId
successKey := ModelRequestRateLimitSuccessCountMark + userId
if totalMaxCount > 0 && ! inMemoryRateLimiter . Request ( totalKey , totalMaxCount , duration ) {
c . Status ( http . StatusTooManyRequests )
c . Abort ( )
return
}
checkKey := successKey + "_check"
if ! inMemoryRateLimiter . Request ( checkKey , successMaxCount , duration ) {
c . Status ( http . StatusTooManyRequests )
c . Abort ( )
return
}
c . Next ( )
if c . Writer . Status ( ) < 400 {
inMemoryRateLimiter . Request ( successKey , successMaxCount , duration )
}
}
}
func ModelRequestRateLimit ( ) func ( c * gin . Context ) {
2025-03-06 16:32:11 +08:00
return func ( c * gin . Context ) {
if ! setting . ModelRequestRateLimitEnabled {
c . Next ( )
return
}
2025-03-02 17:34:39 +08:00
2025-03-06 16:32:11 +08:00
duration := int64 ( setting . ModelRequestRateLimitDurationMinutes * 60 )
2025-02-24 16:20:55 +08:00
2025-05-05 07:31:54 +08:00
group := c . GetString ( "token_group" )
if group == "" {
group = c . GetString ( "group" )
}
if group == "" {
2025-05-05 11:34:57 +08:00
group = "default"
2025-05-05 07:31:54 +08:00
}
2025-05-05 11:34:57 +08:00
finalTotalCount := setting . ModelRequestRateLimitCount
finalSuccessCount := setting . ModelRequestRateLimitSuccessCount
foundGroupLimit := false
2025-05-05 07:31:54 +08:00
2025-05-05 11:34:57 +08:00
groupTotalCount , groupSuccessCount , found := setting . GetGroupRateLimit ( group )
2025-05-05 07:31:54 +08:00
if found {
finalTotalCount = groupTotalCount
finalSuccessCount = groupSuccessCount
2025-05-05 11:34:57 +08:00
foundGroupLimit = true
2025-05-05 07:31:54 +08:00
common . LogWarn ( c . Request . Context ( ) , fmt . Sprintf ( "Using rate limit for group '%s': total=%d, success=%d" , group , finalTotalCount , finalSuccessCount ) )
2025-05-05 11:34:57 +08:00
}
if ! foundGroupLimit {
2025-05-05 07:31:54 +08:00
common . LogInfo ( c . Request . Context ( ) , fmt . Sprintf ( "No specific rate limit found for group '%s', using global limits: total=%d, success=%d" , group , finalTotalCount , finalSuccessCount ) )
}
2025-03-06 16:32:11 +08:00
if common . RedisEnabled {
2025-05-05 07:31:54 +08:00
redisRateLimitHandler ( duration , finalTotalCount , finalSuccessCount ) ( c )
2025-03-06 16:32:11 +08:00
} else {
2025-05-05 07:31:54 +08:00
memoryRateLimitHandler ( duration , finalTotalCount , finalSuccessCount ) ( c )
2025-03-06 16:32:11 +08:00
}
2025-02-24 16:20:55 +08:00
}
}