470 lines
16 KiB
JavaScript
Raw Normal View History

2024-05-15 12:57:15 +08:00
import axios from "axios";
import path from "path";
import { DEFINE_STRING } from "../../define/define_string";
import { define } from "../../define/define";
let fspromises = require("fs").promises;
import { gptDefine } from "../../define/gptDefine";
2024-06-01 15:08:22 +08:00
import { apiUrl } from "../../define/api/apiUrlDefine";
2024-08-03 12:46:12 +08:00
import { successMessage } from "../Public/generalTools";
2024-05-15 12:57:15 +08:00
export class GPT {
constructor(global) {
this.global = global;
}
/**
* 输出测试案例
* @param {*} value 传入的值整个数据
*/
async GenerateGptExampleOut(value) {
try {
let data = JSON.parse(value);
let message = gptDefine.CustomizeGptPrompt(data);
let content = await this.FetchGpt(message);
console.log(content);
return {
code: 1,
data: content
}
} catch (error) {
return {
code: 0,
message: error.toString()
}
}
}
/**
* GPT推理提示词的方法
* @param {*} element 当前推理的句子
* @param {*} gpt_count 设置的GPT上下文理解数量
* @param {*} auto_analyze_character 当前的角色数据
* @returns
*/
async GPTPromptGenerate(element, gpt_count, auto_analyze_character) {
try {
// 获取当前的推理模式
let gpt_auto_inference = this.global.config.gpt_auto_inference;
let message = null;
if (gpt_auto_inference == "customize") {
// 自定义模式
// 获取当前自定义的推理提示词
let customize_gpt_prompt = (await gptDefine.getGptDataByTypeAndProperty("dynamic", "customize_gpt_prompt", [])).data;
let index = customize_gpt_prompt.findIndex(item => item.id == this.global.config.customize_gpt_prompt);
if (this.global.config.customize_gpt_prompt && index < 0) {
throw new Error("自定义推理默认要选择对应的自定义推理词");
}
message = gptDefine.CustomizeGptPrompt(customize_gpt_prompt[index], element.after_gpt);
message.push({
"role": "user",
"content": element.after_gpt
})
} else {
// 内置模式
// 获取
let prefix_word = "";
// 拼接一个word
let i = element.no - 1;
if (i <= gpt_count) {
prefix_word = this.all_data.filter((item, index) => index < i).map(item => item.after_gpt).join('\r\n');
} else if (i > gpt_count) {
prefix_word = this.all_data.filter((item, index) => i - index <= gpt_count && i - index > 0).map(item => item.after_gpt).join('\r\n');
}
let suffix_word = "";
let o_i = this.all_data.length - i;
if (o_i <= gpt_count) {
suffix_word = this.all_data.filter((item, index) => index > i).map(item => item.after_gpt).join('\r\n');
} else if (o_i > gpt_count) {
suffix_word = this.all_data.filter((item, index) => index - i <= gpt_count && index - i > 0).map(item => item.after_gpt).join('\r\n');
}
let word = `${prefix_word}\r\n${element.after_gpt}\r\n${suffix_word}`;
let single_word = element.after_gpt;
// 判断当前的格式
if (["superSinglePrompt", 'onlyPromptMJ'].includes(this.global.config.gpt_auto_inference)) {
// 有返回案例的
message = gptDefine.GetExamplePromptMessage(this.global.config.gpt_auto_inference);
// 加当前提问的
message.push({
"role": "user",
"content": single_word
})
} else {
// 直接返回,没有案例的
message = [
{
"role": "system",
"content": gptDefine.getSystemContentByType(this.global.config.gpt_auto_inference, {
textContent: word,
characterContent: auto_analyze_character
})
},
{
"role": "user",
"content": gptDefine.getUserContentByType(this.global.config.gpt_auto_inference, {
textContent: single_word,
wordCount: this.global.config.gpt_model && this.global.config.gpt_model.includes("gpt-4") ? '20' : '40'
})
}
]
}
}
let res = await this.FetchGpt(message);
return res;
} catch (error) {
throw error;
}
}
/**
* 将推理提示词添加到任务
*/
async GPTPrompt(data) {
try {
console.log(data)
let value = JSON.parse(data[0]);
let show_global_message = data[1];
this.all_data = JSON.parse(data[2]);
// 获取data中的after_gpt然后使用换行符拼接成一个字符串
// let word = value.map(item => item.after_gpt).join('\r\n');
let batch = DEFINE_STRING.QUEUE_BATCH.SD_ORIGINAL_GPT_PROMPT;
// 获取人物角色数据
let config_json = JSON.parse(await fspromises.readFile(path.join(this.global.config.project_path, "scripts/config.json"), 'utf-8'));
let auto_analyze_character = config_json.auto_analyze_character;
let gpt_count = this.global.config.gpt_count ? this.global.config.gpt_count : 10;
for (let i = 0; i < value.length; i++) {
const element = value[i];
this.global.requestQuene.enqueue(async () => {
try {
let content = await this.GPTPromptGenerate(element, gpt_count, auto_analyze_character);
if (content) {
content = content.replace(/\)\s*\(/g, ", ").replace(/^\(/, "").replace(/\)$/, "")
}
// 获取对应的数据,将数据返回前端事件
this.global.newWindow[0].win.webContents.send(DEFINE_STRING.GPT_GENERATE_PROMPT_RETURN, {
id: element.id,
gpt_prompt: content
})
this.global.fileQueue.enqueue(async () => {
// 将推理出来的数据写入执行的文件中
let json_config = JSON.parse(await fspromises.readFile(element.prompt_json, 'utf-8'));
// 写入
json_config.gpt_prompt = content;
await fspromises.writeFile(element.prompt_json, JSON.stringify(json_config));
})
} catch (error) {
throw error;
}
}, `${batch}_${element.id}`, batch);
}
this.global.requestQuene.setBatchCompletionCallback(batch, (failedTasks) => {
if (failedTasks.length > 0) {
let message = `
推理提示词任务都已完成
但是以下任务执行失败
`
failedTasks.forEach(({ taskId, error }) => {
message += `${taskId}-, \n 错误信息: ${error}` + '\n';
});
this.global.newWindow[0].win.webContents.send(DEFINE_STRING.SHOW_MESSAGE_DIALOG, {
code: 0,
message: message
})
} else {
if (show_global_message) {
this.global.newWindow[0].win.webContents.send(DEFINE_STRING.SHOW_MESSAGE_DIALOG, {
code: 1,
message: "所有推理任务完成"
})
}
}
});
return {
code: 1,
}
} catch (error) {
return {
code: 0,
message: error.toString()
}
}
}
/**
* 修改请求的参数
* @param {*} data
* @returns
*/
ModifyData(gpt_url, data) {
let res = data;
if (gpt_url.includes("dashscope.aliyuncs.com")) {
res = {
"model": data.model,
"input": {
"messages": data.messages,
},
"parameters": {
"result_format": "message"
}
}
}
return res;
}
/**
* 获取返回的内容
* @param {*} gpt_url GPT请求的内容
* @param {*} res 请求返回的数据
* @returns
*/
GetResponseContent(gpt_url, res) {
let content = "";
if (gpt_url.includes("dashscope.aliyuncs.com")) {
content = res.data.output.choices[0].message.content;
} else {
content = res.data.choices[0].message.content;
}
return content;
}
/**
* 发送GPT请求
* @param {*} message 请求的信息
* @param {*} gpt_url gpt的url默认在global中取
* @param {*} gpt_key gpt的key默认在global中取
* @param {*} gpt_model gpt的model默认在global中取
* @returns
*/
async FetchGpt(message,
gpt_url = this.global.config.gpt_business,
gpt_key = this.global.config.gpt_key,
gpt_model = this.global.config.gpt_model) {
try {
2024-06-01 15:08:22 +08:00
// 还有自定义的
let all_options = (await this.GetGPTBusinessOption("all", (value) => value.gpt_url)).data;
// 判断gpt_business 是不是一个http开头的
if (!gpt_url.includes("http")) {
// 获取对应Id的gpt_url
let index = all_options.findIndex(item => item.value == gpt_url && item.gpt_url);
if (index < 0) {
throw new Error("获取GPT的服务商配置失败");
}
gpt_url = all_options[index].gpt_url;
}
2024-05-15 12:57:15 +08:00
let data = {
"model": gpt_model,
"messages": message
};
data = this.ModifyData(gpt_url, data);
let config = {
method: 'post',
maxBodyLength: Infinity,
url: gpt_url,
headers: {
'Authorization': `Bearer ${gpt_key}`,
2024-08-03 12:46:12 +08:00
'Content-Type': 'application/json'
2024-05-15 12:57:15 +08:00
},
data: JSON.stringify(data)
};
let res = await axios.request(config);
let content = this.GetResponseContent(gpt_url, res);
return content;
} catch (error) {
throw error;
}
}
/**
* 自动分析文本返回人物场景角色
* @param {要分析的文本} value
* @returns
*/
async AutoAnalyzeCharacter(value) {
try {
let message = [
{
"role": "system",
"content": gptDefine.getSystemContentByType("character", { textContent: value })
},
{
"role": "user",
"content": gptDefine.getUserContentByType("character", {})
}
]
let content = await this.FetchGpt(message);
return {
code: 1,
data: content
}
} catch (error) {
return {
code: 0,
message: error.toString()
}
}
}
/**
* 获取GPT的服务商配置默认的和自定义的
* @returns
*/
2024-06-01 15:08:22 +08:00
async GetGPTBusinessOption(value, callback = null) {
let res = await gptDefine.getGptDataByTypeAndProperty(value, "gpt_options", []);
if (res.code == 0) {
return res;
} else {
if (callback) {
callback(res.data)
}
return successMessage(res.data)
}
2024-05-15 12:57:15 +08:00
}
/**
* 获取GPT的模型配置默认的和自定义的
* @returns
*/
async GetGPTModelOption(value) {
return await gptDefine.getGptDataByTypeAndProperty(value, "gpt_model_options", []);
}
/**
* 获取GPT的自动推理模式配置默认的和自定义的
* @returns
*/
async GetGptAutoInferenceOptions(value) {
return await gptDefine.getGptDataByTypeAndProperty(value, "gpt_auto_inference", []);
}
/**
* 获取GPT的自动推理模式配置默认的和自定义的
* @returns
*/
async GetCustomizeGptPrompt(value) {
return await gptDefine.getGptDataByTypeAndProperty(value, "customize_gpt_prompt", []);
}
/**
* 保存自定义的GPT服务商配置
* @param {*} value 配置信息 0 : 传入的数据 1: 属性名称
* @returns
*/
async SaveDynamicGPTOption(value) {
try {
let res = await gptDefine.saveDynamicGPTOption(value);
return {
code: 1,
}
} catch (error) {
return {
code: 0,
message: error.toString()
}
}
}
/**
* 删除指定Id的自定义GPT服务商配置
* @param {*} value id 0 : 删除的数据 1: 属性名称
* @returns
*/
async DeleteDynamicGPTOption(value) {
try {
let res = await gptDefine.deleteDynamicGPTOption(value);
return {
code: 1,
data: res
}
} catch (error) {
return {
code: 0,
message: error.toString()
}
}
}
/**
*
* @param {Stirng} value 传入的GPT网址和key判断是不是可以链接成功
*/
async TestGPTConnection(value) {
try {
value = JSON.parse(value);
let message = [
{
"role": "system",
"content": "你好"
},
{
"role": "user",
"content": "你好"
}
];
let content = await this.FetchGpt(message, value.gpt_business, value.gpt_key, value.gpt_model);
return {
code: 1,
}
} catch (error) {
return {
code: 0,
message: error.toString()
}
}
}
/**
* 单句洗稿
* @param {文案参数} value
*/
async AIModifyOneWord(value) {
try {
let message = [
{
"role": "system",
"content": "You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible."
},
{
"role": "user",
"content": `请您扮演一个抖音网文改写专家,我会给你一句文案,请你不要改变文案的结构,不改变原来的意思,仅对文案进行同义转换改写,不要有奇怪的写法,说法通俗一点,不要其他的标点符号,每一小句话之间都是以句号连接,参考抖音网文解说,以下是文案:${value[1]}`
}
]
let content = await this.FetchGpt(message);
return {
code: 1,
data: { no: value[0], content: content }
}
} catch (error) {
return {
code: 0,
message: error.toString()
}
}
}
}