LaiTool/resources/scripts/lama/lama_inpaint.py

174 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import io
import os
import sys
from typing import Union
import cv2
import torch
import numpy as np
from PIL import Image
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
# 判断sys.argv 的长度如果小于2说明没有传入参数设置初始参数
# if len(sys.argv) < 2:
# sys.argv = [
# "C:/Users/27698/Desktop/LAITool/resources/scripts/lama/lama_inpaint.exe",
# "-l",
# "C:\\Users\\27698\\Desktop\\测试\\mjTest\\data\\mask\\temp\\1717508661218.png",
# "C:\\Users\\27698\\Desktop\\测试\\mjTest\\data\\mask\\mask_temp_1717508662659.png",
# "C:\\Users\\27698\\Desktop\\测试\\mjTest\\data\\mask\\temp\\1717508564042.png",
# ]
print(sys.argv)
if getattr(sys, "frozen", False):
cript_directory = os.path.dirname(sys.executable)
elif __file__:
cript_directory = os.path.dirname(__file__)
link_name = os.path.join(os.path.expanduser("~"), "big_lama.pt")
cu_name = os.path.join(cript_directory, "model\\big-lama.pt")
mode_pa = link_name
if len(sys.argv) < 2:
# # 判断model_path是否存在如果不存在设置默认值
if not os.path.exists(link_name):
os.system(f'mklink "{link_name}" "{cu_name}"')
print("Params: <runtime-config.json>")
sys.exit(0)
def get_image(image):
if isinstance(image, Image.Image):
img = np.array(image)
elif isinstance(image, np.ndarray):
img = image.copy()
else:
raise Exception("Input image should be either PIL Image or numpy array!")
if img.ndim == 3:
img = np.transpose(img, (2, 0, 1)) # chw
elif img.ndim == 2:
img = img[np.newaxis, ...]
assert img.ndim == 3
img = img.astype(np.float32) / 255
return img
def ceil_modulo(x, mod):
if x % mod == 0:
return x
return (x // mod + 1) * mod
def scale_image(img, factor, interpolation=cv2.INTER_AREA):
if img.shape[0] == 1:
img = img[0]
else:
img = np.transpose(img, (1, 2, 0))
img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation)
if img.ndim == 2:
img = img[None, ...]
else:
img = np.transpose(img, (2, 0, 1))
return img
def pad_img_to_modulo(img, mod):
channels, height, width = img.shape
out_height = ceil_modulo(height, mod)
out_width = ceil_modulo(width, mod)
return np.pad(
img,
((0, 0), (0, out_height - height), (0, out_width - width)),
mode="symmetric",
)
def prepare_img_and_mask(image, mask, device, pad_out_to_modulo=8, scale_factor=None):
out_image = get_image(image)
out_mask = get_image(mask)
if scale_factor is not None:
out_image = scale_image(out_image, 1)
out_mask = scale_image(out_mask, scale_factor, interpolation=cv2.INTER_NEAREST)
if pad_out_to_modulo is not None and pad_out_to_modulo > 1:
out_image = pad_img_to_modulo(out_image, pad_out_to_modulo)
out_mask = pad_img_to_modulo(out_mask, pad_out_to_modulo)
out_image = torch.from_numpy(out_image).unsqueeze(0).to(device)
out_mask = torch.from_numpy(out_mask).unsqueeze(0).to(device)
out_mask = (out_mask > 0) * 1
return out_image, out_mask
class LamaInpaint:
def __init__(
self,
device,
model_path=None,
) -> None:
if model_path is None:
model_path = os.path.join(cript_directory, "model\\big-lama.pt")
self.model = torch.jit.load(model_path, map_location=device)
self.model.eval()
self.model.to(device)
self.device = device
def run(
self,
image: Union[Image.Image, np.ndarray],
mask: Union[Image.Image, np.ndarray],
):
if isinstance(image, np.ndarray):
orig_height, orig_width = image.shape[:2]
else:
orig_height, orig_width = np.array(image).shape[:2]
# image_width = image.shape[1]
# mask_width = mask.shape[1]
scale = image.width / mask.width
image, mask = prepare_img_and_mask(image, mask, self.device, 8, scale)
with torch.inference_mode():
inpainted = self.model(image, mask)
cur_res = inpainted[0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
cur_res = cur_res[:orig_height, :orig_width]
return cur_res
try:
de = "cpu"
if torch.cuda.is_available():
de = "cuda"
lama = LamaInpaint(de, mode_pa)
image_path = sys.argv[2]
mask_path = sys.argv[3]
output_path = sys.argv[4]
# 若是没有传递mask_path需要自己计算mask区域
# 使用Image.open打开图片
image = Image.open(image_path).convert("RGB")
mask = Image.open(mask_path).convert("L")
res = lama.run(image, mask)
# 将修复后的图片保存到本地
img = Image.fromarray(res)
# 使用 save 方法将图像保存到文件
img.save(output_path)
sys.exit(0)
except Exception as e:
print(e)
sys.exit(str(e))