hafnia 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING, Dict, Optional, Type
3
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
4
4
 
5
5
  import polars as pl
6
6
  import rich
@@ -8,7 +8,7 @@ from rich import print as rprint
8
8
  from rich.progress import track
9
9
  from rich.table import Table
10
10
 
11
- from hafnia.dataset.dataset_names import ColumnName, FieldName, SplitName
11
+ from hafnia.dataset.dataset_names import PrimitiveField, SampleField, SplitName
12
12
  from hafnia.dataset.operations.table_transformations import create_primitive_table
13
13
  from hafnia.dataset.primitives import PRIMITIVE_TYPES
14
14
  from hafnia.log import user_logger
@@ -18,14 +18,14 @@ if TYPE_CHECKING: # Using 'TYPE_CHECKING' to avoid circular imports during type
18
18
  from hafnia.dataset.primitives.primitive import Primitive
19
19
 
20
20
 
21
- def split_counts(dataset: HafniaDataset) -> Dict[str, int]:
21
+ def calculate_split_counts(dataset: HafniaDataset) -> Dict[str, int]:
22
22
  """
23
23
  Returns a dictionary with the counts of samples in each split of the dataset.
24
24
  """
25
- return dict(dataset.samples[ColumnName.SPLIT].value_counts().iter_rows())
25
+ return dict(dataset.samples[SampleField.SPLIT].value_counts().iter_rows())
26
26
 
27
27
 
28
- def class_counts_for_task(
28
+ def calculate_task_class_counts(
29
29
  dataset: HafniaDataset,
30
30
  primitive: Optional[Type[Primitive]] = None,
31
31
  task_name: Optional[str] = None,
@@ -53,7 +53,7 @@ def class_counts_for_task(
53
53
  dataset.samples[task.primitive.column_name()]
54
54
  .explode()
55
55
  .struct.unnest()
56
- .filter(pl.col(FieldName.TASK_NAME) == task.name)[FieldName.CLASS_NAME]
56
+ .filter(pl.col(PrimitiveField.TASK_NAME) == task.name)[PrimitiveField.CLASS_NAME]
57
57
  .value_counts()
58
58
  )
59
59
 
@@ -65,7 +65,7 @@ def class_counts_for_task(
65
65
  return class_counts
66
66
 
67
67
 
68
- def class_counts_all(dataset: HafniaDataset) -> Dict[str, int]:
68
+ def calculate_class_counts(dataset: HafniaDataset) -> List[Dict[str, Any]]:
69
69
  """
70
70
  Get class counts for all tasks in the dataset.
71
71
  The counts are returned as a dictionary where keys are in the format
@@ -74,25 +74,59 @@ def class_counts_all(dataset: HafniaDataset) -> Dict[str, int]:
74
74
  Example:
75
75
  >>> counts = dataset.class_counts_all()
76
76
  >>> print(counts)
77
- {
78
- objects/bboxes/car: 500
79
- objects/bboxes/person: 0
80
- classifications/weather/sunny: 300
81
- classifications/weather/rainy: 0
82
- ...
83
- }
77
+ [
78
+ {'Primitive': 'Bbox', 'Task Name': 'detection', 'Class Name': 'car', 'Count': 500},
79
+ {'Primitive': 'Bbox', 'Task Name': 'detection', 'Class Name': 'bus', 'Count': 100},
80
+ {'Primitive': 'Classification', 'Task Name': 'scene', 'Class Name': 'indoor', 'Count': 300},
81
+ {'Primitive': 'Classification', 'Task Name': 'scene', 'Class Name': 'outdoor', 'Count': 700},
82
+ ]
84
83
  """
85
- class_counts = {}
84
+ count_info = []
86
85
  for task in dataset.info.tasks:
87
- if task.class_names is None:
88
- raise ValueError(f"Task '{task.name}' does not have class names defined.")
89
- class_counts_task = dataset.class_counts_for_task(primitive=task.primitive, task_name=task.name)
86
+ class_name_counts = dataset.calculate_task_class_counts(task_name=task.name)
87
+ for name, counts in class_name_counts.items():
88
+ count_info.append(
89
+ {
90
+ "Primitive": task.primitive.__name__,
91
+ "Task Name": task.name,
92
+ "Class Name": name,
93
+ "Count": counts,
94
+ }
95
+ )
96
+ return count_info
90
97
 
91
- for class_idx, (class_name, count) in enumerate(class_counts_task.items()):
92
- count_name = f"{task.primitive.__name__}/{task.name}/{class_name}"
93
- class_counts[count_name] = count
94
98
 
95
- return class_counts
99
+ def calculate_primitive_counts(dataset: HafniaDataset) -> Dict[str, int]:
100
+ annotation_counts = {}
101
+ for task in dataset.info.tasks:
102
+ objects = dataset.create_primitive_table(task.primitive, task_name=task.name)
103
+ name = task.primitive.__name__
104
+ if task.name != task.primitive.default_task_name():
105
+ name = f"{name}.{task.name}"
106
+ annotation_counts[name] = len(objects)
107
+ return annotation_counts
108
+
109
+
110
+ def calculate_split_counts_extended(dataset: HafniaDataset) -> List[Dict[str, Any]]:
111
+ splits_sets = {
112
+ "All": SplitName.valid_splits(),
113
+ "Train": [SplitName.TRAIN],
114
+ "Validation": [SplitName.VAL],
115
+ "Test": [SplitName.TEST],
116
+ }
117
+ rows = []
118
+ for split_name, splits in splits_sets.items():
119
+ dataset_split = dataset.create_split_dataset(splits)
120
+ table = dataset_split.samples
121
+ row: Dict[str, Any] = {}
122
+ row["Split"] = split_name
123
+ row["Samples "] = str(len(table))
124
+
125
+ primitive_counts = calculate_primitive_counts(dataset_split)
126
+ row.update(primitive_counts)
127
+ rows.append(row)
128
+
129
+ return rows
96
130
 
97
131
 
98
132
  def print_stats(dataset: HafniaDataset) -> None:
@@ -118,10 +152,13 @@ def print_class_distribution(dataset: HafniaDataset) -> None:
118
152
  for task in dataset.info.tasks:
119
153
  if task.class_names is None:
120
154
  raise ValueError(f"Task '{task.name}' does not have class names defined.")
121
- class_counts = dataset.class_counts_for_task(primitive=task.primitive, task_name=task.name)
155
+ class_counts = dataset.calculate_task_class_counts(primitive=task.primitive, task_name=task.name)
122
156
 
123
157
  # Print class distribution
124
- rich_table = Table(title=f"Class Count for '{task.primitive.__name__}/{task.name}'", show_lines=False)
158
+ rich_table = Table(
159
+ title=f"Class Count for '{task.primitive.__name__}/{task.name}'",
160
+ show_lines=False,
161
+ )
125
162
  rich_table.add_column("Class Name", style="cyan")
126
163
  rich_table.add_column("Class Idx", style="cyan")
127
164
  rich_table.add_column("Count", justify="right")
@@ -136,32 +173,7 @@ def print_sample_and_task_counts(dataset: HafniaDataset) -> None:
136
173
  Prints a table with sample counts and task counts for each primitive type
137
174
  in total and for each split (train, val, test).
138
175
  """
139
- from hafnia.dataset.operations.table_transformations import create_primitive_table
140
- from hafnia.dataset.primitives import PRIMITIVE_TYPES
141
-
142
- splits_sets = {
143
- "All": SplitName.valid_splits(),
144
- "Train": [SplitName.TRAIN],
145
- "Validation": [SplitName.VAL],
146
- "Test": [SplitName.TEST],
147
- }
148
- rows = []
149
- for split_name, splits in splits_sets.items():
150
- dataset_split = dataset.create_split_dataset(splits)
151
- table = dataset_split.samples
152
- row = {}
153
- row["Split"] = split_name
154
- row["Sample "] = str(len(table))
155
- for PrimitiveType in PRIMITIVE_TYPES:
156
- column_name = PrimitiveType.column_name()
157
- objects_df = create_primitive_table(table, PrimitiveType=PrimitiveType, keep_sample_data=False)
158
- if objects_df is None:
159
- continue
160
- for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
161
- count = len(object_group[FieldName.CLASS_NAME])
162
- row[f"{PrimitiveType.__name__}\n{task_name}"] = str(count)
163
- rows.append(row)
164
-
176
+ rows = calculate_split_counts_extended(dataset)
165
177
  rich_table = Table(title="Dataset Statistics", show_lines=True, box=rich.box.SIMPLE)
166
178
  for i_row, row in enumerate(rows):
167
179
  if i_row == 0:
@@ -171,7 +183,7 @@ def print_sample_and_task_counts(dataset: HafniaDataset) -> None:
171
183
  rprint(rich_table)
172
184
 
173
185
 
174
- def check_dataset(dataset: HafniaDataset):
186
+ def check_dataset(dataset: HafniaDataset, check_splits: bool = True):
175
187
  """
176
188
  Performs various checks on the dataset to ensure its integrity and consistency.
177
189
  Raises errors if any issues are found.
@@ -181,21 +193,21 @@ def check_dataset(dataset: HafniaDataset):
181
193
  user_logger.info("Checking Hafnia dataset...")
182
194
  assert isinstance(dataset.info.dataset_name, str) and len(dataset.info.dataset_name) > 0
183
195
 
184
- sample_dataset = dataset.create_sample_dataset()
185
- if len(sample_dataset) == 0:
186
- raise ValueError("The dataset does not include a sample dataset")
196
+ if check_splits:
197
+ sample_dataset = dataset.create_sample_dataset()
198
+ if len(sample_dataset) == 0:
199
+ raise ValueError("The dataset does not include a sample dataset")
200
+
201
+ actual_splits = dataset.samples.select(pl.col(SampleField.SPLIT)).unique().to_series().to_list()
202
+ required_splits = SplitName.valid_splits()
187
203
 
188
- actual_splits = dataset.samples.select(pl.col(ColumnName.SPLIT)).unique().to_series().to_list()
189
- expected_splits = SplitName.valid_splits()
190
- if set(actual_splits) != set(expected_splits):
191
- raise ValueError(f"Expected all splits '{expected_splits}' in dataset, but got '{actual_splits}'. ")
204
+ if not set(required_splits).issubset(set(actual_splits)):
205
+ raise ValueError(f"Expected all splits '{required_splits}' in dataset, but got '{actual_splits}'. ")
192
206
 
193
207
  dataset.check_dataset_tasks()
194
208
 
195
209
  expected_tasks = dataset.info.tasks
196
- distribution = dataset.info.distributions or []
197
- distribution_names = [task.name for task in distribution]
198
- # Check that tasks found in the 'dataset.table' matches the tasks defined in 'dataset.info.tasks'
210
+ # Check that tasks found in the 'dataset.samples' matches the tasks defined in 'dataset.info.tasks'
199
211
  for PrimitiveType in PRIMITIVE_TYPES:
200
212
  column_name = PrimitiveType.column_name()
201
213
  if column_name not in dataset.samples.columns:
@@ -203,14 +215,14 @@ def check_dataset(dataset: HafniaDataset):
203
215
  objects_df = create_primitive_table(dataset.samples, PrimitiveType=PrimitiveType, keep_sample_data=False)
204
216
  if objects_df is None:
205
217
  continue
206
- for (task_name,), object_group in objects_df.group_by(FieldName.TASK_NAME):
218
+ for (task_name,), object_group in objects_df.group_by(PrimitiveField.TASK_NAME):
207
219
  has_task = any([t for t in expected_tasks if t.name == task_name and t.primitive == PrimitiveType])
208
- if has_task or (task_name in distribution_names):
220
+ if has_task:
209
221
  continue
210
- class_names = object_group[FieldName.CLASS_NAME].unique().to_list()
222
+ class_names = object_group[PrimitiveField.CLASS_NAME].unique().to_list()
211
223
  raise ValueError(
212
224
  f"Task name '{task_name}' for the '{PrimitiveType.__name__}' primitive is missing in "
213
- f"'dataset.info.tasks' for dataset '{task_name}'. Missing task has the following "
225
+ f"'dataset.info.tasks' for dataset '{dataset.info.dataset_name}'. Missing task has the following "
214
226
  f"classes: {class_names}. "
215
227
  )
216
228
 
@@ -237,7 +249,7 @@ def check_dataset_tasks(dataset: HafniaDataset):
237
249
 
238
250
  if len(dataset) > 0: # Check only performed for non-empty datasets
239
251
  primitive_table = (
240
- primitive_column.explode().struct.unnest().filter(pl.col(FieldName.TASK_NAME) == task.name)
252
+ primitive_column.explode().struct.unnest().filter(pl.col(PrimitiveField.TASK_NAME) == task.name)
241
253
  )
242
254
  if primitive_table.is_empty():
243
255
  raise ValueError(
@@ -245,7 +257,7 @@ def check_dataset_tasks(dataset: HafniaDataset):
245
257
  + f"the column '{column_name}' has no {task.name=} objects. Please check the dataset."
246
258
  )
247
259
 
248
- actual_classes = set(primitive_table[FieldName.CLASS_NAME].unique().to_list())
260
+ actual_classes = set(primitive_table[PrimitiveField.CLASS_NAME].unique().to_list())
249
261
  if task.class_names is None:
250
262
  raise ValueError(
251
263
  msg_something_wrong
@@ -260,12 +272,12 @@ def check_dataset_tasks(dataset: HafniaDataset):
260
272
  f"to be a subset of the defined classes\n\t{actual_classes=} \n\t{defined_classes=}."
261
273
  )
262
274
  # Check class_indices
263
- mapped_indices = primitive_table[FieldName.CLASS_NAME].map_elements(
275
+ mapped_indices = primitive_table[PrimitiveField.CLASS_NAME].map_elements(
264
276
  lambda x: task.class_names.index(x), return_dtype=pl.Int64
265
277
  )
266
- table_indices = primitive_table[FieldName.CLASS_IDX]
278
+ table_indices = primitive_table[PrimitiveField.CLASS_IDX]
267
279
 
268
280
  error_msg = msg_something_wrong + (
269
- f"class indices in '{FieldName.CLASS_IDX}' column does not match classes ordering in 'task.class_names'"
281
+ f"class indices in '{PrimitiveField.CLASS_IDX}' column does not match classes ordering in 'task.class_names'"
270
282
  )
271
283
  assert mapped_indices.equals(table_indices), error_msg
@@ -31,6 +31,7 @@ that the signatures match.
31
31
 
32
32
  import json
33
33
  import re
34
+ import shutil
34
35
  import textwrap
35
36
  from pathlib import Path
36
37
  from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Type, Union
@@ -39,17 +40,23 @@ import cv2
39
40
  import more_itertools
40
41
  import numpy as np
41
42
  import polars as pl
42
- from PIL import Image
43
43
  from rich.progress import track
44
44
 
45
45
  from hafnia.dataset import dataset_helpers
46
- from hafnia.dataset.dataset_names import OPS_REMOVE_CLASS, ColumnName, FieldName
46
+ from hafnia.dataset.dataset_names import (
47
+ OPS_REMOVE_CLASS,
48
+ PrimitiveField,
49
+ SampleField,
50
+ StorageFormat,
51
+ )
52
+ from hafnia.dataset.operations.table_transformations import update_class_indices
47
53
  from hafnia.dataset.primitives import get_primitive_type_from_string
48
54
  from hafnia.dataset.primitives.primitive import Primitive
55
+ from hafnia.log import user_logger
49
56
  from hafnia.utils import remove_duplicates_preserve_order
50
57
 
51
58
  if TYPE_CHECKING: # Using 'TYPE_CHECKING' to avoid circular imports during type checking
52
- from hafnia.dataset.hafnia_dataset import HafniaDataset, TaskInfo
59
+ from hafnia.dataset.hafnia_dataset import HafniaDataset, Sample, TaskInfo
53
60
 
54
61
 
55
62
  ### Image transformations ###
@@ -57,7 +64,7 @@ class AnonymizeByPixelation:
57
64
  def __init__(self, resize_factor: float = 0.10):
58
65
  self.resize_factor = resize_factor
59
66
 
60
- def __call__(self, frame: np.ndarray) -> np.ndarray:
67
+ def __call__(self, frame: np.ndarray, sample: "Sample") -> np.ndarray:
61
68
  org_size = frame.shape[:2]
62
69
  frame = cv2.resize(frame, (0, 0), fx=self.resize_factor, fy=self.resize_factor)
63
70
  frame = cv2.resize(frame, org_size[::-1], interpolation=cv2.INTER_NEAREST)
@@ -66,31 +73,100 @@ class AnonymizeByPixelation:
66
73
 
67
74
  def transform_images(
68
75
  dataset: "HafniaDataset",
69
- transform: Callable[[np.ndarray], np.ndarray],
76
+ transform: Callable[[np.ndarray, "Sample"], np.ndarray],
70
77
  path_output: Path,
78
+ description: str = "Transform images",
71
79
  ) -> "HafniaDataset":
80
+ from hafnia.dataset.hafnia_dataset import Sample
81
+
72
82
  new_paths = []
73
83
  path_image_folder = path_output / "data"
74
84
  path_image_folder.mkdir(parents=True, exist_ok=True)
75
85
 
76
- org_paths = dataset.samples[ColumnName.FILE_PATH].to_list()
77
- for org_path in track(org_paths, description="Transform images"):
78
- org_path = Path(org_path)
79
- if not org_path.exists():
80
- raise FileNotFoundError(f"File {org_path} does not exist in the dataset.")
81
-
82
- image = np.array(Image.open(org_path))
83
- image_transformed = transform(image)
86
+ for sample_dict in track(dataset, description=description):
87
+ sample = Sample(**sample_dict)
88
+ image = sample.read_image()
89
+ image_transformed = transform(image, sample)
84
90
  new_path = dataset_helpers.save_image_with_hash_name(image_transformed, path_image_folder)
85
91
 
86
92
  if not new_path.exists():
87
93
  raise FileNotFoundError(f"Transformed file {new_path} does not exist in the dataset.")
88
94
  new_paths.append(str(new_path))
89
95
 
90
- table = dataset.samples.with_columns(pl.Series(new_paths).alias(ColumnName.FILE_PATH))
96
+ table = dataset.samples.with_columns(pl.Series(new_paths).alias(SampleField.FILE_PATH))
91
97
  return dataset.update_samples(table)
92
98
 
93
99
 
100
+ def convert_to_image_storage_format(
101
+ dataset: "HafniaDataset",
102
+ path_output_folder: Path,
103
+ reextract_frames: bool,
104
+ image_format: str = "png",
105
+ transform: Optional[Callable[[np.ndarray, "Sample"], np.ndarray]] = None,
106
+ ) -> "HafniaDataset":
107
+ """
108
+ Convert a video-based dataset ("storage_format" == "video", FieldName.STORAGE_FORMAT == StorageFormat.VIDEO)
109
+ to an image-based dataset by extracting frames.
110
+ """
111
+ from hafnia.dataset.hafnia_dataset import HafniaDataset, Sample
112
+
113
+ path_images = path_output_folder / "data"
114
+ path_images.mkdir(parents=True, exist_ok=True)
115
+
116
+ # Only video format dataset samples are processed
117
+ video_based_samples = dataset.samples.filter(pl.col(SampleField.STORAGE_FORMAT) == StorageFormat.VIDEO)
118
+
119
+ if video_based_samples.is_empty():
120
+ user_logger.info("Dataset has no video-based samples. Returning dataset unchanged.")
121
+ return dataset
122
+
123
+ update_list = []
124
+ for (path_video,), video_samples in video_based_samples.group_by(SampleField.FILE_PATH):
125
+ assert Path(path_video).exists(), (
126
+ f"'{path_video}' not found. We expect the video to be downloaded to '{path_output_folder}'"
127
+ )
128
+ video = cv2.VideoCapture(str(path_video))
129
+
130
+ video_samples = video_samples.sort(SampleField.COLLECTION_INDEX)
131
+ for sample_dict in track(
132
+ video_samples.iter_rows(named=True),
133
+ total=video_samples.height,
134
+ description=f"Extracting frames from '{Path(path_video).name}'",
135
+ ):
136
+ frame_number = sample_dict[SampleField.COLLECTION_INDEX]
137
+ image_name = f"{Path(path_video).stem}_F{frame_number:06d}.{image_format}"
138
+ path_image = path_images / image_name
139
+
140
+ update_list.append(
141
+ {
142
+ SampleField.SAMPLE_INDEX: sample_dict[SampleField.SAMPLE_INDEX],
143
+ SampleField.COLLECTION_ID: sample_dict[SampleField.COLLECTION_ID],
144
+ SampleField.COLLECTION_INDEX: frame_number,
145
+ SampleField.FILE_PATH: path_image.as_posix(),
146
+ SampleField.STORAGE_FORMAT: StorageFormat.IMAGE,
147
+ }
148
+ )
149
+ if reextract_frames:
150
+ shutil.rmtree(path_image, ignore_errors=True)
151
+ if path_image.exists():
152
+ continue
153
+
154
+ video.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
155
+ ret, frame_org = video.read()
156
+ if not ret:
157
+ raise RuntimeError(f"Could not read frame {frame_number} from video '{path_video}'")
158
+
159
+ if transform is not None:
160
+ frame_org = transform(frame_org, Sample(**sample_dict))
161
+
162
+ cv2.imwrite(str(path_image), frame_org)
163
+ df_updates = pl.DataFrame(update_list)
164
+ samples_as_images = dataset.samples.update(df_updates, on=[SampleField.COLLECTION_ID, SampleField.COLLECTION_INDEX])
165
+ hafnia_dataset = HafniaDataset(samples=samples_as_images, info=dataset.info)
166
+
167
+ return hafnia_dataset
168
+
169
+
94
170
  def get_task_info_from_task_name_and_primitive(
95
171
  tasks: List["TaskInfo"],
96
172
  task_name: Optional[str] = None,
@@ -223,25 +299,10 @@ def class_mapper(
223
299
  pl.col(task.primitive.column_name())
224
300
  .list.eval(
225
301
  pl.element().struct.with_fields(
226
- pl.when(pl.field(FieldName.TASK_NAME) == task.name)
227
- .then(pl.field(FieldName.CLASS_NAME).replace_strict(class_mapping))
228
- .otherwise(pl.field(FieldName.CLASS_NAME))
229
- .alias(FieldName.CLASS_NAME)
230
- )
231
- )
232
- .alias(task.primitive.column_name())
233
- )
234
-
235
- # Update class indices too
236
- name_2_idx_mapping: Dict[str, int] = {name: idx for idx, name in enumerate(new_class_names)}
237
- samples_updated = samples_updated.with_columns(
238
- pl.col(task.primitive.column_name())
239
- .list.eval(
240
- pl.element().struct.with_fields(
241
- pl.when(pl.field(FieldName.TASK_NAME) == task.name)
242
- .then(pl.field(FieldName.CLASS_NAME).replace_strict(name_2_idx_mapping))
243
- .otherwise(pl.field(FieldName.CLASS_IDX))
244
- .alias(FieldName.CLASS_IDX)
302
+ pl.when(pl.field(PrimitiveField.TASK_NAME) == task.name)
303
+ .then(pl.field(PrimitiveField.CLASS_NAME).replace_strict(class_mapping, default="Missing"))
304
+ .otherwise(pl.field(PrimitiveField.CLASS_NAME))
305
+ .alias(PrimitiveField.CLASS_NAME)
245
306
  )
246
307
  )
247
308
  .alias(task.primitive.column_name())
@@ -250,7 +311,7 @@ def class_mapper(
250
311
  if OPS_REMOVE_CLASS in new_class_names: # Remove class_names that are mapped to REMOVE_CLASS
251
312
  samples_updated = samples_updated.with_columns(
252
313
  pl.col(task.primitive.column_name())
253
- .list.filter(pl.element().struct.field(FieldName.CLASS_NAME) != OPS_REMOVE_CLASS)
314
+ .list.filter(pl.element().struct.field(PrimitiveField.CLASS_NAME) != OPS_REMOVE_CLASS)
254
315
  .alias(task.primitive.column_name())
255
316
  )
256
317
 
@@ -259,6 +320,10 @@ def class_mapper(
259
320
  new_task = task.model_copy(deep=True)
260
321
  new_task.class_names = new_class_names
261
322
  dataset_info = dataset.info.replace_task(old_task=task, new_task=new_task)
323
+
324
+ # Update class indices to match new class names
325
+ samples_updated = update_class_indices(samples_updated, new_task)
326
+
262
327
  return HafniaDataset(info=dataset_info, samples=samples_updated)
263
328
 
264
329
 
@@ -317,7 +382,7 @@ def rename_task(
317
382
  pl.col(old_task.primitive.column_name())
318
383
  .list.eval(
319
384
  pl.element().struct.with_fields(
320
- pl.field(FieldName.TASK_NAME).replace(old_task.name, new_task.name).alias(FieldName.TASK_NAME)
385
+ pl.field(PrimitiveField.TASK_NAME).replace(old_task.name, new_task.name).alias(PrimitiveField.TASK_NAME)
321
386
  )
322
387
  )
323
388
  .alias(new_task.primitive.column_name())
@@ -343,8 +408,8 @@ def select_samples_by_class_name(
343
408
  samples = dataset.samples.filter(
344
409
  pl.col(task.primitive.column_name())
345
410
  .list.eval(
346
- pl.element().struct.field(FieldName.CLASS_NAME).is_in(class_names)
347
- & (pl.element().struct.field(FieldName.TASK_NAME) == task.name)
411
+ pl.element().struct.field(PrimitiveField.CLASS_NAME).is_in(class_names)
412
+ & (pl.element().struct.field(PrimitiveField.TASK_NAME) == task.name)
348
413
  )
349
414
  .list.any()
350
415
  )