hafnia 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. hafnia/__init__.py +1 -1
  2. hafnia/dataset/dataset_names.py +128 -15
  3. hafnia/dataset/dataset_recipe/dataset_recipe.py +3 -3
  4. hafnia/dataset/dataset_upload_helper.py +31 -26
  5. hafnia/dataset/format_conversions/{image_classification_from_directory.py → format_image_classification_folder.py} +14 -10
  6. hafnia/dataset/format_conversions/format_yolo.py +164 -0
  7. hafnia/dataset/format_conversions/torchvision_datasets.py +10 -4
  8. hafnia/dataset/hafnia_dataset.py +246 -72
  9. hafnia/dataset/operations/dataset_stats.py +82 -70
  10. hafnia/dataset/operations/dataset_transformations.py +102 -37
  11. hafnia/dataset/operations/table_transformations.py +132 -15
  12. hafnia/dataset/primitives/bbox.py +3 -5
  13. hafnia/dataset/primitives/bitmask.py +2 -7
  14. hafnia/dataset/primitives/classification.py +3 -3
  15. hafnia/dataset/primitives/polygon.py +2 -4
  16. hafnia/dataset/primitives/primitive.py +1 -1
  17. hafnia/dataset/primitives/segmentation.py +2 -2
  18. hafnia/platform/datasets.py +4 -8
  19. hafnia/platform/download.py +1 -72
  20. hafnia/torch_helpers.py +12 -12
  21. hafnia/utils.py +1 -1
  22. hafnia/visualizations/image_visualizations.py +2 -0
  23. {hafnia-0.4.0.dist-info → hafnia-0.4.2.dist-info}/METADATA +4 -4
  24. hafnia-0.4.2.dist-info/RECORD +57 -0
  25. hafnia-0.4.2.dist-info/entry_points.txt +2 -0
  26. {cli → hafnia_cli}/__main__.py +2 -2
  27. {cli → hafnia_cli}/config.py +2 -2
  28. {cli → hafnia_cli}/dataset_cmds.py +2 -2
  29. {cli → hafnia_cli}/dataset_recipe_cmds.py +1 -1
  30. {cli → hafnia_cli}/experiment_cmds.py +1 -1
  31. {cli → hafnia_cli}/profile_cmds.py +2 -2
  32. {cli → hafnia_cli}/runc_cmds.py +1 -1
  33. {cli → hafnia_cli}/trainer_package_cmds.py +2 -2
  34. hafnia-0.4.0.dist-info/RECORD +0 -56
  35. hafnia-0.4.0.dist-info/entry_points.txt +0 -2
  36. {hafnia-0.4.0.dist-info → hafnia-0.4.2.dist-info}/WHEEL +0 -0
  37. {hafnia-0.4.0.dist-info → hafnia-0.4.2.dist-info}/licenses/LICENSE +0 -0
  38. {cli → hafnia_cli}/__init__.py +0 -0
  39. {cli → hafnia_cli}/consts.py +0 -0
  40. {cli → hafnia_cli}/keychain.py +0 -0
@@ -1,5 +1,5 @@
1
1
  from pathlib import Path
2
- from typing import List, Optional, Type
2
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Type
3
3
 
4
4
  import polars as pl
5
5
  from rich.progress import track
@@ -7,8 +7,8 @@ from rich.progress import track
7
7
  from hafnia.dataset.dataset_names import (
8
8
  FILENAME_ANNOTATIONS_JSONL,
9
9
  FILENAME_ANNOTATIONS_PARQUET,
10
- ColumnName,
11
- FieldName,
10
+ PrimitiveField,
11
+ SampleField,
12
12
  )
13
13
  from hafnia.dataset.operations import table_transformations
14
14
  from hafnia.dataset.primitives import PRIMITIVE_TYPES
@@ -16,9 +16,15 @@ from hafnia.dataset.primitives.classification import Classification
16
16
  from hafnia.dataset.primitives.primitive import Primitive
17
17
  from hafnia.log import user_logger
18
18
 
19
+ if TYPE_CHECKING:
20
+ from hafnia.dataset.hafnia_dataset import TaskInfo
21
+
19
22
 
20
23
  def create_primitive_table(
21
- samples_table: pl.DataFrame, PrimitiveType: Type[Primitive], keep_sample_data: bool = False
24
+ samples_table: pl.DataFrame,
25
+ PrimitiveType: Type[Primitive],
26
+ keep_sample_data: bool = False,
27
+ task_name: Optional[str] = None,
22
28
  ) -> Optional[pl.DataFrame]:
23
29
  """
24
30
  Returns a DataFrame with objects of the specified primitive type.
@@ -48,6 +54,9 @@ def create_primitive_table(
48
54
  objects_df = remove_no_object_frames.explode(column_name).unnest(column_name)
49
55
  else:
50
56
  objects_df = remove_no_object_frames.select(pl.col(column_name).explode().struct.unnest())
57
+
58
+ if task_name is not None:
59
+ objects_df = objects_df.filter(pl.col(PrimitiveField.TASK_NAME) == task_name)
51
60
  return objects_df
52
61
 
53
62
 
@@ -55,11 +64,12 @@ def merge_samples(samples0: pl.DataFrame, samples1: pl.DataFrame) -> pl.DataFram
55
64
  has_same_schema = samples0.schema == samples1.schema
56
65
  if not has_same_schema:
57
66
  shared_columns = []
58
- for column_name, column_type in samples0.schema.items():
67
+ for column_name, s0_column_type in samples0.schema.items():
59
68
  if column_name not in samples1.schema:
60
69
  continue
70
+ samples0, samples1 = correction_of_list_struct_primitives(samples0, samples1, column_name)
61
71
 
62
- if column_type != samples1.schema[column_name]:
72
+ if samples0.schema[column_name] != samples1.schema[column_name]:
63
73
  continue
64
74
  shared_columns.append(column_name)
65
75
 
@@ -79,16 +89,58 @@ def merge_samples(samples0: pl.DataFrame, samples1: pl.DataFrame) -> pl.DataFram
79
89
  samples0 = samples0.select(list(shared_columns))
80
90
  samples1 = samples1.select(list(shared_columns))
81
91
  merged_samples = pl.concat([samples0, samples1], how="vertical")
82
- merged_samples = merged_samples.drop(ColumnName.SAMPLE_INDEX).with_row_index(name=ColumnName.SAMPLE_INDEX)
92
+ merged_samples = add_sample_index(merged_samples)
83
93
  return merged_samples
84
94
 
85
95
 
96
+ def correction_of_list_struct_primitives(
97
+ samples0: pl.DataFrame,
98
+ samples1: pl.DataFrame,
99
+ column_name: str,
100
+ ) -> Tuple[pl.DataFrame, pl.DataFrame]:
101
+ """
102
+ Corrects primitive columns (bboxes, polygons etc of type 'list[struct]') by removing non-matching struct fields
103
+ between two datasets. This is useful when merging two datasets with the same primitive (e.g. Bbox), where
104
+ some (less important) field types in the struct differ between the two datasets.
105
+ This issue often occurs with the 'meta' field as different dataset formats may store different metadata information.
106
+ """
107
+ s0_column_type = samples0.schema[column_name]
108
+ s1_column_type = samples1.schema[column_name]
109
+ is_list_structs = s1_column_type == pl.List(pl.Struct) and s0_column_type == pl.List(pl.Struct)
110
+ is_non_matching_types = s1_column_type != s0_column_type
111
+ if is_list_structs and is_non_matching_types: # Only perform correction for list[struct] types that do not match
112
+ s0_fields = set(s0_column_type.inner.fields)
113
+ s1_fields = set(s1_column_type.inner.fields)
114
+ similar_fields = s0_fields.intersection(s1_fields)
115
+ s0_dropped_fields = s0_fields - similar_fields
116
+ if len(s0_dropped_fields) > 0:
117
+ samples0 = samples0.with_columns(
118
+ pl.col(column_name)
119
+ .list.eval(pl.struct([pl.element().struct.field(k.name) for k in similar_fields]))
120
+ .alias(column_name)
121
+ )
122
+ s1_dropped_fields = s1_fields - similar_fields
123
+ if len(s1_dropped_fields) > 0:
124
+ samples1 = samples1.with_columns(
125
+ pl.col(column_name)
126
+ .list.eval(pl.struct([pl.element().struct.field(k.name) for k in similar_fields]))
127
+ .alias(column_name)
128
+ )
129
+ user_logger.warning(
130
+ f"Primitive column '{column_name}' has none-matching fields in the two datasets. "
131
+ f"Dropping fields in samples0: {[f.name for f in s0_dropped_fields]}. "
132
+ f"Dropping fields in samples1: {[f.name for f in s1_dropped_fields]}."
133
+ )
134
+
135
+ return samples0, samples1
136
+
137
+
86
138
  def filter_table_for_class_names(
87
139
  samples_table: pl.DataFrame, class_names: List[str], PrimitiveType: Type[Primitive]
88
140
  ) -> Optional[pl.DataFrame]:
89
141
  table_with_selected_class_names = samples_table.filter(
90
142
  pl.col(PrimitiveType.column_name())
91
- .list.eval(pl.element().struct.field(FieldName.CLASS_NAME).is_in(class_names))
143
+ .list.eval(pl.element().struct.field(PrimitiveField.CLASS_NAME).is_in(class_names))
92
144
  .list.any()
93
145
  )
94
146
 
@@ -100,20 +152,20 @@ def split_primitive_columns_by_task_name(
100
152
  coordinate_types: Optional[List[Type[Primitive]]] = None,
101
153
  ) -> pl.DataFrame:
102
154
  """
103
- Convert Primitive columns such as "objects" (Bbox) into a column for each task name.
104
- For example, if the "objects" column (containing Bbox objects) has tasks "task1" and "task2".
155
+ Convert Primitive columns such as "bboxes" (Bbox) into a column for each task name.
156
+ For example, if the "bboxes" column (containing Bbox objects) has tasks "task1" and "task2".
105
157
 
106
158
 
107
159
  This:
108
160
  ─┬────────────┬─
109
- objects
161
+ bboxes
110
162
  ┆ --- ┆
111
163
  ┆ list[struc ┆
112
164
  ┆ t[11]] ┆
113
165
  ═╪════════════╪═
114
166
  becomes this:
115
167
  ─┬────────────┬────────────┬─
116
- objects. ┆ objects. ┆
168
+ bboxes. ┆ bboxes. ┆
117
169
  ┆ task1 ┆ task2 ┆
118
170
  ┆ --- ┆ --- ┆
119
171
  ┆ list[struc ┆ list[struc ┆
@@ -131,11 +183,11 @@ def split_primitive_columns_by_task_name(
131
183
  if samples_table[col_name].dtype != pl.List(pl.Struct):
132
184
  continue
133
185
 
134
- task_names = samples_table[col_name].explode().struct.field(FieldName.TASK_NAME).unique().to_list()
186
+ task_names = samples_table[col_name].explode().struct.field(PrimitiveField.TASK_NAME).unique().to_list()
135
187
  samples_table = samples_table.with_columns(
136
188
  [
137
189
  pl.col(col_name)
138
- .list.filter(pl.element().struct.field(FieldName.TASK_NAME).eq(task_name))
190
+ .list.filter(pl.element().struct.field(PrimitiveField.TASK_NAME).eq(task_name))
139
191
  .alias(f"{col_name}.{task_name}")
140
192
  for task_name in task_names
141
193
  ]
@@ -162,7 +214,7 @@ def read_samples_from_path(path: Path) -> pl.DataFrame:
162
214
 
163
215
  def check_image_paths(table: pl.DataFrame) -> bool:
164
216
  missing_files = []
165
- org_paths = table[ColumnName.FILE_PATH].to_list()
217
+ org_paths = table[SampleField.FILE_PATH].to_list()
166
218
  for org_path in track(org_paths, description="Check image paths"):
167
219
  org_path = Path(org_path)
168
220
  if not org_path.exists():
@@ -219,3 +271,68 @@ def unnest_classification_tasks(table: pl.DataFrame, strict: bool = True) -> pl.
219
271
 
220
272
  table_out = table_out.with_columns([pl.col(c).list.first() for c in classification_columns])
221
273
  return table_out
274
+
275
+
276
+ def update_class_indices(samples: pl.DataFrame, task: "TaskInfo") -> pl.DataFrame:
277
+ if task.class_names is None or len(task.class_names) == 0:
278
+ raise ValueError(f"Task '{task.name}' does not have defined class names to update class indices.")
279
+
280
+ objs = (
281
+ samples[task.primitive.column_name()]
282
+ .explode()
283
+ .struct.unnest()
284
+ .filter(pl.col(PrimitiveField.TASK_NAME) == task.name)
285
+ )
286
+ expected_class_names = set(objs[PrimitiveField.CLASS_NAME].unique())
287
+ missing_class_names = expected_class_names - set(task.class_names)
288
+ if len(missing_class_names) > 0:
289
+ raise ValueError(
290
+ f"Task '{task.name}' is missing class names: {missing_class_names}. Cannot update class indices."
291
+ )
292
+
293
+ name_2_idx_mapping = {name: idx for idx, name in enumerate(task.class_names)}
294
+
295
+ samples_updated = samples.with_columns(
296
+ pl.col(task.primitive.column_name())
297
+ .list.eval(
298
+ pl.element().struct.with_fields(
299
+ pl.when(pl.field(PrimitiveField.TASK_NAME) == task.name)
300
+ .then(pl.field(PrimitiveField.CLASS_NAME).replace_strict(name_2_idx_mapping, default=-1))
301
+ .otherwise(pl.field(PrimitiveField.CLASS_IDX))
302
+ .alias(PrimitiveField.CLASS_IDX)
303
+ )
304
+ )
305
+ .alias(task.primitive.column_name())
306
+ )
307
+
308
+ return samples_updated
309
+
310
+
311
+ def add_sample_index(samples: pl.DataFrame) -> pl.DataFrame:
312
+ """
313
+ Adds a sample index column to the samples DataFrame.
314
+
315
+ Note: Unlike the built-in 'polars.DataFrame.with_row_count', this function
316
+ always guarantees 'pl.UInt64' type for the index column.
317
+ """
318
+ if SampleField.SAMPLE_INDEX in samples.columns:
319
+ samples = samples.drop(SampleField.SAMPLE_INDEX)
320
+ samples = samples.select(
321
+ pl.int_range(0, pl.count(), dtype=pl.UInt64).alias(SampleField.SAMPLE_INDEX),
322
+ pl.all(),
323
+ )
324
+ return samples
325
+
326
+
327
+ def add_dataset_name_if_missing(table: pl.DataFrame, dataset_name: str) -> pl.DataFrame:
328
+ if SampleField.DATASET_NAME not in table.columns:
329
+ table = table.with_columns(pl.lit(dataset_name).alias(SampleField.DATASET_NAME))
330
+ else:
331
+ table = table.with_columns(
332
+ pl.when(pl.col(SampleField.DATASET_NAME).is_null())
333
+ .then(pl.lit(dataset_name))
334
+ .otherwise(pl.col(SampleField.DATASET_NAME))
335
+ .alias(SampleField.DATASET_NAME)
336
+ )
337
+
338
+ return table
@@ -33,9 +33,7 @@ class Bbox(Primitive):
33
33
  class_name: Optional[str] = Field(default=None, description="Class name, e.g. 'car'")
34
34
  class_idx: Optional[int] = Field(default=None, description="Class index, e.g. 0 for 'car' if it is the first class")
35
35
  object_id: Optional[str] = Field(default=None, description="Unique identifier for the object, e.g. '12345123'")
36
- confidence: Optional[float] = Field(
37
- default=None, description="Confidence score (0-1.0) for the primitive, e.g. 0.95 for Bbox"
38
- )
36
+ confidence: float = Field(default=1.0, description="Confidence score (0-1.0) for the primitive, e.g. 0.95 for Bbox")
39
37
  ground_truth: bool = Field(default=True, description="Whether this is ground truth or a prediction")
40
38
 
41
39
  task_name: str = Field(
@@ -45,11 +43,11 @@ class Bbox(Primitive):
45
43
 
46
44
  @staticmethod
47
45
  def default_task_name() -> str:
48
- return "bboxes"
46
+ return "object_detection"
49
47
 
50
48
  @staticmethod
51
49
  def column_name() -> str:
52
- return "objects"
50
+ return "bboxes"
53
51
 
54
52
  def calculate_area(self) -> float:
55
53
  return self.height * self.width
@@ -7,7 +7,6 @@ import numpy as np
7
7
  import pycocotools.mask as coco_mask
8
8
  from pydantic import Field
9
9
 
10
- from hafnia.dataset.dataset_names import FieldName
11
10
  from hafnia.dataset.primitives.primitive import Primitive
12
11
  from hafnia.dataset.primitives.utils import (
13
12
  anonymize_by_resizing,
@@ -16,8 +15,6 @@ from hafnia.dataset.primitives.utils import (
16
15
  text_org_from_left_bottom_to_centered,
17
16
  )
18
17
 
19
- FieldName
20
-
21
18
 
22
19
  class Bitmask(Primitive):
23
20
  # Names should match names in FieldName
@@ -34,9 +31,7 @@ class Bitmask(Primitive):
34
31
  class_name: Optional[str] = Field(default=None, description="Class name of the object represented by the bitmask")
35
32
  class_idx: Optional[int] = Field(default=None, description="Class index of the object represented by the bitmask")
36
33
  object_id: Optional[str] = Field(default=None, description="Object ID of the instance represented by the bitmask")
37
- confidence: Optional[float] = Field(
38
- default=None, description="Confidence score (0-1.0) for the primitive, e.g. 0.95 for Bbox"
39
- )
34
+ confidence: float = Field(default=1.0, description="Confidence score (0-1.0) for the primitive, e.g. 0.95 for Bbox")
40
35
  ground_truth: bool = Field(default=True, description="Whether this is ground truth or a prediction")
41
36
 
42
37
  task_name: str = Field(
@@ -46,7 +41,7 @@ class Bitmask(Primitive):
46
41
 
47
42
  @staticmethod
48
43
  def default_task_name() -> str:
49
- return "bitmask"
44
+ return "mask_detection"
50
45
 
51
46
  @staticmethod
52
47
  def column_name() -> str:
@@ -12,8 +12,8 @@ class Classification(Primitive):
12
12
  class_name: Optional[str] = Field(default=None, description="Class name, e.g. 'car'")
13
13
  class_idx: Optional[int] = Field(default=None, description="Class index, e.g. 0 for 'car' if it is the first class")
14
14
  object_id: Optional[str] = Field(default=None, description="Unique identifier for the object, e.g. '12345123'")
15
- confidence: Optional[float] = Field(
16
- default=None, description="Confidence score (0-1.0) for the primitive, e.g. 0.95 for Classification"
15
+ confidence: float = Field(
16
+ default=1.0, description="Confidence score (0-1.0) for the primitive, e.g. 0.95 for Classification"
17
17
  )
18
18
  ground_truth: bool = Field(default=True, description="Whether this is ground truth or a prediction")
19
19
 
@@ -27,7 +27,7 @@ class Classification(Primitive):
27
27
 
28
28
  @staticmethod
29
29
  def default_task_name() -> str:
30
- return "classification"
30
+ return "image_classification"
31
31
 
32
32
  @staticmethod
33
33
  def column_name() -> str:
@@ -16,9 +16,7 @@ class Polygon(Primitive):
16
16
  class_name: Optional[str] = Field(default=None, description="Class name of the polygon")
17
17
  class_idx: Optional[int] = Field(default=None, description="Class index of the polygon")
18
18
  object_id: Optional[str] = Field(default=None, description="Object ID of the polygon")
19
- confidence: Optional[float] = Field(
20
- default=None, description="Confidence score (0-1.0) for the primitive, e.g. 0.95 for Bbox"
21
- )
19
+ confidence: float = Field(default=1.0, description="Confidence score (0-1.0) for the primitive, e.g. 0.95 for Bbox")
22
20
  ground_truth: bool = Field(default=True, description="Whether this is ground truth or a prediction")
23
21
 
24
22
  task_name: str = Field(
@@ -40,7 +38,7 @@ class Polygon(Primitive):
40
38
 
41
39
  @staticmethod
42
40
  def default_task_name() -> str:
43
- return "polygon"
41
+ return "polygon_detection"
44
42
 
45
43
  @staticmethod
46
44
  def column_name() -> str:
@@ -22,7 +22,7 @@ class Primitive(BaseModel, metaclass=ABCMeta):
22
22
  def column_name() -> str:
23
23
  """
24
24
  Name of field used in hugging face datasets for storing annotations
25
- E.g. "objects" for Bbox.
25
+ E.g. "bboxes" for Bbox.
26
26
  """
27
27
  pass
28
28
 
@@ -24,11 +24,11 @@ class Segmentation(Primitive):
24
24
 
25
25
  @staticmethod
26
26
  def default_task_name() -> str:
27
- return "segmentation"
27
+ return "semantic_segmentation"
28
28
 
29
29
  @staticmethod
30
30
  def column_name() -> str:
31
- return "segmentation"
31
+ return "segmentations"
32
32
 
33
33
  def calculate_area(self) -> float:
34
34
  raise NotImplementedError()
@@ -11,9 +11,8 @@ import rich
11
11
  from rich import print as rprint
12
12
  from rich.progress import track
13
13
 
14
- from cli.config import Config
15
14
  from hafnia import http, utils
16
- from hafnia.dataset.dataset_names import DATASET_FILENAMES_REQUIRED, ColumnName
15
+ from hafnia.dataset.dataset_names import DATASET_FILENAMES_REQUIRED
17
16
  from hafnia.dataset.dataset_recipe.dataset_recipe import (
18
17
  DatasetRecipe,
19
18
  get_dataset_path_from_recipe,
@@ -23,6 +22,7 @@ from hafnia.http import fetch
23
22
  from hafnia.log import sys_logger, user_logger
24
23
  from hafnia.platform.download import get_resource_credentials
25
24
  from hafnia.utils import timed
25
+ from hafnia_cli.config import Config
26
26
 
27
27
 
28
28
  @timed("Fetching dataset list.")
@@ -120,15 +120,11 @@ def download_dataset_from_access_endpoint(
120
120
  return
121
121
  dataset = HafniaDataset.from_path(path_dataset, check_for_images=False)
122
122
  try:
123
- fast_copy_files_s3(
124
- src_paths=dataset.samples[ColumnName.REMOTE_PATH].to_list(),
125
- dst_paths=dataset.samples[ColumnName.FILE_PATH].to_list(),
126
- append_envs=envs,
127
- description="Downloading images",
128
- )
123
+ dataset = dataset.download_files_aws(path_dataset, aws_credentials=resource_credentials, force_redownload=True)
129
124
  except ValueError as e:
130
125
  user_logger.error(f"Failed to download images: {e}")
131
126
  return
127
+ dataset.write_annotations(path_folder=path_dataset) # Overwrite annotations as files have been re-downloaded
132
128
 
133
129
 
134
130
  def fast_copy_files_s3(
@@ -3,83 +3,12 @@ from typing import Dict, Optional
3
3
 
4
4
  import boto3
5
5
  from botocore.exceptions import ClientError
6
- from pydantic import BaseModel, field_validator
7
6
  from rich.progress import Progress
8
7
 
8
+ from hafnia.dataset.dataset_names import ResourceCredentials
9
9
  from hafnia.http import fetch
10
10
  from hafnia.log import sys_logger, user_logger
11
11
 
12
- ARN_PREFIX = "arn:aws:s3:::"
13
-
14
-
15
- class ResourceCredentials(BaseModel):
16
- access_key: str
17
- secret_key: str
18
- session_token: str
19
- s3_arn: str
20
- region: str
21
-
22
- @staticmethod
23
- def fix_naming(payload: Dict[str, str]) -> "ResourceCredentials":
24
- """
25
- The endpoint returns a payload with a key called 's3_path', but it
26
- is actually an ARN path (starts with arn:aws:s3::). This method renames it to 's3_arn' for consistency.
27
- """
28
- if "s3_path" in payload and payload["s3_path"].startswith(ARN_PREFIX):
29
- payload["s3_arn"] = payload.pop("s3_path")
30
-
31
- if "region" not in payload:
32
- payload["region"] = "eu-west-1"
33
- return ResourceCredentials(**payload)
34
-
35
- @field_validator("s3_arn")
36
- @classmethod
37
- def validate_s3_arn(cls, value: str) -> str:
38
- """Validate s3_arn to ensure it starts with 'arn:aws:s3:::'"""
39
- if not value.startswith("arn:aws:s3:::"):
40
- raise ValueError(f"Invalid S3 ARN: {value}. It should start with 'arn:aws:s3:::'")
41
- return value
42
-
43
- def s3_path(self) -> str:
44
- """
45
- Extracts the S3 path from the ARN.
46
- Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket/my-prefix
47
- """
48
- return self.s3_arn[len(ARN_PREFIX) :]
49
-
50
- def s3_uri(self) -> str:
51
- """
52
- Converts the S3 ARN to a URI format.
53
- Example: arn:aws:s3:::my-bucket/my-prefix -> s3://my-bucket/my-prefix
54
- """
55
- return f"s3://{self.s3_path()}"
56
-
57
- def bucket_name(self) -> str:
58
- """
59
- Extracts the bucket name from the S3 ARN.
60
- Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket
61
- """
62
- return self.s3_path().split("/")[0]
63
-
64
- def object_key(self) -> str:
65
- """
66
- Extracts the object key from the S3 ARN.
67
- Example: arn:aws:s3:::my-bucket/my-prefix -> my-prefix
68
- """
69
- return "/".join(self.s3_path().split("/")[1:])
70
-
71
- def aws_credentials(self) -> Dict[str, str]:
72
- """
73
- Returns the AWS credentials as a dictionary.
74
- """
75
- environment_vars = {
76
- "AWS_ACCESS_KEY_ID": self.access_key,
77
- "AWS_SECRET_ACCESS_KEY": self.secret_key,
78
- "AWS_SESSION_TOKEN": self.session_token,
79
- "AWS_REGION": self.region,
80
- }
81
- return environment_vars
82
-
83
12
 
84
13
  def get_resource_credentials(endpoint: str, api_key: str) -> ResourceCredentials:
85
14
  """
hafnia/torch_helpers.py CHANGED
@@ -9,7 +9,7 @@ from torchvision import tv_tensors
9
9
  from torchvision import utils as tv_utils
10
10
  from torchvision.transforms import v2
11
11
 
12
- from hafnia.dataset.dataset_names import FieldName
12
+ from hafnia.dataset.dataset_names import PrimitiveField
13
13
  from hafnia.dataset.hafnia_dataset import HafniaDataset, Sample
14
14
  from hafnia.dataset.primitives import (
15
15
  PRIMITIVE_COLUMN_NAMES,
@@ -68,8 +68,8 @@ class TorchvisionDataset(torch.utils.data.Dataset):
68
68
  for task_name, classifications in class_tasks.items():
69
69
  assert len(classifications) == 1, "Expected exactly one classification task per sample"
70
70
  target_flat[f"{Classification.column_name()}.{task_name}"] = {
71
- FieldName.CLASS_IDX: classifications[0].class_idx,
72
- FieldName.CLASS_NAME: classifications[0].class_name,
71
+ PrimitiveField.CLASS_IDX: classifications[0].class_idx,
72
+ PrimitiveField.CLASS_NAME: classifications[0].class_name,
73
73
  }
74
74
 
75
75
  bbox_tasks: Dict[str, List[Bbox]] = get_primitives_per_task_name_for_primitive(sample, Bbox)
@@ -77,8 +77,8 @@ class TorchvisionDataset(torch.utils.data.Dataset):
77
77
  bboxes_list = [bbox.to_coco(image_height=h, image_width=w) for bbox in bboxes]
78
78
  bboxes_tensor = torch.as_tensor(bboxes_list).reshape(-1, 4)
79
79
  target_flat[f"{Bbox.column_name()}.{task_name}"] = {
80
- FieldName.CLASS_IDX: [bbox.class_idx for bbox in bboxes],
81
- FieldName.CLASS_NAME: [bbox.class_name for bbox in bboxes],
80
+ PrimitiveField.CLASS_IDX: [bbox.class_idx for bbox in bboxes],
81
+ PrimitiveField.CLASS_NAME: [bbox.class_name for bbox in bboxes],
82
82
  "bbox": tv_tensors.BoundingBoxes(bboxes_tensor, format="XYWH", canvas_size=(h, w)),
83
83
  }
84
84
 
@@ -86,8 +86,8 @@ class TorchvisionDataset(torch.utils.data.Dataset):
86
86
  for task_name, bitmasks in bitmask_tasks.items():
87
87
  bitmasks_np = np.array([bitmask.to_mask(img_height=h, img_width=w) for bitmask in bitmasks])
88
88
  target_flat[f"{Bitmask.column_name()}.{task_name}"] = {
89
- FieldName.CLASS_IDX: [bitmask.class_idx for bitmask in bitmasks],
90
- FieldName.CLASS_NAME: [bitmask.class_name for bitmask in bitmasks],
89
+ PrimitiveField.CLASS_IDX: [bitmask.class_idx for bitmask in bitmasks],
90
+ PrimitiveField.CLASS_NAME: [bitmask.class_name for bitmask in bitmasks],
91
91
  "mask": tv_tensors.Mask(bitmasks_np),
92
92
  }
93
93
 
@@ -161,7 +161,7 @@ def draw_image_and_targets(
161
161
  if Bitmask.column_name() in targets:
162
162
  primitive_annotations = targets[Bitmask.column_name()]
163
163
  for task_name, task_annotations in primitive_annotations.items():
164
- colors = [class_color_by_name(class_name) for class_name in task_annotations[FieldName.CLASS_NAME]]
164
+ colors = [class_color_by_name(class_name) for class_name in task_annotations[PrimitiveField.CLASS_NAME]]
165
165
  visualize_image = tv_utils.draw_segmentation_masks(
166
166
  image=visualize_image,
167
167
  masks=task_annotations["mask"],
@@ -172,11 +172,11 @@ def draw_image_and_targets(
172
172
  primitive_annotations = targets[Bbox.column_name()]
173
173
  for task_name, task_annotations in primitive_annotations.items():
174
174
  bboxes = torchvision.ops.box_convert(task_annotations["bbox"], in_fmt="xywh", out_fmt="xyxy")
175
- colors = [class_color_by_name(class_name) for class_name in task_annotations[FieldName.CLASS_NAME]]
175
+ colors = [class_color_by_name(class_name) for class_name in task_annotations[PrimitiveField.CLASS_NAME]]
176
176
  visualize_image = tv_utils.draw_bounding_boxes(
177
177
  image=visualize_image,
178
178
  boxes=bboxes,
179
- labels=task_annotations[FieldName.CLASS_NAME],
179
+ labels=task_annotations[PrimitiveField.CLASS_NAME],
180
180
  width=2,
181
181
  colors=colors,
182
182
  )
@@ -187,9 +187,9 @@ def draw_image_and_targets(
187
187
  text_labels = []
188
188
  for task_name, task_annotations in primitive_annotations.items():
189
189
  if task_name == Classification.default_task_name():
190
- text_label = task_annotations[FieldName.CLASS_NAME]
190
+ text_label = task_annotations[PrimitiveField.CLASS_NAME]
191
191
  else:
192
- text_label = f"{task_name}: {task_annotations[FieldName.CLASS_NAME]}"
192
+ text_label = f"{task_name}: {task_annotations[PrimitiveField.CLASS_NAME]}"
193
193
  text_labels.append(text_label)
194
194
  visualize_image = draw_image_classification(visualize_image, text_labels)
195
195
  return visualize_image
hafnia/utils.py CHANGED
@@ -207,7 +207,7 @@ def is_hafnia_configured() -> bool:
207
207
  """
208
208
  Check if Hafnia is configured by verifying if the API key is set.
209
209
  """
210
- from cli.config import Config
210
+ from hafnia_cli.config import Config
211
211
 
212
212
  return Config().is_configured()
213
213
 
@@ -193,6 +193,8 @@ def save_dataset_sample_set_visualizations(
193
193
  image = draw_annotations(image, annotations, draw_settings=draw_settings)
194
194
 
195
195
  pil_image = Image.fromarray(image)
196
+ if sample.file_path is None:
197
+ raise ValueError("Sample has no file_path defined.")
196
198
  path_image = path_output_folder / Path(sample.file_path).name
197
199
  pil_image.save(path_image)
198
200
  paths.append(path_image)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hafnia
3
- Version: 0.4.0
3
+ Version: 0.4.2
4
4
  Summary: Python SDK for communication with Hafnia platform.
5
5
  Author-email: Milestone Systems <hafniaplatform@milestone.dk>
6
6
  License-File: LICENSE
@@ -158,7 +158,7 @@ and `dataset.samples` with annotations as a polars DataFrame
158
158
  print(dataset.samples.head(2))
159
159
  shape: (2, 14)
160
160
  ┌──────────────┬─────────────────────────────────┬────────┬───────┬───┬─────────────────────────────────┬──────────┬──────────┬─────────────────────────────────┐
161
- │ sample_index ┆ file_name ┆ height ┆ width ┆ … ┆ objects ┆ bitmasks ┆ polygons ┆ meta │
161
+ │ sample_index ┆ file_name ┆ height ┆ width ┆ … ┆ bboxes ┆ bitmasks ┆ polygons ┆ meta │
162
162
  │ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │
163
163
  │ u32 ┆ str ┆ i64 ┆ i64 ┆ ┆ list[struct[11]] ┆ null ┆ null ┆ struct[5] │
164
164
  ╞══════════════╪═════════════════════════════════╪════════╪═══════╪═══╪═════════════════════════════════╪══════════╪══════════╪═════════════════════════════════╡
@@ -218,7 +218,7 @@ sample_dict = dataset[0]
218
218
 
219
219
  for sample_dict in dataset:
220
220
  sample = Sample(**sample_dict)
221
- print(sample.sample_id, sample.objects)
221
+ print(sample.sample_id, sample.bboxes)
222
222
  break
223
223
  ```
224
224
  Not that it is possible to create a `Sample` object from the sample dictionary.
@@ -421,7 +421,7 @@ pil_image.save("visualized_labels.png")
421
421
 
422
422
  # Create DataLoaders - using TorchVisionCollateFn
423
423
  collate_fn = torch_helpers.TorchVisionCollateFn(
424
- skip_stacking=["objects.bbox", "objects.class_idx"]
424
+ skip_stacking=["bboxes.bbox", "bboxes.class_idx"]
425
425
  )
426
426
  train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
427
427
  ```