using LMS.Common.Extensions; using LMS.DAO; using LMS.Repository.DB; using LMS.Repository.MJPackage; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Newtonsoft.Json; using System.Collections.Concurrent; using System.Text; namespace LMS.Tools.MJPackage { public class TaskConcurrencyManager : ITaskConcurrencyManager { private readonly TokenUsageTracker _usageTracker; private readonly ILogger _logger; private readonly ApplicationDbContext _dbContext; private readonly ITokenService _tokenService; public TaskConcurrencyManager( TokenUsageTracker usageTracker, ILogger logger, ApplicationDbContext dbContext, ITokenService tokenService) { _usageTracker = usageTracker; _logger = logger; _dbContext = dbContext; _tokenService = tokenService; } /// /// 尝试开始新任务(获取并发许可) /// public async Task CreateTaskAsync( string token, string thirdPartyTaskId, MJSubmitImageModel model) { try { TokenCacheItem? tokenConfig = await _tokenService.GetTokenAsync(token); if (tokenConfig == null || string.IsNullOrWhiteSpace(tokenConfig.UseToken)) { _logger.LogWarning($"无效的Token: {token}"); return; } // 创建任务信息 var taskId = Guid.NewGuid().ToString("N"); var mJApiTasks = new MJApiTasks { TaskId = taskId, Token = token, TokenId = tokenConfig.Id, StartTime = BeijingTimeExtension.GetBeijingTime(), Status = MJTaskStatus.NOT_START, ThirdPartyTaskId = thirdPartyTaskId, Properties = JsonConvert.SerializeObject(new { id = thirdPartyTaskId, action = "IMAGINE", customId = "", botType = "", prompt = model.Prompt, promptEn = "", description = "提交成功", state = "", mode = "", proxy = "", submitTime = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(), startTime = 0, finishTime = 0, imageUrl = "", imageUrls = null as string[], imageHeight = 0, imageWidth = 0, videoUrl = "", status = "", progress = "0%", failReason = "", buttons = null as object[], maskBase64 = "", properties = null as object, }), }; // 5. 持久化任务信息到数据库 await _dbContext.AddAsync(mJApiTasks); await _dbContext.SaveChangesAsync(); } catch (Exception ex) { _logger.LogError(ex, $"开始任务时发生错误: Token={token}"); } } /// /// 通过第三方ID获取数据 /// /// /// public async Task GetTaskInfoByThirdPartyIdAsync(string thirdPartyId) { if (string.IsNullOrWhiteSpace(thirdPartyId)) { _logger.LogWarning("第三方任务ID为空"); return null; } // 先尝试从内存中获取 MJApiTasks? mjApiTasks = _usageTracker.TryGetTaskCache(thirdPartyId); // 从数据库获取 mjApiTasks ??= await LoadTaskFromDatabaseByThirdPartyId(thirdPartyId); if (mjApiTasks == null) { _logger.LogWarning($"缓存和数据库中均未找到任务: ThirdPartyTaskId={thirdPartyId}"); return null; } return mjApiTasks; } /// /// 更新数据库中的任务状态 /// public async Task UpdateTaskInDatabase(MJApiTasks mJApiTasks) { try { MJApiTasks? apiTasks = await _dbContext.MJApiTasks.FirstOrDefaultAsync(x => x.ThirdPartyTaskId == mJApiTasks.ThirdPartyTaskId); if (apiTasks == null) { _logger.LogWarning($"未找到任务: TaskId={mJApiTasks.TaskId}"); return; } apiTasks.Status = mJApiTasks.Status; apiTasks.EndTime = mJApiTasks.EndTime; apiTasks.Properties = mJApiTasks.Properties; _dbContext.MJApiTasks.Update(apiTasks); await _dbContext.SaveChangesAsync(); } catch (Exception ex) { _logger.LogError(ex, $"更新任务状态到数据库失败: TaskId={mJApiTasks.TaskId}"); } } public async Task BatchUpdateTaskChaheToDatabaseAsync() { var startTime = BeijingTimeExtension.GetBeijingTime(); try { // 获取所有缓存中的任务 var tasks = _usageTracker.GetAllTaskCaches(); if (tasks == null || tasks.Count == 0) { _logger.LogInformation("缓存中没有需要更新的任务"); return; } // 批量同步 var taskList = new List(); foreach (var task in tasks) { // 从缓存中获取任务 MJApiTasks? mJApiTasks = _usageTracker.TryGetTaskCache(task.ThirdPartyTaskId); if (mJApiTasks != null) { taskList.Add(mJApiTasks); } } if (taskList.Count == 0) { _logger.LogInformation("缓存中没有需要更新的任务"); return; } // 批量更新到数据库 _dbContext.MJApiTasks.UpdateRange(taskList); await _dbContext.SaveChangesAsync(); int count = 0; // 删除缓存中状态为已完成的任务 for (int i = 0; i < taskList.Count; i++) { var task = taskList[i]; if (task.Status == MJTaskStatus.SUCCESS || task.Status == MJTaskStatus.FAILURE || task.Status == MJTaskStatus.CANCEL) { bool removeResult = _usageTracker.RemoveTaskCache(task.ThirdPartyTaskId); if (removeResult == true) { count++; } } } var duration = BeijingTimeExtension.GetBeijingTime() - startTime; _logger.LogInformation($"批量更新了 {taskList.Count} 个缓存中的任务到数据库,耗费时间: {duration.TotalMilliseconds}, 缓存中删除了 {count} 个完成的任务"); } catch (Exception ex) { _logger.LogError(ex, "批量更新任务到数据库失败"); } } /// /// 从数据库加载任务 /// private async Task LoadTaskFromDatabaseByThirdPartyId(string thirdPartyId) { try { MJApiTasks? mJApiTasks = await _dbContext.MJApiTasks.FirstOrDefaultAsync(x => x.ThirdPartyTaskId == thirdPartyId); if (mJApiTasks == null) { return null; } return mJApiTasks; } catch (Exception ex) { return null; } } } }