vlm-dataset-captioner 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,4 @@
1
+ __pycache__/
2
+ *.slurm
3
+ data/
4
+ .vscode/
@@ -0,0 +1,77 @@
1
+ Metadata-Version: 2.4
2
+ Name: vlm-dataset-captioner
3
+ Version: 0.0.1
4
+ Summary: Uses a VLM to caption images from a dataset.
5
+ Author: Alex Senden
6
+ Maintainer: Alex Senden
7
+ License: MIT
8
+ Keywords: computer-vision,image-captioning,machine-learning,vision-language-model
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: Intended Audience :: Developers
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3.8
13
+ Classifier: Programming Language :: Python :: 3.9
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Topic :: Scientific/Engineering :: Image Processing
17
+ Requires-Python: >=3.8
18
+ Requires-Dist: accelerate
19
+ Requires-Dist: huggingface-hub
20
+ Requires-Dist: qwen-vl-utils
21
+ Requires-Dist: torch
22
+ Requires-Dist: torchvision
23
+ Requires-Dist: transformers
24
+ Description-Content-Type: text/markdown
25
+
26
+ # VLM Captioner
27
+
28
+ Uses a VLM (initially configured to use `Qwen2.5-VL-32B-Instruct`) to caption images from a dataset.
29
+
30
+ ### Dataset Structure
31
+
32
+ One VLM prompt will be used for each entire image directory.
33
+ For each image directory, a mirror file structure is created with the suffix `_caption`. This structure contains individual `.txt` caption files with filenames matching that of their image counterparts.
34
+
35
+ ```
36
+ dataset/
37
+ └── top_level_folder_1/
38
+ ├── image_folder_1 (contains prompt for entire folder)/
39
+ │ ├── prompt.txt
40
+ │ ├── image_1.png
41
+ │ ├── image_2.png
42
+ │ └── ...
43
+ └── ...
44
+ ```
45
+
46
+ ### Running
47
+
48
+ First, install the required packages:
49
+
50
+ ```
51
+ pip install -r requirements.txt
52
+ ```
53
+
54
+ Then, run the script:
55
+
56
+ ```
57
+ python vlm_caption_cli.py --input_dir=<input_dir> [--model=<vlm_model>]
58
+ ```
59
+
60
+ ### Command Line Args
61
+
62
+ ##### Required Args:
63
+
64
+ ```
65
+ --input_dir=<input_dir> || The path of the input directory containing images to be captioned.
66
+ ```
67
+
68
+ ##### Optional Args:
69
+
70
+ ```
71
+ --model=<vlm_model> || VLM to use to generate captions
72
+ --max_length=<max_new_tokens> || Maximum number of new tokens before truncation
73
+ --ignore_substring=<ignore_substring> || Ignore files/directories containing this substring
74
+ --num_captions=<number_of_captions> || Number of captions to generate per image
75
+ --overwrite=<True/False> || If true, overwrites captions that already exist
76
+ --output_dir=<output_dir> || The directory to act as the root of the caption file structure. Defaults to `<input_dir>_caption`.
77
+ ```
@@ -0,0 +1,52 @@
1
+ # VLM Captioner
2
+
3
+ Uses a VLM (initially configured to use `Qwen2.5-VL-32B-Instruct`) to caption images from a dataset.
4
+
5
+ ### Dataset Structure
6
+
7
+ One VLM prompt will be used for each entire image directory.
8
+ For each image directory, a mirror file structure is created with the suffix `_caption`. This structure contains individual `.txt` caption files with filenames matching that of their image counterparts.
9
+
10
+ ```
11
+ dataset/
12
+ └── top_level_folder_1/
13
+ ├── image_folder_1 (contains prompt for entire folder)/
14
+ │ ├── prompt.txt
15
+ │ ├── image_1.png
16
+ │ ├── image_2.png
17
+ │ └── ...
18
+ └── ...
19
+ ```
20
+
21
+ ### Running
22
+
23
+ First, install the required packages:
24
+
25
+ ```
26
+ pip install -r requirements.txt
27
+ ```
28
+
29
+ Then, run the script:
30
+
31
+ ```
32
+ python vlm_caption_cli.py --input_dir=<input_dir> [--model=<vlm_model>]
33
+ ```
34
+
35
+ ### Command Line Args
36
+
37
+ ##### Required Args:
38
+
39
+ ```
40
+ --input_dir=<input_dir> || The path of the input directory containing images to be captioned.
41
+ ```
42
+
43
+ ##### Optional Args:
44
+
45
+ ```
46
+ --model=<vlm_model> || VLM to use to generate captions
47
+ --max_length=<max_new_tokens> || Maximum number of new tokens before truncation
48
+ --ignore_substring=<ignore_substring> || Ignore files/directories containing this substring
49
+ --num_captions=<number_of_captions> || Number of captions to generate per image
50
+ --overwrite=<True/False> || If true, overwrites captions that already exist
51
+ --output_dir=<output_dir> || The directory to act as the root of the caption file structure. Defaults to `<input_dir>_caption`.
52
+ ```
@@ -0,0 +1,51 @@
1
+ [build-system]
2
+ requires = ["hatchling", "hatch-vcs"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "vlm-dataset-captioner"
7
+ description = "Uses a VLM to caption images from a dataset."
8
+ readme = "README.md"
9
+ license = { text = "MIT" }
10
+ requires-python = ">=3.8"
11
+ authors = [{ name = "Alex Senden" }]
12
+ maintainers = [{ name = "Alex Senden" }]
13
+ keywords = [
14
+ "vision-language-model",
15
+ "image-captioning",
16
+ "computer-vision",
17
+ "machine-learning",
18
+ ]
19
+ classifiers = [
20
+ "Development Status :: 3 - Alpha",
21
+ "Intended Audience :: Developers",
22
+ "Programming Language :: Python :: 3",
23
+ "Programming Language :: Python :: 3.8",
24
+ "Programming Language :: Python :: 3.9",
25
+ "Programming Language :: Python :: 3.10",
26
+ "Programming Language :: Python :: 3.11",
27
+ "Topic :: Scientific/Engineering :: Image Processing",
28
+ ]
29
+ dependencies = [
30
+ "transformers",
31
+ "huggingface-hub",
32
+ "qwen_vl_utils",
33
+ "torch",
34
+ "torchvision",
35
+ "accelerate",
36
+ ]
37
+
38
+ # Version is automatically provided by hatch-vcs
39
+ dynamic = ["version"]
40
+
41
+ [project.scripts]
42
+ vlm-caption = "vlm_dataset_captioner.vlm_caption_cli:main"
43
+
44
+ [tool.hatch.version]
45
+ source = "vcs"
46
+
47
+ [tool.hatch.build.targets.sdist]
48
+ include = ["vlm_dataset_captioner/**", "README.md"]
49
+
50
+ [tool.hatch.build.targets.wheel]
51
+ include = ["vlm_dataset_captioner/**"]
@@ -0,0 +1,219 @@
1
+ import os
2
+ import re
3
+
4
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
5
+ from qwen_vl_utils import process_vision_info
6
+
7
+ DEFAULT_MODEL = "Qwen/Qwen2.5-VL-32B-Instruct"
8
+ IMAGE_FILE_EXTENSIONS = (
9
+ ".png",
10
+ ".jpg",
11
+ ".jpeg",
12
+ ".gif",
13
+ ".bmp",
14
+ ".tif",
15
+ )
16
+
17
+
18
+ def init_model(model_name=None):
19
+ if model_name is None:
20
+ model_name = DEFAULT_MODEL
21
+ print(f"INFO: No model name provided. Initializing default model {model_name}.")
22
+
23
+ print(f"INFO: Initializing model {model_name}.", flush=True)
24
+
25
+ # Load the model on the available device(s)
26
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
27
+ model_name, torch_dtype="auto", device_map="auto"
28
+ )
29
+
30
+ print(f"INFO: Model {model_name} loaded successfully.", flush=True)
31
+ print(f"INFO: Loading processor for model {model_name}.", flush=True)
32
+
33
+ # Load the default processor
34
+ processor = AutoProcessor.from_pretrained(model_name)
35
+
36
+ print(f"INFO: Processor for model {model_name} loaded successfully.", flush=True)
37
+
38
+ return model, processor
39
+
40
+
41
+ def get_prompt_for_directory(directory_path):
42
+ prompt = ""
43
+ prompt_file_path = os.path.join(directory_path, "prompt.txt")
44
+ try:
45
+ with open(prompt_file_path) as prompt_file:
46
+ prompt = prompt_file.read()
47
+ except FileNotFoundError:
48
+ print(
49
+ f"WARN: Prompt file not found for directory {prompt_file_path}. Using default prompt.",
50
+ flush=True,
51
+ )
52
+ prompt = "Describe the image in detail."
53
+
54
+ print(f"INFO: Using prompt: '{prompt}'", flush=True)
55
+
56
+ return prompt
57
+
58
+
59
+ def is_image_file(filename):
60
+ return filename.lower().endswith(IMAGE_FILE_EXTENSIONS)
61
+
62
+
63
+ def is_image_directory(directory_path):
64
+ return any(is_image_file(filename) for filename in os.listdir(directory_path))
65
+
66
+
67
+ def get_messages(prompt, image):
68
+ return [
69
+ {
70
+ "role": "user",
71
+ "content": [
72
+ {"type": "image", "image": image},
73
+ {"type": "text", "text": prompt},
74
+ ],
75
+ }
76
+ ]
77
+
78
+
79
+ def contains_chinese(text_string):
80
+ # The Unicode range for common CJK Unified Ideographs (Han characters)
81
+ # is typically from U+4E00 to U+9FFF.
82
+ chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]")
83
+ return bool(chinese_char_pattern.search(text_string))
84
+
85
+
86
+ def caption_image(prompt, image, model, processor, max_new_tokens=None):
87
+ messages = get_messages(prompt, image)
88
+
89
+ # Prepare inputs for the model
90
+ text = processor.apply_chat_template(
91
+ messages, tokenize=False, add_generation_prompt=True
92
+ )
93
+ image_inputs, video_inputs = process_vision_info(messages)
94
+ inputs = processor(
95
+ text=[text],
96
+ images=image_inputs,
97
+ videos=video_inputs,
98
+ padding=True,
99
+ return_tensors="pt",
100
+ )
101
+ inputs = inputs.to("cuda")
102
+
103
+ # Generate caption
104
+ generated_ids = model.generate(
105
+ **inputs,
106
+ max_new_tokens=128,
107
+ do_sample=True,
108
+ top_p=1.0,
109
+ temperature=0.7,
110
+ top_k=50,
111
+ )
112
+ generated_ids_trimmed = [
113
+ out_ids[len(in_ids) :]
114
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
115
+ ]
116
+
117
+ # Truncate caption if it exceeds max_new_tokens
118
+ if max_new_tokens is not None and len(generated_ids_trimmed[0]) > max_new_tokens:
119
+ print(
120
+ f"WARN: Generated tokens for {image} exceed max_new_tokens={max_new_tokens}. Truncating caption.",
121
+ flush=True,
122
+ )
123
+ generated_ids_trimmed[0] = generated_ids_trimmed[0][:max_new_tokens]
124
+
125
+ # Decode the generated tokens to text
126
+ output_text = processor.batch_decode(
127
+ generated_ids_trimmed,
128
+ skip_special_tokens=True,
129
+ clean_up_tokenization_spaces=False,
130
+ )
131
+
132
+ return output_text[0]
133
+
134
+
135
+ def write_caption_to_file(image_file, caption, output_directory):
136
+ if not os.path.exists(output_directory):
137
+ os.makedirs(output_directory)
138
+
139
+ caption_file = os.path.join(
140
+ output_directory, f"{os.path.splitext(image_file)[0]}.txt"
141
+ )
142
+ with open(caption_file, "w") as f:
143
+ f.write(caption)
144
+
145
+
146
+ def ignore_file(filename, ignore_substring):
147
+ if ignore_substring is None:
148
+ return False
149
+ return ignore_substring in filename
150
+
151
+
152
+ def requires_caption(image_file, output_directory, overwrite):
153
+ caption_file = os.path.join(
154
+ output_directory, f"{os.path.splitext(image_file)[0]}.txt"
155
+ )
156
+ return overwrite or not os.path.exists(caption_file)
157
+
158
+
159
+ def caption_entire_directory(
160
+ directory_path,
161
+ output_directory,
162
+ model,
163
+ processor,
164
+ max_new_tokens=None,
165
+ ignore_substring=None,
166
+ num_captions=None,
167
+ overwrite=False,
168
+ ):
169
+ print(
170
+ f"INFO: Processing directory {directory_path} for image captions.", flush=True
171
+ )
172
+
173
+ if not is_image_directory(directory_path):
174
+ for subdir in os.listdir(directory_path):
175
+ if not ignore_file(subdir, ignore_substring):
176
+ subdir_path = os.path.join(directory_path, subdir)
177
+ if os.path.isdir(subdir_path):
178
+ caption_entire_directory(
179
+ subdir_path,
180
+ os.path.join(output_directory, subdir),
181
+ model,
182
+ processor,
183
+ max_new_tokens,
184
+ ignore_substring,
185
+ num_captions,
186
+ overwrite,
187
+ )
188
+ else:
189
+ prompt = get_prompt_for_directory(directory_path)
190
+ for image_file in os.listdir(directory_path):
191
+ if (
192
+ not ignore_file(image_file, ignore_substring)
193
+ and requires_caption(image_file, output_directory, overwrite)
194
+ and image_file.lower().endswith(
195
+ (".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tif")
196
+ )
197
+ ):
198
+ try:
199
+ caption = ""
200
+ for i in range(int(num_captions) if num_captions else 1):
201
+ if i != 0:
202
+ caption += "\n"
203
+
204
+ while True:
205
+ caption += caption_image(
206
+ prompt,
207
+ os.path.join(directory_path, image_file),
208
+ model,
209
+ processor,
210
+ max_new_tokens,
211
+ )
212
+ if not contains_chinese(caption):
213
+ break
214
+ write_caption_to_file(image_file, caption, output_directory)
215
+ except Exception as e:
216
+ print(
217
+ f"WARN: Error processing image {image_file} in {directory_path}: {e}",
218
+ flush=True,
219
+ )
@@ -0,0 +1,83 @@
1
+ import argparse
2
+
3
+ from vlm_caption import caption_entire_directory, init_model
4
+
5
+
6
+ def parse_args():
7
+ parser = argparse.ArgumentParser(
8
+ description="Caption images from a dataset using a VLM."
9
+ )
10
+ parser.add_argument(
11
+ "--input_dir",
12
+ type=str,
13
+ required=True,
14
+ help="The path of the input directory containing images to be captioned.",
15
+ )
16
+ parser.add_argument(
17
+ "--model",
18
+ type=str,
19
+ default=None,
20
+ help="The HuggingFace model used to generate captions.",
21
+ )
22
+ parser.add_argument(
23
+ "--max_length",
24
+ type=int,
25
+ default=None,
26
+ help="The maximum number of tokens to be generated in any given caption.",
27
+ )
28
+ parser.add_argument(
29
+ "--ignore_substring",
30
+ type=str,
31
+ default=None,
32
+ help="Ignore files and subdirectories that contain this substring in their names.",
33
+ )
34
+ parser.add_argument(
35
+ "--num_captions",
36
+ type=str,
37
+ default=None,
38
+ help="Number of captions to be generated.",
39
+ )
40
+ parser.add_argument(
41
+ "--overwrite",
42
+ type=bool,
43
+ default=False,
44
+ help="If true, overwrites existing captions.",
45
+ )
46
+ parser.add_argument(
47
+ "--output_dir",
48
+ type=str,
49
+ default=None,
50
+ help="The directory to act as the root of the caption file structure. Defaults to `<input_dir>_caption`.",
51
+ )
52
+ return parser.parse_args()
53
+
54
+
55
+ def main():
56
+ args = parse_args()
57
+ model, processor = init_model(args.model)
58
+
59
+ output_dir = args.output_dir if args.output_dir is not None else f"{args.input_dir}_caption"
60
+
61
+ if args.model is not None:
62
+ print(f"INFO: Using model {args.model} for captioning.", flush=True)
63
+ if args.max_length is not None:
64
+ print(f"INFO: Setting max length to {args.max_length} tokens.", flush=True)
65
+ if args.ignore_substring is not None:
66
+ print(
67
+ f"INFO: Ignoring files/directories containing substring '{args.ignore_substring}'.",
68
+ flush=True,
69
+ )
70
+
71
+ caption_entire_directory(
72
+ args.input_dir,
73
+ output_dir,
74
+ model,
75
+ processor,
76
+ max_new_tokens=args.max_length,
77
+ ignore_substring=args.ignore_substring,
78
+ num_captions=args.num_captions,
79
+ overwrite=args.overwrite,
80
+ )
81
+
82
+
83
+ main()