import torch import torch.amp import torchvision.transforms.functional as TVF from PIL import Image from transformers import AutoTokenizer, LlavaForConditionalGeneration IMAGE_PATH = "C:/Users/27698/Desktop/node/12/00001.png" PROMPT = "Write a long descriptive caption for this image in a formal tone." MODEL_NAME = "fancyfeast/llama-joycaption-alpha-two-hf-llava" # Load JoyCaption # bfloat16 is the native dtype of the LLM used in JoyCaption (Llama 3.1) # device_map=0 loads the model into the first GPU tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) llava_model = LlavaForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype="bfloat16", device_map="cuda:0") llava_model.eval() with torch.no_grad(): # Load and preprocess image # Normally you would use the Processor here, but the image module's processor # has some buggy behavior and a simple resize in Pillow yields higher quality results image = Image.open(IMAGE_PATH) if image.size != (384, 384): image = image.resize((384, 384), Image.LANCZOS) image = image.convert("RGB") pixel_values = TVF.pil_to_tensor(image) # Normalize the image pixel_values = pixel_values / 255.0 pixel_values = TVF.normalize(pixel_values, [0.5], [0.5]) pixel_values = pixel_values.to(torch.bfloat16).unsqueeze(0) # Build the conversation convo = [ { "role": "system", "content": "You are a helpful image captioner.", }, { "role": "user", "content": PROMPT, }, ] # Format the conversation convo_string = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=True) # Tokenize the conversation convo_tokens = tokenizer.encode(convo_string, add_special_tokens=False, truncation=False) # Repeat the image tokens input_tokens = [] for token in convo_tokens: if token == llava_model.config.image_token_index: input_tokens.extend([llava_model.config.image_token_index] * llava_model.config.image_seq_length) else: input_tokens.append(token) input_ids = torch.tensor(input_tokens, dtype=torch.long).unsqueeze(0) attention_mask = torch.ones_like(input_ids) # Generate the caption generate_ids = llava_model.generate(input_ids=input_ids.to('cuda'), pixel_values=pixel_values.to('cuda'), attention_mask=attention_mask.to('cuda'), max_new_tokens=300, do_sample=True, suppress_tokens=None, use_cache=True)[0] # Trim off the prompt generate_ids = generate_ids[input_ids.shape[1]:] # Decode the caption caption = tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) caption = caption.strip() print(caption)