hafnia 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. cli/__main__.py +13 -2
  2. cli/config.py +2 -1
  3. cli/consts.py +1 -1
  4. cli/dataset_cmds.py +6 -14
  5. cli/dataset_recipe_cmds.py +78 -0
  6. cli/experiment_cmds.py +226 -43
  7. cli/profile_cmds.py +6 -5
  8. cli/runc_cmds.py +5 -5
  9. cli/trainer_package_cmds.py +65 -0
  10. hafnia/__init__.py +2 -0
  11. hafnia/data/factory.py +1 -2
  12. hafnia/dataset/dataset_helpers.py +0 -12
  13. hafnia/dataset/dataset_names.py +8 -4
  14. hafnia/dataset/dataset_recipe/dataset_recipe.py +119 -33
  15. hafnia/dataset/dataset_recipe/recipe_transforms.py +32 -4
  16. hafnia/dataset/dataset_recipe/recipe_types.py +1 -1
  17. hafnia/dataset/dataset_upload_helper.py +206 -53
  18. hafnia/dataset/hafnia_dataset.py +432 -194
  19. hafnia/dataset/license_types.py +63 -0
  20. hafnia/dataset/operations/dataset_stats.py +260 -3
  21. hafnia/dataset/operations/dataset_transformations.py +325 -4
  22. hafnia/dataset/operations/table_transformations.py +39 -2
  23. hafnia/dataset/primitives/__init__.py +8 -0
  24. hafnia/dataset/primitives/classification.py +1 -1
  25. hafnia/experiment/hafnia_logger.py +112 -0
  26. hafnia/http.py +16 -2
  27. hafnia/platform/__init__.py +9 -3
  28. hafnia/platform/builder.py +12 -10
  29. hafnia/platform/dataset_recipe.py +99 -0
  30. hafnia/platform/datasets.py +44 -6
  31. hafnia/platform/download.py +2 -1
  32. hafnia/platform/experiment.py +51 -56
  33. hafnia/platform/trainer_package.py +57 -0
  34. hafnia/utils.py +64 -13
  35. hafnia/visualizations/image_visualizations.py +3 -3
  36. {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/METADATA +34 -30
  37. hafnia-0.3.0.dist-info/RECORD +53 -0
  38. cli/recipe_cmds.py +0 -45
  39. hafnia-0.2.4.dist-info/RECORD +0 -49
  40. {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/WHEEL +0 -0
  41. {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/entry_points.txt +0 -0
  42. {hafnia-0.2.4.dist-info → hafnia-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import base64
3
4
  from datetime import datetime
4
5
  from enum import Enum
5
6
  from pathlib import Path
@@ -7,13 +8,9 @@ from typing import Dict, List, Optional, Tuple, Type, Union
7
8
 
8
9
  import boto3
9
10
  import polars as pl
10
- from pydantic import BaseModel, ConfigDict
11
+ from PIL import Image
12
+ from pydantic import BaseModel, ConfigDict, field_validator
11
13
 
12
- import hafnia.dataset.primitives.bbox
13
- import hafnia.dataset.primitives.bitmask
14
- import hafnia.dataset.primitives.classification
15
- import hafnia.dataset.primitives.polygon
16
- import hafnia.dataset.primitives.segmentation
17
14
  from cli.config import Config
18
15
  from hafnia.dataset import primitives
19
16
  from hafnia.dataset.dataset_names import (
@@ -23,11 +20,19 @@ from hafnia.dataset.dataset_names import (
23
20
  FieldName,
24
21
  SplitName,
25
22
  )
26
- from hafnia.dataset.hafnia_dataset import HafniaDataset, TaskInfo
23
+ from hafnia.dataset.hafnia_dataset import Attribution, HafniaDataset, Sample, TaskInfo
24
+ from hafnia.dataset.operations import table_transformations
25
+ from hafnia.dataset.primitives import (
26
+ Bbox,
27
+ Bitmask,
28
+ Classification,
29
+ Polygon,
30
+ Segmentation,
31
+ )
27
32
  from hafnia.dataset.primitives.primitive import Primitive
28
33
  from hafnia.http import post
29
34
  from hafnia.log import user_logger
30
- from hafnia.platform import get_dataset_id
35
+ from hafnia.platform.datasets import get_dataset_id
31
36
 
32
37
 
33
38
  def generate_bucket_name(dataset_name: str, deployment_stage: DeploymentStage) -> str:
@@ -53,7 +58,7 @@ class DbDataset(BaseModel, validate_assignment=True): # type: ignore[call-arg]
53
58
  annotation_ontology: Optional[str] = None
54
59
  dataset_variants: Optional[List[DbDatasetVariant]] = None
55
60
  split_annotations_reports: Optional[List[DbSplitAnnotationsReport]] = None
56
- dataset_images: Optional[List[DatasetImage]] = None
61
+ imgs: Optional[List[DatasetImage]] = None
57
62
 
58
63
 
59
64
  class DbDatasetVariant(BaseModel, validate_assignment=True): # type: ignore[call-arg]
@@ -75,6 +80,8 @@ class DbAnnotatedObject(BaseModel, validate_assignment=True): # type: ignore[ca
75
80
  model_config = ConfigDict(use_enum_values=True) # To parse Enum values as strings
76
81
  name: str
77
82
  entity_type: EntityTypeChoices
83
+ annotation_type: DbAnnotationType
84
+ task_name: Optional[str] = None # Not sure if adding task_name makes sense.
78
85
 
79
86
 
80
87
  class DbAnnotatedObjectReport(BaseModel, validate_assignment=True): # type: ignore[call-arg]
@@ -82,10 +89,34 @@ class DbAnnotatedObjectReport(BaseModel, validate_assignment=True): # type: ign
82
89
  obj: DbAnnotatedObject
83
90
  unique_obj_ids: Optional[int] = None
84
91
  obj_instances: Optional[int] = None
92
+ images_with_obj: Optional[int] = None
93
+
85
94
  average_count_per_image: Optional[float] = None
86
- avg_area: Optional[float] = None
87
- min_area: Optional[float] = None
88
- max_area: Optional[float] = None
95
+
96
+ area_avg_ratio: Optional[float] = None
97
+ area_min_ratio: Optional[float] = None
98
+ area_max_ratio: Optional[float] = None
99
+
100
+ height_avg_ratio: Optional[float] = None
101
+ height_min_ratio: Optional[float] = None
102
+ height_max_ratio: Optional[float] = None
103
+
104
+ width_avg_ratio: Optional[float] = None
105
+ width_min_ratio: Optional[float] = None
106
+ width_max_ratio: Optional[float] = None
107
+
108
+ area_avg_px: Optional[float] = None
109
+ area_min_px: Optional[int] = None
110
+ area_max_px: Optional[int] = None
111
+
112
+ height_avg_px: Optional[float] = None
113
+ height_min_px: Optional[int] = None
114
+ height_max_px: Optional[int] = None
115
+
116
+ width_avg_px: Optional[float] = None
117
+ width_min_px: Optional[int] = None
118
+ width_max_px: Optional[int] = None
119
+
89
120
  annotation_type: Optional[List[DbAnnotationType]] = None
90
121
 
91
122
 
@@ -155,8 +186,29 @@ class EntityTypeChoices(str, Enum): # Should match `EntityTypeChoices` in `dipd
155
186
  EVENT = "EVENT"
156
187
 
157
188
 
158
- class DatasetImage(BaseModel, validate_assignment=True): # type: ignore[call-arg]
159
- img: str
189
+ class DatasetImage(Attribution, validate_assignment=True): # type: ignore[call-arg]
190
+ img: str # Base64-encoded image string
191
+ order: Optional[int] = None
192
+
193
+ @field_validator("img", mode="before")
194
+ def validate_image_path(cls, v: Union[str, Path]) -> str:
195
+ if isinstance(v, Path):
196
+ v = path_image_to_base64_str(path_image=v)
197
+
198
+ if not isinstance(v, str):
199
+ raise ValueError("Image must be a string or Path object representing the image path.")
200
+
201
+ if not v.startswith("data:image/"):
202
+ raise ValueError("Image must be a base64-encoded data URL.")
203
+
204
+ return v
205
+
206
+
207
+ def path_image_to_base64_str(path_image: Path) -> str:
208
+ image = Image.open(path_image)
209
+ mime_format = Image.MIME[image.format]
210
+ as_b64 = base64.b64encode(path_image.read_bytes()).decode("ascii")
211
+ return f"data:{mime_format};base64,{as_b64}"
160
212
 
161
213
 
162
214
  class DbDistributionType(BaseModel, validate_assignment=True): # type: ignore[call-arg]
@@ -185,7 +237,10 @@ def get_folder_size(path: Path) -> int:
185
237
  return sum([path.stat().st_size for path in path.rglob("*")])
186
238
 
187
239
 
188
- def upload_to_hafnia_dataset_detail_page(dataset_update: DbDataset) -> dict:
240
+ def upload_to_hafnia_dataset_detail_page(dataset_update: DbDataset, upload_gallery_images: bool) -> dict:
241
+ if not upload_gallery_images:
242
+ dataset_update.imgs = None
243
+
189
244
  cfg = Config()
190
245
  dataset_details = dataset_update.model_dump_json()
191
246
  data = upload_dataset_details(cfg=cfg, data=dataset_details, dataset_name=dataset_update.name)
@@ -200,8 +255,8 @@ def upload_dataset_details(cfg: Config, data: str, dataset_name: str) -> dict:
200
255
  headers = {"Authorization": cfg.api_key}
201
256
 
202
257
  user_logger.info("Importing dataset details. This may take up to 30 seconds...")
203
- data = post(endpoint=import_endpoint, headers=headers, data=data) # type: ignore[assignment]
204
- return data # type: ignore[return-value]
258
+ response = post(endpoint=import_endpoint, headers=headers, data=data) # type: ignore[assignment]
259
+ return response # type: ignore[return-value]
205
260
 
206
261
 
207
262
  def get_resolutions(dataset: HafniaDataset, max_resolutions_selected: int = 8) -> List[DbResolution]:
@@ -219,7 +274,6 @@ def has_primitive(dataset: Union[HafniaDataset, pl.DataFrame], PrimitiveType: Ty
219
274
  col_name = PrimitiveType.column_name()
220
275
  table = dataset.samples if isinstance(dataset, HafniaDataset) else dataset
221
276
  if col_name not in table.columns:
222
- user_logger.warning(f"Warning: No field called '{col_name}' was found for '{PrimitiveType.__name__}'.")
223
277
  return False
224
278
 
225
279
  if table[col_name].dtype == pl.Null:
@@ -235,7 +289,7 @@ def calculate_distribution_values(
235
289
 
236
290
  if len(distribution_tasks) == 0:
237
291
  return []
238
- classification_column = hafnia.dataset.primitives.classification.Classification.column_name()
292
+ classification_column = Classification.column_name()
239
293
  classifications = dataset_split.select(pl.col(classification_column).explode())
240
294
  classifications = classifications.filter(pl.col(classification_column).is_not_null()).unnest(classification_column)
241
295
  classifications = classifications.filter(
@@ -277,6 +331,8 @@ def dataset_info_from_dataset(
277
331
  deployment_stage: DeploymentStage,
278
332
  path_sample: Optional[Path],
279
333
  path_hidden: Optional[Path],
334
+ path_gallery_images: Optional[Path] = None,
335
+ gallery_image_names: Optional[List[str]] = None,
280
336
  ) -> DbDataset:
281
337
  dataset_variants = []
282
338
  dataset_reports = []
@@ -292,6 +348,12 @@ def dataset_info_from_dataset(
292
348
  if len(path_and_variant) == 0:
293
349
  raise ValueError("At least one path must be provided for sample or hidden dataset.")
294
350
 
351
+ gallery_images = create_gallery_images(
352
+ dataset=dataset,
353
+ path_gallery_images=path_gallery_images,
354
+ gallery_image_names=gallery_image_names,
355
+ )
356
+
295
357
  for path_dataset, variant_type in path_and_variant:
296
358
  if variant_type == DatasetVariant.SAMPLE:
297
359
  dataset_variant = dataset.create_sample_dataset()
@@ -331,19 +393,26 @@ def dataset_info_from_dataset(
331
393
  )
332
394
 
333
395
  object_reports: List[DbAnnotatedObjectReport] = []
334
- primitive_columns = [tPrimtive.column_name() for tPrimtive in primitives.PRIMITIVE_TYPES]
335
- if has_primitive(dataset_split, PrimitiveType=hafnia.dataset.primitives.bbox.Bbox):
336
- bbox_column_name = hafnia.dataset.primitives.bbox.Bbox.column_name()
337
- drop_columns = [col for col in primitive_columns if col != bbox_column_name]
338
- drop_columns.append(FieldName.META)
339
- df_per_instance = dataset_split.rename({"height": "image.height", "width": "image.width"})
340
- df_per_instance = df_per_instance.explode(bbox_column_name).drop(drop_columns).unnest(bbox_column_name)
341
-
396
+ primitive_columns = [primitive.column_name() for primitive in primitives.PRIMITIVE_TYPES]
397
+ if has_primitive(dataset_split, PrimitiveType=Bbox):
398
+ df_per_instance = table_transformations.create_primitive_table(
399
+ dataset_split, PrimitiveType=Bbox, keep_sample_data=True
400
+ )
401
+ if df_per_instance is None:
402
+ raise ValueError(f"Expected {Bbox.__name__} primitive column to be present in the dataset split.")
342
403
  # Calculate area of bounding boxes
343
- df_per_instance = df_per_instance.with_columns((pl.col("height") * pl.col("width")).alias("area"))
404
+ df_per_instance = df_per_instance.with_columns(
405
+ (pl.col("height") * pl.col("width")).alias("area"),
406
+ ).with_columns(
407
+ (pl.col("height") * pl.col("image.height")).alias("height_px"),
408
+ (pl.col("width") * pl.col("image.width")).alias("width_px"),
409
+ (pl.col("area") * (pl.col("image.height") * pl.col("image.width"))).alias("area_px"),
410
+ )
344
411
 
345
412
  annotation_type = DbAnnotationType(name=AnnotationType.ObjectDetection.value)
346
- for (class_name,), class_group in df_per_instance.group_by(FieldName.CLASS_NAME):
413
+ for (class_name, task_name), class_group in df_per_instance.group_by(
414
+ FieldName.CLASS_NAME, FieldName.TASK_NAME
415
+ ):
347
416
  if class_name is None:
348
417
  continue
349
418
  object_reports.append(
@@ -351,25 +420,39 @@ def dataset_info_from_dataset(
351
420
  obj=DbAnnotatedObject(
352
421
  name=class_name,
353
422
  entity_type=EntityTypeChoices.OBJECT.value,
423
+ annotation_type=annotation_type,
424
+ task_name=task_name,
354
425
  ),
355
426
  unique_obj_ids=class_group[FieldName.OBJECT_ID].n_unique(),
356
427
  obj_instances=len(class_group),
357
428
  annotation_type=[annotation_type],
358
- avg_area=class_group["area"].mean(),
359
- min_area=class_group["area"].min(),
360
- max_area=class_group["area"].max(),
429
+ images_with_obj=class_group[ColumnName.SAMPLE_INDEX].n_unique(),
430
+ area_avg_ratio=class_group["area"].mean(),
431
+ area_min_ratio=class_group["area"].min(),
432
+ area_max_ratio=class_group["area"].max(),
433
+ height_avg_ratio=class_group["height"].mean(),
434
+ height_min_ratio=class_group["height"].min(),
435
+ height_max_ratio=class_group["height"].max(),
436
+ width_avg_ratio=class_group["width"].mean(),
437
+ width_min_ratio=class_group["width"].min(),
438
+ width_max_ratio=class_group["width"].max(),
439
+ area_avg_px=class_group["area_px"].mean(),
440
+ area_min_px=int(class_group["area_px"].min()),
441
+ area_max_px=int(class_group["area_px"].max()),
442
+ height_avg_px=class_group["height_px"].mean(),
443
+ height_min_px=int(class_group["height_px"].min()),
444
+ height_max_px=int(class_group["height_px"].max()),
445
+ width_avg_px=class_group["width_px"].mean(),
446
+ width_min_px=int(class_group["width_px"].min()),
447
+ width_max_px=int(class_group["width_px"].max()),
361
448
  average_count_per_image=len(class_group) / class_group[ColumnName.SAMPLE_INDEX].n_unique(),
362
449
  )
363
450
  )
364
451
 
365
- if has_primitive(dataset_split, PrimitiveType=hafnia.dataset.primitives.classification.Classification):
452
+ if has_primitive(dataset_split, PrimitiveType=Classification):
366
453
  annotation_type = DbAnnotationType(name=AnnotationType.ImageClassification.value)
367
- col_name = hafnia.dataset.primitives.classification.Classification.column_name()
368
- classification_tasks = [
369
- task.name
370
- for task in dataset.info.tasks
371
- if task.primitive == hafnia.dataset.primitives.classification.Classification
372
- ]
454
+ col_name = Classification.column_name()
455
+ classification_tasks = [task.name for task in dataset.info.tasks if task.primitive == Classification]
373
456
  has_classification_data = dataset_split[col_name].dtype != pl.List(pl.Null)
374
457
  if has_classification_data:
375
458
  classification_df = dataset_split.select(col_name).explode(col_name).unnest(col_name)
@@ -385,7 +468,7 @@ def dataset_info_from_dataset(
385
468
  ), class_group in classification_df.group_by(FieldName.TASK_NAME, FieldName.CLASS_NAME):
386
469
  if class_name is None:
387
470
  continue
388
- if task_name == hafnia.dataset.primitives.classification.Classification.default_task_name():
471
+ if task_name == Classification.default_task_name():
389
472
  display_name = class_name # Prefix class name with task name
390
473
  else:
391
474
  display_name = f"{task_name}.{class_name}"
@@ -394,6 +477,8 @@ def dataset_info_from_dataset(
394
477
  obj=DbAnnotatedObject(
395
478
  name=display_name,
396
479
  entity_type=EntityTypeChoices.EVENT.value,
480
+ annotation_type=annotation_type,
481
+ task_name=task_name,
397
482
  ),
398
483
  unique_obj_ids=len(
399
484
  class_group
@@ -403,22 +488,32 @@ def dataset_info_from_dataset(
403
488
  )
404
489
  )
405
490
 
406
- if has_primitive(dataset_split, PrimitiveType=hafnia.dataset.primitives.segmentation.Segmentation):
491
+ if has_primitive(dataset_split, PrimitiveType=Segmentation):
407
492
  raise NotImplementedError("Not Implemented yet")
408
493
 
409
- if has_primitive(dataset_split, PrimitiveType=hafnia.dataset.primitives.bitmask.Bitmask):
410
- col_name = hafnia.dataset.primitives.bitmask.Bitmask.column_name()
494
+ if has_primitive(dataset_split, PrimitiveType=Bitmask):
495
+ col_name = Bitmask.column_name()
411
496
  drop_columns = [col for col in primitive_columns if col != col_name]
412
497
  drop_columns.append(FieldName.META)
413
- df_per_instance = dataset_split.rename({"height": "image.height", "width": "image.width"})
414
- df_per_instance = df_per_instance.explode(col_name).drop(drop_columns).unnest(col_name)
415
498
 
416
- min_area = df_per_instance["area"].min() if "area" in df_per_instance.columns else None
417
- max_area = df_per_instance["area"].max() if "area" in df_per_instance.columns else None
418
- avg_area = df_per_instance["area"].mean() if "area" in df_per_instance.columns else None
499
+ df_per_instance = table_transformations.create_primitive_table(
500
+ dataset_split, PrimitiveType=Bitmask, keep_sample_data=True
501
+ )
502
+ if df_per_instance is None:
503
+ raise ValueError(
504
+ f"Expected {Bitmask.__name__} primitive column to be present in the dataset split."
505
+ )
506
+ df_per_instance = df_per_instance.rename({"height": "height_px", "width": "width_px"})
507
+ df_per_instance = df_per_instance.with_columns(
508
+ (pl.col("image.height") * pl.col("image.width") * pl.col("area")).alias("area_px"),
509
+ (pl.col("height_px") / pl.col("image.height")).alias("height"),
510
+ (pl.col("width_px") / pl.col("image.width")).alias("width"),
511
+ )
419
512
 
420
513
  annotation_type = DbAnnotationType(name=AnnotationType.InstanceSegmentation)
421
- for (class_name,), class_group in df_per_instance.group_by(FieldName.CLASS_NAME):
514
+ for (class_name, task_name), class_group in df_per_instance.group_by(
515
+ FieldName.CLASS_NAME, FieldName.TASK_NAME
516
+ ):
422
517
  if class_name is None:
423
518
  continue
424
519
  object_reports.append(
@@ -426,18 +521,36 @@ def dataset_info_from_dataset(
426
521
  obj=DbAnnotatedObject(
427
522
  name=class_name,
428
523
  entity_type=EntityTypeChoices.OBJECT.value,
524
+ annotation_type=annotation_type,
525
+ task_name=task_name,
429
526
  ),
430
527
  unique_obj_ids=class_group[FieldName.OBJECT_ID].n_unique(),
431
528
  obj_instances=len(class_group),
432
529
  annotation_type=[annotation_type],
433
530
  average_count_per_image=len(class_group) / class_group[ColumnName.SAMPLE_INDEX].n_unique(),
434
- avg_area=avg_area,
435
- min_area=min_area,
436
- max_area=max_area,
531
+ images_with_obj=class_group[ColumnName.SAMPLE_INDEX].n_unique(),
532
+ area_avg_ratio=class_group["area"].mean(),
533
+ area_min_ratio=class_group["area"].min(),
534
+ area_max_ratio=class_group["area"].max(),
535
+ height_avg_ratio=class_group["height"].mean(),
536
+ height_min_ratio=class_group["height"].min(),
537
+ height_max_ratio=class_group["height"].max(),
538
+ width_avg_ratio=class_group["width"].mean(),
539
+ width_min_ratio=class_group["width"].min(),
540
+ width_max_ratio=class_group["width"].max(),
541
+ area_avg_px=class_group["area_px"].mean(),
542
+ area_min_px=int(class_group["area_px"].min()),
543
+ area_max_px=int(class_group["area_px"].max()),
544
+ height_avg_px=class_group["height_px"].mean(),
545
+ height_min_px=int(class_group["height_px"].min()),
546
+ height_max_px=int(class_group["height_px"].max()),
547
+ width_avg_px=class_group["width_px"].mean(),
548
+ width_min_px=int(class_group["width_px"].min()),
549
+ width_max_px=int(class_group["width_px"].max()),
437
550
  )
438
551
  )
439
552
 
440
- if has_primitive(dataset_split, PrimitiveType=hafnia.dataset.primitives.polygon.Polygon):
553
+ if has_primitive(dataset_split, PrimitiveType=Polygon):
441
554
  raise NotImplementedError("Not Implemented yet")
442
555
 
443
556
  # Sort object reports by name to more easily compare between versions
@@ -463,6 +576,46 @@ def dataset_info_from_dataset(
463
576
  data_received_end=dataset_meta_info.get("data_received_end", None),
464
577
  annotation_project_id=dataset_meta_info.get("annotation_project_id", None),
465
578
  annotation_dataset_id=dataset_meta_info.get("annotation_dataset_id", None),
579
+ imgs=gallery_images,
466
580
  )
467
581
 
468
582
  return dataset_info
583
+
584
+
585
+ def create_gallery_images(
586
+ dataset: HafniaDataset,
587
+ path_gallery_images: Optional[Path],
588
+ gallery_image_names: Optional[List[str]],
589
+ ) -> Optional[List[DatasetImage]]:
590
+ gallery_images = None
591
+ if (gallery_image_names is not None) and (len(gallery_image_names) > 0):
592
+ if path_gallery_images is None:
593
+ raise ValueError("Path to gallery images must be provided.")
594
+ path_gallery_images.mkdir(parents=True, exist_ok=True)
595
+ COL_IMAGE_NAME = "image_name"
596
+ samples = dataset.samples.with_columns(
597
+ dataset.samples[ColumnName.FILE_NAME].str.split("/").list.last().alias(COL_IMAGE_NAME)
598
+ )
599
+ gallery_samples = samples.filter(pl.col(COL_IMAGE_NAME).is_in(gallery_image_names))
600
+
601
+ missing_gallery_samples = set(gallery_image_names) - set(gallery_samples[COL_IMAGE_NAME])
602
+ if len(missing_gallery_samples):
603
+ raise ValueError(f"Gallery images not found in dataset: {missing_gallery_samples}")
604
+ gallery_images = []
605
+ for gallery_sample in gallery_samples.iter_rows(named=True):
606
+ sample = Sample(**gallery_sample)
607
+ image = sample.draw_annotations()
608
+
609
+ path_gallery_image = path_gallery_images / gallery_sample[COL_IMAGE_NAME]
610
+ Image.fromarray(image).save(path_gallery_image)
611
+
612
+ dataset_image_dict = {
613
+ "img": path_gallery_image,
614
+ }
615
+ if sample.attribution is not None:
616
+ sample.attribution.changes = "Annotations have been visualized"
617
+ dataset_image_dict.update(sample.attribution.model_dump(exclude_none=True))
618
+ gallery_img = DatasetImage(**dataset_image_dict)
619
+ gallery_img.licenses = gallery_img.licenses or []
620
+ gallery_images.append(gallery_img)
621
+ return gallery_images