hirundo 0.1.9__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hirundo/labeling.py ADDED
@@ -0,0 +1,140 @@
1
+ import typing
2
+ from abc import ABC
3
+
4
+ from pydantic import BaseModel, Field
5
+
6
+ from hirundo.dataset_enum import DatasetMetadataType
7
+
8
+ if typing.TYPE_CHECKING:
9
+ from hirundo._urls import HirundoUrl
10
+
11
+
12
+ class Metadata(BaseModel, ABC, frozen=True):
13
+ type: DatasetMetadataType
14
+
15
+
16
+ class HirundoCSV(Metadata, frozen=True):
17
+ """
18
+ A dataset metadata file in the Hirundo CSV format
19
+ """
20
+
21
+ type: typing.Literal[DatasetMetadataType.HIRUNDO_CSV] = (
22
+ DatasetMetadataType.HIRUNDO_CSV
23
+ )
24
+ csv_url: "HirundoUrl"
25
+ """
26
+ The URL to access the dataset metadata CSV file.
27
+ e.g. `s3://my-bucket-name/my-folder/my-metadata.csv`, `gs://my-bucket-name/my-folder/my-metadata.csv`,
28
+ or `ssh://my-username@my-repo-name/my-folder/my-metadata.csv`
29
+ (or `file:///datasets/my-folder/my-metadata.csv` if using LOCAL storage type with on-premises installation)
30
+ """
31
+
32
+
33
+ class COCO(Metadata, frozen=True):
34
+ """
35
+ A dataset metadata file in the COCO format
36
+ """
37
+
38
+ type: typing.Literal[DatasetMetadataType.COCO] = DatasetMetadataType.COCO
39
+ json_url: "HirundoUrl"
40
+ """
41
+ The URL to access the dataset metadata JSON file.
42
+ e.g. `s3://my-bucket-name/my-folder/my-metadata.json`, `gs://my-bucket-name/my-folder/my-metadata.json`,
43
+ or `ssh://my-username@my-repo-name/my-folder/my-metadata.json`
44
+ (or `file:///datasets/my-folder/my-metadata.json` if using LOCAL storage type with on-premises installation)
45
+ """
46
+
47
+
48
+ class YOLO(Metadata, frozen=True):
49
+ type: typing.Literal[DatasetMetadataType.YOLO] = DatasetMetadataType.YOLO
50
+ data_yaml_url: "typing.Optional[HirundoUrl]" = None
51
+ labels_dir_url: "HirundoUrl"
52
+
53
+
54
+ class KeylabsAuth(BaseModel):
55
+ username: str
56
+ password: str
57
+ instance: str
58
+
59
+
60
+ class Keylabs(Metadata, frozen=True):
61
+ project_id: str
62
+ """
63
+ Keylabs project ID.
64
+ """
65
+
66
+ labels_dir_url: "HirundoUrl"
67
+ """
68
+ URL to the directory containing the Keylabs labels.
69
+ """
70
+
71
+ with_attributes: bool = True
72
+ """
73
+ Whether to include attributes in the class name.
74
+ """
75
+
76
+ project_name: typing.Optional[str] = None
77
+ """
78
+ Keylabs project name (optional; added to output CSV if provided).
79
+ """
80
+ keylabs_auth: typing.Optional[KeylabsAuth] = None
81
+ """
82
+ Keylabs authentication credentials (optional; if provided, used to provide links to each sample).
83
+ """
84
+
85
+
86
+ class KeylabsObjDetImages(Keylabs, frozen=True):
87
+ type: typing.Literal[DatasetMetadataType.KeylabsObjDetImages] = (
88
+ DatasetMetadataType.KeylabsObjDetImages
89
+ )
90
+
91
+
92
+ class KeylabsObjDetVideo(Keylabs, frozen=True):
93
+ type: typing.Literal[DatasetMetadataType.KeylabsObjDetVideo] = (
94
+ DatasetMetadataType.KeylabsObjDetVideo
95
+ )
96
+
97
+
98
+ class KeylabsObjSegImages(Keylabs, frozen=True):
99
+ type: typing.Literal[DatasetMetadataType.KeylabsObjSegImages] = (
100
+ DatasetMetadataType.KeylabsObjSegImages
101
+ )
102
+
103
+
104
+ class KeylabsObjSegVideo(Keylabs, frozen=True):
105
+ type: typing.Literal[DatasetMetadataType.KeylabsObjSegVideo] = (
106
+ DatasetMetadataType.KeylabsObjSegVideo
107
+ )
108
+
109
+
110
+ KeylabsInfo = typing.Union[
111
+ KeylabsObjDetImages, KeylabsObjDetVideo, KeylabsObjSegImages, KeylabsObjSegVideo
112
+ ]
113
+ """
114
+ The dataset labeling info for Keylabs. The dataset labeling info can be one of the following:
115
+ - `DatasetMetadataType.KeylabsObjDetImages`: Indicates that the dataset metadata file is in the Keylabs object detection image format
116
+ - `DatasetMetadataType.KeylabsObjDetVideo`: Indicates that the dataset metadata file is in the Keylabs object detection video format
117
+ - `DatasetMetadataType.KeylabsObjSegImages`: Indicates that the dataset metadata file is in the Keylabs object segmentation image format
118
+ - `DatasetMetadataType.KeylabsObjSegVideo`: Indicates that the dataset metadata file is in the Keylabs object segmentation video format
119
+ """
120
+ LabelingInfo = typing.Annotated[
121
+ typing.Union[
122
+ HirundoCSV,
123
+ COCO,
124
+ YOLO,
125
+ KeylabsInfo,
126
+ ],
127
+ Field(discriminator="type"),
128
+ ]
129
+ """
130
+ The dataset labeling info. The dataset labeling info can be one of the following:
131
+ - `DatasetMetadataType.HirundoCSV`: Indicates that the dataset metadata file is a CSV file with the Hirundo format
132
+ - `DatasetMetadataType.COCO`: Indicates that the dataset metadata file is a JSON file with the COCO format
133
+ - `DatasetMetadataType.YOLO`: Indicates that the dataset metadata file is in the YOLO format
134
+ - `DatasetMetadataType.KeylabsObjDetImages`: Indicates that the dataset metadata file is in the Keylabs object detection image format
135
+ - `DatasetMetadataType.KeylabsObjDetVideo`: Indicates that the dataset metadata file is in the Keylabs object detection video format
136
+ - `DatasetMetadataType.KeylabsObjSegImages`: Indicates that the dataset metadata file is in the Keylabs object segmentation image format
137
+ - `DatasetMetadataType.KeylabsObjSegVideo`: Indicates that the dataset metadata file is in the Keylabs object segmentation video format
138
+
139
+ Currently no other formats are supported. Future versions of `hirundo` may support additional formats.
140
+ """
hirundo/storage.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import typing
2
- from enum import Enum
3
2
  from pathlib import Path
4
3
 
5
4
  import pydantic
@@ -7,11 +6,12 @@ import requests
7
6
  from pydantic import BaseModel, model_validator
8
7
  from pydantic_core import Url
9
8
 
10
- from hirundo._constraints import S3BucketUrl, StorageConfigName
11
9
  from hirundo._env import API_HOST
12
- from hirundo._headers import get_auth_headers, json_headers
10
+ from hirundo._headers import get_headers
13
11
  from hirundo._http import raise_for_status_with_reason
14
12
  from hirundo._timeouts import MODIFY_TIMEOUT, READ_TIMEOUT
13
+ from hirundo._urls import S3BucketUrl, StorageConfigName
14
+ from hirundo.dataset_enum import StorageTypes
15
15
  from hirundo.git import GitRepo, GitRepoOut
16
16
  from hirundo.logger import get_logger
17
17
 
@@ -34,11 +34,11 @@ class StorageS3Base(BaseModel):
34
34
  Chains the bucket URL with the path, ensuring that the path is formatted correctly
35
35
 
36
36
  Args:
37
- - path: The path to the file in the S3 bucket, e.g. `my-file.txt` or `/my-folder/my-file.txt`
37
+ path: The path to the file in the S3 bucket, e.g. :file:`my-file.txt` or :file:`/my-folder/my-file.txt`
38
38
 
39
39
  Returns:
40
- The full URL to the file in the S3 bucket, e.g. `s3://my-bucket/my-file.txt` or `s3://my-bucket/my-folder/my-file.txt`,
41
- where `s3://my-bucket` is the bucket URL provided in the S3 storage config
40
+ The full URL to the file in the S3 bucket, e.g. :file:`s3://my-bucket/my-file.txt` or :file:`s3://my-bucket/my-folder/my-file.txt`,
41
+ where :file:`s3://my-bucket` is the bucket URL provided in the S3 storage config
42
42
  """
43
43
  return Url(
44
44
  f"{S3_PREFIX}{self.bucket_url.removeprefix(S3_PREFIX).removesuffix('/')}/{str(path).removeprefix('/')}"
@@ -64,11 +64,11 @@ class StorageGCPBase(BaseModel):
64
64
  Chains the bucket URL with the path, ensuring that the path is formatted correctly
65
65
 
66
66
  Args:
67
- - path: The path to the file in the GCP bucket, e.g. `my-file.txt` or `/my-folder/my-file.txt`
67
+ path: The path to the file in the GCP bucket, e.g. :file:`my-file.txt` or :file:`/my-folder/my-file.txt`
68
68
 
69
69
  Returns:
70
- The full URL to the file in the GCP bucket, e.g. `gs://my-bucket/my-file.txt` or `gs://my-bucket/my-folder/my-file.txt`,
71
- where `my-bucket` is the bucket name provided in the GCP storage config
70
+ The full URL to the file in the GCP bucket, e.g. :file:`gs://my-bucket/my-file.txt` or :file:`gs://my-bucket/my-folder/my-file.txt`,
71
+ where :file:`my-bucket` is the bucket name provided in the GCP storage config
72
72
  """
73
73
  return Url(f"gs://{self.bucket_name}/{str(path).removeprefix('/')}")
74
74
 
@@ -94,7 +94,7 @@ class StorageGCPOut(StorageGCPBase):
94
94
  # Chains the container URL with the path, ensuring that the path is formatted correctly
95
95
 
96
96
  # Args:
97
- # - path: The path to the file in the Azure container, e.g. `my-file.txt` or `/my-folder/my-file.txt`
97
+ # path: The path to the file in the Azure container, e.g. :file:`my-file.txt` or :file:`/my-folder/my-file.txt`
98
98
 
99
99
  # Returns:
100
100
  # The full URL to the file in the Azure container
@@ -114,11 +114,11 @@ def get_git_repo_url(
114
114
  Chains the repository URL with the path, ensuring that the path is formatted correctly
115
115
 
116
116
  Args:
117
- - repo_url: The URL of the git repository, e.g. `https://my-git-repository.com`
118
- - path: The path to the file in the git repository, e.g. `my-file.txt` or `/my-folder/my-file.txt`
117
+ repo_url: The URL of the git repository, e.g. :file:`https://my-git-repository.com`
118
+ path: The path to the file in the git repository, e.g. :file:`my-file.txt` or :file:`/my-folder/my-file.txt`
119
119
 
120
120
  Returns:
121
- The full URL to the file in the git repository, e.g. `https://my-git-repository.com/my-file.txt` or `https://my-git-repository.com/my-folder/my-file.txt`
121
+ The full URL to the file in the git repository, e.g. :file:`https://my-git-repository.com/my-file.txt` or :file:`https://my-git-repository.com/my-folder/my-file.txt`
122
122
  """
123
123
  if not isinstance(repo_url, Url):
124
124
  repo_url = Url(repo_url)
@@ -131,12 +131,12 @@ class StorageGit(BaseModel):
131
131
  repo_id: typing.Optional[int] = None
132
132
  """
133
133
  The ID of the Git repository in the Hirundo system.
134
- Either `repo_id` or `repo` must be provided.
134
+ Either :code:`repo_id` or :code:`repo` must be provided.
135
135
  """
136
136
  repo: typing.Optional[GitRepo] = None
137
137
  """
138
138
  The Git repository to link to.
139
- Either `repo_id` or `repo` must be provided.
139
+ Either :code:`repo_id` or :code:`repo` must be provided.
140
140
  """
141
141
  branch: str
142
142
  """
@@ -156,11 +156,11 @@ class StorageGit(BaseModel):
156
156
  Chains the repository URL with the path, ensuring that the path is formatted correctly
157
157
 
158
158
  Args:
159
- - path: The path to the file in the git repository, e.g. `my-file.txt` or `/my-folder/my-file.txt`
159
+ path: The path to the file in the git repository, e.g. :file:`my-file.txt` or :file:`/my-folder/my-file.txt`
160
160
 
161
161
  Returns:
162
- The full URL to the file in the git repository, e.g. `https://my-git-repository.com/my-file.txt` or `https://my-git-repository.com/my-folder/my-file.txt`,
163
- where `https://my-git-repository.com` is the repository URL provided in the git storage config's git repo
162
+ The full URL to the file in the git repository, e.g. :file:`https://my-git-repository.com/my-file.txt` or :file:`https://my-git-repository.com/my-folder/my-file.txt`,
163
+ where :file:`https://my-git-repository.com` is the repository URL provided in the git storage config's git repo
164
164
  """
165
165
  if not self.repo:
166
166
  raise ValueError("Repo must be provided to use `get_url`")
@@ -179,47 +179,31 @@ class StorageGitOut(BaseModel):
179
179
  Chains the repository URL with the path, ensuring that the path is formatted correctly
180
180
 
181
181
  Args:
182
- - path: The path to the file in the git repository, e.g. `my-file.txt` or `/my-folder/my-file.txt`
182
+ path: The path to the file in the git repository, e.g. :file:`my-file.txt` or :file:`/my-folder/my-file.txt`
183
183
 
184
184
  Returns:
185
- The full URL to the file in the git repository, e.g. `https://my-git-repository.com/my-file.txt` or `https://my-git-repository.com/my-folder/my-file.txt`,
186
- where `https://my-git-repository.com` is the repository URL provided in the git storage config's git repo
185
+ The full URL to the file in the git repository, e.g. :file:`https://my-git-repository.com/my-file.txt` or :file:`https://my-git-repository.com/my-folder/my-file.txt`,
186
+ where :file:`https://my-git-repository.com` is the repository URL provided in the git storage config's git repo
187
187
  """
188
188
  repo_url = self.repo.repository_url
189
189
  return get_git_repo_url(repo_url, path)
190
190
 
191
191
 
192
- class StorageTypes(str, Enum):
193
- """
194
- Enum for the different types of storage configs.
195
- Supported types are:
196
- """
197
-
198
- S3 = "S3"
199
- GCP = "GCP"
200
- # AZURE = "Azure" TODO: Azure storage config is coming soon
201
- GIT = "Git"
202
- LOCAL = "Local"
203
- """
204
- Local storage config is only supported for on-premises installations.
205
- """
206
-
207
-
208
192
  class StorageConfig(BaseModel):
209
193
  id: typing.Optional[int] = None
210
194
  """
211
- The ID of the `StorageConfig` in the Hirundo system.
195
+ The ID of the :code:`StorageConfig` in the Hirundo system.
212
196
  """
213
197
 
214
198
  organization_id: typing.Optional[int] = None
215
199
  """
216
- The ID of the organization that the `StorageConfig` belongs to.
200
+ The ID of the organization that the :code:`StorageConfig` belongs to.
217
201
  If not provided, it will be assigned to your default organization.
218
202
  """
219
203
 
220
204
  name: StorageConfigName
221
205
  """
222
- A name to identify the `StorageConfig` in the Hirundo system.
206
+ A name to identify the :code:`StorageConfig` in the Hirundo system.
223
207
  """
224
208
  type: typing.Optional[StorageTypes] = pydantic.Field(
225
209
  examples=[
@@ -230,12 +214,12 @@ class StorageConfig(BaseModel):
230
214
  ]
231
215
  )
232
216
  """
233
- The type of the `StorageConfig`.
217
+ The type of the :code:`StorageConfig`.
234
218
  Supported types are:
235
- - `S3`
236
- - `GCP`
237
- - `Azure` (coming soon)
238
- - `Git`
219
+ - :code:`S3`
220
+ - :code:`GCP`
221
+ - :code:`Azure` (coming soon)
222
+ - :code:`Git`
239
223
  """
240
224
  s3: typing.Optional[StorageS3] = pydantic.Field(
241
225
  default=None,
@@ -323,14 +307,14 @@ class StorageConfig(BaseModel):
323
307
  @staticmethod
324
308
  def get_by_id(storage_config_id: int) -> "ResponseStorageConfig":
325
309
  """
326
- Retrieves a `StorageConfig` instance from the server by its ID
310
+ Retrieves a :code:`StorageConfig` instance from the server by its ID
327
311
 
328
312
  Args:
329
- storage_config_id: The ID of the `StorageConfig` to retrieve
313
+ storage_config_id: The ID of the :code:`StorageConfig` to retrieve
330
314
  """
331
315
  storage_config = requests.get(
332
316
  f"{API_HOST}/storage-config/{storage_config_id}",
333
- headers=get_auth_headers(),
317
+ headers=get_headers(),
334
318
  timeout=READ_TIMEOUT,
335
319
  )
336
320
  raise_for_status_with_reason(storage_config)
@@ -339,17 +323,17 @@ class StorageConfig(BaseModel):
339
323
  @staticmethod
340
324
  def get_by_name(name: str, storage_type: StorageTypes) -> "ResponseStorageConfig":
341
325
  """
342
- Retrieves a `StorageConfig` instance from the server by its name
326
+ Retrieves a :code:`StorageConfig` instance from the server by its name
343
327
 
344
328
  Args:
345
- name: The name of the `StorageConfig` to retrieve
346
- storage_type: The type of the `StorageConfig` to retrieve
329
+ name: The name of the :code:`StorageConfig` to retrieve
330
+ storage_type: The type of the :code:`StorageConfig` to retrieve
347
331
 
348
332
  Note: The type is required because the name is not unique across different storage types
349
333
  """
350
334
  storage_config = requests.get(
351
335
  f"{API_HOST}/storage-config/by-name/{name}?storage_type={storage_type.value}",
352
- headers=get_auth_headers(),
336
+ headers=get_headers(),
353
337
  timeout=READ_TIMEOUT,
354
338
  )
355
339
  raise_for_status_with_reason(storage_config)
@@ -360,17 +344,17 @@ class StorageConfig(BaseModel):
360
344
  organization_id: typing.Optional[int] = None,
361
345
  ) -> list["ResponseStorageConfig"]:
362
346
  """
363
- Lists all the `StorageConfig`'s created by user's default organization
364
- Note: The return type is `list[dict]` and not `list[StorageConfig]`
347
+ Lists all the :code:`StorageConfig`'s created by user's default organization
348
+ Note: The return type is :code:`list[dict]` and not :code:`list[StorageConfig]`
365
349
 
366
350
  Args:
367
- organization_id: The ID of the organization to list `StorageConfig`'s for.
368
- If not provided, it will list `StorageConfig`'s for the default organization.
351
+ organization_id: The ID of the organization to list :code:`StorageConfig`'s for.
352
+ If not provided, it will list :code:`StorageConfig`'s for the default organization.
369
353
  """
370
354
  storage_configs = requests.get(
371
355
  f"{API_HOST}/storage-config/",
372
356
  params={"storage_config_organization_id": organization_id},
373
- headers=get_auth_headers(),
357
+ headers=get_headers(),
374
358
  timeout=READ_TIMEOUT,
375
359
  )
376
360
  raise_for_status_with_reason(storage_configs)
@@ -379,14 +363,14 @@ class StorageConfig(BaseModel):
379
363
  @staticmethod
380
364
  def delete_by_id(storage_config_id) -> None:
381
365
  """
382
- Deletes a `StorageConfig` instance from the server by its ID
366
+ Deletes a :code:`StorageConfig` instance from the server by its ID
383
367
 
384
368
  Args:
385
- storage_config_id: The ID of the `StorageConfig` to delete
369
+ storage_config_id: The ID of the :code:`StorageConfig` to delete
386
370
  """
387
371
  storage_config = requests.delete(
388
372
  f"{API_HOST}/storage-config/{storage_config_id}",
389
- headers=get_auth_headers(),
373
+ headers=get_headers(),
390
374
  timeout=MODIFY_TIMEOUT,
391
375
  )
392
376
  raise_for_status_with_reason(storage_config)
@@ -394,7 +378,7 @@ class StorageConfig(BaseModel):
394
378
 
395
379
  def delete(self) -> None:
396
380
  """
397
- Deletes the `StorageConfig` instance from the server
381
+ Deletes the :code:`StorageConfig` instance from the server
398
382
  """
399
383
  if not self.id:
400
384
  raise ValueError("No StorageConfig has been created")
@@ -402,10 +386,10 @@ class StorageConfig(BaseModel):
402
386
 
403
387
  def create(self, replace_if_exists: bool = False) -> int:
404
388
  """
405
- Create a `StorageConfig` instance on the server
389
+ Create a :code:`StorageConfig` instance on the server
406
390
 
407
391
  Args:
408
- replace_if_exists: If a `StorageConfig` with the same name and type already exists, replace it.
392
+ replace_if_exists: If a :code:`StorageConfig` with the same name and type already exists, replace it.
409
393
  """
410
394
  if self.git and self.git.repo:
411
395
  self.git.repo_id = self.git.repo.create(replace_if_exists=replace_if_exists)
@@ -415,10 +399,7 @@ class StorageConfig(BaseModel):
415
399
  **self.model_dump(mode="json"),
416
400
  "replace_if_exists": replace_if_exists,
417
401
  },
418
- headers={
419
- **json_headers,
420
- **get_auth_headers(),
421
- },
402
+ headers=get_headers(),
422
403
  timeout=MODIFY_TIMEOUT,
423
404
  )
424
405
  raise_for_status_with_reason(storage_config)
hirundo/unzip.py ADDED
@@ -0,0 +1,247 @@
1
+ import typing
2
+ import zipfile
3
+ from collections.abc import Mapping
4
+ from pathlib import Path
5
+ from typing import IO, cast
6
+
7
+ import requests
8
+ from pydantic_core import Url
9
+
10
+ from hirundo._dataframe import (
11
+ float32,
12
+ has_pandas,
13
+ has_polars,
14
+ int32,
15
+ pd,
16
+ pl,
17
+ string,
18
+ )
19
+ from hirundo._env import API_HOST
20
+ from hirundo._headers import _get_auth_headers
21
+ from hirundo._timeouts import DOWNLOAD_READ_TIMEOUT
22
+ from hirundo.dataset_optimization_results import (
23
+ DataFrameType,
24
+ DatasetOptimizationResults,
25
+ )
26
+ from hirundo.logger import get_logger
27
+
28
+ ZIP_FILE_CHUNK_SIZE = 50 * 1024 * 1024 # 50 MB
29
+
30
+ Dtype = typing.Union[type[int32], type[float32], type[string]]
31
+
32
+
33
+ CUSTOMER_INTERCHANGE_DTYPES: Mapping[str, Dtype] = {
34
+ "image_path": string,
35
+ "label_path": string,
36
+ "segments_mask_path": string,
37
+ "segment_id": int32,
38
+ "label": string,
39
+ "bbox_id": string,
40
+ "xmin": float32,
41
+ "ymin": float32,
42
+ "xmax": float32,
43
+ "ymax": float32,
44
+ "suspect_level": float32, # If exists, must be one of the values in the enum below
45
+ "suggested_label": string,
46
+ "suggested_label_conf": float32,
47
+ "status": string,
48
+ # ⬆️ If exists, must be one of the following:
49
+ # NO_LABELS/MISSING_IMAGE/INVALID_IMAGE/INVALID_BBOX/INVALID_BBOX_SIZE/INVALID_SEG/INVALID_SEG_SIZE
50
+ }
51
+
52
+ logger = get_logger(__name__)
53
+
54
+
55
+ def _clean_df_index(df: "pd.DataFrame") -> "pd.DataFrame":
56
+ """
57
+ Clean the index of a DataFrame in case it has unnamed columns.
58
+
59
+ Args:
60
+ df (DataFrame): DataFrame to clean
61
+
62
+ Returns:
63
+ Cleaned Pandas DataFrame
64
+ """
65
+ index_cols = sorted(
66
+ [col for col in df.columns if col.startswith("Unnamed")], reverse=True
67
+ )
68
+ if len(index_cols) > 0:
69
+ df.set_index(index_cols.pop(), inplace=True)
70
+ df.rename_axis(index=None, columns=None, inplace=True)
71
+ if len(index_cols) > 0:
72
+ df.drop(columns=index_cols, inplace=True)
73
+
74
+ return df
75
+
76
+
77
+ def load_df(
78
+ file: "typing.Union[str, IO[bytes]]",
79
+ ) -> "DataFrameType":
80
+ """
81
+ Load a DataFrame from a CSV file.
82
+
83
+ Args:
84
+ file_name: The name of the CSV file to load.
85
+ dtypes: The data types of the columns in the DataFrame.
86
+
87
+ Returns:
88
+ The loaded DataFrame or `None` if neither Polars nor Pandas is available.
89
+ """
90
+ if has_polars:
91
+ return pl.read_csv(file, schema_overrides=CUSTOMER_INTERCHANGE_DTYPES)
92
+ elif has_pandas:
93
+ if typing.TYPE_CHECKING:
94
+ from pandas._typing import DtypeArg
95
+
96
+ dtype = cast("DtypeArg", CUSTOMER_INTERCHANGE_DTYPES)
97
+ # ⬆️ Casting since CUSTOMER_INTERCHANGE_DTYPES is a Mapping[str, Dtype] in this case
98
+ df = pd.read_csv(file, dtype=dtype)
99
+ return cast("DataFrameType", _clean_df_index(df))
100
+ # ⬆️ Casting since the return type is pd.DataFrame, but this is what DataFrameType is in this case
101
+ else:
102
+ return None
103
+
104
+
105
+ def get_mislabel_suspect_filename(filenames: list[str]):
106
+ mislabel_suspect_filename = "mislabel_suspects.csv"
107
+ if mislabel_suspect_filename not in filenames:
108
+ mislabel_suspect_filename = "image_mislabel_suspects.csv"
109
+ if mislabel_suspect_filename not in filenames:
110
+ mislabel_suspect_filename = "suspects.csv"
111
+ if mislabel_suspect_filename not in filenames:
112
+ raise ValueError(
113
+ "None of mislabel_suspects.csv, image_mislabel_suspects.csv or suspects.csv were found in the zip file"
114
+ )
115
+ return mislabel_suspect_filename
116
+
117
+
118
+ def download_and_extract_zip(
119
+ run_id: str, zip_url: str
120
+ ) -> DatasetOptimizationResults[DataFrameType]:
121
+ """
122
+ Download and extract the zip file from the given URL.
123
+
124
+ Note: It will only extract the `mislabel_suspects.csv` (vision - classification)
125
+ or `image_mislabel_suspects.csv` & `object_mislabel_suspects.csv` (vision - OD)
126
+ or `suspects.csv` (STT)
127
+ and `warnings_and_errors.csv` files from the zip file.
128
+
129
+ Args:
130
+ run_id: The ID of the optimization run.
131
+ zip_url: The URL of the zip file to download.
132
+
133
+ Returns:
134
+ The dataset optimization results object.
135
+ """
136
+ # Define the local file path
137
+ cache_dir = Path.home() / ".hirundo" / "cache"
138
+ cache_dir.mkdir(parents=True, exist_ok=True)
139
+ zip_file_path = cache_dir / f"{run_id}.zip"
140
+
141
+ headers = None
142
+ if Url(zip_url).scheme == "file":
143
+ zip_url = (
144
+ f"{API_HOST}/dataset-optimization/run/local-download"
145
+ + zip_url.replace("file://", "")
146
+ )
147
+ headers = _get_auth_headers()
148
+ # Stream the zip file download
149
+ with requests.get(
150
+ zip_url,
151
+ headers=headers,
152
+ timeout=DOWNLOAD_READ_TIMEOUT,
153
+ stream=True,
154
+ ) as r:
155
+ r.raise_for_status()
156
+ with open(zip_file_path, "wb") as f:
157
+ for chunk in r.iter_content(chunk_size=ZIP_FILE_CHUNK_SIZE):
158
+ f.write(chunk)
159
+ logger.info(
160
+ "Successfully downloaded the result zip file for run ID %s to %s",
161
+ run_id,
162
+ zip_file_path,
163
+ )
164
+
165
+ with zipfile.ZipFile(zip_file_path, "r") as z:
166
+ # Extract suspects file
167
+ suspects_df = None
168
+ object_suspects_df = None
169
+ warnings_and_errors_df = None
170
+
171
+ filenames = []
172
+ try:
173
+ filenames = [file.filename for file in z.filelist]
174
+ except Exception as e:
175
+ logger.error("Failed to get filenames from ZIP", exc_info=e)
176
+
177
+ try:
178
+ mislabel_suspect_filename = get_mislabel_suspect_filename(filenames)
179
+ with z.open(mislabel_suspect_filename) as suspects_file:
180
+ suspects_df = load_df(suspects_file)
181
+ logger.debug(
182
+ "Successfully loaded mislabel suspects into DataFrame for run ID %s",
183
+ run_id,
184
+ )
185
+ except Exception as e:
186
+ logger.error(
187
+ "Failed to load mislabel suspects into DataFrame", exc_info=e
188
+ )
189
+
190
+ object_mislabel_suspects_filename = "object_mislabel_suspects.csv"
191
+ if object_mislabel_suspects_filename in filenames:
192
+ try:
193
+ with z.open(
194
+ object_mislabel_suspects_filename
195
+ ) as object_suspects_file:
196
+ object_suspects_df = load_df(object_suspects_file)
197
+ logger.debug(
198
+ "Successfully loaded object mislabel suspects into DataFrame for run ID %s",
199
+ run_id,
200
+ )
201
+ except Exception as e:
202
+ logger.error(
203
+ "Failed to load object mislabel suspects into DataFrame",
204
+ exc_info=e,
205
+ )
206
+
207
+ try:
208
+ # Extract warnings_and_errors file
209
+ with z.open("warnings_and_errors.csv") as warnings_file:
210
+ warnings_and_errors_df = load_df(warnings_file)
211
+ logger.debug(
212
+ "Successfully loaded warnings and errors into DataFrame for run ID %s",
213
+ run_id,
214
+ )
215
+ except Exception as e:
216
+ logger.error(
217
+ "Failed to load warnings and errors into DataFrame", exc_info=e
218
+ )
219
+
220
+ return DatasetOptimizationResults[DataFrameType](
221
+ cached_zip_path=zip_file_path,
222
+ suspects=suspects_df,
223
+ object_suspects=object_suspects_df,
224
+ warnings_and_errors=warnings_and_errors_df,
225
+ )
226
+
227
+
228
+ def load_from_zip(
229
+ zip_path: Path, file_name: str
230
+ ) -> "typing.Union[pd.DataFrame, pl.DataFrame, None]":
231
+ """
232
+ Load a given file from a given zip file.
233
+
234
+ Args:
235
+ zip_path: The path to the zip file.
236
+ file_name: The name of the file to load.
237
+
238
+ Returns:
239
+ The loaded DataFrame or `None` if neither Polars nor Pandas is available.
240
+ """
241
+ with zipfile.ZipFile(zip_path, "r") as z:
242
+ try:
243
+ with z.open(file_name) as file:
244
+ return load_df(file)
245
+ except Exception as e:
246
+ logger.error("Failed to load %s from zip file", file_name, exc_info=e)
247
+ return None