labelr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,54 +3,70 @@ import logging
3
3
  import pickle
4
4
  import random
5
5
  import tempfile
6
- from collections.abc import Iterator
6
+ import typing
7
7
  from pathlib import Path
8
8
 
9
9
  import datasets
10
10
  import tqdm
11
11
  from label_studio_sdk.client import LabelStudio
12
- from openfoodfacts.images import download_image, generate_image_url
13
- from openfoodfacts.types import Flavor
12
+ from openfoodfacts.images import download_image
14
13
  from PIL import Image, ImageOps
15
14
 
16
- from labelr.sample import (
17
- HF_DS_CLASSIFICATION_FEATURES,
18
- HF_DS_LLM_IMAGE_EXTRACTION_FEATURES,
19
- HF_DS_OBJECT_DETECTION_FEATURES,
20
- LLMImageExtractionSample,
15
+ from labelr.export.common import _pickle_sample_generator
16
+ from labelr.sample.object_detection import (
21
17
  format_object_detection_sample_to_hf,
18
+ get_hf_object_detection_features,
22
19
  )
23
- from labelr.types import TaskType
24
- from labelr.utils import PathWithContext
25
20
 
26
21
  logger = logging.getLogger(__name__)
27
22
 
28
23
 
29
- def _pickle_sample_generator(dir: Path):
30
- """Generator that yields samples from pickles in a directory."""
31
- for pkl in dir.glob("*.pkl"):
32
- with open(pkl, "rb") as f:
33
- yield pickle.load(f)
34
-
35
-
36
24
  def export_from_ls_to_hf_object_detection(
37
25
  ls: LabelStudio,
38
26
  repo_id: str,
39
27
  label_names: list[str],
40
28
  project_id: int,
29
+ is_openfoodfacts_dataset: bool,
30
+ image_max_size: int | None = None,
31
+ view_id: int | None = None,
41
32
  merge_labels: bool = False,
42
33
  use_aws_cache: bool = True,
43
34
  revision: str = "main",
44
- ):
35
+ ) -> None:
36
+ """Export annotations from a Label Studio project to a Hugging Face
37
+ dataset.
38
+
39
+ The Label Studio project should be an object detection project.
40
+
41
+ Args:
42
+ ls (LabelStudio): Label Studio client instance.
43
+ repo_id (str): Hugging Face repository ID to push the dataset to.
44
+ label_names (list[str]): List of label names in the project.
45
+ project_id (int): Label Studio project ID to export from.
46
+ is_openfoodfacts_dataset (bool): Whether the dataset is an Open Food
47
+ Facts dataset. If True, the dataset will include additional
48
+ metadata fields specific to Open Food Facts (`barcode` and
49
+ `off_image_id`).
50
+ image_max_size (int | None): Maximum size (in pixels) for the images.
51
+ If None, no resizing is performed. Defaults to None.
52
+ view_id (int | None): Label Studio view ID to export from. If None,
53
+ all tasks are exported. Defaults to None.
54
+ merge_labels (bool): Whether to merge all labels into a single label
55
+ named "object". Defaults to False.
56
+ use_aws_cache (bool): Whether to use the AWS image cache when
57
+ downloading images. Defaults to True.
58
+ revision (str): The dataset revision to push to. Defaults to 'main'.
59
+ """
45
60
  if merge_labels:
46
61
  label_names = ["object"]
47
62
 
48
63
  logger.info(
49
- "Project ID: %d, label names: %s, repo_id: %s, revision: %s",
64
+ "Project ID: %d, label names: %s, repo_id: %s, revision: %s, view ID: %s",
50
65
  project_id,
51
66
  label_names,
52
67
  repo_id,
53
68
  revision,
69
+ view_id,
54
70
  )
55
71
 
56
72
  for split in ["train", "val"]:
@@ -60,7 +76,9 @@ def export_from_ls_to_hf_object_detection(
60
76
  tmp_dir = Path(tmp_dir_str)
61
77
  logger.info("Saving samples to temporary directory: %s", tmp_dir)
62
78
  for i, task in tqdm.tqdm(
63
- enumerate(ls.tasks.list(project=project_id, fields="all")),
79
+ enumerate(
80
+ ls.tasks.list(project=project_id, fields="all", view=view_id)
81
+ ),
64
82
  desc="tasks",
65
83
  ):
66
84
  if task.data["split"] != split:
@@ -71,15 +89,17 @@ def export_from_ls_to_hf_object_detection(
71
89
  label_names=label_names,
72
90
  merge_labels=merge_labels,
73
91
  use_aws_cache=use_aws_cache,
92
+ image_max_size=image_max_size,
74
93
  )
75
94
  if sample is not None:
76
95
  # Save output as pickle
77
96
  with open(tmp_dir / f"{split}_{i:05}.pkl", "wb") as f:
78
97
  pickle.dump(sample, f)
79
98
 
99
+ features = get_hf_object_detection_features(is_openfoodfacts_dataset)
80
100
  hf_ds = datasets.Dataset.from_generator(
81
101
  functools.partial(_pickle_sample_generator, tmp_dir),
82
- features=HF_DS_OBJECT_DETECTION_FEATURES,
102
+ features=features,
83
103
  )
84
104
  hf_ds.push_to_hub(repo_id, split=split, revision=revision)
85
105
 
@@ -93,12 +113,32 @@ def export_from_ls_to_ultralytics_object_detection(
93
113
  error_raise: bool = True,
94
114
  merge_labels: bool = False,
95
115
  use_aws_cache: bool = True,
116
+ view_id: int | None = None,
117
+ image_max_size: int | None = None,
96
118
  ):
97
119
  """Export annotations from a Label Studio project to the Ultralytics
98
120
  format.
99
121
 
100
122
  The Label Studio project should be an object detection project with a
101
123
  single rectanglelabels annotation result per task.
124
+
125
+ Args:
126
+ ls (LabelStudio): Label Studio client instance.
127
+ output_dir (Path): Path to the output directory.
128
+ label_names (list[str]): List of label names in the project.
129
+ project_id (int): Label Studio project ID to export from.
130
+ train_ratio (float): Ratio of training samples. The rest will be used
131
+ for validation. Defaults to 0.8.
132
+ error_raise (bool): Whether to raise an error if an image fails to
133
+ download. If False, the image will be skipped. Defaults to True.
134
+ merge_labels (bool): Whether to merge all labels into a single label
135
+ named "object". Defaults to False.
136
+ use_aws_cache (bool): Whether to use the AWS image cache when
137
+ downloading images. Defaults to True.
138
+ view_id (int | None): Label Studio view ID to export from. If None,
139
+ all tasks are exported. Defaults to None.
140
+ image_max_size (int | None): Maximum size (in pixels) for the images.
141
+ If None, no resizing is performed. Defaults to None.
102
142
  """
103
143
  if merge_labels:
104
144
  label_names = ["object"]
@@ -116,7 +156,7 @@ def export_from_ls_to_ultralytics_object_detection(
116
156
  (images_dir / split).mkdir(parents=True, exist_ok=True)
117
157
 
118
158
  for task in tqdm.tqdm(
119
- ls.tasks.list(project=project_id, fields="all"),
159
+ ls.tasks.list(project=project_id, fields="all", view=view_id),
120
160
  desc="tasks",
121
161
  ):
122
162
  split = task.data.get("split")
@@ -194,18 +234,28 @@ def export_from_ls_to_ultralytics_object_detection(
194
234
  has_valid_annotation = True
195
235
 
196
236
  if has_valid_annotation:
197
- download_output = download_image(
237
+ image = download_image(
198
238
  image_url,
199
- return_struct=True,
239
+ return_struct=False,
200
240
  error_raise=error_raise,
201
241
  use_cache=use_aws_cache,
202
242
  )
203
- if download_output is None:
243
+ if image is None:
204
244
  logger.error("Failed to download image: %s", image_url)
205
245
  continue
206
246
 
207
- with (images_dir / split / f"{image_id}.jpg").open("wb") as f:
208
- f.write(download_output.image_bytes)
247
+ image = typing.cast(Image.Image, image)
248
+
249
+ # Rotate image according to exif orientation using Pillow
250
+ ImageOps.exif_transpose(image, in_place=True)
251
+ # Resize image if larger than max size
252
+ if image_max_size is not None and (
253
+ image.width > image_max_size or image.height > image_max_size
254
+ ):
255
+ image.thumbnail(
256
+ (image_max_size, image_max_size), Image.Resampling.LANCZOS
257
+ )
258
+ image.save(images_dir / split / f"{image_id}.jpg", format="JPEG")
209
259
 
210
260
  with (output_dir / "data.yaml").open("w") as f:
211
261
  f.write("path: data\n")
@@ -223,6 +273,7 @@ def export_from_hf_to_ultralytics_object_detection(
223
273
  download_images: bool = True,
224
274
  error_raise: bool = True,
225
275
  use_aws_cache: bool = True,
276
+ image_max_size: int | None = None,
226
277
  revision: str = "main",
227
278
  ):
228
279
  """Export annotations from a Hugging Face dataset project to the
@@ -243,6 +294,8 @@ def export_from_hf_to_ultralytics_object_detection(
243
294
  use_aws_cache (bool): Whether to use the AWS image cache when
244
295
  downloading images. This option is only used if `download_images`
245
296
  is True. Defaults to True.
297
+ image_max_size (int | None): Maximum size (in pixels) for the images.
298
+ If None, no resizing is performed. Defaults to None.
246
299
  revision (str): The dataset revision to load. Defaults to 'main'.
247
300
  """
248
301
  logger.info("Repo ID: %s, revision: %s", repo_id, revision)
@@ -278,21 +331,31 @@ def export_from_hf_to_ultralytics_object_detection(
278
331
  "`download_images` to False."
279
332
  )
280
333
  image_url = sample["meta"]["image_url"]
281
- download_output = download_image(
334
+ image = download_image(
282
335
  image_url,
283
- return_struct=True,
336
+ return_struct=False,
284
337
  error_raise=error_raise,
285
338
  use_cache=use_aws_cache,
286
339
  )
287
- if download_output is None:
340
+ if image is None:
288
341
  logger.error("Failed to download image: %s", image_url)
289
342
  continue
290
-
291
- with (split_images_dir / f"{image_id}.jpg").open("wb") as f:
292
- f.write(download_output.image_bytes)
293
343
  else:
294
344
  image = sample["image"]
295
- image.save(split_images_dir / f"{image_id}.jpg")
345
+
346
+ image = typing.cast(Image.Image, image)
347
+ # Rotate image according to exif orientation using Pillow
348
+ # If the image source is Hugging Face, EXIF data is not preserved,
349
+ # so this step is only useful when downloading images.
350
+ ImageOps.exif_transpose(image, in_place=True)
351
+ # Resize image if larger than max size
352
+ if image_max_size is not None and (
353
+ image.width > image_max_size or image.height > image_max_size
354
+ ):
355
+ image.thumbnail(
356
+ (image_max_size, image_max_size), Image.Resampling.LANCZOS
357
+ )
358
+ image.save(split_images_dir / f"{image_id}.jpg")
296
359
 
297
360
  objects = sample["objects"]
298
361
  bboxes = objects["bbox"]
@@ -335,186 +398,3 @@ def export_from_hf_to_ultralytics_object_detection(
335
398
  f.write("names:\n")
336
399
  for i, category_name in enumerate(category_names):
337
400
  f.write(f" {i}: {category_name}\n")
338
-
339
-
340
- def export_from_ultralytics_to_hf(
341
- task_type: TaskType,
342
- dataset_dir: Path,
343
- repo_id: str,
344
- label_names: list[str],
345
- merge_labels: bool = False,
346
- is_openfoodfacts_dataset: bool = False,
347
- openfoodfacts_flavor: Flavor = Flavor.off,
348
- ) -> None:
349
- if task_type != TaskType.classification:
350
- raise NotImplementedError(
351
- "Only classification task is currently supported for Ultralytics to HF export"
352
- )
353
-
354
- if task_type == TaskType.classification:
355
- export_from_ultralytics_to_hf_classification(
356
- dataset_dir=dataset_dir,
357
- repo_id=repo_id,
358
- label_names=label_names,
359
- merge_labels=merge_labels,
360
- is_openfoodfacts_dataset=is_openfoodfacts_dataset,
361
- openfoodfacts_flavor=openfoodfacts_flavor,
362
- )
363
-
364
-
365
- def export_from_ultralytics_to_hf_classification(
366
- dataset_dir: Path,
367
- repo_id: str,
368
- label_names: list[str],
369
- merge_labels: bool = False,
370
- is_openfoodfacts_dataset: bool = False,
371
- openfoodfacts_flavor: Flavor = Flavor.off,
372
- ) -> None:
373
- """Export an Ultralytics classification dataset to a Hugging Face dataset.
374
-
375
- The Ultralytics dataset directory should contain 'train', 'val' and/or
376
- 'test' subdirectories, each containing subdirectories for each label.
377
-
378
- Args:
379
- dataset_dir (Path): Path to the Ultralytics dataset directory.
380
- repo_id (str): Hugging Face repository ID to push the dataset to.
381
- label_names (list[str]): List of label names.
382
- merge_labels (bool): Whether to merge all labels into a single label
383
- named 'object'.
384
- is_openfoodfacts_dataset (bool): Whether the dataset is from
385
- Open Food Facts. If True, the `off_image_id` and `image_url` will
386
- be generated automatically. `off_image_id` is extracted from the
387
- image filename.
388
- openfoodfacts_flavor (Flavor): Flavor of Open Food Facts dataset. This
389
- is ignored if `is_openfoodfacts_dataset` is False.
390
- """
391
- logger.info("Repo ID: %s, dataset_dir: %s", repo_id, dataset_dir)
392
-
393
- if not any((dataset_dir / split).is_dir() for split in ["train", "val", "test"]):
394
- raise ValueError(
395
- f"Dataset directory {dataset_dir} does not contain 'train', 'val' or 'test' subdirectories"
396
- )
397
-
398
- # Save output as pickle
399
- for split in ["train", "val", "test"]:
400
- split_dir = dataset_dir / split
401
-
402
- if not split_dir.is_dir():
403
- logger.info("Skipping missing split directory: %s", split_dir)
404
- continue
405
-
406
- with tempfile.TemporaryDirectory() as tmp_dir_str:
407
- tmp_dir = Path(tmp_dir_str)
408
- for label_dir in (d for d in split_dir.iterdir() if d.is_dir()):
409
- label_name = label_dir.name
410
- if merge_labels:
411
- label_name = "object"
412
- if label_name not in label_names:
413
- raise ValueError(
414
- "Label name %s not in provided label names (label names: %s)"
415
- % (label_name, label_names),
416
- )
417
- label_id = label_names.index(label_name)
418
-
419
- for image_path in label_dir.glob("*"):
420
- if is_openfoodfacts_dataset:
421
- image_stem_parts = image_path.stem.split("_")
422
- barcode = image_stem_parts[0]
423
- off_image_id = image_stem_parts[1]
424
- image_id = f"{barcode}_{off_image_id}"
425
- image_url = generate_image_url(
426
- barcode, off_image_id, flavor=openfoodfacts_flavor
427
- )
428
- else:
429
- image_id = image_path.stem
430
- barcode = ""
431
- off_image_id = ""
432
- image_url = ""
433
- image = Image.open(image_path)
434
- image.load()
435
-
436
- if image.mode != "RGB":
437
- image = image.convert("RGB")
438
-
439
- # Rotate image according to exif orientation using Pillow
440
- ImageOps.exif_transpose(image, in_place=True)
441
- sample = {
442
- "image_id": image_id,
443
- "image": image,
444
- "width": image.width,
445
- "height": image.height,
446
- "meta": {
447
- "barcode": barcode,
448
- "off_image_id": off_image_id,
449
- "image_url": image_url,
450
- },
451
- "category_id": label_id,
452
- "category_name": label_name,
453
- }
454
- with open(tmp_dir / f"{split}_{image_id}.pkl", "wb") as f:
455
- pickle.dump(sample, f)
456
-
457
- hf_ds = datasets.Dataset.from_generator(
458
- functools.partial(_pickle_sample_generator, tmp_dir),
459
- features=HF_DS_CLASSIFICATION_FEATURES,
460
- )
461
- hf_ds.push_to_hub(repo_id, split=split)
462
-
463
-
464
- def export_to_hf_llm_image_extraction(
465
- sample_iter: Iterator[LLMImageExtractionSample],
466
- split: str,
467
- repo_id: str,
468
- revision: str = "main",
469
- tmp_dir: Path | None = None,
470
- ) -> None:
471
- """Export LLM image extraction samples to a Hugging Face dataset.
472
-
473
- Args:
474
- sample_iter (Iterator[LLMImageExtractionSample]): Iterator of samples
475
- to export.
476
- split (str): Name of the dataset split (e.g., 'train', 'val').
477
- repo_id (str): Hugging Face repository ID to push the dataset to.
478
- revision (str): Revision (branch, tag or commit) to use for the
479
- Hugging Face Datasets repository.
480
- tmp_dir (Path | None): Temporary directory to use for intermediate
481
- files. If None, a temporary directory will be created
482
- automatically.
483
- """
484
- logger.info(
485
- "Repo ID: %s, revision: %s, split: %s, tmp_dir: %s",
486
- repo_id,
487
- revision,
488
- split,
489
- tmp_dir,
490
- )
491
-
492
- tmp_dir_with_context: PathWithContext | tempfile.TemporaryDirectory
493
- if tmp_dir:
494
- tmp_dir.mkdir(parents=True, exist_ok=True)
495
- tmp_dir_with_context = PathWithContext(tmp_dir)
496
- else:
497
- tmp_dir_with_context = tempfile.TemporaryDirectory()
498
-
499
- with tmp_dir_with_context as tmp_dir_str:
500
- tmp_dir = Path(tmp_dir_str)
501
- for sample in tqdm.tqdm(sample_iter, desc="samples"):
502
- image = sample.image
503
- # Rotate image according to exif orientation using Pillow
504
- image = ImageOps.exif_transpose(image)
505
- image_id = sample.image_id
506
- sample = {
507
- "image_id": image_id,
508
- "image": image,
509
- "meta": sample.meta.model_dump(),
510
- "output": sample.output,
511
- }
512
- # Save output as pickle
513
- with open(tmp_dir / f"{split}_{image_id}.pkl", "wb") as f:
514
- pickle.dump(sample, f)
515
-
516
- hf_ds = datasets.Dataset.from_generator(
517
- functools.partial(_pickle_sample_generator, tmp_dir),
518
- features=HF_DS_LLM_IMAGE_EXTRACTION_FEATURES,
519
- )
520
- hf_ds.push_to_hub(repo_id, split=split, revision=revision)
labelr/google_genai.py CHANGED
@@ -11,10 +11,11 @@ import orjson
11
11
  import typer
12
12
  from gcloud.aio.storage import Storage
13
13
  from openfoodfacts import Flavor
14
- from openfoodfacts.images import download_image, generate_image_url
14
+ from openfoodfacts.images import generate_image_url
15
15
  from tqdm.asyncio import tqdm
16
16
 
17
- from labelr.sample import LLMImageExtractionSample, SampleMeta
17
+ from labelr.sample.common import SampleMeta
18
+ from labelr.sample.llm import LLMImageExtractionSample
18
19
  from labelr.utils import download_image_from_gcs
19
20
 
20
21
  try:
@@ -335,6 +336,7 @@ def generate_sample_iter(
335
336
  """
336
337
  skipped = 0
337
338
  invalid = 0
339
+ storage_client = storage.Client()
338
340
  with prediction_path.open("r") as f_in:
339
341
  for i, sample_str in enumerate(f_in):
340
342
  if i < skip:
@@ -349,6 +351,7 @@ def generate_sample_iter(
349
351
  sample=sample,
350
352
  is_openfoodfacts_dataset=is_openfoodfacts_dataset,
351
353
  openfoodfacts_flavor=openfoodfacts_flavor,
354
+ storage_client=storage_client,
352
355
  )
353
356
  except Exception as e:
354
357
  if raise_on_invalid_sample:
@@ -370,6 +373,7 @@ def generate_sample_from_prediction(
370
373
  sample: JSONType,
371
374
  is_openfoodfacts_dataset: bool = False,
372
375
  openfoodfacts_flavor: Flavor = Flavor.off,
376
+ storage_client: storage.Client | None = None,
373
377
  ) -> LLMImageExtractionSample:
374
378
  """Generate a LLMImageExtractionSample from a prediction sample.
375
379
  Args:
@@ -378,13 +382,15 @@ def generate_sample_from_prediction(
378
382
  is_openfoodfacts_dataset (bool): Whether the dataset is from Open Food
379
383
  Facts.
380
384
  openfoodfacts_flavor (Flavor): Flavor of the Open Food Facts dataset.
385
+ storage_client (storage.Client | None): Optional Google Cloud Storage
386
+ client. If not provided, a new client will be created.
381
387
  Returns:
382
388
  LLMImageExtractionSample: Generated sample.
383
389
  """
384
390
  image_id = sample["key"][len("key:") :]
385
391
  response_str = sample["response"]["candidates"][0]["content"]["parts"][0]["text"]
386
392
  image_uri = sample["request"]["contents"][0]["parts"][1]["file_data"]["file_uri"]
387
- image = download_image_from_gcs(image_uri=image_uri)
393
+ image = download_image_from_gcs(image_uri=image_uri, client=storage_client)
388
394
  response = orjson.loads(response_str)
389
395
  jsonschema.validate(response, json_schema)
390
396
 
labelr/main.py CHANGED
@@ -4,11 +4,13 @@ import typer
4
4
  from openfoodfacts.utils import get_logger
5
5
 
6
6
  from labelr.apps import datasets as dataset_app
7
+ from labelr.apps import directus as directus_app
7
8
  from labelr.apps import evaluate as evaluate_app
8
9
  from labelr.apps import google_batch as google_batch_app
9
10
  from labelr.apps import hugging_face as hf_app
10
11
  from labelr.apps import label_studio as ls_app
11
12
  from labelr.apps import train as train_app
13
+ from labelr import config as _config
12
14
 
13
15
  app = typer.Typer(pretty_exceptions_show_locals=False)
14
16
 
@@ -60,6 +62,17 @@ def predict(
60
62
  typer.echo(result)
61
63
 
62
64
 
65
+ @app.command()
66
+ def config(name: str, value: str):
67
+ """Set a Labelr configuration value.
68
+
69
+ The configuration is stored in a JSON file at ~/.config/.labelr/config.json.
70
+ """
71
+ typer.echo(f"Set '{name}' to '{value}'")
72
+ _config.set_file_config(name, value)
73
+ typer.echo(f"Configuration saved to {_config.CONFIG_PATH}")
74
+
75
+
63
76
  app.add_typer(
64
77
  ls_app.app,
65
78
  name="ls",
@@ -90,6 +103,9 @@ app.add_typer(
90
103
  name="google-batch",
91
104
  help="Generate datasets and launch batch jobs on Google Gemini.",
92
105
  )
106
+ app.add_typer(
107
+ directus_app.app, name="directus", help="Manage directus collections and items."
108
+ )
93
109
 
94
110
  if __name__ == "__main__":
95
111
  app()
File without changes
@@ -0,0 +1,17 @@
1
+ import datasets
2
+
3
+ HF_DS_CLASSIFICATION_FEATURES = datasets.Features(
4
+ {
5
+ "image_id": datasets.Value("string"),
6
+ "image": datasets.features.Image(),
7
+ "width": datasets.Value("int64"),
8
+ "height": datasets.Value("int64"),
9
+ "meta": {
10
+ "barcode": datasets.Value("string"),
11
+ "off_image_id": datasets.Value("string"),
12
+ "image_url": datasets.Value("string"),
13
+ },
14
+ "category_id": datasets.Value("int64"),
15
+ "category_name": datasets.Value("string"),
16
+ }
17
+ )
@@ -0,0 +1,14 @@
1
+ from pydantic import BaseModel, Field
2
+
3
+
4
+ class SampleMeta(BaseModel):
5
+ barcode: str | None = Field(
6
+ ..., description="The barcode of the product, if applicable"
7
+ )
8
+ off_image_id: str | None = Field(
9
+ ...,
10
+ description="The Open Food Facts image ID associated with the image, if applicable",
11
+ )
12
+ image_url: str | None = Field(
13
+ ..., description="The URL of the image, if applicable"
14
+ )
labelr/sample/llm.py ADDED
@@ -0,0 +1,75 @@
1
+ import typing
2
+ from collections.abc import Iterator
3
+ from pathlib import Path
4
+
5
+ import datasets
6
+ import orjson
7
+ from PIL import Image
8
+ from pydantic import BaseModel, Field
9
+
10
+ from labelr.sample.common import SampleMeta
11
+ from labelr.utils import download_image
12
+
13
+
14
+ class LLMImageExtractionSample(BaseModel):
15
+ class Config:
16
+ # required to allow PIL Image type
17
+ arbitrary_types_allowed = True
18
+
19
+ image_id: str = Field(
20
+ ...,
21
+ description="unique ID for the image. For Open Food Facts images, it follows the "
22
+ "format `barcode:imgid`",
23
+ )
24
+ image: Image.Image = Field(..., description="Image to extract information from")
25
+ output: str | None = Field(..., description="Expected response of the LLM")
26
+ meta: SampleMeta = Field(..., description="Metadata associated with the sample")
27
+
28
+
29
+ HF_DS_LLM_IMAGE_EXTRACTION_FEATURES = datasets.Features(
30
+ {
31
+ "image_id": datasets.Value("string"),
32
+ "image": datasets.features.Image(),
33
+ "output": datasets.features.Value("string"),
34
+ "meta": {
35
+ "barcode": datasets.Value("string"),
36
+ "off_image_id": datasets.Value("string"),
37
+ "image_url": datasets.Value("string"),
38
+ },
39
+ }
40
+ )
41
+
42
+
43
+ def load_llm_image_extraction_dataset_from_jsonl(
44
+ dataset_path: Path, **kwargs
45
+ ) -> Iterator[LLMImageExtractionSample]:
46
+ """Load a Hugging Face dataset for LLM image extraction from a JSONL file.
47
+
48
+ Args:
49
+ dataset_path (Path): Path to the JSONL dataset file.
50
+ **kwargs: Additional keyword arguments to pass to the image downloader.
51
+ Yields:
52
+ Iterator[LLMImageExtractionSample]: Iterator of LLM image extraction
53
+ samples.
54
+ """
55
+ with dataset_path.open("r") as f:
56
+ for line in f:
57
+ item = orjson.loads(line)
58
+ image_id = item["image_id"]
59
+ image_url = item["image_url"]
60
+ image = typing.cast(Image.Image, download_image(image_url, **kwargs))
61
+ barcode = item.pop("barcode", None)
62
+ off_image_id = item.pop("off_image_id", None)
63
+ output = item.pop("output", None)
64
+ meta = SampleMeta(
65
+ barcode=barcode,
66
+ off_image_id=off_image_id,
67
+ image_url=image_url,
68
+ )
69
+ sample = LLMImageExtractionSample(
70
+ image_id=image_id,
71
+ image=image,
72
+ output=output,
73
+ meta=meta,
74
+ )
75
+ yield sample