374 lines
14 KiB
JavaScript
374 lines
14 KiB
JavaScript
import axios from "axios";
|
||
import path from "path";
|
||
import { DEFINE_STRING } from "../../define/define_string";
|
||
import { define } from "../../define/define";
|
||
import { ImageStyleDefine } from "../../define/iamgeStyleDefine";
|
||
import { cloneDeep } from 'lodash';
|
||
let fspromises = require("fs").promises;
|
||
const sharp = require('sharp');
|
||
import { SdSettingDefine } from "../../define/setting/sdSettingDefine";
|
||
import { PublicMethod } from "./publicMethod";
|
||
import { Tools } from "../tools";
|
||
import { errorMessage, successMessage } from "../generalTools";
|
||
import { SdApi } from "../../api/sdApi";
|
||
const { v4: uuidv4 } = require('uuid');
|
||
|
||
export class SD {
|
||
constructor(global) {
|
||
this.global = global;
|
||
this.pm = new PublicMethod(global);
|
||
this.tools = new Tools();
|
||
this.sdApi = new SdApi();
|
||
}
|
||
|
||
/**
|
||
* 获取当前SD服务器所有的lora信息
|
||
*/
|
||
async GetAllLoras(baseURL = null) {
|
||
try {
|
||
let data = await this.sdApi.getAllLoras(baseURL);
|
||
return successMessage(data);
|
||
} catch (error) {
|
||
return errorMessage(error.toString());
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 获取所有的checkpoint模型
|
||
* @param {*} baseURL
|
||
* @returns
|
||
*/
|
||
async GetAllSDModel(baseURL = null) {
|
||
try {
|
||
let data = await this.sdApi.getAllSDModel(baseURL);
|
||
return successMessage(data);
|
||
} catch (error) {
|
||
return errorMessage(error.toString());
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 获取所有的采样器
|
||
* @param {*} baseURL
|
||
* @returns
|
||
*/
|
||
async GetAllSamplers(baseURL = null) {
|
||
try {
|
||
let data = await this.sdApi.getAllSamplers(baseURL);
|
||
return successMessage(data);
|
||
} catch (error) {
|
||
return errorMessage(error.toString());
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 加载所有的SD数据
|
||
* @param {*} baseURL
|
||
* @returns
|
||
*/
|
||
async LoadSDServiceData(baseURL = null) {
|
||
try {
|
||
// 加载大模型
|
||
let sd_model = await this.GetAllSDModel(baseURL);
|
||
// 往sd_model中添加一个默认的选项
|
||
sd_model.data.data.unshift({
|
||
title: "无",
|
||
name: "无",
|
||
description: "无",
|
||
})
|
||
// 加载Lora
|
||
let lora = await this.GetAllLoras(baseURL);
|
||
lora.data.data.unshift({
|
||
Key: "无",
|
||
name: "无",
|
||
description: "无",
|
||
})
|
||
// 加载采样器
|
||
let sampler = await this.GetAllSamplers(baseURL);
|
||
sampler.data.data.unshift({
|
||
name: "无",
|
||
description: "无",
|
||
})
|
||
|
||
if (!(sd_model.code & lora.code & sampler.code)) {
|
||
throw new Error("获取SD数据错误,请检查SD WEBUI链接!");
|
||
}
|
||
|
||
for (let i = 0; i < lora.data.data.length; i++) {
|
||
delete lora.data.data[i].metadata;
|
||
}
|
||
let data = {
|
||
sd_model: sd_model.data.data,
|
||
lora: lora.data.data,
|
||
sampler: sampler.data.data
|
||
}
|
||
// 处理当前获取的数据,保存到配置文件中
|
||
await SdSettingDefine.SavePropertyValue("sd_model", data.sd_model);
|
||
await SdSettingDefine.SavePropertyValue("lora", data.lora);
|
||
await SdSettingDefine.SavePropertyValue("sampler", data.sampler);
|
||
|
||
return successMessage(data);
|
||
|
||
} catch (error) {
|
||
return errorMessage("加载数据失败,错误信息如下:" +error.toString());
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 获取图片风格菜单
|
||
* @returns 返回图片风格菜单
|
||
*
|
||
* */
|
||
async GetImageStyleMenu() {
|
||
try {
|
||
let style = ImageStyleDefine.getImageStyleMenu();
|
||
return {
|
||
code: 1,
|
||
data: style
|
||
}
|
||
} catch (error) {
|
||
return {
|
||
code: 0,
|
||
message: "不可能出现错误"
|
||
}
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 获取指定的ID的风格信息,传入的是一个数组
|
||
* @param {*} value id集合
|
||
*/
|
||
async GetImageStyleInfomation(value) {
|
||
try {
|
||
if (value) {
|
||
value = JSON.parse(value);
|
||
} else {
|
||
value = [];
|
||
}
|
||
value = value ? value : [];
|
||
let style = ImageStyleDefine.getAllSubStyle();
|
||
let tmp = [];
|
||
for (let i = 0; i < value.length; i++) {
|
||
const element = value[i];
|
||
for (let j = 0; j < style.length; j++) {
|
||
const item = style[j];
|
||
if (item.id == element) {
|
||
tmp.push(item);
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
let newSubStyle = cloneDeep(tmp);
|
||
for (let i = 0; i < newSubStyle.length; i++) {
|
||
const element = newSubStyle[i];
|
||
element.image = path.join(define.image_path, "style/" + element.image);
|
||
}
|
||
return {
|
||
code: 1,
|
||
data: newSubStyle
|
||
}
|
||
|
||
} catch (error) {
|
||
return {
|
||
code: 0,
|
||
message: error.toString()
|
||
}
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 获取指定ID的分类的子风格信息
|
||
* @param {*} value ID
|
||
* @returns 返回ID对应的子风格的详细信息
|
||
*/
|
||
async GetStyleImageSubList(value) {
|
||
try {
|
||
let subStyle = ImageStyleDefine.getImagePathById(value);
|
||
let newSubStyle = cloneDeep(subStyle);
|
||
for (let i = 0; i < newSubStyle.length; i++) {
|
||
const element = newSubStyle[i];
|
||
element.image = path.join(define.image_path, "style/" + element.image);
|
||
}
|
||
return {
|
||
code: 1,
|
||
data: newSubStyle
|
||
}
|
||
|
||
} catch (error) {
|
||
return {
|
||
code: 0,
|
||
message: error.toString()
|
||
}
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 单张生图
|
||
* @param {*} value 0 生图的参数,1 图片的表示,用于保存 ,2 baseUrl
|
||
* @returns
|
||
*/
|
||
async txt2img(value) {
|
||
try {
|
||
value = JSON.parse(value);
|
||
let data = value[0];
|
||
let res = await this.sdApi.txt2img(data);
|
||
// 将base· 64的图片转换为图片
|
||
// 将当前的图片保存到指定的文件夹中,然后返回文件路径,并且可以复制到指定的文件,删除exif信息
|
||
let image_paths = [];
|
||
for (let i = 0; res.data.images && i < res.data.images.length; i++) {
|
||
const element = res.data.images[i];
|
||
let image_data = {
|
||
base64: element
|
||
}
|
||
// 将保存图片添加到队列中
|
||
let image_name = `sd_${Date.now()}_${uuidv4()}.png`;
|
||
let image_path = path.join(define.temp_sd_image, image_name);
|
||
image_path = await this.tools.saveBase64ToImage(element, image_path);
|
||
image_data["image_path"] = image_path;
|
||
image_paths.push(image_data);
|
||
}
|
||
return successMessage(image_paths);
|
||
} catch (error) {
|
||
return errorMessage("生图失败,错误信息如下:" + error.toString());
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 生成一次图片的方法。可以区分模式
|
||
* @param {图片名称 } image
|
||
* @param {任务队列信息} task_list 301198499
|
||
*/
|
||
async OneImageGeneration(image, task_list, seed = -1) {
|
||
let taskPath = path.join(this.global.config.project_path, "scripts/task_list.json")
|
||
try {
|
||
let imageJson = JSON.parse(await fspromises.readFile(image + '.json', 'utf-8'));
|
||
let sd_setting = JSON.parse(await fspromises.readFile(define.sd_setting, 'utf-8'));
|
||
let model = imageJson.model;
|
||
let image_json = JSON.parse(await fspromises.readFile(image + '.json', 'utf-8'));
|
||
let image_path = "";
|
||
let target_image_path = "";
|
||
if (image_json.name) {
|
||
image_path = path.join(this.global.config.project_path, `tmp/${task_list.out_folder}/tmp_${image_json.name}`)
|
||
target_image_path = path.join(this.global.config.project_path, `tmp/${task_list.out_folder}/${image_json.name}`)
|
||
} else {
|
||
image_path = image.replaceAll("input_crop", task_list.out_folder).split(".png")[0] + "_tmp.png";
|
||
target_image_path = image.replaceAll("input_crop", task_list.out_folder);
|
||
}
|
||
let image_styles = await ImageStyleDefine.getImageStyleStringByIds(task_list.image_style_list ? task_list.image_style_list : []);
|
||
let prompt = sd_setting.webui.prompt + image_styles;
|
||
// 拼接提示词
|
||
if (task_list.image_style != null) {
|
||
prompt += `((${task_list.image_style})), `;
|
||
}
|
||
if (task_list.lora != null) {
|
||
prompt += `${task_list.lora}, `;
|
||
}
|
||
prompt += imageJson.webui_config.prompt;
|
||
// 判断当前是不是有开修脸修手
|
||
let ADetailer = {
|
||
args: sd_setting.adetailer
|
||
};
|
||
if (model == "img2img") {
|
||
let web_api = this.global.config.webui_api_url + 'sdapi/v1/img2img'
|
||
let sd_config = imageJson["webui_config"];
|
||
sd_config.prompt = prompt;
|
||
sd_config.seed = seed;
|
||
let im = await fspromises.readFile(image, 'binary');
|
||
sd_config.init_images = [new Buffer.from(im, 'binary').toString('base64')];
|
||
if (imageJson.adetailer) {
|
||
let ta = {
|
||
ADetailer: ADetailer
|
||
}
|
||
sd_config.alwayson_scripts = ta;
|
||
}
|
||
sd_config.height = sd_setting.webui.height;
|
||
sd_config.width = sd_setting.webui.width;
|
||
const response = await axios.post(web_api, sd_config);
|
||
let info = JSON.parse(response.data.info);
|
||
if (seed == -1) {
|
||
seed = info.seed;
|
||
}
|
||
// 目前是单图出图
|
||
let images = response.data.images;
|
||
let imageData = Buffer.from(images[0].split(",", 1)[0], 'base64');
|
||
await sharp(imageData)
|
||
.toFile(image_path)
|
||
.then(async () => {
|
||
// console.log("图生图成功" + image_path);
|
||
await this.tools.deletePngAndDeleteExifData(image_path, target_image_path);
|
||
})
|
||
.catch(err => {
|
||
throw new Error(err);
|
||
});
|
||
return seed;
|
||
|
||
} else if (model == "txt2img") {
|
||
let body = {
|
||
"prompt": prompt,
|
||
"negative_prompt": imageJson.webui_config.negative_prompt,
|
||
"seed": seed,
|
||
"sampler_name": imageJson.webui_config.sampler_name,
|
||
// 提示词相关性
|
||
"cfg_scale": imageJson.webui_config.cfg_scale,
|
||
"width": sd_setting.webui.width,
|
||
"height": sd_setting.webui.height,
|
||
"batch_size": 1,
|
||
"n_iter": 1,
|
||
"steps": imageJson.webui_config.steps,
|
||
"save_images": false,
|
||
}
|
||
let web_api = this.global.config.webui_api_url + 'sdapi/v1/txt2img';
|
||
|
||
if (imageJson.adetailer) {
|
||
let ta = {
|
||
ADetailer: ADetailer
|
||
}
|
||
body.alwayson_scripts = ta;
|
||
}
|
||
const response = await axios.post(web_api, body);
|
||
let info = JSON.parse(response.data.info);
|
||
if (seed == -1) {
|
||
seed = info.seed;
|
||
}
|
||
// 目前是单图出图
|
||
let images = response.data.images;
|
||
let imageData = Buffer.from(images[0].split(",", 1)[0], 'base64');
|
||
await sharp(imageData)
|
||
.toFile(image_path)
|
||
.then(async () => {
|
||
// console.log("文生图成功" + image_path);
|
||
await this.tools.deletePngAndDeleteExifData(image_path, target_image_path);
|
||
})
|
||
.catch(err => {
|
||
// console.log(err)
|
||
throw new Error(err);
|
||
});
|
||
return seed;
|
||
} else {
|
||
throw new Error("SD 模式错误");
|
||
}
|
||
|
||
} catch (error) {
|
||
// 当前队列执行失败移除整个批次的任务
|
||
this.global.requestQuene.removeTask(task_list.out_folder, null)
|
||
this.global.fileQueue.enqueue(async () => {
|
||
// 记录失败状态
|
||
let task_list_json = JSON.parse(await fspromises.readFile(taskPath, 'utf-8'));
|
||
// 修改指定的列表的数据
|
||
task_list_json.task_list.map(a => {
|
||
if (a.id == task_list.id) {
|
||
a.status = "error";
|
||
a.errorMessage = error.toString();
|
||
}
|
||
})
|
||
// 写入
|
||
await fspromises.writeFile(taskPath, JSON.stringify(task_list_json));
|
||
this.global.newWindow[0].win.webContents.send(DEFINE_STRING.IMAGE_TASK_STATUS_REFRESH, {
|
||
out_folder: task_list.out_folder,
|
||
status: "error"
|
||
});
|
||
})
|
||
throw error;
|
||
}
|
||
}
|
||
} |