1. (聚合推文)MJ反推、SD反推 添加剪映分镜 2. (聚合推文)完善SD反推分类(界面同MJ反推,些许不一致) 3. (聚合推文)完成一键合成视频(单个和批量) 4. 修改聚合推文进入界面小说批次任务表格样式 5. (聚合推文)完善一键重置 6. (聚合推文)完善一键删除
381 lines
11 KiB
Python
381 lines
11 KiB
Python
# -*- coding: utf-8 -*-
|
||
import io
|
||
import os
|
||
import re
|
||
import subprocess
|
||
import sys
|
||
|
||
import pandas as pd
|
||
import numpy as np
|
||
from typing import Tuple, List, Dict
|
||
import cv2
|
||
from PIL import Image
|
||
from pathlib import Path
|
||
import public_tools
|
||
|
||
# sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
|
||
|
||
TAG_MODE_ACTION_COMMON = "action_common"
|
||
TAG_MODE_ACTION = "action"
|
||
|
||
if getattr(sys, "frozen", False):
|
||
cript_directory = os.path.dirname(sys.executable)
|
||
elif __file__:
|
||
cript_directory = os.path.dirname(__file__)
|
||
|
||
|
||
def make_square(img, target_size):
|
||
old_size = img.shape[:2]
|
||
desired_size = max(old_size)
|
||
desired_size = max(desired_size, target_size)
|
||
|
||
delta_w = desired_size - old_size[1]
|
||
delta_h = desired_size - old_size[0]
|
||
top, bottom = delta_h // 2, delta_h - (delta_h // 2)
|
||
left, right = delta_w // 2, delta_w - (delta_w // 2)
|
||
|
||
color = [255, 255, 255]
|
||
new_im = cv2.copyMakeBorder(
|
||
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
|
||
)
|
||
return new_im
|
||
|
||
|
||
def smart_resize(img, size):
|
||
# 假设图像已经经过 make_square 处理
|
||
if img.shape[0] > size:
|
||
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
|
||
elif img.shape[0] < size:
|
||
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
|
||
return img
|
||
|
||
|
||
# 调用gpu,很难成功,环境要求太高
|
||
use_cpu = False
|
||
|
||
if use_cpu:
|
||
tf_device_name = "/cpu:0"
|
||
else:
|
||
tf_device_name = "/gpu:0"
|
||
|
||
|
||
class Interrogator:
|
||
@staticmethod
|
||
def postprocess_tags(
|
||
tags: Dict[str, float],
|
||
threshold=0.35, # 阈值强度,默认0.35
|
||
additional_tags: List[str] = [],
|
||
exclude_tags: List[str] = [],
|
||
sort_by_alphabetical_order=False,
|
||
add_confident_as_weight=False,
|
||
replace_underscore=False,
|
||
replace_underscore_excludes: List[str] = [],
|
||
escape_tag=False,
|
||
) -> Dict[str, float]:
|
||
for t in additional_tags:
|
||
tags[t] = 1.0
|
||
|
||
tags = {
|
||
t: c
|
||
# 按标签名称或置信度排序
|
||
for t, c in sorted(
|
||
tags.items(),
|
||
key=lambda i: i[0 if sort_by_alphabetical_order else 1],
|
||
reverse=not sort_by_alphabetical_order,
|
||
)
|
||
# 筛选大于阈值的标签
|
||
if (c >= threshold and t not in exclude_tags)
|
||
}
|
||
|
||
new_tags = []
|
||
for tag in list(tags):
|
||
new_tag = tag
|
||
|
||
if replace_underscore and tag not in replace_underscore_excludes:
|
||
new_tag = new_tag.replace("_", " ")
|
||
|
||
"""
|
||
if escape_tag:
|
||
new_tag = tag_escape_pattern.sub(r'\\\1', new_tag)
|
||
"""
|
||
if add_confident_as_weight:
|
||
new_tag = f"({new_tag}:{tags[tag]})"
|
||
|
||
new_tags.append((new_tag, tags[tag]))
|
||
tags = dict(new_tags)
|
||
|
||
return tags
|
||
|
||
def __init__(self, name: str) -> None:
|
||
self.name = name
|
||
|
||
def load(self):
|
||
raise NotImplementedError()
|
||
|
||
def unload(self) -> bool:
|
||
unloaded = False
|
||
|
||
if hasattr(self, "model") and self.model is not None:
|
||
del self.model
|
||
unloaded = True
|
||
print(f"Unloaded {self.name}")
|
||
|
||
if hasattr(self, "tags"):
|
||
del self.tags
|
||
|
||
return unloaded
|
||
|
||
def interrogate(self, image: Image) -> Tuple[Dict[str, float], Dict[str, float]]:
|
||
raise NotImplementedError()
|
||
|
||
|
||
class WaifuDiffusionInterrogator(Interrogator):
|
||
def __init__(
|
||
self,
|
||
name: str,
|
||
model_path="model.onnx",
|
||
tags_path="selected_tags.csv",
|
||
**kwargs,
|
||
) -> None:
|
||
super().__init__(name)
|
||
self.model_path = model_path
|
||
self.tags_path = tags_path
|
||
self.kwargs = kwargs
|
||
|
||
def interrogate(
|
||
self, image: Image
|
||
) -> Tuple[
|
||
Dict[str, float], Dict[str, float] # rating confidents # tag confidents
|
||
]:
|
||
# init model
|
||
if not hasattr(self, "model") or self.model is None:
|
||
model_path = os.path.join(cript_directory, "model/tag/model.onnx")
|
||
tags_path = os.path.join(cript_directory, "model/tag/selected_tags.csv")
|
||
from onnxruntime import InferenceSession
|
||
|
||
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||
self.model = InferenceSession(str(model_path), providers=providers)
|
||
print(f"从{model_path} 读取 {self.name}模型")
|
||
self.tags = pd.read_csv(tags_path)
|
||
|
||
_, height, _, _ = self.model.get_inputs()[0].shape
|
||
|
||
# 透明转换成白色
|
||
image = image.convert("RGBA")
|
||
new_image = Image.new("RGBA", image.size, "WHITE")
|
||
new_image.paste(image, mask=image)
|
||
image = new_image.convert("RGB")
|
||
image = np.asarray(image)
|
||
|
||
# RGB格式转换
|
||
image = image[:, :, ::-1]
|
||
image = make_square(image, height)
|
||
|
||
image = smart_resize(image, height)
|
||
image = image.astype(np.float32)
|
||
image = np.expand_dims(image, 0)
|
||
|
||
# 验证一下模型
|
||
input_name = self.model.get_inputs()[0].name
|
||
label_name = self.model.get_outputs()[0].name
|
||
confidents = self.model.run([label_name], {input_name: image})[0]
|
||
|
||
tags = self.tags[:][["name"]]
|
||
tags["confidents"] = confidents[0]
|
||
|
||
# 前4项标签用于评定模型(一般、敏感、可疑、明确)
|
||
ratings = dict(tags[:4].values)
|
||
|
||
# 其他的是常规标签
|
||
tags = dict(tags[4:].values)
|
||
|
||
return ratings, tags
|
||
|
||
|
||
def getTags(model, img_path):
|
||
img = Image.open(img_path)
|
||
ratings, tags = model.interrogate(img)
|
||
img.close()
|
||
tags = model.postprocess_tags(tags)
|
||
return ",".join(tags.keys())
|
||
|
||
|
||
pattern_word_split = re.compile(r"\W+")
|
||
|
||
|
||
def is_tag_in_list(tag, rule_list):
|
||
words = pattern_word_split.split(tag)
|
||
for word in words:
|
||
if word in rule_list:
|
||
return True
|
||
return False
|
||
|
||
|
||
def filter_action(tag_actions: [], tags: []):
|
||
action_tags = []
|
||
other_tags = []
|
||
for tag in tags:
|
||
if public_tools.is_empty(tag):
|
||
continue
|
||
if is_tag_in_list(tag, tag_actions):
|
||
action_tags.append(tag)
|
||
else:
|
||
other_tags.append(tag)
|
||
return action_tags, other_tags
|
||
|
||
|
||
def getAssignTxt(txtPath):
|
||
if not os.path.exists(txtPath):
|
||
os.makedirs(txtPath)
|
||
|
||
# load model
|
||
model = WaifuDiffusionInterrogator(
|
||
"wd14-convnextv2-v2",
|
||
repo_id="SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
|
||
revision="v2.0",
|
||
)
|
||
|
||
frame_files = []
|
||
with open(txtPath, 'r', encoding='utf-8') as file:
|
||
for line in file:
|
||
frame_files.append(line.strip()) # 使用 strip() 去除每行的换行符和多余的空白
|
||
|
||
# 轮询开始输出
|
||
frame_files.sort()
|
||
|
||
for frame_file in frame_files:
|
||
txt = getTags(model, frame_file)
|
||
# tags = txt.split(",")
|
||
|
||
# save tag
|
||
txt_file = os.path.join(os.path.dirname(frame_file), f"{Path(frame_file).stem}.txt")
|
||
with open(txt_file, "w", encoding="utf-8") as tags:
|
||
tags.write(txt)
|
||
print(f"{frame_file} 提示词反推完成")
|
||
sys.stdout.flush()
|
||
|
||
|
||
|
||
def getAssignImage(imagePath):
|
||
if not os.path.exists(imagePath):
|
||
os.makedirs(imagePath)
|
||
|
||
# load model
|
||
model = WaifuDiffusionInterrogator(
|
||
"wd14-convnextv2-v2",
|
||
repo_id="SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
|
||
revision="v2.0",
|
||
)
|
||
|
||
txt = getTags(model, imagePath)
|
||
# tags = txt.split(",")
|
||
|
||
# save tag
|
||
txt_file = os.path.join(os.path.dirname(imagePath), f"{Path(imagePath).stem}.txt")
|
||
with open(txt_file, "w", encoding="utf-8") as tags:
|
||
tags.write(txt)
|
||
print(f"{imagePath} 提示词反推完成")
|
||
sys.stdout.flush()
|
||
|
||
|
||
def getAssignDir(imagePath):
|
||
if not os.path.exists(imagePath):
|
||
os.makedirs(imagePath)
|
||
|
||
# load model
|
||
model = WaifuDiffusionInterrogator(
|
||
"wd14-convnextv2-v2",
|
||
repo_id="SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
|
||
revision="v2.0",
|
||
)
|
||
|
||
# 轮询开始输出
|
||
frame_files = [f for f in os.listdir(imagePath) if f.endswith(".png")]
|
||
frame_files.sort()
|
||
|
||
for frame in frame_files:
|
||
frame_file = os.path.join(imagePath, frame)
|
||
txt = getTags(model, frame_file)
|
||
# tags = txt.split(",")
|
||
|
||
# save tag
|
||
txt_file = os.path.join(imagePath, f"{Path(frame_file).stem}.txt")
|
||
with open(txt_file, "w", encoding="utf-8") as tags:
|
||
tags.write(txt)
|
||
print(f"{frame} 提示词反推完成")
|
||
sys.stdout.flush()
|
||
|
||
|
||
def init(sd_setting, m, project_path):
|
||
try:
|
||
setting_json = public_tools.read_config(sd_setting, webui=False)
|
||
except Exception as e:
|
||
print("Error: read config", e)
|
||
exit(0)
|
||
|
||
setting_config = public_tools.SettingConfig(setting_json, project_path)
|
||
|
||
# workspace path config
|
||
workspace = setting_config.get_workspace_config()
|
||
|
||
if not os.path.exists(workspace.input_tag):
|
||
os.makedirs(workspace.input_tag)
|
||
|
||
# 可选功能
|
||
if setting_config.enable_tag():
|
||
|
||
# load model
|
||
model = WaifuDiffusionInterrogator(
|
||
"wd14-convnextv2-v2",
|
||
repo_id="SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
|
||
revision="v2.0",
|
||
)
|
||
|
||
tag_mode = setting_config.get_tag_mode()
|
||
tag_actions = setting_config.get_tag_actions()
|
||
|
||
# 轮询开始输出
|
||
frame_files = [
|
||
f for f in os.listdir(workspace.input_crop) if f.endswith(".png")
|
||
]
|
||
frame_files.sort()
|
||
|
||
common_tags = dict()
|
||
for frame in frame_files:
|
||
frame_file = os.path.join(workspace.input_crop, frame)
|
||
txt = getTags(model, frame_file)
|
||
tags = txt.split(",")
|
||
|
||
if tag_mode == TAG_MODE_ACTION:
|
||
actions, others = filter_action(tag_actions, tags)
|
||
# 替换 txt 为 action txt
|
||
txt = ",".join(actions) if len(actions) > 0 else ""
|
||
elif tag_mode == TAG_MODE_ACTION_COMMON:
|
||
actions, others = filter_action(tag_actions, tags)
|
||
txt = ",".join(actions) if len(actions) > 0 else ""
|
||
# tag 计数
|
||
for tag in others:
|
||
if tag in common_tags:
|
||
common_tags[tag] = common_tags[tag] + 1
|
||
else:
|
||
common_tags[tag] = 1
|
||
|
||
# save tag
|
||
txt_file = os.path.join(workspace.input_tag, f"{Path(frame_file).stem}.txt")
|
||
with open(txt_file, "w", encoding="utf-8") as tags:
|
||
tags.write(txt)
|
||
print(f"{frame} 提示词反推完成")
|
||
sys.stdout.flush()
|
||
|
||
# 过滤出现次数 > 30% 的 tags 作为 common tags
|
||
threshold_count = max(int(len(frame_files) * 0.3), 1)
|
||
common_tag_list = []
|
||
for tag in common_tags:
|
||
if common_tags[tag] > threshold_count:
|
||
common_tag_list.append(tag)
|
||
# save common tag
|
||
# txt_file = os.path.join(workspace.input_tag, f'common.txt')
|
||
# with open(txt_file, 'w', encoding='utf-8') as tags:
|
||
# txt = ",".join(common_tag_list) if len(common_tag_list) > 0 else ""
|
||
# tags.write(txt)
|