hirundo 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,24 +1,30 @@
1
+ import datetime
1
2
  import json
2
3
  import typing
4
+ from abc import ABC, abstractmethod
3
5
  from collections.abc import AsyncGenerator, Generator
4
6
  from enum import Enum
5
7
  from io import StringIO
6
- from typing import Union, overload
8
+ from typing import overload
7
9
 
8
10
  import httpx
11
+ import numpy as np
9
12
  import pandas as pd
10
13
  import requests
14
+ from pandas._typing import DtypeArg
11
15
  from pydantic import BaseModel, Field, model_validator
12
16
  from tqdm import tqdm
13
17
  from tqdm.contrib.logging import logging_redirect_tqdm
14
18
 
19
+ from hirundo._constraints import HirundoUrl
15
20
  from hirundo._env import API_HOST
16
21
  from hirundo._headers import get_auth_headers, json_headers
22
+ from hirundo._http import raise_for_status_with_reason
17
23
  from hirundo._iter_sse_retrying import aiter_sse_retrying, iter_sse_retrying
18
24
  from hirundo._timeouts import MODIFY_TIMEOUT, READ_TIMEOUT
19
- from hirundo.enum import DatasetMetadataType, LabellingType
25
+ from hirundo.enum import DatasetMetadataType, LabelingType
20
26
  from hirundo.logger import get_logger
21
- from hirundo.storage import StorageIntegration, StorageLink
27
+ from hirundo.storage import ResponseStorageConfig, StorageConfig
22
28
 
23
29
  logger = get_logger(__name__)
24
30
 
@@ -35,70 +41,342 @@ MAX_RETRIES = 200 # Max 200 retries for HTTP SSE connection
35
41
 
36
42
 
37
43
  class RunStatus(Enum):
38
- STARTED = "STARTED"
39
44
  PENDING = "PENDING"
45
+ STARTED = "STARTED"
40
46
  SUCCESS = "SUCCESS"
41
47
  FAILURE = "FAILURE"
42
48
  AWAITING_MANUAL_APPROVAL = "AWAITING MANUAL APPROVAL"
49
+ REVOKED = "REVOKED"
50
+ REJECTED = "REJECTED"
51
+ RETRY = "RETRY"
52
+
53
+
54
+ STATUS_TO_TEXT_MAP = {
55
+ RunStatus.STARTED.value: "Optimization run in progress. Downloading dataset",
56
+ RunStatus.PENDING.value: "Optimization run queued and not yet started",
57
+ RunStatus.SUCCESS.value: "Optimization run completed successfully",
58
+ RunStatus.FAILURE.value: "Optimization run failed",
59
+ RunStatus.AWAITING_MANUAL_APPROVAL.value: "Awaiting manual approval",
60
+ RunStatus.RETRY.value: "Optimization run failed. Retrying",
61
+ RunStatus.REVOKED.value: "Optimization run was cancelled",
62
+ RunStatus.REJECTED.value: "Optimization run was rejected",
63
+ }
64
+ STATUS_TO_PROGRESS_MAP = {
65
+ RunStatus.STARTED.value: 0.0,
66
+ RunStatus.PENDING.value: 0.0,
67
+ RunStatus.SUCCESS.value: 100.0,
68
+ RunStatus.FAILURE.value: 100.0,
69
+ RunStatus.AWAITING_MANUAL_APPROVAL.value: 100.0,
70
+ RunStatus.RETRY.value: 0.0,
71
+ RunStatus.REVOKED.value: 100.0,
72
+ RunStatus.REJECTED.value: 0.0,
73
+ }
74
+
75
+
76
+ class DatasetOptimizationResults(BaseModel):
77
+ model_config = {"arbitrary_types_allowed": True}
78
+
79
+ suspects: pd.DataFrame
80
+ """
81
+ A pandas DataFrame containing the results of the optimization run
82
+ """
83
+ warnings_and_errors: pd.DataFrame
84
+ """
85
+ A pandas DataFrame containing the warnings and errors of the optimization run
86
+ """
87
+
88
+
89
+ CUSTOMER_INTERCHANGE_DTYPES: DtypeArg = {
90
+ "image_path": str,
91
+ "label_path": str,
92
+ "segments_mask_path": str,
93
+ "segment_id": np.int32,
94
+ "label": str,
95
+ "bbox_id": str,
96
+ "xmin": np.float32,
97
+ "ymin": np.float32,
98
+ "xmax": np.float32,
99
+ "ymax": np.float32,
100
+ "suspect_level": np.float32, # If exists, must be one of the values in the enum below
101
+ "suggested_label": str,
102
+ "suggested_label_conf": np.float32,
103
+ "status": str,
104
+ # ⬆️ If exists, must be one of the following:
105
+ # NO_LABELS/MISSING_IMAGE/INVALID_IMAGE/INVALID_BBOX/INVALID_BBOX_SIZE/INVALID_SEG/INVALID_SEG_SIZE
106
+ }
107
+
108
+
109
+ class Metadata(BaseModel, ABC):
110
+ type: DatasetMetadataType
111
+
112
+ @property
113
+ @abstractmethod
114
+ def metadata_url(self) -> HirundoUrl:
115
+ raise NotImplementedError()
116
+
117
+
118
+ class HirundoCSV(Metadata):
119
+ """
120
+ A dataset metadata file in the Hirundo CSV format
121
+ """
122
+
123
+ type: DatasetMetadataType = DatasetMetadataType.HIRUNDO_CSV
124
+ csv_url: HirundoUrl
125
+ """
126
+ The URL to access the dataset metadata CSV file.
127
+ e.g. `s3://my-bucket-name/my-folder/my-metadata.csv`, `gs://my-bucket-name/my-folder/my-metadata.csv`,
128
+ or `ssh://my-username@my-repo-name/my-folder/my-metadata.csv`
129
+ (or `file:///datasets/my-folder/my-metadata.csv` if using LOCAL storage type with on-premises installation)
130
+ """
131
+
132
+ @property
133
+ def metadata_url(self) -> HirundoUrl:
134
+ return self.csv_url
135
+
136
+
137
+ class COCO(Metadata):
138
+ """
139
+ A dataset metadata file in the COCO format
140
+ """
141
+
142
+ type: DatasetMetadataType = DatasetMetadataType.COCO
143
+ json_url: HirundoUrl
144
+ """
145
+ The URL to access the dataset metadata JSON file.
146
+ e.g. `s3://my-bucket-name/my-folder/my-metadata.json`, `gs://my-bucket-name/my-folder/my-metadata.json`,
147
+ or `ssh://my-username@my-repo-name/my-folder/my-metadata.json`
148
+ (or `file:///datasets/my-folder/my-metadata.json` if using LOCAL storage type with on-premises installation)
149
+ """
150
+
151
+ @property
152
+ def metadata_url(self) -> HirundoUrl:
153
+ return self.json_url
154
+
155
+
156
+ class YOLO(Metadata):
157
+ type: DatasetMetadataType = DatasetMetadataType.YOLO
158
+ data_yaml_url: typing.Optional[HirundoUrl] = None
159
+ labels_dir_url: HirundoUrl
160
+
161
+ @property
162
+ def metadata_url(self) -> HirundoUrl:
163
+ return self.labels_dir_url
164
+
165
+
166
+ LabelingInfo = typing.Union[HirundoCSV, COCO, YOLO]
167
+ """
168
+ The dataset labeling info. The dataset labeling info can be one of the following:
169
+ - `DatasetMetadataType.HirundoCSV`: Indicates that the dataset metadata file is a CSV file with the Hirundo format
170
+
171
+ Currently no other formats are supported. Future versions of `hirundo` may support additional formats.
172
+ """
173
+
174
+
175
+ class VisionRunArgs(BaseModel):
176
+ upsample: bool = False
177
+ """
178
+ Whether to upsample the dataset to attempt to balance the classes.
179
+ """
180
+ min_abs_bbox_size: int = 0
181
+ """
182
+ Minimum valid size (in pixels) of a bounding box to keep it in the dataset for optimization.
183
+ """
184
+ min_abs_bbox_area: int = 0
185
+ """
186
+ Minimum valid absolute area (in pixels²) of a bounding box to keep it in the dataset for optimization.
187
+ """
188
+ min_rel_bbox_size: float = 0.0
189
+ """
190
+ Minimum valid size (as a fraction of both image height and width) for a bounding box
191
+ to keep it in the dataset for optimization, relative to the corresponding dimension size,
192
+ i.e. if the bounding box is 10% of the image width and 5% of the image height, it will be kept if this value is 0.05, but not if the
193
+ value is 0.06 (since both width and height are checked).
194
+ """
195
+ min_rel_bbox_area: float = 0.0
196
+ """
197
+ Minimum valid relative area (as a fraction of the image area) of a bounding box to keep it in the dataset for optimization.
198
+ """
199
+
200
+
201
+ RunArgs = typing.Union[VisionRunArgs]
202
+
203
+
204
+ class AugmentationNames(str, Enum):
205
+ RandomHorizontalFlip = "RandomHorizontalFlip"
206
+ RandomVerticalFlip = "RandomVerticalFlip"
207
+ RandomRotation = "RandomRotation"
208
+ ColorJitter = "ColorJitter"
209
+ RandomAffine = "RandomAffine"
210
+ RandomPerspective = "RandomPerspective"
211
+
212
+
213
+ class Modality(str, Enum):
214
+ IMAGE = "Image"
215
+ RADAR = "Radar"
216
+ EKG = "EKG"
43
217
 
44
218
 
45
219
  class OptimizationDataset(BaseModel):
220
+ id: typing.Optional[int] = Field(default=None)
221
+ """
222
+ The ID of the dataset created on the server.
223
+ """
46
224
  name: str
47
225
  """
48
226
  The name of the dataset. Used to identify it amongst the list of datasets
49
227
  belonging to your organization in `hirundo`.
50
228
  """
51
- labelling_type: LabellingType
229
+ labeling_type: LabelingType
52
230
  """
53
- Indicates the labelling type of the dataset. The labelling type can be one of the following:
54
- - `LabellingType.SingleLabelClassification`: Indicates that the dataset is for classification tasks
55
- - `LabellingType.ObjectDetection`: Indicates that the dataset is for object detection tasks
231
+ Indicates the labeling type of the dataset. The labeling type can be one of the following:
232
+ - `LabelingType.SINGLE_LABEL_CLASSIFICATION`: Indicates that the dataset is for classification tasks
233
+ - `LabelingType.OBJECT_DETECTION`: Indicates that the dataset is for object detection tasks
234
+ - `LabelingType.SPEECH_TO_TEXT`: Indicates that the dataset is for speech-to-text tasks
56
235
  """
57
- dataset_storage: Union[StorageLink, None]
236
+ language: typing.Optional[str] = None
58
237
  """
59
- The storage link to the dataset. This can be a link to a file or a directory containing the dataset.
60
- If `None`, the `dataset_id` field must be set.
238
+ Language of the Speech-to-Text audio dataset. This is required for Speech-to-Text datasets.
61
239
  """
62
-
63
- classes: list[str]
240
+ storage_config_id: typing.Optional[int] = None
64
241
  """
65
- A full list of possible classes used in classification / object detection.
66
- It is currently required for clarity and performance.
242
+ The ID of the storage config used to store the dataset and metadata.
243
+ """
244
+ storage_config: typing.Optional[
245
+ typing.Union[StorageConfig, ResponseStorageConfig]
246
+ ] = None
67
247
  """
68
- dataset_metadata_path: str = "metadata.csv"
248
+ The `StorageConfig` instance to link to.
69
249
  """
70
- The path to the dataset metadata file within storage integration, e.g. S3 Bucket / GCP Bucket / Azure Blob storage / Git repo.
71
- Note: This path will be prefixed with the `StorageLink`'s `path`.
250
+ data_root_url: HirundoUrl
72
251
  """
73
- dataset_metadata_type: DatasetMetadataType = DatasetMetadataType.HirundoCSV
252
+ URL for data (e.g. images) within the `StorageConfig` instance,
253
+ e.g. `s3://my-bucket-name/my-images-folder`, `gs://my-bucket-name/my-images-folder`,
254
+ or `ssh://my-username@my-repo-name/my-images-folder`
255
+ (or `file:///datasets/my-images-folder` if using LOCAL storage type with on-premises installation)
256
+
257
+ Note: All CSV `image_path` entries in the metadata file should be relative to this folder.
74
258
  """
75
- The type of dataset metadata file. The dataset metadata file can be one of the following:
76
- - `DatasetMetadataType.HirundoCSV`: Indicates that the dataset metadata file is a CSV file with the Hirundo format
77
259
 
78
- Currently no other formats are supported. Future versions of `hirundo` may support additional formats.
260
+ classes: typing.Optional[list[str]] = None
261
+ """
262
+ A full list of possible classes used in classification / object detection.
263
+ It is currently required for clarity and performance.
79
264
  """
265
+ labeling_info: LabelingInfo
80
266
 
81
- storage_integration_id: Union[int, None] = Field(default=None, init=False)
267
+ augmentations: typing.Optional[list[AugmentationNames]] = None
82
268
  """
83
- The ID of the storage integration used to store the dataset and metadata.
269
+ Used to define which augmentations are apply to a vision dataset.
270
+ For audio datasets, this field is ignored.
271
+ If no value is provided, all augmentations are applied to vision datasets.
84
272
  """
85
- dataset_id: Union[int, None] = Field(default=None, init=False)
273
+ modality: Modality = Modality.IMAGE
86
274
  """
87
- The ID of the dataset created on the server.
275
+ Used to define the modality of the dataset.
276
+ Defaults to Image.
88
277
  """
89
- run_id: Union[str, None] = Field(default=None, init=False)
278
+
279
+ run_id: typing.Optional[str] = Field(default=None, init=False)
90
280
  """
91
281
  The ID of the Dataset Optimization run created on the server.
92
282
  """
93
283
 
284
+ status: typing.Optional[RunStatus] = None
285
+
94
286
  @model_validator(mode="after")
95
287
  def validate_dataset(self):
96
- if self.dataset_storage is None and self.storage_integration_id is None:
97
- raise ValueError("No dataset storage has been provided")
288
+ if self.storage_config is None and self.storage_config_id is None:
289
+ raise ValueError(
290
+ "No dataset storage has been provided. Provide one via `storage_config` or `storage_config_id`"
291
+ )
292
+ elif self.storage_config is not None and self.storage_config_id is not None:
293
+ raise ValueError(
294
+ "Both `storage_config` and `storage_config_id` have been provided. Pick one."
295
+ )
296
+ if self.labeling_type == LabelingType.SPEECH_TO_TEXT and self.language is None:
297
+ raise ValueError("Language is required for Speech-to-Text datasets.")
298
+ elif (
299
+ self.labeling_type != LabelingType.SPEECH_TO_TEXT
300
+ and self.language is not None
301
+ ):
302
+ raise ValueError("Language is only allowed for Speech-to-Text datasets.")
303
+ if (
304
+ self.labeling_info.type == DatasetMetadataType.YOLO
305
+ and isinstance(self.labeling_info, YOLO)
306
+ and (
307
+ self.labeling_info.data_yaml_url is not None
308
+ and self.classes is not None
309
+ )
310
+ ):
311
+ raise ValueError(
312
+ "Only one of `classes` or `labeling_info.data_yaml_url` should be provided for YOLO datasets"
313
+ )
98
314
  return self
99
315
 
100
316
  @staticmethod
101
- def list(organization_id: Union[int, None] = None) -> list[dict]:
317
+ def get_by_id(dataset_id: int) -> "OptimizationDataset":
318
+ """
319
+ Get a `OptimizationDataset` instance from the server by its ID
320
+
321
+ Args:
322
+ dataset_id: The ID of the `OptimizationDataset` instance to get
323
+ """
324
+ response = requests.get(
325
+ f"{API_HOST}/dataset-optimization/dataset/{dataset_id}",
326
+ headers=get_auth_headers(),
327
+ timeout=READ_TIMEOUT,
328
+ )
329
+ raise_for_status_with_reason(response)
330
+ dataset = response.json()
331
+ return OptimizationDataset(**dataset)
332
+
333
+ @staticmethod
334
+ def get_by_name(name: str) -> "OptimizationDataset":
335
+ """
336
+ Get a `OptimizationDataset` instance from the server by its name
337
+
338
+ Args:
339
+ name: The name of the `OptimizationDataset` instance to get
340
+ """
341
+ response = requests.get(
342
+ f"{API_HOST}/dataset-optimization/dataset/by-name/{name}",
343
+ headers=get_auth_headers(),
344
+ timeout=READ_TIMEOUT,
345
+ )
346
+ raise_for_status_with_reason(response)
347
+ dataset = response.json()
348
+ return OptimizationDataset(**dataset)
349
+
350
+ @staticmethod
351
+ def list_datasets(
352
+ organization_id: typing.Optional[int] = None,
353
+ ) -> list["DataOptimizationDatasetOut"]:
354
+ """
355
+ Lists all the optimization datasets created by user's default organization
356
+ or the `organization_id` passed
357
+
358
+ Args:
359
+ organization_id: The ID of the organization to list the datasets for.
360
+ """
361
+ response = requests.get(
362
+ f"{API_HOST}/dataset-optimization/dataset/",
363
+ params={"dataset_organization_id": organization_id},
364
+ headers=get_auth_headers(),
365
+ timeout=READ_TIMEOUT,
366
+ )
367
+ raise_for_status_with_reason(response)
368
+ datasets = response.json()
369
+ return [
370
+ DataOptimizationDatasetOut(
371
+ **ds,
372
+ )
373
+ for ds in datasets
374
+ ]
375
+
376
+ @staticmethod
377
+ def list_runs(
378
+ organization_id: typing.Optional[int] = None,
379
+ ) -> list["DataOptimizationRunOut"]:
102
380
  """
103
381
  Lists all the `OptimizationDataset` instances created by user's default organization
104
382
  or the `organization_id` passed
@@ -108,13 +386,19 @@ class OptimizationDataset(BaseModel):
108
386
  organization_id: The ID of the organization to list the datasets for.
109
387
  """
110
388
  response = requests.get(
111
- f"{API_HOST}/dataset-optimization/dataset/",
389
+ f"{API_HOST}/dataset-optimization/run/list",
112
390
  params={"dataset_organization_id": organization_id},
113
391
  headers=get_auth_headers(),
114
392
  timeout=READ_TIMEOUT,
115
393
  )
116
- response.raise_for_status()
117
- return response.json()
394
+ raise_for_status_with_reason(response)
395
+ runs = response.json()
396
+ return [
397
+ DataOptimizationRunOut(
398
+ **run,
399
+ )
400
+ for run in runs
401
+ ]
118
402
 
119
403
  @staticmethod
120
404
  def delete_by_id(dataset_id: int) -> None:
@@ -129,54 +413,74 @@ class OptimizationDataset(BaseModel):
129
413
  headers=get_auth_headers(),
130
414
  timeout=MODIFY_TIMEOUT,
131
415
  )
132
- response.raise_for_status()
416
+ raise_for_status_with_reason(response)
133
417
  logger.info("Deleted dataset with ID: %s", dataset_id)
134
418
 
135
- def delete(self, storage_integration=True) -> None:
419
+ def delete(self, storage_config=True) -> None:
136
420
  """
137
421
  Deletes the active `OptimizationDataset` instance from the server.
138
422
  It can only be used on a `OptimizationDataset` instance that has been created.
139
423
 
140
424
  Args:
141
- storage_integration: If True, the `OptimizationDataset`'s `StorageIntegration` will also be deleted
425
+ storage_config: If True, the `OptimizationDataset`'s `StorageConfig` will also be deleted
142
426
 
143
- Note: If `storage_integration` is not set to `False` then the `storage_integration_id` must be set
144
- This can either be set manually or by creating the `StorageIntegration` instance via the `OptimizationDataset`'s
427
+ Note: If `storage_config` is not set to `False` then the `storage_config_id` must be set
428
+ This can either be set manually or by creating the `StorageConfig` instance via the `OptimizationDataset`'s
145
429
  `create` method
146
430
  """
147
- if storage_integration:
148
- if not self.storage_integration_id:
149
- raise ValueError("No storage integration has been created")
150
- StorageIntegration.delete_by_id(self.storage_integration_id)
151
- if not self.dataset_id:
431
+ if storage_config:
432
+ if not self.storage_config_id:
433
+ raise ValueError("No storage config has been created")
434
+ StorageConfig.delete_by_id(self.storage_config_id)
435
+ if not self.id:
152
436
  raise ValueError("No dataset has been created")
153
- self.delete_by_id(self.dataset_id)
437
+ self.delete_by_id(self.id)
154
438
 
155
- def create(self) -> int:
439
+ def create(
440
+ self,
441
+ organization_id: typing.Optional[int] = None,
442
+ replace_if_exists: bool = False,
443
+ ) -> int:
156
444
  """
157
445
  Create a `OptimizationDataset` instance on the server.
158
- If `storage_integration_id` is not set, it will be created.
446
+ If the `storage_config_id` field is not set, the storage config will also be created and the field will be set.
447
+
448
+ Args:
449
+ organization_id: The ID of the organization to create the dataset for.
450
+ replace_if_exists: If True, the dataset will be replaced if it already exists
451
+ (this is determined by a dataset of the same name in the same organization).
452
+
453
+ Returns:
454
+ The ID of the created `OptimizationDataset` instance
159
455
  """
160
- if not self.dataset_storage:
456
+ if self.storage_config is None and self.storage_config_id is None:
161
457
  raise ValueError("No dataset storage has been provided")
162
- if (
163
- self.dataset_storage
164
- and self.dataset_storage.storage_integration
165
- and not self.storage_integration_id
458
+ elif self.storage_config and self.storage_config_id is None:
459
+ if isinstance(self.storage_config, ResponseStorageConfig):
460
+ self.storage_config_id = self.storage_config.id
461
+ elif isinstance(self.storage_config, StorageConfig):
462
+ self.storage_config_id = self.storage_config.create(
463
+ replace_if_exists=replace_if_exists,
464
+ )
465
+ elif (
466
+ self.storage_config is not None
467
+ and self.storage_config_id is not None
468
+ and (
469
+ not isinstance(self.storage_config, ResponseStorageConfig)
470
+ or self.storage_config.id != self.storage_config_id
471
+ )
166
472
  ):
167
- self.storage_integration_id = (
168
- self.dataset_storage.storage_integration.create()
473
+ raise ValueError(
474
+ "Both `storage_config` and `storage_config_id` have been provided. Storage config IDs do not match."
169
475
  )
170
- model_dict = self.model_dump()
476
+ model_dict = self.model_dump(mode="json")
171
477
  # ⬆️ Get dict of model fields from Pydantic model instance
172
478
  dataset_response = requests.post(
173
479
  f"{API_HOST}/dataset-optimization/dataset/",
174
480
  json={
175
- "dataset_storage": {
176
- "storage_integration_id": self.storage_integration_id,
177
- "path": self.dataset_storage.path,
178
- },
179
- **{k: model_dict[k] for k in model_dict.keys() - {"dataset_storage"}},
481
+ **{k: model_dict[k] for k in model_dict.keys() - {"storage_config"}},
482
+ "organization_id": organization_id,
483
+ "replace_if_exists": replace_if_exists,
180
484
  },
181
485
  headers={
182
486
  **json_headers,
@@ -184,15 +488,19 @@ class OptimizationDataset(BaseModel):
184
488
  },
185
489
  timeout=MODIFY_TIMEOUT,
186
490
  )
187
- dataset_response.raise_for_status()
188
- self.dataset_id = dataset_response.json()["id"]
189
- if not self.dataset_id:
190
- raise HirundoError("Failed to create the dataset")
191
- logger.info("Created dataset with ID: %s", self.dataset_id)
192
- return self.dataset_id
491
+ raise_for_status_with_reason(dataset_response)
492
+ self.id = dataset_response.json()["id"]
493
+ if not self.id:
494
+ raise HirundoError("An error ocurred while trying to create the dataset")
495
+ logger.info("Created dataset with ID: %s", self.id)
496
+ return self.id
193
497
 
194
498
  @staticmethod
195
- def launch_optimization_run(dataset_id: int) -> str:
499
+ def launch_optimization_run(
500
+ dataset_id: int,
501
+ organization_id: typing.Optional[int] = None,
502
+ run_args: typing.Optional[RunArgs] = None,
503
+ ) -> str:
196
504
  """
197
505
  Run the dataset optimization process on the server using the dataset with the given ID
198
506
  i.e. `dataset_id`.
@@ -203,26 +511,62 @@ class OptimizationDataset(BaseModel):
203
511
  Returns:
204
512
  ID of the run (`run_id`).
205
513
  """
514
+ run_info = {}
515
+ if organization_id:
516
+ run_info["organization_id"] = organization_id
517
+ if run_args:
518
+ run_info["run_args"] = run_args.model_dump(mode="json")
206
519
  run_response = requests.post(
207
520
  f"{API_HOST}/dataset-optimization/run/{dataset_id}",
521
+ json=run_info if len(run_info) > 0 else None,
208
522
  headers=get_auth_headers(),
209
523
  timeout=MODIFY_TIMEOUT,
210
524
  )
211
- run_response.raise_for_status()
525
+ raise_for_status_with_reason(run_response)
212
526
  return run_response.json()["run_id"]
213
527
 
214
- def run_optimization(self) -> str:
528
+ def _validate_run_args(self, run_args: RunArgs) -> None:
529
+ if self.labeling_type == LabelingType.SPEECH_TO_TEXT:
530
+ raise Exception("Speech to text cannot have `run_args` set")
531
+ if self.labeling_type != LabelingType.OBJECT_DETECTION and any(
532
+ (
533
+ run_args.min_abs_bbox_size != 0,
534
+ run_args.min_abs_bbox_area != 0,
535
+ run_args.min_rel_bbox_size != 0,
536
+ run_args.min_rel_bbox_area != 0,
537
+ )
538
+ ):
539
+ raise Exception(
540
+ "Cannot set `min_abs_bbox_size`, `min_abs_bbox_area`, "
541
+ + "`min_rel_bbox_size`, or `min_rel_bbox_area` for "
542
+ + f"labeling type {self.labeling_type}"
543
+ )
544
+
545
+ def run_optimization(
546
+ self,
547
+ organization_id: typing.Optional[int] = None,
548
+ replace_dataset_if_exists: bool = False,
549
+ run_args: typing.Optional[RunArgs] = None,
550
+ ) -> str:
215
551
  """
216
552
  If the dataset was not created on the server yet, it is created.
217
553
  Run the dataset optimization process on the server using the active `OptimizationDataset` instance
218
554
 
555
+ Args:
556
+ organization_id: The ID of the organization to run the optimization for.
557
+ replace_dataset_if_exists: If True, the dataset will be replaced if it already exists
558
+ (this is determined by a dataset of the same name in the same organization).
559
+ run_args: The run arguments to use for the optimization run
560
+
219
561
  Returns:
220
562
  An ID of the run (`run_id`) and stores that `run_id` on the instance
221
563
  """
222
564
  try:
223
- if not self.dataset_id:
224
- self.dataset_id = self.create()
225
- run_id = self.launch_optimization_run(self.dataset_id)
565
+ if not self.id:
566
+ self.id = self.create(replace_if_exists=replace_dataset_if_exists)
567
+ if run_args is not None:
568
+ self._validate_run_args(run_args)
569
+ run_id = self.launch_optimization_run(self.id, organization_id, run_args)
226
570
  self.run_id = run_id
227
571
  logger.info("Started the run with ID: %s", run_id)
228
572
  return run_id
@@ -238,17 +582,17 @@ class OptimizationDataset(BaseModel):
238
582
  except Exception:
239
583
  content = error.response.text
240
584
  raise HirundoError(
241
- f"Failed to start the run. Status code: {error.response.status_code} Content: {content}"
585
+ f"Unable to start the run. Status code: {error.response.status_code} Content: {content}"
242
586
  ) from error
243
587
  except Exception as error:
244
- raise HirundoError(f"Failed to start the run: {error}") from error
588
+ raise HirundoError(f"Unable to start the run: {error}") from error
245
589
 
246
590
  def clean_ids(self):
247
591
  """
248
- Reset `dataset_id`, `storage_integration_id`, and `run_id` values on the instance to default value of `None`
592
+ Reset `dataset_id`, `storage_config_id`, and `run_id` values on the instance to default value of `None`
249
593
  """
250
- self.storage_integration_id = None
251
- self.dataset_id = None
594
+ self.storage_config_id = None
595
+ self.id = None
252
596
  self.run_id = None
253
597
 
254
598
  @staticmethod
@@ -274,10 +618,19 @@ class OptimizationDataset(BaseModel):
274
618
  return df
275
619
 
276
620
  @staticmethod
277
- def _read_csv_to_df(data: dict):
621
+ def _read_csvs_to_df(data: dict):
278
622
  if data["state"] == RunStatus.SUCCESS.value:
279
- data["result"] = OptimizationDataset._clean_df_index(
280
- pd.read_csv(StringIO(data["result"]))
623
+ data["result"]["suspects"] = OptimizationDataset._clean_df_index(
624
+ pd.read_csv(
625
+ StringIO(data["result"]["suspects"]),
626
+ dtype=CUSTOMER_INTERCHANGE_DTYPES,
627
+ )
628
+ )
629
+ data["result"]["warnings_and_errors"] = OptimizationDataset._clean_df_index(
630
+ pd.read_csv(
631
+ StringIO(data["result"]["warnings_and_errors"]),
632
+ dtype=CUSTOMER_INTERCHANGE_DTYPES,
633
+ )
281
634
  )
282
635
  else:
283
636
  pass
@@ -306,8 +659,16 @@ class OptimizationDataset(BaseModel):
306
659
  last_event = json.loads(sse.data)
307
660
  if not last_event:
308
661
  continue
309
- data = last_event["data"]
310
- OptimizationDataset._read_csv_to_df(data)
662
+ if "data" in last_event:
663
+ data = last_event["data"]
664
+ else:
665
+ if "detail" in last_event:
666
+ raise HirundoError(last_event["detail"])
667
+ elif "reason" in last_event:
668
+ raise HirundoError(last_event["reason"])
669
+ else:
670
+ raise HirundoError("Unknown error")
671
+ OptimizationDataset._read_csvs_to_df(data)
311
672
  yield data
312
673
  if not last_event or last_event["data"]["state"] == RunStatus.PENDING.value:
313
674
  OptimizationDataset._check_run_by_id(run_id, retry + 1)
@@ -316,27 +677,24 @@ class OptimizationDataset(BaseModel):
316
677
  @overload
317
678
  def check_run_by_id(
318
679
  run_id: str, stop_on_manual_approval: typing.Literal[True]
319
- ) -> typing.Optional[pd.DataFrame]:
320
- ...
680
+ ) -> typing.Optional[DatasetOptimizationResults]: ...
321
681
 
322
682
  @staticmethod
323
683
  @overload
324
684
  def check_run_by_id(
325
685
  run_id: str, stop_on_manual_approval: typing.Literal[False] = False
326
- ) -> pd.DataFrame:
327
- ...
686
+ ) -> DatasetOptimizationResults: ...
328
687
 
329
688
  @staticmethod
330
689
  @overload
331
690
  def check_run_by_id(
332
691
  run_id: str, stop_on_manual_approval: bool
333
- ) -> typing.Optional[pd.DataFrame]:
334
- ...
692
+ ) -> typing.Optional[DatasetOptimizationResults]: ...
335
693
 
336
694
  @staticmethod
337
695
  def check_run_by_id(
338
696
  run_id: str, stop_on_manual_approval: bool = False
339
- ) -> typing.Optional[pd.DataFrame]:
697
+ ) -> typing.Optional[DatasetOptimizationResults]:
340
698
  """
341
699
  Check the status of a run given its ID
342
700
 
@@ -345,7 +703,7 @@ class OptimizationDataset(BaseModel):
345
703
  stop_on_manual_approval: If True, the function will return `None` if the run is awaiting manual approval
346
704
 
347
705
  Returns:
348
- A pandas DataFrame with the results of the optimization run
706
+ A DatasetOptimizationResults object with the results of the optimization run
349
707
 
350
708
  Raises:
351
709
  HirundoError: If the maximum number of retries is reached or if the run fails
@@ -354,22 +712,33 @@ class OptimizationDataset(BaseModel):
354
712
  with logging_redirect_tqdm():
355
713
  t = tqdm(total=100.0)
356
714
  for iteration in OptimizationDataset._check_run_by_id(run_id):
357
- if iteration["state"] == RunStatus.SUCCESS.value:
358
- t.set_description("Optimization run completed successfully")
359
- t.n = 100.0
360
- t.refresh()
361
- t.close()
362
- return iteration["result"]
363
- elif iteration["state"] == RunStatus.PENDING.value:
364
- t.set_description("Optimization run queued and not yet started")
365
- t.n = 0.0
366
- t.refresh()
367
- elif iteration["state"] == RunStatus.STARTED.value:
368
- t.set_description(
369
- "Optimization run in progress. Downloading dataset"
370
- )
371
- t.n = 0.0
715
+ if iteration["state"] in STATUS_TO_PROGRESS_MAP:
716
+ t.set_description(STATUS_TO_TEXT_MAP[iteration["state"]])
717
+ t.n = STATUS_TO_PROGRESS_MAP[iteration["state"]]
718
+ logger.debug("Setting progress to %s", t.n)
372
719
  t.refresh()
720
+ if iteration["state"] in [
721
+ RunStatus.FAILURE.value,
722
+ RunStatus.REJECTED.value,
723
+ RunStatus.REVOKED.value,
724
+ ]:
725
+ raise HirundoError(
726
+ f"Optimization run failed with error: {iteration['result']}"
727
+ )
728
+ elif iteration["state"] == RunStatus.SUCCESS.value:
729
+ t.close()
730
+ return DatasetOptimizationResults(
731
+ suspects=iteration["result"]["suspects"],
732
+ warnings_and_errors=iteration["result"][
733
+ "warnings_and_errors"
734
+ ],
735
+ )
736
+ elif (
737
+ iteration["state"] == RunStatus.AWAITING_MANUAL_APPROVAL.value
738
+ and stop_on_manual_approval
739
+ ):
740
+ t.close()
741
+ return None
373
742
  elif iteration["state"] is None:
374
743
  if (
375
744
  iteration["result"]
@@ -377,47 +746,42 @@ class OptimizationDataset(BaseModel):
377
746
  and iteration["result"]["result"]
378
747
  and isinstance(iteration["result"]["result"], str)
379
748
  ):
380
- current_progress_percentage = float(
381
- iteration["result"]["result"].removesuffix("% done")
382
- )
749
+ result_info = iteration["result"]["result"].split(":")
750
+ if len(result_info) > 1:
751
+ stage = result_info[0]
752
+ current_progress_percentage = float(
753
+ result_info[1].removeprefix(" ").removesuffix("% done")
754
+ )
755
+ elif len(result_info) == 1:
756
+ stage = result_info[0]
757
+ current_progress_percentage = t.n # Keep the same progress
758
+ else:
759
+ stage = "Unknown progress state"
760
+ current_progress_percentage = t.n # Keep the same progress
383
761
  desc = (
384
762
  "Optimization run completed. Uploading results"
385
763
  if current_progress_percentage == 100.0
386
- else "Optimization run in progress"
764
+ else stage
387
765
  )
388
766
  t.set_description(desc)
389
767
  t.n = current_progress_percentage
768
+ logger.debug("Setting progress to %s", t.n)
390
769
  t.refresh()
391
- elif iteration["state"] == RunStatus.AWAITING_MANUAL_APPROVAL.value:
392
- t.set_description("Awaiting manual approval")
393
- t.n = 100.0
394
- t.refresh()
395
- if stop_on_manual_approval:
396
- t.close()
397
- return None
398
- elif iteration["state"] == RunStatus.FAILURE.value:
399
- t.set_description("Optimization run failed")
400
- t.close()
401
- raise HirundoError(
402
- f"Optimization run failed with error: {iteration['result']}"
403
- )
404
770
  raise HirundoError("Optimization run failed with an unknown error")
405
771
 
406
772
  @overload
407
773
  def check_run(
408
774
  self, stop_on_manual_approval: typing.Literal[True]
409
- ) -> typing.Union[pd.DataFrame, None]:
410
- ...
775
+ ) -> typing.Optional[DatasetOptimizationResults]: ...
411
776
 
412
777
  @overload
413
778
  def check_run(
414
779
  self, stop_on_manual_approval: typing.Literal[False] = False
415
- ) -> pd.DataFrame:
416
- ...
780
+ ) -> DatasetOptimizationResults: ...
417
781
 
418
782
  def check_run(
419
783
  self, stop_on_manual_approval: bool = False
420
- ) -> typing.Union[pd.DataFrame, None]:
784
+ ) -> typing.Optional[DatasetOptimizationResults]:
421
785
  """
422
786
  Check the status of the current active instance's run.
423
787
 
@@ -511,7 +875,7 @@ class OptimizationDataset(BaseModel):
511
875
  headers=get_auth_headers(),
512
876
  timeout=MODIFY_TIMEOUT,
513
877
  )
514
- response.raise_for_status()
878
+ raise_for_status_with_reason(response)
515
879
 
516
880
  def cancel(self) -> None:
517
881
  """
@@ -520,3 +884,31 @@ class OptimizationDataset(BaseModel):
520
884
  if not self.run_id:
521
885
  raise ValueError("No run has been started")
522
886
  self.cancel_by_id(self.run_id)
887
+
888
+
889
+ class DataOptimizationDatasetOut(BaseModel):
890
+ id: int
891
+
892
+ name: str
893
+ labeling_type: LabelingType
894
+
895
+ storage_config: ResponseStorageConfig
896
+
897
+ data_root_url: HirundoUrl
898
+
899
+ classes: typing.Optional[list[str]] = None
900
+ labeling_info: LabelingInfo
901
+
902
+ organization_id: typing.Optional[int]
903
+ creator_id: typing.Optional[int]
904
+ created_at: datetime.datetime
905
+ updated_at: datetime.datetime
906
+
907
+
908
+ class DataOptimizationRunOut(BaseModel):
909
+ id: int
910
+ name: str
911
+ run_id: str
912
+ status: RunStatus
913
+ approved: bool
914
+ created_at: datetime.datetime