hafnia 0.4.3__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,7 +4,7 @@ import base64
4
4
  from datetime import datetime
5
5
  from enum import Enum
6
6
  from pathlib import Path
7
- from typing import Any, Dict, List, Optional, Tuple, Type, Union
7
+ from typing import Any, Dict, List, Optional, Type, Union
8
8
 
9
9
  import boto3
10
10
  import polars as pl
@@ -13,7 +13,6 @@ from pydantic import BaseModel, ConfigDict, field_validator
13
13
 
14
14
  from hafnia.dataset.dataset_names import (
15
15
  DatasetVariant,
16
- DeploymentStage,
17
16
  PrimitiveField,
18
17
  SampleField,
19
18
  SplitName,
@@ -29,26 +28,21 @@ from hafnia.dataset.primitives import (
29
28
  Segmentation,
30
29
  )
31
30
  from hafnia.dataset.primitives.primitive import Primitive
32
- from hafnia.http import post
33
- from hafnia.log import user_logger
34
- from hafnia.platform.datasets import get_dataset_id
31
+ from hafnia.platform.datasets import upload_dataset_details
32
+ from hafnia.utils import get_path_dataset_gallery_images
35
33
  from hafnia_cli.config import Config
36
34
 
37
35
 
38
- def generate_bucket_name(dataset_name: str, deployment_stage: DeploymentStage) -> str:
39
- # TODO: When moving to versioning we do NOT need 'staging' and 'production' specific buckets
40
- # and the new name convention should be: f"hafnia-dataset-{dataset_name}"
41
- return f"mdi-{deployment_stage.value}-{dataset_name}"
42
-
43
-
44
36
  class DatasetDetails(BaseModel, validate_assignment=True): # type: ignore[call-arg]
45
37
  model_config = ConfigDict(use_enum_values=True) # To parse Enum values as strings
46
38
  name: str
39
+ title: Optional[str] = None
40
+ overview: Optional[str] = None
47
41
  data_captured_start: Optional[datetime] = None
48
42
  data_captured_end: Optional[datetime] = None
49
43
  data_received_start: Optional[datetime] = None
50
44
  data_received_end: Optional[datetime] = None
51
- latest_update: Optional[datetime] = None
45
+ dataset_updated_at: Optional[datetime] = None
52
46
  license_citation: Optional[str] = None
53
47
  version: Optional[str] = None
54
48
  s3_bucket_name: Optional[str] = None
@@ -281,26 +275,32 @@ def get_folder_size(path: Path) -> int:
281
275
  return sum([path.stat().st_size for path in path.rglob("*")])
282
276
 
283
277
 
284
- def upload_to_hafnia_dataset_detail_page(dataset_update: DatasetDetails, upload_gallery_images: bool) -> dict:
285
- if not upload_gallery_images:
286
- dataset_update.imgs = None
287
-
288
- cfg = Config()
289
- dataset_details = dataset_update.model_dump_json()
290
- data = upload_dataset_details(cfg=cfg, data=dataset_details, dataset_name=dataset_update.name)
291
- return data
292
-
293
-
294
- def upload_dataset_details(cfg: Config, data: str, dataset_name: str) -> dict:
295
- dataset_endpoint = cfg.get_platform_endpoint("datasets")
296
- dataset_id = get_dataset_id(dataset_name, dataset_endpoint, cfg.api_key)
278
+ def upload_dataset_details_to_platform(
279
+ dataset: HafniaDataset,
280
+ path_gallery_images: Optional[Path] = None,
281
+ gallery_image_names: Optional[List[str]] = None,
282
+ distribution_task_names: Optional[List[str]] = None,
283
+ update_platform: bool = True,
284
+ cfg: Optional[Config] = None,
285
+ ) -> dict:
286
+ cfg = cfg or Config()
287
+ dataset_details = dataset_details_from_hafnia_dataset(
288
+ dataset=dataset,
289
+ path_gallery_images=path_gallery_images,
290
+ gallery_image_names=gallery_image_names,
291
+ distribution_task_names=distribution_task_names,
292
+ )
297
293
 
298
- import_endpoint = f"{dataset_endpoint}/{dataset_id}/import"
299
- headers = {"Authorization": cfg.api_key}
294
+ if update_platform:
295
+ dataset_details_exclude_none = dataset_details.model_dump(exclude_none=True, mode="json")
296
+ upload_dataset_details(
297
+ cfg=cfg,
298
+ data=dataset_details_exclude_none,
299
+ dataset_name=dataset_details.name,
300
+ )
300
301
 
301
- user_logger.info("Exporting dataset details to platform. This may take up to 30 seconds...")
302
- response = post(endpoint=import_endpoint, headers=headers, data=data) # type: ignore[assignment]
303
- return response # type: ignore[return-value]
302
+ dataset_details_dict = dataset_details.model_dump(exclude_none=False, mode="json")
303
+ return dataset_details_dict
304
304
 
305
305
 
306
306
  def get_resolutions(dataset: HafniaDataset, max_resolutions_selected: int = 8) -> List[DbResolution]:
@@ -360,9 +360,6 @@ def s3_based_fields(bucket_name: str, variant_type: DatasetVariant, session: bot
360
360
 
361
361
  def dataset_details_from_hafnia_dataset(
362
362
  dataset: HafniaDataset,
363
- deployment_stage: DeploymentStage,
364
- path_sample: Optional[Path],
365
- path_hidden: Optional[Path],
366
363
  path_gallery_images: Optional[Path] = None,
367
364
  gallery_image_names: Optional[List[str]] = None,
368
365
  distribution_task_names: Optional[List[str]] = None,
@@ -371,33 +368,24 @@ def dataset_details_from_hafnia_dataset(
371
368
  dataset_reports = []
372
369
  dataset_meta_info = dataset.info.meta or {}
373
370
 
374
- path_and_variant: List[Tuple[Path, DatasetVariant]] = []
375
- if path_sample is not None:
376
- path_and_variant.append((path_sample, DatasetVariant.SAMPLE))
377
-
378
- if path_hidden is not None:
379
- path_and_variant.append((path_hidden, DatasetVariant.HIDDEN))
380
-
381
- if len(path_and_variant) == 0:
382
- raise ValueError("At least one path must be provided for sample or hidden dataset.")
383
-
371
+ path_and_variant = [DatasetVariant.SAMPLE, DatasetVariant.HIDDEN]
384
372
  gallery_images = create_gallery_images(
385
373
  dataset=dataset,
386
374
  path_gallery_images=path_gallery_images,
387
375
  gallery_image_names=gallery_image_names,
388
376
  )
389
377
 
390
- for path_dataset, variant_type in path_and_variant:
378
+ for variant_type in path_and_variant:
391
379
  if variant_type == DatasetVariant.SAMPLE:
392
380
  dataset_variant = dataset.create_sample_dataset()
393
381
  else:
394
382
  dataset_variant = dataset
395
383
 
396
- size_bytes = get_folder_size(path_dataset)
384
+ files_paths = dataset_variant.samples[SampleField.FILE_PATH].to_list()
385
+ size_bytes = sum([Path(file_path).stat().st_size for file_path in files_paths])
397
386
  dataset_variants.append(
398
387
  DbDatasetVariant(
399
388
  variant_type=VARIANT_TYPE_MAPPING[variant_type], # type: ignore[index]
400
- # upload_date: Optional[datetime] = None
401
389
  size_bytes=size_bytes,
402
390
  data_type=DataTypeChoices.images,
403
391
  number_of_data_items=len(dataset_variant),
@@ -405,7 +393,6 @@ def dataset_details_from_hafnia_dataset(
405
393
  duration=dataset_meta_info.get("duration", None),
406
394
  duration_average=dataset_meta_info.get("duration_average", None),
407
395
  frame_rate=dataset_meta_info.get("frame_rate", None),
408
- # bit_rate: Optional[float] = None
409
396
  n_cameras=dataset_meta_info.get("n_cameras", None),
410
397
  )
411
398
  )
@@ -435,19 +422,19 @@ def dataset_details_from_hafnia_dataset(
435
422
  object_reports = sorted(object_reports, key=lambda x: x.obj.name) # Sort object reports by name
436
423
  report.annotated_object_reports = object_reports
437
424
 
438
- if report.distribution_values is None:
439
- report.distribution_values = []
425
+ if report.distribution_values is None:
426
+ report.distribution_values = []
440
427
 
441
- dataset_reports.append(report)
428
+ dataset_reports.append(report)
442
429
  dataset_name = dataset.info.dataset_name
443
- bucket_sample = generate_bucket_name(dataset_name, deployment_stage=deployment_stage)
444
430
  dataset_info = DatasetDetails(
445
431
  name=dataset_name,
432
+ title=dataset.info.dataset_title,
433
+ overview=dataset.info.description,
446
434
  version=dataset.info.version,
447
- s3_bucket_name=bucket_sample,
448
435
  dataset_variants=dataset_variants,
449
436
  split_annotations_reports=dataset_reports,
450
- latest_update=dataset.info.updated_at,
437
+ dataset_updated_at=dataset.info.updated_at,
451
438
  dataset_format_version=dataset.info.format_version,
452
439
  license_citation=dataset.info.reference_bibtex,
453
440
  data_captured_start=dataset_meta_info.get("data_captured_start", None),
@@ -565,7 +552,7 @@ def create_gallery_images(
565
552
  gallery_images = None
566
553
  if (gallery_image_names is not None) and (len(gallery_image_names) > 0):
567
554
  if path_gallery_images is None:
568
- raise ValueError("Path to gallery images must be provided.")
555
+ path_gallery_images = get_path_dataset_gallery_images(dataset.info.dataset_name)
569
556
  path_gallery_images.mkdir(parents=True, exist_ok=True)
570
557
  COL_IMAGE_NAME = "image_name"
571
558
  samples = dataset.samples.with_columns(
@@ -3,12 +3,70 @@ import math
3
3
  import random
4
4
  import shutil
5
5
  from pathlib import Path
6
- from typing import Dict, List
6
+ from typing import Dict, List, Optional, Tuple
7
7
 
8
8
  import numpy as np
9
9
  import xxhash
10
+ from packaging.version import InvalidVersion, Version
10
11
  from PIL import Image
11
12
 
13
+ from hafnia.log import user_logger
14
+
15
+
16
+ def is_valid_version_string(version: Optional[str], allow_none: bool = False, allow_latest: bool = False) -> bool:
17
+ if allow_none and version is None:
18
+ return True
19
+ if allow_latest and version == "latest":
20
+ return True
21
+ return version_from_string(version, raise_error=False) is not None
22
+
23
+
24
+ def version_from_string(version: Optional[str], raise_error: bool = True) -> Optional[Version]:
25
+ if version is None:
26
+ if raise_error:
27
+ raise ValueError("Version is 'None'. A valid version string is required e.g '1.0.0'")
28
+ return None
29
+
30
+ try:
31
+ version_casted = Version(version)
32
+ except (InvalidVersion, TypeError) as e:
33
+ if raise_error:
34
+ raise ValueError(f"Invalid version string/type: {version}") from e
35
+ return None
36
+
37
+ # Check if version is semantic versioning (MAJOR.MINOR.PATCH)
38
+ if len(version_casted.release) < 3:
39
+ if raise_error:
40
+ raise ValueError(f"Version string '{version}' is not semantic versioning (MAJOR.MINOR.PATCH)")
41
+ return None
42
+ return version_casted
43
+
44
+
45
+ def dataset_name_and_version_from_string(
46
+ string: str,
47
+ resolve_missing_version: bool = True,
48
+ ) -> Tuple[str, Optional[str]]:
49
+ if not isinstance(string, str):
50
+ raise TypeError(f"'{type(string)}' for '{string}' is an unsupported type. Expected 'str' e.g 'mnist:1.0.0'")
51
+
52
+ parts = string.split(":")
53
+ if len(parts) == 1:
54
+ dataset_name = parts[0]
55
+ if resolve_missing_version:
56
+ version = "latest" # Default to 'latest' if version is missing. This will be resolved to a specific version later.
57
+ user_logger.info(f"Version is missing in dataset name: {string}. Defaulting to version='latest'.")
58
+ else:
59
+ raise ValueError(f"Version is missing in dataset name: {string}. Use 'name:version'")
60
+ elif len(parts) == 2:
61
+ dataset_name, version = parts
62
+ else:
63
+ raise ValueError(f"Invalid dataset name format: {string}. Use 'name' or 'name:version' ")
64
+
65
+ if not is_valid_version_string(version, allow_none=True, allow_latest=True):
66
+ raise ValueError(f"Invalid version string: {version}. Use semantic versioning e.g. '1.0.0' or 'latest'")
67
+
68
+ return dataset_name, version
69
+
12
70
 
13
71
  def create_split_name_list_from_ratios(split_ratios: Dict[str, float], n_items: int, seed: int = 42) -> List[str]:
14
72
  samples_per_split = split_sizes_from_ratios(split_ratios=split_ratios, n_items=n_items)
@@ -57,20 +115,6 @@ def save_pil_image_with_hash_name(image: Image.Image, path_folder: Path, allow_s
57
115
  def copy_and_rename_file_to_hash_value(path_source: Path, path_dataset_root: Path) -> Path:
58
116
  """
59
117
  Copies a file to a dataset root directory with a hash-based name and sub-directory structure.
60
-
61
- E.g. for an "image.png" with hash "dfe8f3b1c2a4f5b6c7d8e9f0a1b2c3d4", the image will be copied to
62
- 'path_dataset_root / "data" / "dfe" / "dfe8f3b1c2a4f5b6c7d8e9f0a1b2c3d4.png"'
63
- Notice that the hash is used for both the filename and the subfolder name.
64
-
65
- Placing image/video files into multiple sub-folders (instead of one large folder) is seemingly
66
- unnecessary, but it is actually a requirement when the dataset is later downloaded from S3.
67
-
68
- The reason is that AWS has a rate limit of 3500 ops/sec per prefix (sub-folder) in S3 - meaning we can "only"
69
- download 3500 files per second from a single folder (prefix) in S3.
70
-
71
- For even a single user, we found that this limit was being reached when files are stored in single folder (prefix)
72
- in S3. To support multiple users and concurrent experiments, we are required to separate files into
73
- multiple sub-folders (prefixes) in S3 to not hit the rate limit.
74
118
  """
75
119
 
76
120
  if not path_source.exists():
@@ -86,7 +130,7 @@ def copy_and_rename_file_to_hash_value(path_source: Path, path_dataset_root: Pat
86
130
 
87
131
 
88
132
  def relative_path_from_hash(hash: str, suffix: str) -> Path:
89
- path_file = Path("data") / hash[:3] / f"{hash}{suffix}"
133
+ path_file = Path("data") / f"{hash}{suffix}"
90
134
  return path_file
91
135
 
92
136
 
@@ -1,8 +1,5 @@
1
1
  from enum import Enum
2
- from typing import Dict, List, Optional
3
-
4
- import boto3
5
- from pydantic import BaseModel, field_validator
2
+ from typing import List
6
3
 
7
4
  FILENAME_RECIPE_JSON = "recipe.json"
8
5
  FILENAME_DATASET_INFO = "dataset_info.json"
@@ -124,93 +121,3 @@ class DatasetVariant(Enum):
124
121
  DUMP = "dump"
125
122
  SAMPLE = "sample"
126
123
  HIDDEN = "hidden"
127
-
128
-
129
- class AwsCredentials(BaseModel):
130
- access_key: str
131
- secret_key: str
132
- session_token: str
133
- region: Optional[str]
134
-
135
- def aws_credentials(self) -> Dict[str, str]:
136
- """
137
- Returns the AWS credentials as a dictionary.
138
- """
139
- environment_vars = {
140
- "AWS_ACCESS_KEY_ID": self.access_key,
141
- "AWS_SECRET_ACCESS_KEY": self.secret_key,
142
- "AWS_SESSION_TOKEN": self.session_token,
143
- }
144
- if self.region:
145
- environment_vars["AWS_REGION"] = self.region
146
-
147
- return environment_vars
148
-
149
- @staticmethod
150
- def from_session(session: boto3.Session) -> "AwsCredentials":
151
- """
152
- Creates AwsCredentials from a Boto3 session.
153
- """
154
- frozen_credentials = session.get_credentials().get_frozen_credentials()
155
- return AwsCredentials(
156
- access_key=frozen_credentials.access_key,
157
- secret_key=frozen_credentials.secret_key,
158
- session_token=frozen_credentials.token,
159
- region=session.region_name,
160
- )
161
-
162
-
163
- ARN_PREFIX = "arn:aws:s3:::"
164
-
165
-
166
- class ResourceCredentials(AwsCredentials):
167
- s3_arn: str
168
-
169
- @staticmethod
170
- def fix_naming(payload: Dict[str, str]) -> "ResourceCredentials":
171
- """
172
- The endpoint returns a payload with a key called 's3_path', but it
173
- is actually an ARN path (starts with arn:aws:s3::). This method renames it to 's3_arn' for consistency.
174
- """
175
- if "s3_path" in payload and payload["s3_path"].startswith(ARN_PREFIX):
176
- payload["s3_arn"] = payload.pop("s3_path")
177
-
178
- if "region" not in payload:
179
- payload["region"] = "eu-west-1"
180
- return ResourceCredentials(**payload)
181
-
182
- @field_validator("s3_arn")
183
- @classmethod
184
- def validate_s3_arn(cls, value: str) -> str:
185
- """Validate s3_arn to ensure it starts with 'arn:aws:s3:::'"""
186
- if not value.startswith("arn:aws:s3:::"):
187
- raise ValueError(f"Invalid S3 ARN: {value}. It should start with 'arn:aws:s3:::'")
188
- return value
189
-
190
- def s3_path(self) -> str:
191
- """
192
- Extracts the S3 path from the ARN.
193
- Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket/my-prefix
194
- """
195
- return self.s3_arn[len(ARN_PREFIX) :]
196
-
197
- def s3_uri(self) -> str:
198
- """
199
- Converts the S3 ARN to a URI format.
200
- Example: arn:aws:s3:::my-bucket/my-prefix -> s3://my-bucket/my-prefix
201
- """
202
- return f"s3://{self.s3_path()}"
203
-
204
- def bucket_name(self) -> str:
205
- """
206
- Extracts the bucket name from the S3 ARN.
207
- Example: arn:aws:s3:::my-bucket/my-prefix -> my-bucket
208
- """
209
- return self.s3_path().split("/")[0]
210
-
211
- def object_key(self) -> str:
212
- """
213
- Extracts the object key from the S3 ARN.
214
- Example: arn:aws:s3:::my-bucket/my-prefix -> my-prefix
215
- """
216
- return "/".join(self.s3_path().split("/")[1:])
@@ -11,14 +11,19 @@ from pydantic import (
11
11
  )
12
12
 
13
13
  from hafnia import utils
14
+ from hafnia.dataset.dataset_helpers import dataset_name_and_version_from_string
14
15
  from hafnia.dataset.dataset_recipe import recipe_transforms
15
16
  from hafnia.dataset.dataset_recipe.recipe_types import (
16
17
  RecipeCreation,
17
18
  RecipeTransform,
18
19
  Serializable,
19
20
  )
20
- from hafnia.dataset.hafnia_dataset import HafniaDataset
21
+ from hafnia.dataset.hafnia_dataset import (
22
+ HafniaDataset,
23
+ available_dataset_versions_from_name,
24
+ )
21
25
  from hafnia.dataset.primitives.primitive import Primitive
26
+ from hafnia.log import user_logger
22
27
 
23
28
 
24
29
  class DatasetRecipe(Serializable):
@@ -41,8 +46,31 @@ class DatasetRecipe(Serializable):
41
46
 
42
47
  ### Creation Methods (using the 'from_X' )###
43
48
  @staticmethod
44
- def from_name(name: str, force_redownload: bool = False, download_files: bool = True) -> DatasetRecipe:
45
- creation = FromName(name=name, force_redownload=force_redownload, download_files=download_files)
49
+ def from_name(
50
+ name: str,
51
+ version: Optional[str] = None,
52
+ force_redownload: bool = False,
53
+ download_files: bool = True,
54
+ ) -> DatasetRecipe:
55
+ if version == "latest":
56
+ user_logger.info(
57
+ f"The dataset '{name}' in a dataset recipe uses 'latest' as version. For dataset recipes the "
58
+ "version is pinned to a specific version. Consider specifying a specific version to ensure "
59
+ "reproducibility of your experiments. "
60
+ )
61
+ available_versions = available_dataset_versions_from_name(name)
62
+ version = str(max(available_versions))
63
+ if version is None:
64
+ available_versions = available_dataset_versions_from_name(name)
65
+ str_versions = ", ".join([str(v) for v in available_versions])
66
+ raise ValueError(
67
+ f"Version must be specified when creating a DatasetRecipe from name. "
68
+ f"Available versions are: {str_versions}"
69
+ )
70
+
71
+ creation = FromName(
72
+ name=name, version=version, force_redownload=force_redownload, download_files=download_files
73
+ )
46
74
  return DatasetRecipe(creation=creation)
47
75
 
48
76
  @staticmethod
@@ -125,6 +153,21 @@ class DatasetRecipe(Serializable):
125
153
  recipe_id = recipe["id"]
126
154
  return DatasetRecipe.from_recipe_id(recipe_id)
127
155
 
156
+ @staticmethod
157
+ def from_name_and_version_string(string: str, resolve_missing_version: bool = False) -> "DatasetRecipe":
158
+ """
159
+ Validates and converts a dataset name and version string (name:version) to a DatasetRecipe.from_name recipe.
160
+ If version is missing and 'resolve_missing_version' is True, it will default to 'latest'.
161
+ If resolve_missing_version is False, it will raise an error if version is missing.
162
+ """
163
+
164
+ dataset_name, version = dataset_name_and_version_from_string(
165
+ string=string,
166
+ resolve_missing_version=resolve_missing_version,
167
+ )
168
+
169
+ return DatasetRecipe.from_name(name=dataset_name, version=version)
170
+
128
171
  @staticmethod
129
172
  def from_implicit_form(recipe: Any) -> DatasetRecipe:
130
173
  """
@@ -180,7 +223,7 @@ class DatasetRecipe(Serializable):
180
223
  return recipe
181
224
 
182
225
  if isinstance(recipe, str): # str-type is convert to DatasetFromName
183
- return DatasetRecipe.from_name(name=recipe)
226
+ return DatasetRecipe.from_name_and_version_string(string=recipe, resolve_missing_version=True)
184
227
 
185
228
  if isinstance(recipe, Path): # Path-type is convert to DatasetFromPath
186
229
  return DatasetRecipe.from_path(path_folder=recipe)
@@ -409,6 +452,7 @@ class FromPath(RecipeCreation):
409
452
 
410
453
  class FromName(RecipeCreation):
411
454
  name: str
455
+ version: Optional[str] = None
412
456
  force_redownload: bool = False
413
457
  download_files: bool = True
414
458
 
@@ -40,7 +40,7 @@ def mnist_as_hafnia_dataset(force_redownload=False, n_samples: Optional[int] = N
40
40
 
41
41
  dataset_info = DatasetInfo(
42
42
  dataset_name="mnist",
43
- version="1.1.0",
43
+ version="1.0.0",
44
44
  tasks=tasks,
45
45
  reference_bibtex=textwrap.dedent("""\
46
46
  @article{lecun2010mnist,
@@ -78,7 +78,7 @@ def caltech_101_as_hafnia_dataset(
78
78
  n_samples=n_samples,
79
79
  dataset_name=dataset_name,
80
80
  )
81
- hafnia_dataset.info.version = "1.1.0"
81
+ hafnia_dataset.info.version = "1.0.0"
82
82
  hafnia_dataset.info.reference_bibtex = textwrap.dedent("""\
83
83
  @article{FeiFei2004LearningGV,
84
84
  title={Learning Generative Visual Models from Few Training Examples: An Incremental Bayesian
@@ -108,7 +108,7 @@ def caltech_256_as_hafnia_dataset(
108
108
  n_samples=n_samples,
109
109
  dataset_name=dataset_name,
110
110
  )
111
- hafnia_dataset.info.version = "1.1.0"
111
+ hafnia_dataset.info.version = "1.0.0"
112
112
  hafnia_dataset.info.reference_bibtex = textwrap.dedent("""\
113
113
  @misc{griffin_2023_5sv1j-ytw97,
114
114
  author = {Griffin, Gregory and
@@ -150,7 +150,7 @@ def cifar_as_hafnia_dataset(
150
150
 
151
151
  dataset_info = DatasetInfo(
152
152
  dataset_name=dataset_name,
153
- version="1.1.0",
153
+ version="1.0.0",
154
154
  tasks=tasks,
155
155
  reference_bibtex=textwrap.dedent("""\
156
156
  @@TECHREPORT{Krizhevsky09learningmultiple,
@@ -268,7 +268,10 @@ def _download_and_extract_caltech_dataset(dataset_name: str, force_redownload: b
268
268
  path_output_extracted = path_tmp_output / "caltech-101"
269
269
  for gzip_file in os.listdir(path_output_extracted):
270
270
  if gzip_file.endswith(".gz"):
271
- extract_archive(os.path.join(path_output_extracted, gzip_file), path_output_extracted)
271
+ extract_archive(
272
+ from_path=os.path.join(path_output_extracted, gzip_file),
273
+ to_path=path_output_extracted,
274
+ )
272
275
  path_org = path_output_extracted / "101_ObjectCategories"
273
276
 
274
277
  elif dataset_name == "caltech-256":