LMS.service/LMS.Tools/MJPackage/TaskConcurrencyManager.cs

217 lines
8.0 KiB
C#
Raw Normal View History

using LMS.Common.Extensions;
using LMS.DAO;
using LMS.Repository.DB;
using LMS.Repository.MJPackage;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using System.Collections.Concurrent;
using System.Text;
namespace LMS.Tools.MJPackage
{
public class TaskConcurrencyManager : ITaskConcurrencyManager
{
private readonly ConcurrentDictionary<string, MJApiTasks> _activeTasks = new();
private readonly ConcurrentDictionary<string, string> _thirdPartyTaskMap = new(); // ThirdPartyTaskId -> TaskId
private readonly TokenUsageTracker _usageTracker;
private readonly IServiceScopeFactory _scopeFactory;
private readonly ILogger<TaskConcurrencyManager> _logger;
private readonly ApplicationDbContext _dbContext;
private readonly ITokenService _tokenService;
public TaskConcurrencyManager(
TokenUsageTracker usageTracker,
IServiceScopeFactory scopeFactory,
ILogger<TaskConcurrencyManager> logger,
ApplicationDbContext dbContext,
ITokenService tokenService)
{
_usageTracker = usageTracker;
_scopeFactory = scopeFactory;
_logger = logger;
_dbContext = dbContext;
_tokenService = tokenService;
}
/// <summary>
/// 尝试开始新任务(获取并发许可)
/// </summary>
public async Task CreateTaskAsync(
string token,
string thirdPartyTaskId)
{
try
{
TokenCacheItem? tokenConfig = await _tokenService.GetTokenAsync(token);
if (tokenConfig == null || string.IsNullOrWhiteSpace(tokenConfig.UseToken))
{
_logger.LogWarning($"无效的Token: {token}");
return;
}
// 创建任务信息
var taskId = Guid.NewGuid().ToString("N");
var mJApiTasks = new MJApiTasks
{
TaskId = taskId,
Token = token,
TokenId = tokenConfig.Id,
StartTime = BeijingTimeExtension.GetBeijingTime(),
Status = MJTaskStatus.NOT_START,
ThirdPartyTaskId = thirdPartyTaskId,
Properties = null
};
// 5. 持久化任务信息到数据库
await SaveTaskToDatabase(mJApiTasks);
}
catch (Exception ex)
{
_logger.LogError(ex, $"开始任务时发生错误: Token={token}");
}
}
/// <summary>
/// 获取任务信息
/// </summary>
public async Task<MJApiTasks> GetTaskInfoAsync(string taskId)
{
if (_activeTasks.TryGetValue(taskId, out var taskInfo))
{
return taskInfo;
}
// 如果内存中没有,尝试从数据库加载
return await LoadTaskFromDatabase(taskId);
}
/// <summary>
/// 通过第三方ID获取数据
/// </summary>
/// <param name="thirdPartyId"></param>
/// <returns></returns>
public async Task<MJApiTasks> GetTaskInfoByThirdPartyIdAsync(string thirdPartyId)
{
if (string.IsNullOrWhiteSpace(thirdPartyId))
{
_logger.LogWarning("第三方任务ID为空");
return null;
}
MJApiTasks? mJApiTasks = await _dbContext.MJApiTasks.FirstOrDefaultAsync(x => x.ThirdPartyTaskId == thirdPartyId);
return mJApiTasks;
}
/// <summary>
/// 获取运行中的任务列表
/// </summary>
public async Task<IEnumerable<MJApiTasks>> GetRunningTasksAsync(string token = null)
{
var runningTasks = _activeTasks.Values
.Where(t => t.Status != MJTaskStatus.SUCCESS && t.Status != MJTaskStatus.FAILURE && t.Status != MJTaskStatus.CANCEL)
.Where(t => string.IsNullOrEmpty(token) || t.Token == token)
.OrderBy(t => t.StartTime)
.ToList();
_logger.LogDebug($"当前运行中的任务数: {runningTasks.Count}" + (string.IsNullOrEmpty(token) ? "" : $", Token={token}"));
return await Task.FromResult(runningTasks);
}
/// <summary>
/// 获取Token的并发状态
/// </summary>
public async Task<(int maxConcurrency, int running, int available)> GetConcurrencyStatusAsync(string token)
{
var status = _usageTracker.GetConcurrencyStatus(token);
return await Task.FromResult((status.maxCount, status.currentlyExecuting, status.available));
}
/// <summary>
/// 清理超时任务
/// </summary>
public async Task CleanupTimeoutTasksAsync(TimeSpan timeout)
{
_logger.LogInformation($"开始清理超时任务,超时阈值: {timeout.TotalMinutes}分钟");
var cutoffTime = BeijingTimeExtension.GetBeijingTime() - timeout;
var timeoutTasks = _activeTasks.Values
.Where(t => t.StartTime < cutoffTime && t.Status != MJTaskStatus.SUCCESS && t.Status != MJTaskStatus.FAILURE && t.Status != MJTaskStatus.CANCEL)
.ToList();
_logger.LogInformation($"发现 {timeoutTasks.Count} 个超时任务");
foreach (var task in timeoutTasks)
{
_logger.LogWarning($"清理超时任务: TaskId={task.TaskId}, Token={task.Token}, 开始时间={task.StartTime:yyyy-MM-dd HH:mm:ss}");
_usageTracker.ReleaseConcurrencyPermit(task.Token);
}
}
/// <summary>
/// 保存任务到数据库
/// </summary>
private async Task SaveTaskToDatabase(MJApiTasks mJApiTasks)
{
try
{
await _dbContext.MJApiTasks.AddAsync(mJApiTasks);
await _dbContext.SaveChangesAsync();
_logger.LogInformation($"任务已保存到数据库: TaskId={mJApiTasks.TaskId}, Token={mJApiTasks.Token}");
}
catch (Exception ex)
{
_logger.LogError(ex, $"保存任务到数据库失败: TaskId={mJApiTasks.TaskId}");
}
}
/// <summary>
/// 更新数据库中的任务状态
/// </summary>
public async Task UpdateTaskInDatabase(MJApiTasks mJApiTasks)
{
try
{
MJApiTasks? apiTasks = await _dbContext.MJApiTasks.FirstOrDefaultAsync(x => x.ThirdPartyTaskId == mJApiTasks.ThirdPartyTaskId);
if (apiTasks == null)
{
_logger.LogWarning($"未找到任务: TaskId={mJApiTasks.TaskId}");
return;
}
apiTasks.Status = mJApiTasks.Status;
apiTasks.EndTime = mJApiTasks.EndTime;
apiTasks.Properties = mJApiTasks.Properties;
_dbContext.MJApiTasks.Update(apiTasks);
await _dbContext.SaveChangesAsync();
}
catch (Exception ex)
{
_logger.LogError(ex, $"更新任务状态到数据库失败: TaskId={mJApiTasks.TaskId}");
}
}
/// <summary>
/// 从数据库加载任务
/// </summary>
private async Task<MJApiTasks> LoadTaskFromDatabase(string taskId)
{
try
{
MJApiTasks? mJApiTasks = await _dbContext.MJApiTasks.FirstOrDefaultAsync(x => x.TaskId == taskId);
if (mJApiTasks == null)
{
_logger.LogWarning($"未找到任务: TaskId={taskId}");
return null;
}
return mJApiTasks;
}
catch (Exception ex)
{
_logger.LogError(ex, $"从数据库加载任务失败: TaskId={taskId}");
return null;
}
}
}
}