diff --git a/model/topup.go b/model/topup.go index 00ad5c49..d1b0c5cb 100644 --- a/model/topup.go +++ b/model/topup.go @@ -110,6 +110,14 @@ func Recharge(referenceId string, customerId string, callerIp string) (err error return nil } +// topUpQueryWindowSeconds 限制充值记录查询的时间窗口(秒)。 +const topUpQueryWindowSeconds int64 = 30 * 24 * 60 * 60 + +// topUpQueryCutoff 返回允许查询的最早 create_time(秒级 Unix 时间戳)。 +func topUpQueryCutoff() int64 { + return common.GetTimestamp() - topUpQueryWindowSeconds +} + func GetUserTopUps(userId int, pageInfo *common.PageInfo) (topups []*TopUp, total int64, err error) { // Start transaction tx := DB.Begin() @@ -122,15 +130,17 @@ func GetUserTopUps(userId int, pageInfo *common.PageInfo) (topups []*TopUp, tota } }() + cutoff := topUpQueryCutoff() + // Get total count within transaction - err = tx.Model(&TopUp{}).Where("user_id = ?", userId).Count(&total).Error + err = tx.Model(&TopUp{}).Where("user_id = ? AND create_time >= ?", userId, cutoff).Count(&total).Error if err != nil { tx.Rollback() return nil, 0, err } // Get paginated topups within same transaction - err = tx.Where("user_id = ?", userId).Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Find(&topups).Error + err = tx.Where("user_id = ? AND create_time >= ?", userId, cutoff).Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Find(&topups).Error if err != nil { tx.Rollback() return nil, 0, err @@ -144,7 +154,7 @@ func GetUserTopUps(userId int, pageInfo *common.PageInfo) (topups []*TopUp, tota return topups, total, nil } -// GetAllTopUps 获取全平台的充值记录(管理员使用) +// GetAllTopUps 获取全平台的充值记录(管理员使用,不限制时间窗口) func GetAllTopUps(pageInfo *common.PageInfo) (topups []*TopUp, total int64, err error) { tx := DB.Begin() if tx.Error != nil { @@ -173,6 +183,10 @@ func GetAllTopUps(pageInfo *common.PageInfo) (topups []*TopUp, total int64, err return topups, total, nil } +// searchTopUpCountHardLimit 搜索充值记录时 COUNT 的安全上限, +// 防止对超大表执行无界 COUNT 触发 DoS。 +const searchTopUpCountHardLimit = 10000 + // SearchUserTopUps 按订单号搜索某用户的充值记录 func SearchUserTopUps(userId int, keyword string, pageInfo *common.PageInfo) (topups []*TopUp, total int64, err error) { tx := DB.Begin() @@ -185,20 +199,26 @@ func SearchUserTopUps(userId int, keyword string, pageInfo *common.PageInfo) (to } }() - query := tx.Model(&TopUp{}).Where("user_id = ?", userId) + query := tx.Model(&TopUp{}).Where("user_id = ? AND create_time >= ?", userId, topUpQueryCutoff()) if keyword != "" { - like := "%%" + keyword + "%%" - query = query.Where("trade_no LIKE ?", like) + pattern, perr := sanitizeLikePattern(keyword) + if perr != nil { + tx.Rollback() + return nil, 0, perr + } + query = query.Where("trade_no LIKE ? ESCAPE '!'", pattern) } - if err = query.Count(&total).Error; err != nil { + if err = query.Limit(searchTopUpCountHardLimit).Count(&total).Error; err != nil { tx.Rollback() - return nil, 0, err + common.SysError("failed to count search topups: " + err.Error()) + return nil, 0, errors.New("搜索充值记录失败") } if err = query.Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Find(&topups).Error; err != nil { tx.Rollback() - return nil, 0, err + common.SysError("failed to search topups: " + err.Error()) + return nil, 0, errors.New("搜索充值记录失败") } if err = tx.Commit().Error; err != nil { @@ -207,7 +227,7 @@ func SearchUserTopUps(userId int, keyword string, pageInfo *common.PageInfo) (to return topups, total, nil } -// SearchAllTopUps 按订单号搜索全平台充值记录(管理员使用) +// SearchAllTopUps 按订单号搜索全平台充值记录(管理员使用,不限制时间窗口) func SearchAllTopUps(keyword string, pageInfo *common.PageInfo) (topups []*TopUp, total int64, err error) { tx := DB.Begin() if tx.Error != nil { @@ -221,18 +241,24 @@ func SearchAllTopUps(keyword string, pageInfo *common.PageInfo) (topups []*TopUp query := tx.Model(&TopUp{}) if keyword != "" { - like := "%%" + keyword + "%%" - query = query.Where("trade_no LIKE ?", like) + pattern, perr := sanitizeLikePattern(keyword) + if perr != nil { + tx.Rollback() + return nil, 0, perr + } + query = query.Where("trade_no LIKE ? ESCAPE '!'", pattern) } - if err = query.Count(&total).Error; err != nil { + if err = query.Limit(searchTopUpCountHardLimit).Count(&total).Error; err != nil { tx.Rollback() - return nil, 0, err + common.SysError("failed to count search topups: " + err.Error()) + return nil, 0, errors.New("搜索充值记录失败") } if err = query.Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Find(&topups).Error; err != nil { tx.Rollback() - return nil, 0, err + common.SysError("failed to search topups: " + err.Error()) + return nil, 0, errors.New("搜索充值记录失败") } if err = tx.Commit().Error; err != nil {