174 lines
4.9 KiB
Python
174 lines
4.9 KiB
Python
|
|
import io
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
from typing import Union
|
|||
|
|
import cv2
|
|||
|
|
import torch
|
|||
|
|
import numpy as np
|
|||
|
|
from PIL import Image
|
|||
|
|
|
|||
|
|
|
|||
|
|
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
|
|||
|
|
|
|||
|
|
# 判断sys.argv 的长度,如果小于2,说明没有传入参数,设置初始参数
|
|||
|
|
# if len(sys.argv) < 2:
|
|||
|
|
# sys.argv = [
|
|||
|
|
# "C:/Users/27698/Desktop/LAITool/resources/scripts/lama/lama_inpaint.exe",
|
|||
|
|
# "-l",
|
|||
|
|
# "C:\\Users\\27698\\Desktop\\测试\\mjTest\\data\\mask\\temp\\1717508661218.png",
|
|||
|
|
# "C:\\Users\\27698\\Desktop\\测试\\mjTest\\data\\mask\\mask_temp_1717508662659.png",
|
|||
|
|
# "C:\\Users\\27698\\Desktop\\测试\\mjTest\\data\\mask\\temp\\1717508564042.png",
|
|||
|
|
# ]
|
|||
|
|
print(sys.argv)
|
|||
|
|
|
|||
|
|
if getattr(sys, "frozen", False):
|
|||
|
|
cript_directory = os.path.dirname(sys.executable)
|
|||
|
|
elif __file__:
|
|||
|
|
cript_directory = os.path.dirname(__file__)
|
|||
|
|
|
|||
|
|
link_name = os.path.join(os.path.expanduser("~"), "big_lama.pt")
|
|||
|
|
cu_name = os.path.join(cript_directory, "model\\big-lama.pt")
|
|||
|
|
mode_pa = link_name
|
|||
|
|
|
|||
|
|
if len(sys.argv) < 2:
|
|||
|
|
# # 判断model_path是否存在,如果不存在,设置默认值
|
|||
|
|
if not os.path.exists(link_name):
|
|||
|
|
os.system(f'mklink "{link_name}" "{cu_name}"')
|
|||
|
|
print("Params: <runtime-config.json>")
|
|||
|
|
sys.exit(0)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_image(image):
|
|||
|
|
if isinstance(image, Image.Image):
|
|||
|
|
img = np.array(image)
|
|||
|
|
elif isinstance(image, np.ndarray):
|
|||
|
|
img = image.copy()
|
|||
|
|
else:
|
|||
|
|
raise Exception("Input image should be either PIL Image or numpy array!")
|
|||
|
|
|
|||
|
|
if img.ndim == 3:
|
|||
|
|
img = np.transpose(img, (2, 0, 1)) # chw
|
|||
|
|
elif img.ndim == 2:
|
|||
|
|
img = img[np.newaxis, ...]
|
|||
|
|
|
|||
|
|
assert img.ndim == 3
|
|||
|
|
|
|||
|
|
img = img.astype(np.float32) / 255
|
|||
|
|
return img
|
|||
|
|
|
|||
|
|
|
|||
|
|
def ceil_modulo(x, mod):
|
|||
|
|
if x % mod == 0:
|
|||
|
|
return x
|
|||
|
|
return (x // mod + 1) * mod
|
|||
|
|
|
|||
|
|
|
|||
|
|
def scale_image(img, factor, interpolation=cv2.INTER_AREA):
|
|||
|
|
if img.shape[0] == 1:
|
|||
|
|
img = img[0]
|
|||
|
|
else:
|
|||
|
|
img = np.transpose(img, (1, 2, 0))
|
|||
|
|
|
|||
|
|
img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation)
|
|||
|
|
|
|||
|
|
if img.ndim == 2:
|
|||
|
|
img = img[None, ...]
|
|||
|
|
else:
|
|||
|
|
img = np.transpose(img, (2, 0, 1))
|
|||
|
|
return img
|
|||
|
|
|
|||
|
|
|
|||
|
|
def pad_img_to_modulo(img, mod):
|
|||
|
|
channels, height, width = img.shape
|
|||
|
|
out_height = ceil_modulo(height, mod)
|
|||
|
|
out_width = ceil_modulo(width, mod)
|
|||
|
|
return np.pad(
|
|||
|
|
img,
|
|||
|
|
((0, 0), (0, out_height - height), (0, out_width - width)),
|
|||
|
|
mode="symmetric",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def prepare_img_and_mask(image, mask, device, pad_out_to_modulo=8, scale_factor=None):
|
|||
|
|
out_image = get_image(image)
|
|||
|
|
out_mask = get_image(mask)
|
|||
|
|
|
|||
|
|
if scale_factor is not None:
|
|||
|
|
out_image = scale_image(out_image, 1)
|
|||
|
|
out_mask = scale_image(out_mask, scale_factor, interpolation=cv2.INTER_NEAREST)
|
|||
|
|
|
|||
|
|
if pad_out_to_modulo is not None and pad_out_to_modulo > 1:
|
|||
|
|
out_image = pad_img_to_modulo(out_image, pad_out_to_modulo)
|
|||
|
|
out_mask = pad_img_to_modulo(out_mask, pad_out_to_modulo)
|
|||
|
|
|
|||
|
|
out_image = torch.from_numpy(out_image).unsqueeze(0).to(device)
|
|||
|
|
out_mask = torch.from_numpy(out_mask).unsqueeze(0).to(device)
|
|||
|
|
|
|||
|
|
out_mask = (out_mask > 0) * 1
|
|||
|
|
|
|||
|
|
return out_image, out_mask
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LamaInpaint:
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
device,
|
|||
|
|
model_path=None,
|
|||
|
|
) -> None:
|
|||
|
|
if model_path is None:
|
|||
|
|
model_path = os.path.join(cript_directory, "model\\big-lama.pt")
|
|||
|
|
|
|||
|
|
self.model = torch.jit.load(model_path, map_location=device)
|
|||
|
|
self.model.eval()
|
|||
|
|
self.model.to(device)
|
|||
|
|
self.device = device
|
|||
|
|
|
|||
|
|
def run(
|
|||
|
|
self,
|
|||
|
|
image: Union[Image.Image, np.ndarray],
|
|||
|
|
mask: Union[Image.Image, np.ndarray],
|
|||
|
|
):
|
|||
|
|
if isinstance(image, np.ndarray):
|
|||
|
|
orig_height, orig_width = image.shape[:2]
|
|||
|
|
else:
|
|||
|
|
orig_height, orig_width = np.array(image).shape[:2]
|
|||
|
|
|
|||
|
|
# image_width = image.shape[1]
|
|||
|
|
# mask_width = mask.shape[1]
|
|||
|
|
scale = image.width / mask.width
|
|||
|
|
image, mask = prepare_img_and_mask(image, mask, self.device, 8, scale)
|
|||
|
|
with torch.inference_mode():
|
|||
|
|
inpainted = self.model(image, mask)
|
|||
|
|
cur_res = inpainted[0].permute(1, 2, 0).detach().cpu().numpy()
|
|||
|
|
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
|||
|
|
cur_res = cur_res[:orig_height, :orig_width]
|
|||
|
|
return cur_res
|
|||
|
|
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
de = "cpu"
|
|||
|
|
if torch.cuda.is_available():
|
|||
|
|
de = "cuda"
|
|||
|
|
|
|||
|
|
lama = LamaInpaint(de, mode_pa)
|
|||
|
|
|
|||
|
|
image_path = sys.argv[2]
|
|||
|
|
mask_path = sys.argv[3]
|
|||
|
|
output_path = sys.argv[4]
|
|||
|
|
|
|||
|
|
# 若是没有传递mask_path,需要自己计算mask区域
|
|||
|
|
# 使用Image.open打开图片
|
|||
|
|
image = Image.open(image_path).convert("RGB")
|
|||
|
|
mask = Image.open(mask_path).convert("L")
|
|||
|
|
|
|||
|
|
res = lama.run(image, mask)
|
|||
|
|
# 将修复后的图片保存到本地
|
|||
|
|
img = Image.fromarray(res)
|
|||
|
|
# 使用 save 方法将图像保存到文件
|
|||
|
|
img.save(output_path)
|
|||
|
|
sys.exit(0)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(e)
|
|||
|
|
sys.exit(str(e))
|