lq1405 22cfe65dde 3.1.9 2024.10.28
1. (聚合推文)MJ反推、SD反推 添加剪映分镜
2. (聚合推文)完善SD反推分类(界面同MJ反推,些许不一致)
3. (聚合推文)完成一键合成视频(单个和批量)
4. 修改聚合推文进入界面小说批次任务表格样式
5. (聚合推文)完善一键重置
6. (聚合推文)完善一键删除
2024-10-28 18:38:11 +08:00

75 lines
2.5 KiB
Python

import torch
import torch.amp
import torchvision.transforms.functional as TVF
from PIL import Image
from transformers import AutoTokenizer, LlavaForConditionalGeneration
IMAGE_PATH = "C:/Users/27698/Desktop/node/12/00001.png"
PROMPT = "Write a long descriptive caption for this image in a formal tone."
MODEL_NAME = "fancyfeast/llama-joycaption-alpha-two-hf-llava"
# Load JoyCaption
# bfloat16 is the native dtype of the LLM used in JoyCaption (Llama 3.1)
# device_map=0 loads the model into the first GPU
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
llava_model = LlavaForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype="bfloat16", device_map="cuda:0")
llava_model.eval()
with torch.no_grad():
# Load and preprocess image
# Normally you would use the Processor here, but the image module's processor
# has some buggy behavior and a simple resize in Pillow yields higher quality results
image = Image.open(IMAGE_PATH)
if image.size != (384, 384):
image = image.resize((384, 384), Image.LANCZOS)
image = image.convert("RGB")
pixel_values = TVF.pil_to_tensor(image)
# Normalize the image
pixel_values = pixel_values / 255.0
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
pixel_values = pixel_values.to(torch.bfloat16).unsqueeze(0)
# Build the conversation
convo = [
{
"role": "system",
"content": "You are a helpful image captioner.",
},
{
"role": "user",
"content": PROMPT,
},
]
# Format the conversation
convo_string = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
# Tokenize the conversation
convo_tokens = tokenizer.encode(convo_string, add_special_tokens=False, truncation=False)
# Repeat the image tokens
input_tokens = []
for token in convo_tokens:
if token == llava_model.config.image_token_index:
input_tokens.extend([llava_model.config.image_token_index] * llava_model.config.image_seq_length)
else:
input_tokens.append(token)
input_ids = torch.tensor(input_tokens, dtype=torch.long).unsqueeze(0)
attention_mask = torch.ones_like(input_ids)
# Generate the caption
generate_ids = llava_model.generate(input_ids=input_ids.to('cuda'), pixel_values=pixel_values.to('cuda'), attention_mask=attention_mask.to('cuda'), max_new_tokens=300, do_sample=True, suppress_tokens=None, use_cache=True)[0]
# Trim off the prompt
generate_ids = generate_ids[input_ids.shape[1]:]
# Decode the caption
caption = tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
caption = caption.strip()
print(caption)