hirundo 0.1.8__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,27 +1,28 @@
1
+ import datetime
1
2
  import json
2
3
  import typing
4
+ from abc import ABC, abstractmethod
3
5
  from collections.abc import AsyncGenerator, Generator
4
6
  from enum import Enum
5
- from io import StringIO
6
7
  from typing import overload
7
8
 
8
9
  import httpx
9
- import numpy as np
10
- import pandas as pd
11
10
  import requests
12
- from pandas._typing import DtypeArg
13
11
  from pydantic import BaseModel, Field, model_validator
14
12
  from tqdm import tqdm
15
13
  from tqdm.contrib.logging import logging_redirect_tqdm
16
14
 
15
+ from hirundo._constraints import HirundoUrl
17
16
  from hirundo._env import API_HOST
18
- from hirundo._headers import get_auth_headers, json_headers
17
+ from hirundo._headers import get_headers
19
18
  from hirundo._http import raise_for_status_with_reason
20
19
  from hirundo._iter_sse_retrying import aiter_sse_retrying, iter_sse_retrying
21
20
  from hirundo._timeouts import MODIFY_TIMEOUT, READ_TIMEOUT
22
- from hirundo.enum import DatasetMetadataType, LabellingType
21
+ from hirundo.dataset_enum import DatasetMetadataType, LabelingType
22
+ from hirundo.dataset_optimization_results import DatasetOptimizationResults
23
23
  from hirundo.logger import get_logger
24
- from hirundo.storage import StorageIntegration, StorageLink
24
+ from hirundo.storage import ResponseStorageConfig, StorageConfig
25
+ from hirundo.unzip import download_and_extract_zip
25
26
 
26
27
  logger = get_logger(__name__)
27
28
 
@@ -38,12 +39,14 @@ MAX_RETRIES = 200 # Max 200 retries for HTTP SSE connection
38
39
 
39
40
 
40
41
  class RunStatus(Enum):
41
- STARTED = "STARTED"
42
42
  PENDING = "PENDING"
43
+ STARTED = "STARTED"
43
44
  SUCCESS = "SUCCESS"
44
45
  FAILURE = "FAILURE"
45
46
  AWAITING_MANUAL_APPROVAL = "AWAITING MANUAL APPROVAL"
46
- RETRYING = "RETRYING"
47
+ REVOKED = "REVOKED"
48
+ REJECTED = "REJECTED"
49
+ RETRY = "RETRY"
47
50
 
48
51
 
49
52
  STATUS_TO_TEXT_MAP = {
@@ -52,7 +55,9 @@ STATUS_TO_TEXT_MAP = {
52
55
  RunStatus.SUCCESS.value: "Optimization run completed successfully",
53
56
  RunStatus.FAILURE.value: "Optimization run failed",
54
57
  RunStatus.AWAITING_MANUAL_APPROVAL.value: "Awaiting manual approval",
55
- RunStatus.RETRYING.value: "Optimization run failed. Retrying",
58
+ RunStatus.RETRY.value: "Optimization run failed. Retrying",
59
+ RunStatus.REVOKED.value: "Optimization run was cancelled",
60
+ RunStatus.REJECTED.value: "Optimization run was rejected",
56
61
  }
57
62
  STATUS_TO_PROGRESS_MAP = {
58
63
  RunStatus.STARTED.value: 0.0,
@@ -60,100 +65,284 @@ STATUS_TO_PROGRESS_MAP = {
60
65
  RunStatus.SUCCESS.value: 100.0,
61
66
  RunStatus.FAILURE.value: 100.0,
62
67
  RunStatus.AWAITING_MANUAL_APPROVAL.value: 100.0,
63
- RunStatus.RETRYING.value: 0.0,
68
+ RunStatus.RETRY.value: 0.0,
69
+ RunStatus.REVOKED.value: 100.0,
70
+ RunStatus.REJECTED.value: 0.0,
64
71
  }
65
72
 
66
73
 
67
- class DatasetOptimizationResults(BaseModel):
68
- model_config = {"arbitrary_types_allowed": True}
74
+ class Metadata(BaseModel, ABC):
75
+ type: DatasetMetadataType
76
+
77
+ @property
78
+ @abstractmethod
79
+ def metadata_url(self) -> HirundoUrl:
80
+ raise NotImplementedError()
69
81
 
70
- suspects: pd.DataFrame
82
+
83
+ class HirundoCSV(Metadata):
71
84
  """
72
- A pandas DataFrame containing the results of the optimization run
85
+ A dataset metadata file in the Hirundo CSV format
73
86
  """
74
- warnings_and_errors: pd.DataFrame
87
+
88
+ type: DatasetMetadataType = DatasetMetadataType.HIRUNDO_CSV
89
+ csv_url: HirundoUrl
75
90
  """
76
- A pandas DataFrame containing the warnings and errors of the optimization run
91
+ The URL to access the dataset metadata CSV file.
92
+ e.g. `s3://my-bucket-name/my-folder/my-metadata.csv`, `gs://my-bucket-name/my-folder/my-metadata.csv`,
93
+ or `ssh://my-username@my-repo-name/my-folder/my-metadata.csv`
94
+ (or `file:///datasets/my-folder/my-metadata.csv` if using LOCAL storage type with on-premises installation)
77
95
  """
78
96
 
97
+ @property
98
+ def metadata_url(self) -> HirundoUrl:
99
+ return self.csv_url
100
+
101
+
102
+ class COCO(Metadata):
103
+ """
104
+ A dataset metadata file in the COCO format
105
+ """
106
+
107
+ type: DatasetMetadataType = DatasetMetadataType.COCO
108
+ json_url: HirundoUrl
109
+ """
110
+ The URL to access the dataset metadata JSON file.
111
+ e.g. `s3://my-bucket-name/my-folder/my-metadata.json`, `gs://my-bucket-name/my-folder/my-metadata.json`,
112
+ or `ssh://my-username@my-repo-name/my-folder/my-metadata.json`
113
+ (or `file:///datasets/my-folder/my-metadata.json` if using LOCAL storage type with on-premises installation)
114
+ """
115
+
116
+ @property
117
+ def metadata_url(self) -> HirundoUrl:
118
+ return self.json_url
119
+
120
+
121
+ class YOLO(Metadata):
122
+ type: DatasetMetadataType = DatasetMetadataType.YOLO
123
+ data_yaml_url: typing.Optional[HirundoUrl] = None
124
+ labels_dir_url: HirundoUrl
125
+
126
+ @property
127
+ def metadata_url(self) -> HirundoUrl:
128
+ return self.labels_dir_url
129
+
130
+
131
+ LabelingInfo = typing.Union[HirundoCSV, COCO, YOLO]
132
+ """
133
+ The dataset labeling info. The dataset labeling info can be one of the following:
134
+ - `DatasetMetadataType.HirundoCSV`: Indicates that the dataset metadata file is a CSV file with the Hirundo format
135
+
136
+ Currently no other formats are supported. Future versions of `hirundo` may support additional formats.
137
+ """
138
+
139
+
140
+ class VisionRunArgs(BaseModel):
141
+ upsample: bool = False
142
+ """
143
+ Whether to upsample the dataset to attempt to balance the classes.
144
+ """
145
+ min_abs_bbox_size: int = 0
146
+ """
147
+ Minimum valid size (in pixels) of a bounding box to keep it in the dataset for optimization.
148
+ """
149
+ min_abs_bbox_area: int = 0
150
+ """
151
+ Minimum valid absolute area (in pixels²) of a bounding box to keep it in the dataset for optimization.
152
+ """
153
+ min_rel_bbox_size: float = 0.0
154
+ """
155
+ Minimum valid size (as a fraction of both image height and width) for a bounding box
156
+ to keep it in the dataset for optimization, relative to the corresponding dimension size,
157
+ i.e. if the bounding box is 10% of the image width and 5% of the image height, it will be kept if this value is 0.05, but not if the
158
+ value is 0.06 (since both width and height are checked).
159
+ """
160
+ min_rel_bbox_area: float = 0.0
161
+ """
162
+ Minimum valid relative area (as a fraction of the image area) of a bounding box to keep it in the dataset for optimization.
163
+ """
164
+
165
+
166
+ RunArgs = typing.Union[VisionRunArgs]
167
+
168
+
169
+ class AugmentationName(str, Enum):
170
+ RANDOM_HORIZONTAL_FLIP = "RandomHorizontalFlip"
171
+ RANDOM_VERTICAL_FLIP = "RandomVerticalFlip"
172
+ RANDOM_ROTATION = "RandomRotation"
173
+ RANDOM_PERSPECTIVE = "RandomPerspective"
174
+ GAUSSIAN_NOISE = "GaussianNoise"
175
+ RANDOM_GRAYSCALE = "RandomGrayscale"
176
+ GAUSSIAN_BLUR = "GaussianBlur"
79
177
 
80
- CUSTOMER_INTERCHANGE_DTYPES: DtypeArg = {
81
- "image_path": str,
82
- "label_path": str,
83
- "segments_mask_path": str,
84
- "segment_id": np.int32,
85
- "label": str,
86
- "bbox_id": str,
87
- "xmin": np.int32,
88
- "ymin": np.int32,
89
- "xmax": np.int32,
90
- "ymax": np.int32,
91
- "suspect_level": np.float32, # If exists, must be one of the values in the enum below
92
- "suggested_label": str,
93
- "suggested_label_conf": np.float32,
94
- "status": str,
95
- # ⬆️ If exists, must be one of the following:
96
- # NO_LABELS/MISSING_IMAGE/INVALID_IMAGE/INVALID_BBOX/INVALID_BBOX_SIZE/INVALID_SEG/INVALID_SEG_SIZE
97
- }
178
+
179
+ class Modality(str, Enum):
180
+ IMAGE = "Image"
181
+ RADAR = "Radar"
182
+ EKG = "EKG"
98
183
 
99
184
 
100
185
  class OptimizationDataset(BaseModel):
186
+ id: typing.Optional[int] = Field(default=None)
187
+ """
188
+ The ID of the dataset created on the server.
189
+ """
101
190
  name: str
102
191
  """
103
192
  The name of the dataset. Used to identify it amongst the list of datasets
104
193
  belonging to your organization in `hirundo`.
105
194
  """
106
- labelling_type: LabellingType
195
+ labeling_type: LabelingType
107
196
  """
108
- Indicates the labelling type of the dataset. The labelling type can be one of the following:
109
- - `LabellingType.SingleLabelClassification`: Indicates that the dataset is for classification tasks
110
- - `LabellingType.ObjectDetection`: Indicates that the dataset is for object detection tasks
197
+ Indicates the labeling type of the dataset. The labeling type can be one of the following:
198
+ - `LabelingType.SINGLE_LABEL_CLASSIFICATION`: Indicates that the dataset is for classification tasks
199
+ - `LabelingType.OBJECT_DETECTION`: Indicates that the dataset is for object detection tasks
200
+ - `LabelingType.SPEECH_TO_TEXT`: Indicates that the dataset is for speech-to-text tasks
111
201
  """
112
- dataset_storage: typing.Optional[StorageLink]
202
+ language: typing.Optional[str] = None
113
203
  """
114
- The storage link to the dataset. This can be a link to a file or a directory containing the dataset.
115
- If `None`, the `dataset_id` field must be set.
204
+ Language of the Speech-to-Text audio dataset. This is required for Speech-to-Text datasets.
116
205
  """
117
-
118
- classes: typing.Optional[list[str]] = None
206
+ storage_config_id: typing.Optional[int] = None
119
207
  """
120
- A full list of possible classes used in classification / object detection.
121
- It is currently required for clarity and performance.
208
+ The ID of the storage config used to store the dataset and metadata.
209
+ """
210
+ storage_config: typing.Optional[
211
+ typing.Union[StorageConfig, ResponseStorageConfig]
212
+ ] = None
122
213
  """
123
- dataset_metadata_path: str = "metadata.csv"
214
+ The `StorageConfig` instance to link to.
124
215
  """
125
- The path to the dataset metadata file within storage integration, e.g. S3 Bucket / GCP Bucket / Azure Blob storage / Git repo.
126
- Note: This path will be prefixed with the `StorageLink`'s `path`.
216
+ data_root_url: HirundoUrl
127
217
  """
128
- dataset_metadata_type: DatasetMetadataType = DatasetMetadataType.HirundoCSV
218
+ URL for data (e.g. images) within the `StorageConfig` instance,
219
+ e.g. `s3://my-bucket-name/my-images-folder`, `gs://my-bucket-name/my-images-folder`,
220
+ or `ssh://my-username@my-repo-name/my-images-folder`
221
+ (or `file:///datasets/my-images-folder` if using LOCAL storage type with on-premises installation)
222
+
223
+ Note: All CSV `image_path` entries in the metadata file should be relative to this folder.
129
224
  """
130
- The type of dataset metadata file. The dataset metadata file can be one of the following:
131
- - `DatasetMetadataType.HirundoCSV`: Indicates that the dataset metadata file is a CSV file with the Hirundo format
132
225
 
133
- Currently no other formats are supported. Future versions of `hirundo` may support additional formats.
226
+ classes: typing.Optional[list[str]] = None
227
+ """
228
+ A full list of possible classes used in classification / object detection.
229
+ It is currently required for clarity and performance.
134
230
  """
231
+ labeling_info: LabelingInfo
135
232
 
136
- storage_integration_id: typing.Optional[int] = Field(default=None, init=False)
233
+ augmentations: typing.Optional[list[AugmentationName]] = None
137
234
  """
138
- The ID of the storage integration used to store the dataset and metadata.
235
+ Used to define which augmentations are apply to a vision dataset.
236
+ For audio datasets, this field is ignored.
237
+ If no value is provided, all augmentations are applied to vision datasets.
139
238
  """
140
- dataset_id: typing.Optional[int] = Field(default=None, init=False)
239
+ modality: Modality = Modality.IMAGE
141
240
  """
142
- The ID of the dataset created on the server.
241
+ Used to define the modality of the dataset.
242
+ Defaults to Image.
143
243
  """
244
+
144
245
  run_id: typing.Optional[str] = Field(default=None, init=False)
145
246
  """
146
247
  The ID of the Dataset Optimization run created on the server.
147
248
  """
148
249
 
250
+ status: typing.Optional[RunStatus] = None
251
+
149
252
  @model_validator(mode="after")
150
253
  def validate_dataset(self):
151
- if self.dataset_storage is None and self.storage_integration_id is None:
152
- raise ValueError("No dataset storage has been provided")
254
+ if self.storage_config is None and self.storage_config_id is None:
255
+ raise ValueError(
256
+ "No dataset storage has been provided. Provide one via `storage_config` or `storage_config_id`"
257
+ )
258
+ elif self.storage_config is not None and self.storage_config_id is not None:
259
+ raise ValueError(
260
+ "Both `storage_config` and `storage_config_id` have been provided. Pick one."
261
+ )
262
+ if self.labeling_type == LabelingType.SPEECH_TO_TEXT and self.language is None:
263
+ raise ValueError("Language is required for Speech-to-Text datasets.")
264
+ elif (
265
+ self.labeling_type != LabelingType.SPEECH_TO_TEXT
266
+ and self.language is not None
267
+ ):
268
+ raise ValueError("Language is only allowed for Speech-to-Text datasets.")
269
+ if (
270
+ self.labeling_info.type == DatasetMetadataType.YOLO
271
+ and isinstance(self.labeling_info, YOLO)
272
+ and (
273
+ self.labeling_info.data_yaml_url is not None
274
+ and self.classes is not None
275
+ )
276
+ ):
277
+ raise ValueError(
278
+ "Only one of `classes` or `labeling_info.data_yaml_url` should be provided for YOLO datasets"
279
+ )
153
280
  return self
154
281
 
155
282
  @staticmethod
156
- def list(organization_id: typing.Optional[int] = None) -> list[dict]:
283
+ def get_by_id(dataset_id: int) -> "OptimizationDataset":
284
+ """
285
+ Get a `OptimizationDataset` instance from the server by its ID
286
+
287
+ Args:
288
+ dataset_id: The ID of the `OptimizationDataset` instance to get
289
+ """
290
+ response = requests.get(
291
+ f"{API_HOST}/dataset-optimization/dataset/{dataset_id}",
292
+ headers=get_headers(),
293
+ timeout=READ_TIMEOUT,
294
+ )
295
+ raise_for_status_with_reason(response)
296
+ dataset = response.json()
297
+ return OptimizationDataset(**dataset)
298
+
299
+ @staticmethod
300
+ def get_by_name(name: str) -> "OptimizationDataset":
301
+ """
302
+ Get a `OptimizationDataset` instance from the server by its name
303
+
304
+ Args:
305
+ name: The name of the `OptimizationDataset` instance to get
306
+ """
307
+ response = requests.get(
308
+ f"{API_HOST}/dataset-optimization/dataset/by-name/{name}",
309
+ headers=get_headers(),
310
+ timeout=READ_TIMEOUT,
311
+ )
312
+ raise_for_status_with_reason(response)
313
+ dataset = response.json()
314
+ return OptimizationDataset(**dataset)
315
+
316
+ @staticmethod
317
+ def list_datasets(
318
+ organization_id: typing.Optional[int] = None,
319
+ ) -> list["DataOptimizationDatasetOut"]:
320
+ """
321
+ Lists all the optimization datasets created by user's default organization
322
+ or the `organization_id` passed
323
+
324
+ Args:
325
+ organization_id: The ID of the organization to list the datasets for.
326
+ """
327
+ response = requests.get(
328
+ f"{API_HOST}/dataset-optimization/dataset/",
329
+ params={"dataset_organization_id": organization_id},
330
+ headers=get_headers(),
331
+ timeout=READ_TIMEOUT,
332
+ )
333
+ raise_for_status_with_reason(response)
334
+ datasets = response.json()
335
+ return [
336
+ DataOptimizationDatasetOut(
337
+ **ds,
338
+ )
339
+ for ds in datasets
340
+ ]
341
+
342
+ @staticmethod
343
+ def list_runs(
344
+ organization_id: typing.Optional[int] = None,
345
+ ) -> list["DataOptimizationRunOut"]:
157
346
  """
158
347
  Lists all the `OptimizationDataset` instances created by user's default organization
159
348
  or the `organization_id` passed
@@ -163,13 +352,19 @@ class OptimizationDataset(BaseModel):
163
352
  organization_id: The ID of the organization to list the datasets for.
164
353
  """
165
354
  response = requests.get(
166
- f"{API_HOST}/dataset-optimization/dataset/",
355
+ f"{API_HOST}/dataset-optimization/run/list",
167
356
  params={"dataset_organization_id": organization_id},
168
- headers=get_auth_headers(),
357
+ headers=get_headers(),
169
358
  timeout=READ_TIMEOUT,
170
359
  )
171
360
  raise_for_status_with_reason(response)
172
- return response.json()
361
+ runs = response.json()
362
+ return [
363
+ DataOptimizationRunOut(
364
+ **run,
365
+ )
366
+ for run in runs
367
+ ]
173
368
 
174
369
  @staticmethod
175
370
  def delete_by_id(dataset_id: int) -> None:
@@ -181,73 +376,94 @@ class OptimizationDataset(BaseModel):
181
376
  """
182
377
  response = requests.delete(
183
378
  f"{API_HOST}/dataset-optimization/dataset/{dataset_id}",
184
- headers=get_auth_headers(),
379
+ headers=get_headers(),
185
380
  timeout=MODIFY_TIMEOUT,
186
381
  )
187
382
  raise_for_status_with_reason(response)
188
383
  logger.info("Deleted dataset with ID: %s", dataset_id)
189
384
 
190
- def delete(self, storage_integration=True) -> None:
385
+ def delete(self, storage_config=True) -> None:
191
386
  """
192
387
  Deletes the active `OptimizationDataset` instance from the server.
193
388
  It can only be used on a `OptimizationDataset` instance that has been created.
194
389
 
195
390
  Args:
196
- storage_integration: If True, the `OptimizationDataset`'s `StorageIntegration` will also be deleted
391
+ storage_config: If True, the `OptimizationDataset`'s `StorageConfig` will also be deleted
197
392
 
198
- Note: If `storage_integration` is not set to `False` then the `storage_integration_id` must be set
199
- This can either be set manually or by creating the `StorageIntegration` instance via the `OptimizationDataset`'s
393
+ Note: If `storage_config` is not set to `False` then the `storage_config_id` must be set
394
+ This can either be set manually or by creating the `StorageConfig` instance via the `OptimizationDataset`'s
200
395
  `create` method
201
396
  """
202
- if storage_integration:
203
- if not self.storage_integration_id:
204
- raise ValueError("No storage integration has been created")
205
- StorageIntegration.delete_by_id(self.storage_integration_id)
206
- if not self.dataset_id:
397
+ if storage_config:
398
+ if not self.storage_config_id:
399
+ raise ValueError("No storage config has been created")
400
+ StorageConfig.delete_by_id(self.storage_config_id)
401
+ if not self.id:
207
402
  raise ValueError("No dataset has been created")
208
- self.delete_by_id(self.dataset_id)
403
+ self.delete_by_id(self.id)
209
404
 
210
- def create(self) -> int:
405
+ def create(
406
+ self,
407
+ organization_id: typing.Optional[int] = None,
408
+ replace_if_exists: bool = False,
409
+ ) -> int:
211
410
  """
212
411
  Create a `OptimizationDataset` instance on the server.
213
- If `storage_integration_id` is not set, it will be created.
412
+ If the `storage_config_id` field is not set, the storage config will also be created and the field will be set.
413
+
414
+ Args:
415
+ organization_id: The ID of the organization to create the dataset for.
416
+ replace_if_exists: If True, the dataset will be replaced if it already exists
417
+ (this is determined by a dataset of the same name in the same organization).
418
+
419
+ Returns:
420
+ The ID of the created `OptimizationDataset` instance
214
421
  """
215
- if not self.dataset_storage:
422
+ if self.storage_config is None and self.storage_config_id is None:
216
423
  raise ValueError("No dataset storage has been provided")
217
- if (
218
- self.dataset_storage
219
- and self.dataset_storage.storage_integration
220
- and not self.storage_integration_id
424
+ elif self.storage_config and self.storage_config_id is None:
425
+ if isinstance(self.storage_config, ResponseStorageConfig):
426
+ self.storage_config_id = self.storage_config.id
427
+ elif isinstance(self.storage_config, StorageConfig):
428
+ self.storage_config_id = self.storage_config.create(
429
+ replace_if_exists=replace_if_exists,
430
+ )
431
+ elif (
432
+ self.storage_config is not None
433
+ and self.storage_config_id is not None
434
+ and (
435
+ not isinstance(self.storage_config, ResponseStorageConfig)
436
+ or self.storage_config.id != self.storage_config_id
437
+ )
221
438
  ):
222
- self.storage_integration_id = (
223
- self.dataset_storage.storage_integration.create()
439
+ raise ValueError(
440
+ "Both `storage_config` and `storage_config_id` have been provided. Storage config IDs do not match."
224
441
  )
225
- model_dict = self.model_dump()
442
+ model_dict = self.model_dump(mode="json")
226
443
  # ⬆️ Get dict of model fields from Pydantic model instance
227
444
  dataset_response = requests.post(
228
445
  f"{API_HOST}/dataset-optimization/dataset/",
229
446
  json={
230
- "dataset_storage": {
231
- "storage_integration_id": self.storage_integration_id,
232
- "path": self.dataset_storage.path,
233
- },
234
- **{k: model_dict[k] for k in model_dict.keys() - {"dataset_storage"}},
235
- },
236
- headers={
237
- **json_headers,
238
- **get_auth_headers(),
447
+ **{k: model_dict[k] for k in model_dict.keys() - {"storage_config"}},
448
+ "organization_id": organization_id,
449
+ "replace_if_exists": replace_if_exists,
239
450
  },
451
+ headers=get_headers(),
240
452
  timeout=MODIFY_TIMEOUT,
241
453
  )
242
454
  raise_for_status_with_reason(dataset_response)
243
- self.dataset_id = dataset_response.json()["id"]
244
- if not self.dataset_id:
245
- raise HirundoError("Failed to create the dataset")
246
- logger.info("Created dataset with ID: %s", self.dataset_id)
247
- return self.dataset_id
455
+ self.id = dataset_response.json()["id"]
456
+ if not self.id:
457
+ raise HirundoError("An error ocurred while trying to create the dataset")
458
+ logger.info("Created dataset with ID: %s", self.id)
459
+ return self.id
248
460
 
249
461
  @staticmethod
250
- def launch_optimization_run(dataset_id: int) -> str:
462
+ def launch_optimization_run(
463
+ dataset_id: int,
464
+ organization_id: typing.Optional[int] = None,
465
+ run_args: typing.Optional[RunArgs] = None,
466
+ ) -> str:
251
467
  """
252
468
  Run the dataset optimization process on the server using the dataset with the given ID
253
469
  i.e. `dataset_id`.
@@ -258,26 +474,62 @@ class OptimizationDataset(BaseModel):
258
474
  Returns:
259
475
  ID of the run (`run_id`).
260
476
  """
477
+ run_info = {}
478
+ if organization_id:
479
+ run_info["organization_id"] = organization_id
480
+ if run_args:
481
+ run_info["run_args"] = run_args.model_dump(mode="json")
261
482
  run_response = requests.post(
262
483
  f"{API_HOST}/dataset-optimization/run/{dataset_id}",
263
- headers=get_auth_headers(),
484
+ json=run_info if len(run_info) > 0 else None,
485
+ headers=get_headers(),
264
486
  timeout=MODIFY_TIMEOUT,
265
487
  )
266
488
  raise_for_status_with_reason(run_response)
267
489
  return run_response.json()["run_id"]
268
490
 
269
- def run_optimization(self) -> str:
491
+ def _validate_run_args(self, run_args: RunArgs) -> None:
492
+ if self.labeling_type == LabelingType.SPEECH_TO_TEXT:
493
+ raise Exception("Speech to text cannot have `run_args` set")
494
+ if self.labeling_type != LabelingType.OBJECT_DETECTION and any(
495
+ (
496
+ run_args.min_abs_bbox_size != 0,
497
+ run_args.min_abs_bbox_area != 0,
498
+ run_args.min_rel_bbox_size != 0,
499
+ run_args.min_rel_bbox_area != 0,
500
+ )
501
+ ):
502
+ raise Exception(
503
+ "Cannot set `min_abs_bbox_size`, `min_abs_bbox_area`, "
504
+ + "`min_rel_bbox_size`, or `min_rel_bbox_area` for "
505
+ + f"labeling type {self.labeling_type}"
506
+ )
507
+
508
+ def run_optimization(
509
+ self,
510
+ organization_id: typing.Optional[int] = None,
511
+ replace_dataset_if_exists: bool = False,
512
+ run_args: typing.Optional[RunArgs] = None,
513
+ ) -> str:
270
514
  """
271
515
  If the dataset was not created on the server yet, it is created.
272
516
  Run the dataset optimization process on the server using the active `OptimizationDataset` instance
273
517
 
518
+ Args:
519
+ organization_id: The ID of the organization to run the optimization for.
520
+ replace_dataset_if_exists: If True, the dataset will be replaced if it already exists
521
+ (this is determined by a dataset of the same name in the same organization).
522
+ run_args: The run arguments to use for the optimization run
523
+
274
524
  Returns:
275
525
  An ID of the run (`run_id`) and stores that `run_id` on the instance
276
526
  """
277
527
  try:
278
- if not self.dataset_id:
279
- self.dataset_id = self.create()
280
- run_id = self.launch_optimization_run(self.dataset_id)
528
+ if not self.id:
529
+ self.id = self.create(replace_if_exists=replace_dataset_if_exists)
530
+ if run_args is not None:
531
+ self._validate_run_args(run_args)
532
+ run_id = self.launch_optimization_run(self.id, organization_id, run_args)
281
533
  self.run_id = run_id
282
534
  logger.info("Started the run with ID: %s", run_id)
283
535
  return run_id
@@ -293,59 +545,19 @@ class OptimizationDataset(BaseModel):
293
545
  except Exception:
294
546
  content = error.response.text
295
547
  raise HirundoError(
296
- f"Failed to start the run. Status code: {error.response.status_code} Content: {content}"
548
+ f"Unable to start the run. Status code: {error.response.status_code} Content: {content}"
297
549
  ) from error
298
550
  except Exception as error:
299
- raise HirundoError(f"Failed to start the run: {error}") from error
551
+ raise HirundoError(f"Unable to start the run: {error}") from error
300
552
 
301
553
  def clean_ids(self):
302
554
  """
303
- Reset `dataset_id`, `storage_integration_id`, and `run_id` values on the instance to default value of `None`
555
+ Reset `dataset_id`, `storage_config_id`, and `run_id` values on the instance to default value of `None`
304
556
  """
305
- self.storage_integration_id = None
306
- self.dataset_id = None
557
+ self.storage_config_id = None
558
+ self.id = None
307
559
  self.run_id = None
308
560
 
309
- @staticmethod
310
- def _clean_df_index(df: "pd.DataFrame") -> "pd.DataFrame":
311
- """
312
- Clean the index of a dataframe in case it has unnamed columns.
313
-
314
- Args:
315
- df (DataFrame): Dataframe to clean
316
-
317
- Returns:
318
- DataFrame: Cleaned dataframe
319
- """
320
- index_cols = sorted(
321
- [col for col in df.columns if col.startswith("Unnamed")], reverse=True
322
- )
323
- if len(index_cols) > 0:
324
- df.set_index(index_cols.pop(), inplace=True)
325
- df.rename_axis(index=None, columns=None, inplace=True)
326
- if len(index_cols) > 0:
327
- df.drop(columns=index_cols, inplace=True)
328
-
329
- return df
330
-
331
- @staticmethod
332
- def _read_csvs_to_df(data: dict):
333
- if data["state"] == RunStatus.SUCCESS.value:
334
- data["result"]["suspects"] = OptimizationDataset._clean_df_index(
335
- pd.read_csv(
336
- StringIO(data["result"]["suspects"]),
337
- dtype=CUSTOMER_INTERCHANGE_DTYPES,
338
- )
339
- )
340
- data["result"]["warnings_and_errors"] = OptimizationDataset._clean_df_index(
341
- pd.read_csv(
342
- StringIO(data["result"]["warnings_and_errors"]),
343
- dtype=CUSTOMER_INTERCHANGE_DTYPES,
344
- )
345
- )
346
- else:
347
- pass
348
-
349
561
  @staticmethod
350
562
  def _check_run_by_id(run_id: str, retry=0) -> Generator[dict, None, None]:
351
563
  if retry > MAX_RETRIES:
@@ -356,7 +568,7 @@ class OptimizationDataset(BaseModel):
356
568
  client,
357
569
  "GET",
358
570
  f"{API_HOST}/dataset-optimization/run/{run_id}",
359
- headers=get_auth_headers(),
571
+ headers=get_headers(),
360
572
  ):
361
573
  if sse.event == "ping":
362
574
  continue
@@ -370,8 +582,15 @@ class OptimizationDataset(BaseModel):
370
582
  last_event = json.loads(sse.data)
371
583
  if not last_event:
372
584
  continue
373
- data = last_event["data"]
374
- OptimizationDataset._read_csvs_to_df(data)
585
+ if "data" in last_event:
586
+ data = last_event["data"]
587
+ else:
588
+ if "detail" in last_event:
589
+ raise HirundoError(last_event["detail"])
590
+ elif "reason" in last_event:
591
+ raise HirundoError(last_event["reason"])
592
+ else:
593
+ raise HirundoError("Unknown error")
375
594
  yield data
376
595
  if not last_event or last_event["data"]["state"] == RunStatus.PENDING.value:
377
596
  OptimizationDataset._check_run_by_id(run_id, retry + 1)
@@ -420,17 +639,22 @@ class OptimizationDataset(BaseModel):
420
639
  t.n = STATUS_TO_PROGRESS_MAP[iteration["state"]]
421
640
  logger.debug("Setting progress to %s", t.n)
422
641
  t.refresh()
423
- if iteration["state"] == RunStatus.FAILURE.value:
642
+ if iteration["state"] in [
643
+ RunStatus.FAILURE.value,
644
+ RunStatus.REJECTED.value,
645
+ RunStatus.REVOKED.value,
646
+ ]:
424
647
  raise HirundoError(
425
648
  f"Optimization run failed with error: {iteration['result']}"
426
649
  )
427
650
  elif iteration["state"] == RunStatus.SUCCESS.value:
428
651
  t.close()
429
- return DatasetOptimizationResults(
430
- suspects=iteration["result"]["suspects"],
431
- warnings_and_errors=iteration["result"][
432
- "warnings_and_errors"
433
- ],
652
+ zip_temporary_url = iteration["result"]
653
+ logger.debug("Optimization run completed. Downloading results")
654
+
655
+ return download_and_extract_zip(
656
+ run_id,
657
+ zip_temporary_url,
434
658
  )
435
659
  elif (
436
660
  iteration["state"] == RunStatus.AWAITING_MANUAL_APPROVAL.value
@@ -445,13 +669,22 @@ class OptimizationDataset(BaseModel):
445
669
  and iteration["result"]["result"]
446
670
  and isinstance(iteration["result"]["result"], str)
447
671
  ):
448
- current_progress_percentage = float(
449
- iteration["result"]["result"].removesuffix("% done")
450
- )
672
+ result_info = iteration["result"]["result"].split(":")
673
+ if len(result_info) > 1:
674
+ stage = result_info[0]
675
+ current_progress_percentage = float(
676
+ result_info[1].removeprefix(" ").removesuffix("% done")
677
+ )
678
+ elif len(result_info) == 1:
679
+ stage = result_info[0]
680
+ current_progress_percentage = t.n # Keep the same progress
681
+ else:
682
+ stage = "Unknown progress state"
683
+ current_progress_percentage = t.n # Keep the same progress
451
684
  desc = (
452
685
  "Optimization run completed. Uploading results"
453
686
  if current_progress_percentage == 100.0
454
- else "Optimization run in progress"
687
+ else stage
455
688
  )
456
689
  t.set_description(desc)
457
690
  t.n = current_progress_percentage
@@ -513,7 +746,7 @@ class OptimizationDataset(BaseModel):
513
746
  client,
514
747
  "GET",
515
748
  f"{API_HOST}/dataset-optimization/run/{run_id}",
516
- headers=get_auth_headers(),
749
+ headers=get_headers(),
517
750
  )
518
751
  async for sse in async_iterator:
519
752
  if sse.event == "ping":
@@ -562,7 +795,7 @@ class OptimizationDataset(BaseModel):
562
795
  logger.info("Cancelling run with ID: %s", run_id)
563
796
  response = requests.delete(
564
797
  f"{API_HOST}/dataset-optimization/run/{run_id}",
565
- headers=get_auth_headers(),
798
+ headers=get_headers(),
566
799
  timeout=MODIFY_TIMEOUT,
567
800
  )
568
801
  raise_for_status_with_reason(response)
@@ -574,3 +807,33 @@ class OptimizationDataset(BaseModel):
574
807
  if not self.run_id:
575
808
  raise ValueError("No run has been started")
576
809
  self.cancel_by_id(self.run_id)
810
+
811
+
812
+ class DataOptimizationDatasetOut(BaseModel):
813
+ id: int
814
+
815
+ name: str
816
+ labeling_type: LabelingType
817
+
818
+ storage_config: ResponseStorageConfig
819
+
820
+ data_root_url: HirundoUrl
821
+
822
+ classes: typing.Optional[list[str]] = None
823
+ labeling_info: LabelingInfo
824
+
825
+ organization_id: typing.Optional[int]
826
+ creator_id: typing.Optional[int]
827
+ created_at: datetime.datetime
828
+ updated_at: datetime.datetime
829
+
830
+
831
+ class DataOptimizationRunOut(BaseModel):
832
+ id: int
833
+ name: str
834
+ dataset_id: int
835
+ run_id: str
836
+ status: RunStatus
837
+ approved: bool
838
+ created_at: datetime.datetime
839
+ run_args: typing.Optional[RunArgs]