hafnia 0.1.27__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. cli/__main__.py +2 -2
  2. cli/config.py +17 -4
  3. cli/dataset_cmds.py +60 -0
  4. cli/runc_cmds.py +1 -1
  5. hafnia/data/__init__.py +2 -2
  6. hafnia/data/factory.py +12 -56
  7. hafnia/dataset/dataset_helpers.py +91 -0
  8. hafnia/dataset/dataset_names.py +72 -0
  9. hafnia/dataset/dataset_recipe/dataset_recipe.py +327 -0
  10. hafnia/dataset/dataset_recipe/recipe_transforms.py +53 -0
  11. hafnia/dataset/dataset_recipe/recipe_types.py +140 -0
  12. hafnia/dataset/dataset_upload_helper.py +468 -0
  13. hafnia/dataset/hafnia_dataset.py +624 -0
  14. hafnia/dataset/operations/dataset_stats.py +15 -0
  15. hafnia/dataset/operations/dataset_transformations.py +82 -0
  16. hafnia/dataset/operations/table_transformations.py +183 -0
  17. hafnia/dataset/primitives/__init__.py +16 -0
  18. hafnia/dataset/primitives/bbox.py +137 -0
  19. hafnia/dataset/primitives/bitmask.py +182 -0
  20. hafnia/dataset/primitives/classification.py +56 -0
  21. hafnia/dataset/primitives/point.py +25 -0
  22. hafnia/dataset/primitives/polygon.py +100 -0
  23. hafnia/dataset/primitives/primitive.py +44 -0
  24. hafnia/dataset/primitives/segmentation.py +51 -0
  25. hafnia/dataset/primitives/utils.py +51 -0
  26. hafnia/experiment/hafnia_logger.py +7 -7
  27. hafnia/helper_testing.py +108 -0
  28. hafnia/http.py +5 -3
  29. hafnia/platform/__init__.py +2 -2
  30. hafnia/platform/datasets.py +197 -0
  31. hafnia/platform/download.py +85 -23
  32. hafnia/torch_helpers.py +180 -95
  33. hafnia/utils.py +21 -2
  34. hafnia/visualizations/colors.py +267 -0
  35. hafnia/visualizations/image_visualizations.py +202 -0
  36. {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/METADATA +209 -99
  37. hafnia-0.2.1.dist-info/RECORD +50 -0
  38. cli/data_cmds.py +0 -53
  39. hafnia-0.1.27.dist-info/RECORD +0 -27
  40. {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/WHEEL +0 -0
  41. {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/entry_points.txt +0 -0
  42. {hafnia-0.1.27.dist-info → hafnia-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,15 +1,87 @@
1
1
  from pathlib import Path
2
- from typing import Any, Dict
2
+ from typing import Dict
3
3
 
4
4
  import boto3
5
5
  from botocore.exceptions import ClientError
6
+ from pydantic import BaseModel, field_validator
6
7
  from tqdm import tqdm
7
8
 
8
9
  from hafnia.http import fetch
9
10
  from hafnia.log import sys_logger, user_logger
10
11
 
11
-
12
- def get_resource_creds(endpoint: str, api_key: str) -> Dict[str, Any]:
12
+ ARN_PREFIX = "arn:aws:s3:::"
13
+
14
+
15
+ class ResourceCredentials(BaseModel):
16
+ access_key: str
17
+ secret_key: str
18
+ session_token: str
19
+ s3_arn: str
20
+ region: str
21
+
22
+ @staticmethod
23
+ def fix_naming(payload: Dict[str, str]) -> "ResourceCredentials":
24
+ """
25
+ The endpoint returns a payload with a key called 's3_path', but it
26
+ is actually an ARN path (starts with arn:aws:s3::). This method renames it to 's3_arn' for consistency.
27
+ """
28
+ if "s3_path" in payload and payload["s3_path"].startswith(ARN_PREFIX):
29
+ payload["s3_arn"] = payload.pop("s3_path")
30
+
31
+ if "region" not in payload:
32
+ payload["region"] = "eu-west-1"
33
+ return ResourceCredentials(**payload)
34
+
35
+ @field_validator("s3_arn")
36
+ @classmethod
37
+ def validate_s3_arn(cls, value: str) -> str:
38
+ """Validate s3_arn to ensure it starts with 'arn:aws:s3:::'"""
39
+ if not value.startswith("arn:aws:s3:::"):
40
+ raise ValueError(f"Invalid S3 ARN: {value}. It should start with 'arn:aws:s3:::'")
41
+ return value
42
+
43
+ def s3_path(self) -> str:
44
+ """
45
+ Extracts the S3 path from the ARN.
46
+ Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket/my-prefix
47
+ """
48
+ return self.s3_arn[len(ARN_PREFIX) :]
49
+
50
+ def s3_uri(self) -> str:
51
+ """
52
+ Converts the S3 ARN to a URI format.
53
+ Example: arn:aws:s3:::my-bucket/my-prefix -> s3://my-bucket/my-prefix
54
+ """
55
+ return f"s3://{self.s3_path()}"
56
+
57
+ def bucket_name(self) -> str:
58
+ """
59
+ Extracts the bucket name from the S3 ARN.
60
+ Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket
61
+ """
62
+ return self.s3_path().split("/")[0]
63
+
64
+ def object_key(self) -> str:
65
+ """
66
+ Extracts the object key from the S3 ARN.
67
+ Example: arn:aws:s3:::my-bucket/my-prefix -> my-prefix
68
+ """
69
+ return "/".join(self.s3_path().split("/")[1:])
70
+
71
+ def aws_credentials(self) -> Dict[str, str]:
72
+ """
73
+ Returns the AWS credentials as a dictionary.
74
+ """
75
+ environment_vars = {
76
+ "AWS_ACCESS_KEY_ID": self.access_key,
77
+ "AWS_SECRET_ACCESS_KEY": self.secret_key,
78
+ "AWS_SESSION_TOKEN": self.session_token,
79
+ "AWS_REGION": self.region,
80
+ }
81
+ return environment_vars
82
+
83
+
84
+ def get_resource_credentials(endpoint: str, api_key: str) -> ResourceCredentials:
13
85
  """
14
86
  Retrieve credentials for accessing the recipe stored in S3 (or another resource)
15
87
  by calling a DIP endpoint with the API key.
@@ -18,21 +90,16 @@ def get_resource_creds(endpoint: str, api_key: str) -> Dict[str, Any]:
18
90
  endpoint (str): The endpoint URL to fetch credentials from.
19
91
 
20
92
  Returns:
21
- Dict[str, Any]: Dictionary containing the credentials, for example:
22
- {
23
- "access_key": str,
24
- "secret_key": str,
25
- "session_token": str,
26
- "s3_path": str
27
- }
93
+ ResourceCredentials
28
94
 
29
95
  Raises:
30
96
  RuntimeError: If the call to fetch the credentials fails for any reason.
31
97
  """
32
98
  try:
33
- creds = fetch(endpoint, headers={"Authorization": api_key, "accept": "application/json"})
99
+ credentials_dict = fetch(endpoint, headers={"Authorization": api_key, "accept": "application/json"})
100
+ credentials = ResourceCredentials.fix_naming(credentials_dict)
34
101
  sys_logger.debug("Successfully retrieved credentials from DIP endpoint.")
35
- return creds
102
+ return credentials
36
103
  except Exception as e:
37
104
  sys_logger.error(f"Failed to fetch credentials from endpoint: {e}")
38
105
  raise RuntimeError(f"Failed to retrieve credentials: {e}") from e
@@ -76,23 +143,18 @@ def download_resource(resource_url: str, destination: str, api_key: str) -> Dict
76
143
  ValueError: If the S3 ARN is invalid or no objects found under prefix.
77
144
  RuntimeError: If S3 calls fail with an unexpected error.
78
145
  """
79
- res_creds = get_resource_creds(resource_url, api_key)
80
- s3_arn = res_creds["s3_path"]
81
- arn_prefix = "arn:aws:s3:::"
82
- if not s3_arn.startswith(arn_prefix):
83
- raise ValueError(f"Invalid S3 ARN: {s3_arn}")
146
+ res_credentials = get_resource_credentials(resource_url, api_key)
84
147
 
85
- s3_path = s3_arn[len(arn_prefix) :]
86
- bucket_name, *key_parts = s3_path.split("/")
87
- key = "/".join(key_parts)
148
+ bucket_name = res_credentials.bucket_name()
149
+ key = res_credentials.object_key()
88
150
 
89
151
  output_path = Path(destination)
90
152
  output_path.mkdir(parents=True, exist_ok=True)
91
153
  s3_client = boto3.client(
92
154
  "s3",
93
- aws_access_key_id=res_creds["access_key"],
94
- aws_secret_access_key=res_creds["secret_key"],
95
- aws_session_token=res_creds["session_token"],
155
+ aws_access_key_id=res_credentials.access_key,
156
+ aws_secret_access_key=res_credentials.secret_key,
157
+ aws_session_token=res_credentials.session_token,
96
158
  )
97
159
  downloaded_files = []
98
160
  try:
hafnia/torch_helpers.py CHANGED
@@ -1,80 +1,126 @@
1
- from typing import List, Optional
1
+ from typing import Dict, List, Optional, Tuple, Type, Union
2
2
 
3
- import datasets
3
+ import numpy as np
4
4
  import torch
5
5
  import torchvision
6
- from flatten_dict import flatten
6
+ from flatten_dict import flatten, unflatten
7
7
  from PIL import Image, ImageDraw, ImageFont
8
8
  from torchvision import tv_tensors
9
9
  from torchvision import utils as tv_utils
10
10
  from torchvision.transforms import v2
11
11
 
12
+ from hafnia.dataset.dataset_names import FieldName
13
+ from hafnia.dataset.hafnia_dataset import HafniaDataset, Sample
14
+ from hafnia.dataset.primitives import (
15
+ PRIMITIVE_COLUMN_NAMES,
16
+ class_color_by_name,
17
+ )
18
+ from hafnia.dataset.primitives.bbox import Bbox
19
+ from hafnia.dataset.primitives.bitmask import Bitmask
20
+ from hafnia.dataset.primitives.classification import Classification
21
+ from hafnia.dataset.primitives.primitive import Primitive
22
+ from hafnia.dataset.primitives.segmentation import Segmentation
23
+ from hafnia.log import user_logger
24
+
25
+
26
+ def get_primitives_per_task_name_for_primitive(
27
+ sample: Sample, PrimitiveType: Type[Primitive], split_by_task_name: bool = True
28
+ ) -> Dict[str, List[Primitive]]:
29
+ if not hasattr(sample, PrimitiveType.column_name()):
30
+ return {}
31
+
32
+ primitives = getattr(sample, PrimitiveType.column_name())
33
+ if primitives is None:
34
+ return {}
35
+
36
+ primitives_by_task_name: Dict[str, List[Primitive]] = {}
37
+ for primitive in primitives:
38
+ if primitive.task_name not in primitives_by_task_name:
39
+ primitives_by_task_name[primitive.task_name] = []
40
+ primitives_by_task_name[primitive.task_name].append(primitive)
41
+ return primitives_by_task_name
42
+
12
43
 
13
44
  class TorchvisionDataset(torch.utils.data.Dataset):
14
45
  def __init__(
15
46
  self,
16
- hf_dataset: datasets.Dataset,
47
+ dataset: HafniaDataset,
17
48
  transforms=None,
18
- classification_tasks: Optional[List] = None,
19
- object_tasks: Optional[List] = None,
20
- segmentation_tasks: Optional[List] = None,
21
49
  keep_metadata: bool = False,
22
50
  ):
23
- self.dataset = hf_dataset
51
+ self.dataset = dataset
52
+
24
53
  self.transforms = transforms
25
- self.object_tasks = object_tasks or ["objects"]
26
- self.segmentation_tasks = segmentation_tasks or ["segmentation"]
27
- self.classification_tasks = classification_tasks or ["classification"]
28
54
  self.keep_metadata = keep_metadata
29
55
 
30
- def __getitem__(self, idx):
31
- sample = self.dataset[idx]
32
-
33
- # For now, we expect a dataset to always have an image field
34
- image = tv_tensors.Image(sample.pop("image"))
35
-
36
- img_shape = image.shape[-2:]
37
- target = {}
56
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Dict]:
57
+ sample_dict = self.dataset[idx]
58
+ sample = Sample(**sample_dict)
59
+ image = tv_tensors.Image(sample.read_image_pillow())
60
+ h, w = image.shape[-2:]
61
+ target_flat = {}
62
+ mask_tasks: Dict[str, List[Segmentation]] = get_primitives_per_task_name_for_primitive(sample, Segmentation)
63
+ for task_name, masks in mask_tasks.items():
64
+ raise NotImplementedError("Segmentation tasks are not yet implemented")
65
+ # target[f"{mask.task_name}.mask"] = tv_tensors.Mask(mask.mask)
66
+
67
+ class_tasks: Dict[str, List] = get_primitives_per_task_name_for_primitive(sample, Classification)
68
+ for task_name, classifications in class_tasks.items():
69
+ assert len(classifications) == 1, "Expected exactly one classification task per sample"
70
+ target_flat[f"{Classification.column_name()}.{task_name}"] = {
71
+ FieldName.CLASS_IDX: classifications[0].class_idx,
72
+ FieldName.CLASS_NAME: classifications[0].class_name,
73
+ }
74
+
75
+ bbox_tasks: Dict[str, List[Bbox]] = get_primitives_per_task_name_for_primitive(sample, Bbox)
76
+ for task_name, bboxes in bbox_tasks.items():
77
+ bboxes_list = [bbox.to_coco(image_height=h, image_width=w) for bbox in bboxes]
78
+ bboxes_tensor = torch.as_tensor(bboxes_list).reshape(-1, 4)
79
+ target_flat[f"{Bbox.column_name()}.{task_name}"] = {
80
+ FieldName.CLASS_IDX: [bbox.class_idx for bbox in bboxes],
81
+ FieldName.CLASS_NAME: [bbox.class_name for bbox in bboxes],
82
+ "bbox": tv_tensors.BoundingBoxes(bboxes_tensor, format="XYWH", canvas_size=(h, w)),
83
+ }
84
+
85
+ bitmask_tasks: Dict[str, List[Bitmask]] = get_primitives_per_task_name_for_primitive(sample, Bitmask)
86
+ for task_name, bitmasks in bitmask_tasks.items():
87
+ bitmasks_np = np.array([bitmask.to_mask(img_height=h, img_width=w) for bitmask in bitmasks])
88
+ target_flat[f"{Bitmask.column_name()}.{task_name}"] = {
89
+ FieldName.CLASS_IDX: [bitmask.class_idx for bitmask in bitmasks],
90
+ FieldName.CLASS_NAME: [bitmask.class_name for bitmask in bitmasks],
91
+ "mask": tv_tensors.Mask(bitmasks_np),
92
+ }
38
93
 
39
- for segmentation_task in self.segmentation_tasks:
40
- if segmentation_task in sample:
41
- target[f"{segmentation_task}.mask"] = tv_tensors.Mask(sample[segmentation_task].pop("mask"))
42
-
43
- for classification_task in self.classification_tasks:
44
- if classification_task in sample:
45
- target[f"{classification_task}.class_idx"] = sample[classification_task].pop("class_idx")
46
-
47
- for object_task in self.object_tasks:
48
- if object_task in sample:
49
- bboxes_list = sample[object_task].pop("bbox")
50
- bboxes = tv_tensors.BoundingBoxes(bboxes_list, format="XYWH", canvas_size=img_shape)
51
- if bboxes.numel() == 0:
52
- bboxes = bboxes.reshape(-1, 4)
53
- target[f"{object_task}.bbox"] = bboxes
54
- target[f"{object_task}.class_idx"] = torch.tensor(sample[object_task].pop("class_idx"))
94
+ if self.transforms:
95
+ image, target_flat = self.transforms(image, target_flat)
55
96
 
56
97
  if self.keep_metadata:
57
- target.update(flatten(sample, reducer="dot"))
58
-
59
- if self.transforms:
60
- image, target = self.transforms(image, target)
98
+ sample_dict = sample_dict.copy()
99
+ drop_columns = PRIMITIVE_COLUMN_NAMES
100
+ for column in drop_columns:
101
+ if column in sample_dict:
102
+ sample_dict.pop(column)
61
103
 
104
+ target = flatten(target_flat, reducer="dot")
62
105
  return image, target
63
106
 
64
107
  def __len__(self):
65
108
  return len(self.dataset)
66
109
 
67
110
 
68
- def draw_image_classification(visualize_image: torch.Tensor, text_label: str) -> torch.Tensor:
111
+ def draw_image_classification(visualize_image: torch.Tensor, text_labels: Union[str, List[str]]) -> torch.Tensor:
112
+ if isinstance(text_labels, str):
113
+ text_labels = [text_labels]
114
+ text = "\n".join(text_labels)
69
115
  max_dim = max(visualize_image.shape[-2:])
70
- font_size = max(int(max_dim * 0.1), 10) # Minimum font size of 10
116
+ font_size = max(int(max_dim * 0.06), 10) # Minimum font size of 10
71
117
  txt_font = ImageFont.load_default(font_size)
72
118
  dummie_draw = ImageDraw.Draw(Image.new("RGB", (10, 10)))
73
- _, _, w, h = dummie_draw.textbbox((0, 0), text=text_label, font=txt_font) # type: ignore[arg-type]
119
+ _, _, w, h = dummie_draw.textbbox((0, 0), text=text, font=txt_font) # type: ignore[arg-type]
74
120
 
75
121
  text_image = Image.new("RGB", (int(w), int(h)))
76
122
  draw = ImageDraw.Draw(text_image)
77
- draw.text((0, 0), text=text_label, font=txt_font) # type: ignore[arg-type]
123
+ draw.text((0, 0), text=text, font=txt_font) # type: ignore[arg-type]
78
124
  text_tensor = v2.functional.to_image(text_image)
79
125
 
80
126
  height = text_tensor.shape[-2] + visualize_image.shape[-2]
@@ -94,77 +140,116 @@ def draw_image_classification(visualize_image: torch.Tensor, text_label: str) ->
94
140
  def draw_image_and_targets(
95
141
  image: torch.Tensor,
96
142
  targets,
97
- detection_tasks: Optional[List[str]] = None,
98
- segmentation_tasks: Optional[List[str]] = None,
99
- classification_tasks: Optional[List[str]] = None,
100
143
  ) -> torch.Tensor:
101
- detection_tasks = detection_tasks or ["objects"]
102
- segmentation_tasks = segmentation_tasks or ["segmentation"]
103
- classification_tasks = classification_tasks or ["classification"]
104
-
105
144
  visualize_image = image.clone()
106
145
  if visualize_image.is_floating_point():
107
146
  visualize_image = image - torch.min(image)
108
147
  visualize_image = visualize_image / visualize_image.max()
109
148
 
110
149
  visualize_image = v2.functional.to_dtype(visualize_image, torch.uint8, scale=True)
111
-
112
- for object_task in detection_tasks:
113
- bbox_field = f"{object_task}.bbox"
114
- if bbox_field in targets:
115
- hugging_face_format = "xywh"
116
- bbox = torchvision.ops.box_convert(targets[bbox_field], in_fmt=hugging_face_format, out_fmt="xyxy")
117
- class_names_field = f"{object_task}.class_name"
118
- class_names = targets.get(class_names_field, None)
119
- visualize_image = tv_utils.draw_bounding_boxes(visualize_image, bbox, labels=class_names, width=2)
120
-
121
- for segmentation_task in segmentation_tasks:
122
- mask_field = f"{segmentation_task}.mask"
123
- if mask_field in targets:
124
- mask = targets[mask_field].squeeze(0)
125
- masks_list = [mask == value for value in mask.unique()]
126
- masks = torch.stack(masks_list, dim=0).to(torch.bool)
127
- visualize_image = tv_utils.draw_segmentation_masks(visualize_image, masks=masks, alpha=0.5)
128
-
129
- for classification_task in classification_tasks:
130
- classification_field = f"{classification_task}.class_idx"
131
- if classification_field in targets:
132
- text_label = f"[{targets[classification_field]}]"
133
- classification_name_field = f"{classification_task}.class_name"
134
- if classification_name_field in targets:
135
- text_label = text_label + f" {targets[classification_name_field]}"
136
- visualize_image = draw_image_classification(visualize_image, text_label)
150
+ targets = unflatten(targets, splitter="dot") # Nested dictionary format
151
+ # NOTE: Order of drawing is important so visualizations are not overlapping in an undesired way
152
+ if Segmentation.column_name() in targets:
153
+ primitive_annotations = targets[Segmentation.column_name()]
154
+ for task_name, task_annotations in primitive_annotations.items():
155
+ raise NotImplementedError("Segmentation tasks are not yet implemented")
156
+ # mask = targets[mask_field].squeeze(0)
157
+ # masks_list = [mask == value for value in mask.unique()]
158
+ # masks = torch.stack(masks_list, dim=0).to(torch.bool)
159
+ # visualize_image = tv_utils.draw_segmentation_masks(visualize_image, masks=masks, alpha=0.5)
160
+
161
+ if Bitmask.column_name() in targets:
162
+ primitive_annotations = targets[Bitmask.column_name()]
163
+ for task_name, task_annotations in primitive_annotations.items():
164
+ colors = [class_color_by_name(class_name) for class_name in task_annotations[FieldName.CLASS_NAME]]
165
+ visualize_image = tv_utils.draw_segmentation_masks(
166
+ image=visualize_image,
167
+ masks=task_annotations["mask"],
168
+ colors=colors,
169
+ )
170
+
171
+ if Bbox.column_name() in targets:
172
+ primitive_annotations = targets[Bbox.column_name()]
173
+ for task_name, task_annotations in primitive_annotations.items():
174
+ bboxes = torchvision.ops.box_convert(task_annotations["bbox"], in_fmt="xywh", out_fmt="xyxy")
175
+ colors = [class_color_by_name(class_name) for class_name in task_annotations[FieldName.CLASS_NAME]]
176
+ visualize_image = tv_utils.draw_bounding_boxes(
177
+ image=visualize_image,
178
+ boxes=bboxes,
179
+ labels=task_annotations[FieldName.CLASS_NAME],
180
+ width=2,
181
+ colors=colors,
182
+ )
183
+
184
+ # Important that classification is drawn last as it will change image dimensions
185
+ if Classification.column_name() in targets:
186
+ primitive_annotations = targets[Classification.column_name()]
187
+ text_labels = []
188
+ for task_name, task_annotations in primitive_annotations.items():
189
+ if task_name == Classification.default_task_name():
190
+ text_label = task_annotations[FieldName.CLASS_NAME]
191
+ else:
192
+ text_label = f"{task_name}: {task_annotations[FieldName.CLASS_NAME]}"
193
+ text_labels.append(text_label)
194
+ visualize_image = draw_image_classification(visualize_image, text_labels)
137
195
  return visualize_image
138
196
 
139
197
 
140
198
  class TorchVisionCollateFn:
141
199
  def __init__(self, skip_stacking: Optional[List] = None):
142
200
  if skip_stacking is None:
143
- skip_stacking = []
144
- self.skip_stacking_list = skip_stacking
201
+ skip_stacking = [f"{Bbox.column_name()}.*", f"{Bitmask.column_name()}.*"]
202
+
203
+ self.wild_card_skip_stacking = []
204
+ self.skip_stacking_list = []
205
+ for skip_name in skip_stacking:
206
+ if skip_name.endswith("*"):
207
+ self.wild_card_skip_stacking.append(skip_name[:-1]) # Remove the trailing '*'
208
+ else:
209
+ self.skip_stacking_list.append(skip_name)
210
+
211
+ def skip_key_name(self, key_name: str) -> bool:
212
+ if key_name in self.skip_stacking_list:
213
+ return True
214
+ if any(key_name.startswith(wild_card) for wild_card in self.wild_card_skip_stacking):
215
+ return True
216
+ return False
145
217
 
146
218
  def __call__(self, batch):
147
219
  images, targets = tuple(zip(*batch, strict=False))
148
220
  if "image" not in self.skip_stacking_list:
149
221
  images = torch.stack(images)
150
222
 
151
- targets_modified = {k: [d[k] for d in targets] for k in targets[0]}
223
+ keys_min = set(targets[0])
224
+ keys_max = set(targets[0])
225
+ for target in targets:
226
+ keys_min = keys_min.intersection(target)
227
+ keys_max = keys_max.union(target)
228
+
229
+ if keys_min != keys_max:
230
+ user_logger.warning(
231
+ "Not all images in the batch contain the same targets. To solve for missing targets "
232
+ f"the following keys {keys_max - keys_min} are dropped from the batch "
233
+ )
234
+
235
+ targets_modified = {k: [d[k] for d in targets] for k in keys_min}
152
236
  for key_name, item_values in targets_modified.items():
153
- if key_name not in self.skip_stacking_list:
154
- first_element = item_values[0]
155
- if isinstance(first_element, torch.Tensor):
156
- item_values = torch.stack(item_values)
157
- elif isinstance(first_element, (int, float)):
158
- item_values = torch.tensor(item_values)
159
- elif isinstance(first_element, (str, list)):
160
- # Skip stacking for certain types such as strings and lists
161
- pass
162
- if isinstance(first_element, tv_tensors.Mask):
163
- item_values = tv_tensors.Mask(item_values)
164
- elif isinstance(first_element, tv_tensors.Image):
165
- item_values = tv_tensors.Image(item_values)
166
- elif isinstance(first_element, tv_tensors.BoundingBoxes):
167
- item_values = tv_tensors.BoundingBoxes(item_values)
168
- targets_modified[key_name] = item_values
237
+ if self.skip_key_name(key_name):
238
+ continue
239
+ first_element = item_values[0]
240
+ if isinstance(first_element, torch.Tensor):
241
+ item_values = torch.stack(item_values)
242
+ elif isinstance(first_element, (int, float)):
243
+ item_values = torch.tensor(item_values)
244
+ elif isinstance(first_element, (str, list)):
245
+ # Skip stacking for certain types such as strings and lists
246
+ pass
247
+ if isinstance(first_element, tv_tensors.Mask):
248
+ item_values = tv_tensors.Mask(item_values)
249
+ elif isinstance(first_element, tv_tensors.Image):
250
+ item_values = tv_tensors.Image(item_values)
251
+ elif isinstance(first_element, tv_tensors.BoundingBoxes):
252
+ item_values = tv_tensors.BoundingBoxes(item_values)
253
+ targets_modified[key_name] = item_values
169
254
 
170
255
  return images, targets_modified
hafnia/utils.py CHANGED
@@ -1,3 +1,4 @@
1
+ import hashlib
1
2
  import os
2
3
  import time
3
4
  import zipfile
@@ -14,7 +15,7 @@ from rich import print as rprint
14
15
  from hafnia.log import sys_logger, user_logger
15
16
 
16
17
  PATH_DATA = Path("./.data")
17
- PATH_DATASET = PATH_DATA / "datasets"
18
+ PATH_DATASETS = PATH_DATA / "datasets"
18
19
  PATH_RECIPES = PATH_DATA / "recipes"
19
20
  FILENAME_HAFNIAIGNORE = ".hafniaignore"
20
21
  DEFAULT_IGNORE_SPECIFICATION = [
@@ -132,6 +133,24 @@ def show_recipe_content(recipe_path: Path, style: str = "emoji", depth_limit: in
132
133
  user_logger.info(f"Recipe size: {size_human_readable(os.path.getsize(recipe_path))}. Max size 800 MiB")
133
134
 
134
135
 
135
- def is_remote_job() -> bool:
136
+ def is_hafnia_cloud_job() -> bool:
136
137
  """Check if the current job is running in HAFNIA cloud environment."""
137
138
  return os.getenv("HAFNIA_CLOUD", "false").lower() == "true"
139
+
140
+
141
+ def pascal_to_snake_case(name: str) -> str:
142
+ """
143
+ Convert PascalCase to snake_case.
144
+ """
145
+ return "".join(["_" + char.lower() if char.isupper() else char for char in name]).lstrip("_")
146
+
147
+
148
+ def snake_to_pascal_case(name: str) -> str:
149
+ """
150
+ Convert snake_case to PascalCase.
151
+ """
152
+ return "".join(word.capitalize() for word in name.split("_"))
153
+
154
+
155
+ def hash_from_string(s: str) -> str:
156
+ return hashlib.md5(s.encode("utf-8")).hexdigest()