# -*- coding: utf-8 -*- import io import os import re import subprocess import sys import pandas as pd import numpy as np from typing import Tuple, List, Dict import cv2 from PIL import Image from pathlib import Path import public_tools # sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") TAG_MODE_ACTION_COMMON = "action_common" TAG_MODE_ACTION = "action" if getattr(sys, "frozen", False): cript_directory = os.path.dirname(sys.executable) elif __file__: cript_directory = os.path.dirname(__file__) def make_square(img, target_size): old_size = img.shape[:2] desired_size = max(old_size) desired_size = max(desired_size, target_size) delta_w = desired_size - old_size[1] delta_h = desired_size - old_size[0] top, bottom = delta_h // 2, delta_h - (delta_h // 2) left, right = delta_w // 2, delta_w - (delta_w // 2) color = [255, 255, 255] new_im = cv2.copyMakeBorder( img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color ) return new_im def smart_resize(img, size): # 假设图像已经经过 make_square 处理 if img.shape[0] > size: img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) elif img.shape[0] < size: img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC) return img # 调用gpu,很难成功,环境要求太高 use_cpu = False if use_cpu: tf_device_name = "/cpu:0" else: tf_device_name = "/gpu:0" class Interrogator: @staticmethod def postprocess_tags( tags: Dict[str, float], threshold=0.35, # 阈值强度,默认0.35 additional_tags: List[str] = [], exclude_tags: List[str] = [], sort_by_alphabetical_order=False, add_confident_as_weight=False, replace_underscore=False, replace_underscore_excludes: List[str] = [], escape_tag=False, ) -> Dict[str, float]: for t in additional_tags: tags[t] = 1.0 tags = { t: c # 按标签名称或置信度排序 for t, c in sorted( tags.items(), key=lambda i: i[0 if sort_by_alphabetical_order else 1], reverse=not sort_by_alphabetical_order, ) # 筛选大于阈值的标签 if (c >= threshold and t not in exclude_tags) } new_tags = [] for tag in list(tags): new_tag = tag if replace_underscore and tag not in replace_underscore_excludes: new_tag = new_tag.replace("_", " ") """ if escape_tag: new_tag = tag_escape_pattern.sub(r'\\\1', new_tag) """ if add_confident_as_weight: new_tag = f"({new_tag}:{tags[tag]})" new_tags.append((new_tag, tags[tag])) tags = dict(new_tags) return tags def __init__(self, name: str) -> None: self.name = name def load(self): raise NotImplementedError() def unload(self) -> bool: unloaded = False if hasattr(self, "model") and self.model is not None: del self.model unloaded = True print(f"Unloaded {self.name}") if hasattr(self, "tags"): del self.tags return unloaded def interrogate(self, image: Image) -> Tuple[Dict[str, float], Dict[str, float]]: raise NotImplementedError() class WaifuDiffusionInterrogator(Interrogator): def __init__( self, name: str, model_path="model.onnx", tags_path="selected_tags.csv", **kwargs, ) -> None: super().__init__(name) self.model_path = model_path self.tags_path = tags_path self.kwargs = kwargs def interrogate( self, image: Image ) -> Tuple[ Dict[str, float], Dict[str, float] # rating confidents # tag confidents ]: # init model if not hasattr(self, "model") or self.model is None: model_path = os.path.join(cript_directory, "model/tag/model.onnx") tags_path = os.path.join(cript_directory, "model/tag/selected_tags.csv") from onnxruntime import InferenceSession providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] self.model = InferenceSession(str(model_path), providers=providers) print(f"从{model_path} 读取 {self.name}模型") self.tags = pd.read_csv(tags_path) _, height, _, _ = self.model.get_inputs()[0].shape # 透明转换成白色 image = image.convert("RGBA") new_image = Image.new("RGBA", image.size, "WHITE") new_image.paste(image, mask=image) image = new_image.convert("RGB") image = np.asarray(image) # RGB格式转换 image = image[:, :, ::-1] image = make_square(image, height) image = smart_resize(image, height) image = image.astype(np.float32) image = np.expand_dims(image, 0) # 验证一下模型 input_name = self.model.get_inputs()[0].name label_name = self.model.get_outputs()[0].name confidents = self.model.run([label_name], {input_name: image})[0] tags = self.tags[:][["name"]] tags["confidents"] = confidents[0] # 前4项标签用于评定模型(一般、敏感、可疑、明确) ratings = dict(tags[:4].values) # 其他的是常规标签 tags = dict(tags[4:].values) return ratings, tags def getTags(model, img_path): img = Image.open(img_path) ratings, tags = model.interrogate(img) img.close() tags = model.postprocess_tags(tags) return ",".join(tags.keys()) pattern_word_split = re.compile(r"\W+") def is_tag_in_list(tag, rule_list): words = pattern_word_split.split(tag) for word in words: if word in rule_list: return True return False def filter_action(tag_actions: [], tags: []): action_tags = [] other_tags = [] for tag in tags: if public_tools.is_empty(tag): continue if is_tag_in_list(tag, tag_actions): action_tags.append(tag) else: other_tags.append(tag) return action_tags, other_tags def getAssignTxt(txtPath): if not os.path.exists(txtPath): os.makedirs(txtPath) # load model model = WaifuDiffusionInterrogator( "wd14-convnextv2-v2", repo_id="SmilingWolf/wd-v1-4-convnextv2-tagger-v2", revision="v2.0", ) frame_files = [] with open(txtPath, 'r', encoding='utf-8') as file: for line in file: frame_files.append(line.strip()) # 使用 strip() 去除每行的换行符和多余的空白 # 轮询开始输出 frame_files.sort() for frame_file in frame_files: txt = getTags(model, frame_file) # tags = txt.split(",") # save tag txt_file = os.path.join(os.path.dirname(frame_file), f"{Path(frame_file).stem}.txt") with open(txt_file, "w", encoding="utf-8") as tags: tags.write(txt) print(f"{frame_file} 提示词反推完成") sys.stdout.flush() def getAssignImage(imagePath): if not os.path.exists(imagePath): os.makedirs(imagePath) # load model model = WaifuDiffusionInterrogator( "wd14-convnextv2-v2", repo_id="SmilingWolf/wd-v1-4-convnextv2-tagger-v2", revision="v2.0", ) txt = getTags(model, imagePath) # tags = txt.split(",") # save tag txt_file = os.path.join(os.path.dirname(imagePath), f"{Path(imagePath).stem}.txt") with open(txt_file, "w", encoding="utf-8") as tags: tags.write(txt) print(f"{imagePath} 提示词反推完成") sys.stdout.flush() def getAssignDir(imagePath): if not os.path.exists(imagePath): os.makedirs(imagePath) # load model model = WaifuDiffusionInterrogator( "wd14-convnextv2-v2", repo_id="SmilingWolf/wd-v1-4-convnextv2-tagger-v2", revision="v2.0", ) # 轮询开始输出 frame_files = [f for f in os.listdir(imagePath) if f.endswith(".png")] frame_files.sort() for frame in frame_files: frame_file = os.path.join(imagePath, frame) txt = getTags(model, frame_file) # tags = txt.split(",") # save tag txt_file = os.path.join(imagePath, f"{Path(frame_file).stem}.txt") with open(txt_file, "w", encoding="utf-8") as tags: tags.write(txt) print(f"{frame} 提示词反推完成") sys.stdout.flush() def init(sd_setting, m, project_path): try: setting_json = public_tools.read_config(sd_setting, webui=False) except Exception as e: print("Error: read config", e) exit(0) setting_config = public_tools.SettingConfig(setting_json, project_path) # workspace path config workspace = setting_config.get_workspace_config() if not os.path.exists(workspace.input_tag): os.makedirs(workspace.input_tag) # 可选功能 if setting_config.enable_tag(): # load model model = WaifuDiffusionInterrogator( "wd14-convnextv2-v2", repo_id="SmilingWolf/wd-v1-4-convnextv2-tagger-v2", revision="v2.0", ) tag_mode = setting_config.get_tag_mode() tag_actions = setting_config.get_tag_actions() # 轮询开始输出 frame_files = [ f for f in os.listdir(workspace.input_crop) if f.endswith(".png") ] frame_files.sort() common_tags = dict() for frame in frame_files: frame_file = os.path.join(workspace.input_crop, frame) txt = getTags(model, frame_file) tags = txt.split(",") if tag_mode == TAG_MODE_ACTION: actions, others = filter_action(tag_actions, tags) # 替换 txt 为 action txt txt = ",".join(actions) if len(actions) > 0 else "" elif tag_mode == TAG_MODE_ACTION_COMMON: actions, others = filter_action(tag_actions, tags) txt = ",".join(actions) if len(actions) > 0 else "" # tag 计数 for tag in others: if tag in common_tags: common_tags[tag] = common_tags[tag] + 1 else: common_tags[tag] = 1 # save tag txt_file = os.path.join(workspace.input_tag, f"{Path(frame_file).stem}.txt") with open(txt_file, "w", encoding="utf-8") as tags: tags.write(txt) print(f"{frame} 提示词反推完成") sys.stdout.flush() # 过滤出现次数 > 30% 的 tags 作为 common tags threshold_count = max(int(len(frame_files) * 0.3), 1) common_tag_list = [] for tag in common_tags: if common_tags[tag] > threshold_count: common_tag_list.append(tag) # save common tag # txt_file = os.path.join(workspace.input_tag, f'common.txt') # with open(txt_file, 'w', encoding='utf-8') as tags: # txt = ",".join(common_tag_list) if len(common_tag_list) > 0 else "" # tags.write(txt)