vlm-dataset-captioner 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vlm_dataset_captioner/vlm_caption.py +219 -0
- vlm_dataset_captioner/vlm_caption_cli.py +83 -0
- vlm_dataset_captioner-0.0.1.dist-info/METADATA +77 -0
- vlm_dataset_captioner-0.0.1.dist-info/RECORD +6 -0
- vlm_dataset_captioner-0.0.1.dist-info/WHEEL +4 -0
- vlm_dataset_captioner-0.0.1.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
|
|
4
|
+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
|
5
|
+
from qwen_vl_utils import process_vision_info
|
|
6
|
+
|
|
7
|
+
DEFAULT_MODEL = "Qwen/Qwen2.5-VL-32B-Instruct"
|
|
8
|
+
IMAGE_FILE_EXTENSIONS = (
|
|
9
|
+
".png",
|
|
10
|
+
".jpg",
|
|
11
|
+
".jpeg",
|
|
12
|
+
".gif",
|
|
13
|
+
".bmp",
|
|
14
|
+
".tif",
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def init_model(model_name=None):
|
|
19
|
+
if model_name is None:
|
|
20
|
+
model_name = DEFAULT_MODEL
|
|
21
|
+
print(f"INFO: No model name provided. Initializing default model {model_name}.")
|
|
22
|
+
|
|
23
|
+
print(f"INFO: Initializing model {model_name}.", flush=True)
|
|
24
|
+
|
|
25
|
+
# Load the model on the available device(s)
|
|
26
|
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
|
27
|
+
model_name, torch_dtype="auto", device_map="auto"
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
print(f"INFO: Model {model_name} loaded successfully.", flush=True)
|
|
31
|
+
print(f"INFO: Loading processor for model {model_name}.", flush=True)
|
|
32
|
+
|
|
33
|
+
# Load the default processor
|
|
34
|
+
processor = AutoProcessor.from_pretrained(model_name)
|
|
35
|
+
|
|
36
|
+
print(f"INFO: Processor for model {model_name} loaded successfully.", flush=True)
|
|
37
|
+
|
|
38
|
+
return model, processor
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_prompt_for_directory(directory_path):
|
|
42
|
+
prompt = ""
|
|
43
|
+
prompt_file_path = os.path.join(directory_path, "prompt.txt")
|
|
44
|
+
try:
|
|
45
|
+
with open(prompt_file_path) as prompt_file:
|
|
46
|
+
prompt = prompt_file.read()
|
|
47
|
+
except FileNotFoundError:
|
|
48
|
+
print(
|
|
49
|
+
f"WARN: Prompt file not found for directory {prompt_file_path}. Using default prompt.",
|
|
50
|
+
flush=True,
|
|
51
|
+
)
|
|
52
|
+
prompt = "Describe the image in detail."
|
|
53
|
+
|
|
54
|
+
print(f"INFO: Using prompt: '{prompt}'", flush=True)
|
|
55
|
+
|
|
56
|
+
return prompt
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def is_image_file(filename):
|
|
60
|
+
return filename.lower().endswith(IMAGE_FILE_EXTENSIONS)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def is_image_directory(directory_path):
|
|
64
|
+
return any(is_image_file(filename) for filename in os.listdir(directory_path))
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def get_messages(prompt, image):
|
|
68
|
+
return [
|
|
69
|
+
{
|
|
70
|
+
"role": "user",
|
|
71
|
+
"content": [
|
|
72
|
+
{"type": "image", "image": image},
|
|
73
|
+
{"type": "text", "text": prompt},
|
|
74
|
+
],
|
|
75
|
+
}
|
|
76
|
+
]
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def contains_chinese(text_string):
|
|
80
|
+
# The Unicode range for common CJK Unified Ideographs (Han characters)
|
|
81
|
+
# is typically from U+4E00 to U+9FFF.
|
|
82
|
+
chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]")
|
|
83
|
+
return bool(chinese_char_pattern.search(text_string))
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def caption_image(prompt, image, model, processor, max_new_tokens=None):
|
|
87
|
+
messages = get_messages(prompt, image)
|
|
88
|
+
|
|
89
|
+
# Prepare inputs for the model
|
|
90
|
+
text = processor.apply_chat_template(
|
|
91
|
+
messages, tokenize=False, add_generation_prompt=True
|
|
92
|
+
)
|
|
93
|
+
image_inputs, video_inputs = process_vision_info(messages)
|
|
94
|
+
inputs = processor(
|
|
95
|
+
text=[text],
|
|
96
|
+
images=image_inputs,
|
|
97
|
+
videos=video_inputs,
|
|
98
|
+
padding=True,
|
|
99
|
+
return_tensors="pt",
|
|
100
|
+
)
|
|
101
|
+
inputs = inputs.to("cuda")
|
|
102
|
+
|
|
103
|
+
# Generate caption
|
|
104
|
+
generated_ids = model.generate(
|
|
105
|
+
**inputs,
|
|
106
|
+
max_new_tokens=128,
|
|
107
|
+
do_sample=True,
|
|
108
|
+
top_p=1.0,
|
|
109
|
+
temperature=0.7,
|
|
110
|
+
top_k=50,
|
|
111
|
+
)
|
|
112
|
+
generated_ids_trimmed = [
|
|
113
|
+
out_ids[len(in_ids) :]
|
|
114
|
+
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
115
|
+
]
|
|
116
|
+
|
|
117
|
+
# Truncate caption if it exceeds max_new_tokens
|
|
118
|
+
if max_new_tokens is not None and len(generated_ids_trimmed[0]) > max_new_tokens:
|
|
119
|
+
print(
|
|
120
|
+
f"WARN: Generated tokens for {image} exceed max_new_tokens={max_new_tokens}. Truncating caption.",
|
|
121
|
+
flush=True,
|
|
122
|
+
)
|
|
123
|
+
generated_ids_trimmed[0] = generated_ids_trimmed[0][:max_new_tokens]
|
|
124
|
+
|
|
125
|
+
# Decode the generated tokens to text
|
|
126
|
+
output_text = processor.batch_decode(
|
|
127
|
+
generated_ids_trimmed,
|
|
128
|
+
skip_special_tokens=True,
|
|
129
|
+
clean_up_tokenization_spaces=False,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
return output_text[0]
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def write_caption_to_file(image_file, caption, output_directory):
|
|
136
|
+
if not os.path.exists(output_directory):
|
|
137
|
+
os.makedirs(output_directory)
|
|
138
|
+
|
|
139
|
+
caption_file = os.path.join(
|
|
140
|
+
output_directory, f"{os.path.splitext(image_file)[0]}.txt"
|
|
141
|
+
)
|
|
142
|
+
with open(caption_file, "w") as f:
|
|
143
|
+
f.write(caption)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def ignore_file(filename, ignore_substring):
|
|
147
|
+
if ignore_substring is None:
|
|
148
|
+
return False
|
|
149
|
+
return ignore_substring in filename
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def requires_caption(image_file, output_directory, overwrite):
|
|
153
|
+
caption_file = os.path.join(
|
|
154
|
+
output_directory, f"{os.path.splitext(image_file)[0]}.txt"
|
|
155
|
+
)
|
|
156
|
+
return overwrite or not os.path.exists(caption_file)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def caption_entire_directory(
|
|
160
|
+
directory_path,
|
|
161
|
+
output_directory,
|
|
162
|
+
model,
|
|
163
|
+
processor,
|
|
164
|
+
max_new_tokens=None,
|
|
165
|
+
ignore_substring=None,
|
|
166
|
+
num_captions=None,
|
|
167
|
+
overwrite=False,
|
|
168
|
+
):
|
|
169
|
+
print(
|
|
170
|
+
f"INFO: Processing directory {directory_path} for image captions.", flush=True
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
if not is_image_directory(directory_path):
|
|
174
|
+
for subdir in os.listdir(directory_path):
|
|
175
|
+
if not ignore_file(subdir, ignore_substring):
|
|
176
|
+
subdir_path = os.path.join(directory_path, subdir)
|
|
177
|
+
if os.path.isdir(subdir_path):
|
|
178
|
+
caption_entire_directory(
|
|
179
|
+
subdir_path,
|
|
180
|
+
os.path.join(output_directory, subdir),
|
|
181
|
+
model,
|
|
182
|
+
processor,
|
|
183
|
+
max_new_tokens,
|
|
184
|
+
ignore_substring,
|
|
185
|
+
num_captions,
|
|
186
|
+
overwrite,
|
|
187
|
+
)
|
|
188
|
+
else:
|
|
189
|
+
prompt = get_prompt_for_directory(directory_path)
|
|
190
|
+
for image_file in os.listdir(directory_path):
|
|
191
|
+
if (
|
|
192
|
+
not ignore_file(image_file, ignore_substring)
|
|
193
|
+
and requires_caption(image_file, output_directory, overwrite)
|
|
194
|
+
and image_file.lower().endswith(
|
|
195
|
+
(".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tif")
|
|
196
|
+
)
|
|
197
|
+
):
|
|
198
|
+
try:
|
|
199
|
+
caption = ""
|
|
200
|
+
for i in range(int(num_captions) if num_captions else 1):
|
|
201
|
+
if i != 0:
|
|
202
|
+
caption += "\n"
|
|
203
|
+
|
|
204
|
+
while True:
|
|
205
|
+
caption += caption_image(
|
|
206
|
+
prompt,
|
|
207
|
+
os.path.join(directory_path, image_file),
|
|
208
|
+
model,
|
|
209
|
+
processor,
|
|
210
|
+
max_new_tokens,
|
|
211
|
+
)
|
|
212
|
+
if not contains_chinese(caption):
|
|
213
|
+
break
|
|
214
|
+
write_caption_to_file(image_file, caption, output_directory)
|
|
215
|
+
except Exception as e:
|
|
216
|
+
print(
|
|
217
|
+
f"WARN: Error processing image {image_file} in {directory_path}: {e}",
|
|
218
|
+
flush=True,
|
|
219
|
+
)
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
|
|
3
|
+
from vlm_caption import caption_entire_directory, init_model
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def parse_args():
|
|
7
|
+
parser = argparse.ArgumentParser(
|
|
8
|
+
description="Caption images from a dataset using a VLM."
|
|
9
|
+
)
|
|
10
|
+
parser.add_argument(
|
|
11
|
+
"--input_dir",
|
|
12
|
+
type=str,
|
|
13
|
+
required=True,
|
|
14
|
+
help="The path of the input directory containing images to be captioned.",
|
|
15
|
+
)
|
|
16
|
+
parser.add_argument(
|
|
17
|
+
"--model",
|
|
18
|
+
type=str,
|
|
19
|
+
default=None,
|
|
20
|
+
help="The HuggingFace model used to generate captions.",
|
|
21
|
+
)
|
|
22
|
+
parser.add_argument(
|
|
23
|
+
"--max_length",
|
|
24
|
+
type=int,
|
|
25
|
+
default=None,
|
|
26
|
+
help="The maximum number of tokens to be generated in any given caption.",
|
|
27
|
+
)
|
|
28
|
+
parser.add_argument(
|
|
29
|
+
"--ignore_substring",
|
|
30
|
+
type=str,
|
|
31
|
+
default=None,
|
|
32
|
+
help="Ignore files and subdirectories that contain this substring in their names.",
|
|
33
|
+
)
|
|
34
|
+
parser.add_argument(
|
|
35
|
+
"--num_captions",
|
|
36
|
+
type=str,
|
|
37
|
+
default=None,
|
|
38
|
+
help="Number of captions to be generated.",
|
|
39
|
+
)
|
|
40
|
+
parser.add_argument(
|
|
41
|
+
"--overwrite",
|
|
42
|
+
type=bool,
|
|
43
|
+
default=False,
|
|
44
|
+
help="If true, overwrites existing captions.",
|
|
45
|
+
)
|
|
46
|
+
parser.add_argument(
|
|
47
|
+
"--output_dir",
|
|
48
|
+
type=str,
|
|
49
|
+
default=None,
|
|
50
|
+
help="The directory to act as the root of the caption file structure. Defaults to `<input_dir>_caption`.",
|
|
51
|
+
)
|
|
52
|
+
return parser.parse_args()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def main():
|
|
56
|
+
args = parse_args()
|
|
57
|
+
model, processor = init_model(args.model)
|
|
58
|
+
|
|
59
|
+
output_dir = args.output_dir if args.output_dir is not None else f"{args.input_dir}_caption"
|
|
60
|
+
|
|
61
|
+
if args.model is not None:
|
|
62
|
+
print(f"INFO: Using model {args.model} for captioning.", flush=True)
|
|
63
|
+
if args.max_length is not None:
|
|
64
|
+
print(f"INFO: Setting max length to {args.max_length} tokens.", flush=True)
|
|
65
|
+
if args.ignore_substring is not None:
|
|
66
|
+
print(
|
|
67
|
+
f"INFO: Ignoring files/directories containing substring '{args.ignore_substring}'.",
|
|
68
|
+
flush=True,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
caption_entire_directory(
|
|
72
|
+
args.input_dir,
|
|
73
|
+
output_dir,
|
|
74
|
+
model,
|
|
75
|
+
processor,
|
|
76
|
+
max_new_tokens=args.max_length,
|
|
77
|
+
ignore_substring=args.ignore_substring,
|
|
78
|
+
num_captions=args.num_captions,
|
|
79
|
+
overwrite=args.overwrite,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
main()
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: vlm-dataset-captioner
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Uses a VLM to caption images from a dataset.
|
|
5
|
+
Author: Alex Senden
|
|
6
|
+
Maintainer: Alex Senden
|
|
7
|
+
License: MIT
|
|
8
|
+
Keywords: computer-vision,image-captioning,machine-learning,vision-language-model
|
|
9
|
+
Classifier: Development Status :: 3 - Alpha
|
|
10
|
+
Classifier: Intended Audience :: Developers
|
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering :: Image Processing
|
|
17
|
+
Requires-Python: >=3.8
|
|
18
|
+
Requires-Dist: accelerate
|
|
19
|
+
Requires-Dist: huggingface-hub
|
|
20
|
+
Requires-Dist: qwen-vl-utils
|
|
21
|
+
Requires-Dist: torch
|
|
22
|
+
Requires-Dist: torchvision
|
|
23
|
+
Requires-Dist: transformers
|
|
24
|
+
Description-Content-Type: text/markdown
|
|
25
|
+
|
|
26
|
+
# VLM Captioner
|
|
27
|
+
|
|
28
|
+
Uses a VLM (initially configured to use `Qwen2.5-VL-32B-Instruct`) to caption images from a dataset.
|
|
29
|
+
|
|
30
|
+
### Dataset Structure
|
|
31
|
+
|
|
32
|
+
One VLM prompt will be used for each entire image directory.
|
|
33
|
+
For each image directory, a mirror file structure is created with the suffix `_caption`. This structure contains individual `.txt` caption files with filenames matching that of their image counterparts.
|
|
34
|
+
|
|
35
|
+
```
|
|
36
|
+
dataset/
|
|
37
|
+
└── top_level_folder_1/
|
|
38
|
+
├── image_folder_1 (contains prompt for entire folder)/
|
|
39
|
+
│ ├── prompt.txt
|
|
40
|
+
│ ├── image_1.png
|
|
41
|
+
│ ├── image_2.png
|
|
42
|
+
│ └── ...
|
|
43
|
+
└── ...
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
### Running
|
|
47
|
+
|
|
48
|
+
First, install the required packages:
|
|
49
|
+
|
|
50
|
+
```
|
|
51
|
+
pip install -r requirements.txt
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
Then, run the script:
|
|
55
|
+
|
|
56
|
+
```
|
|
57
|
+
python vlm_caption_cli.py --input_dir=<input_dir> [--model=<vlm_model>]
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
### Command Line Args
|
|
61
|
+
|
|
62
|
+
##### Required Args:
|
|
63
|
+
|
|
64
|
+
```
|
|
65
|
+
--input_dir=<input_dir> || The path of the input directory containing images to be captioned.
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
##### Optional Args:
|
|
69
|
+
|
|
70
|
+
```
|
|
71
|
+
--model=<vlm_model> || VLM to use to generate captions
|
|
72
|
+
--max_length=<max_new_tokens> || Maximum number of new tokens before truncation
|
|
73
|
+
--ignore_substring=<ignore_substring> || Ignore files/directories containing this substring
|
|
74
|
+
--num_captions=<number_of_captions> || Number of captions to generate per image
|
|
75
|
+
--overwrite=<True/False> || If true, overwrites captions that already exist
|
|
76
|
+
--output_dir=<output_dir> || The directory to act as the root of the caption file structure. Defaults to `<input_dir>_caption`.
|
|
77
|
+
```
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
vlm_dataset_captioner/vlm_caption.py,sha256=k711kghgWmWXZIYva8t7v2ew519BjwcchZt3vwzmfZc,6854
|
|
2
|
+
vlm_dataset_captioner/vlm_caption_cli.py,sha256=i1SS43ga2hpxCAQ2XtOkzNFBfI0zKZ5y-aKWI6djt4M,2341
|
|
3
|
+
vlm_dataset_captioner-0.0.1.dist-info/METADATA,sha256=b7B8SwZAIIs2DPsfoY0nUb3ZomPujDl0-LEYDGco-x8,2430
|
|
4
|
+
vlm_dataset_captioner-0.0.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
vlm_dataset_captioner-0.0.1.dist-info/entry_points.txt,sha256=k-zH3SWvcplaeDuGV4J6OyHKLr9GieWcOhRB5sF2pEI,75
|
|
6
|
+
vlm_dataset_captioner-0.0.1.dist-info/RECORD,,
|