hirundo 0.1.8__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,7 @@
1
+ import datetime
1
2
  import json
2
3
  import typing
4
+ from abc import ABC, abstractmethod
3
5
  from collections.abc import AsyncGenerator, Generator
4
6
  from enum import Enum
5
7
  from io import StringIO
@@ -14,14 +16,15 @@ from pydantic import BaseModel, Field, model_validator
14
16
  from tqdm import tqdm
15
17
  from tqdm.contrib.logging import logging_redirect_tqdm
16
18
 
19
+ from hirundo._constraints import HirundoUrl
17
20
  from hirundo._env import API_HOST
18
21
  from hirundo._headers import get_auth_headers, json_headers
19
22
  from hirundo._http import raise_for_status_with_reason
20
23
  from hirundo._iter_sse_retrying import aiter_sse_retrying, iter_sse_retrying
21
24
  from hirundo._timeouts import MODIFY_TIMEOUT, READ_TIMEOUT
22
- from hirundo.enum import DatasetMetadataType, LabellingType
25
+ from hirundo.enum import DatasetMetadataType, LabelingType
23
26
  from hirundo.logger import get_logger
24
- from hirundo.storage import StorageIntegration, StorageLink
27
+ from hirundo.storage import ResponseStorageConfig, StorageConfig
25
28
 
26
29
  logger = get_logger(__name__)
27
30
 
@@ -38,12 +41,14 @@ MAX_RETRIES = 200 # Max 200 retries for HTTP SSE connection
38
41
 
39
42
 
40
43
  class RunStatus(Enum):
41
- STARTED = "STARTED"
42
44
  PENDING = "PENDING"
45
+ STARTED = "STARTED"
43
46
  SUCCESS = "SUCCESS"
44
47
  FAILURE = "FAILURE"
45
48
  AWAITING_MANUAL_APPROVAL = "AWAITING MANUAL APPROVAL"
46
- RETRYING = "RETRYING"
49
+ REVOKED = "REVOKED"
50
+ REJECTED = "REJECTED"
51
+ RETRY = "RETRY"
47
52
 
48
53
 
49
54
  STATUS_TO_TEXT_MAP = {
@@ -52,7 +57,9 @@ STATUS_TO_TEXT_MAP = {
52
57
  RunStatus.SUCCESS.value: "Optimization run completed successfully",
53
58
  RunStatus.FAILURE.value: "Optimization run failed",
54
59
  RunStatus.AWAITING_MANUAL_APPROVAL.value: "Awaiting manual approval",
55
- RunStatus.RETRYING.value: "Optimization run failed. Retrying",
60
+ RunStatus.RETRY.value: "Optimization run failed. Retrying",
61
+ RunStatus.REVOKED.value: "Optimization run was cancelled",
62
+ RunStatus.REJECTED.value: "Optimization run was rejected",
56
63
  }
57
64
  STATUS_TO_PROGRESS_MAP = {
58
65
  RunStatus.STARTED.value: 0.0,
@@ -60,7 +67,9 @@ STATUS_TO_PROGRESS_MAP = {
60
67
  RunStatus.SUCCESS.value: 100.0,
61
68
  RunStatus.FAILURE.value: 100.0,
62
69
  RunStatus.AWAITING_MANUAL_APPROVAL.value: 100.0,
63
- RunStatus.RETRYING.value: 0.0,
70
+ RunStatus.RETRY.value: 0.0,
71
+ RunStatus.REVOKED.value: 100.0,
72
+ RunStatus.REJECTED.value: 0.0,
64
73
  }
65
74
 
66
75
 
@@ -84,10 +93,10 @@ CUSTOMER_INTERCHANGE_DTYPES: DtypeArg = {
84
93
  "segment_id": np.int32,
85
94
  "label": str,
86
95
  "bbox_id": str,
87
- "xmin": np.int32,
88
- "ymin": np.int32,
89
- "xmax": np.int32,
90
- "ymax": np.int32,
96
+ "xmin": np.float32,
97
+ "ymin": np.float32,
98
+ "xmax": np.float32,
99
+ "ymax": np.float32,
91
100
  "suspect_level": np.float32, # If exists, must be one of the values in the enum below
92
101
  "suggested_label": str,
93
102
  "suggested_label_conf": np.float32,
@@ -97,63 +106,277 @@ CUSTOMER_INTERCHANGE_DTYPES: DtypeArg = {
97
106
  }
98
107
 
99
108
 
109
+ class Metadata(BaseModel, ABC):
110
+ type: DatasetMetadataType
111
+
112
+ @property
113
+ @abstractmethod
114
+ def metadata_url(self) -> HirundoUrl:
115
+ raise NotImplementedError()
116
+
117
+
118
+ class HirundoCSV(Metadata):
119
+ """
120
+ A dataset metadata file in the Hirundo CSV format
121
+ """
122
+
123
+ type: DatasetMetadataType = DatasetMetadataType.HIRUNDO_CSV
124
+ csv_url: HirundoUrl
125
+ """
126
+ The URL to access the dataset metadata CSV file.
127
+ e.g. `s3://my-bucket-name/my-folder/my-metadata.csv`, `gs://my-bucket-name/my-folder/my-metadata.csv`,
128
+ or `ssh://my-username@my-repo-name/my-folder/my-metadata.csv`
129
+ (or `file:///datasets/my-folder/my-metadata.csv` if using LOCAL storage type with on-premises installation)
130
+ """
131
+
132
+ @property
133
+ def metadata_url(self) -> HirundoUrl:
134
+ return self.csv_url
135
+
136
+
137
+ class COCO(Metadata):
138
+ """
139
+ A dataset metadata file in the COCO format
140
+ """
141
+
142
+ type: DatasetMetadataType = DatasetMetadataType.COCO
143
+ json_url: HirundoUrl
144
+ """
145
+ The URL to access the dataset metadata JSON file.
146
+ e.g. `s3://my-bucket-name/my-folder/my-metadata.json`, `gs://my-bucket-name/my-folder/my-metadata.json`,
147
+ or `ssh://my-username@my-repo-name/my-folder/my-metadata.json`
148
+ (or `file:///datasets/my-folder/my-metadata.json` if using LOCAL storage type with on-premises installation)
149
+ """
150
+
151
+ @property
152
+ def metadata_url(self) -> HirundoUrl:
153
+ return self.json_url
154
+
155
+
156
+ class YOLO(Metadata):
157
+ type: DatasetMetadataType = DatasetMetadataType.YOLO
158
+ data_yaml_url: typing.Optional[HirundoUrl] = None
159
+ labels_dir_url: HirundoUrl
160
+
161
+ @property
162
+ def metadata_url(self) -> HirundoUrl:
163
+ return self.labels_dir_url
164
+
165
+
166
+ LabelingInfo = typing.Union[HirundoCSV, COCO, YOLO]
167
+ """
168
+ The dataset labeling info. The dataset labeling info can be one of the following:
169
+ - `DatasetMetadataType.HirundoCSV`: Indicates that the dataset metadata file is a CSV file with the Hirundo format
170
+
171
+ Currently no other formats are supported. Future versions of `hirundo` may support additional formats.
172
+ """
173
+
174
+
175
+ class VisionRunArgs(BaseModel):
176
+ upsample: bool = False
177
+ """
178
+ Whether to upsample the dataset to attempt to balance the classes.
179
+ """
180
+ min_abs_bbox_size: int = 0
181
+ """
182
+ Minimum valid size (in pixels) of a bounding box to keep it in the dataset for optimization.
183
+ """
184
+ min_abs_bbox_area: int = 0
185
+ """
186
+ Minimum valid absolute area (in pixels²) of a bounding box to keep it in the dataset for optimization.
187
+ """
188
+ min_rel_bbox_size: float = 0.0
189
+ """
190
+ Minimum valid size (as a fraction of both image height and width) for a bounding box
191
+ to keep it in the dataset for optimization, relative to the corresponding dimension size,
192
+ i.e. if the bounding box is 10% of the image width and 5% of the image height, it will be kept if this value is 0.05, but not if the
193
+ value is 0.06 (since both width and height are checked).
194
+ """
195
+ min_rel_bbox_area: float = 0.0
196
+ """
197
+ Minimum valid relative area (as a fraction of the image area) of a bounding box to keep it in the dataset for optimization.
198
+ """
199
+
200
+
201
+ RunArgs = typing.Union[VisionRunArgs]
202
+
203
+
204
+ class AugmentationNames(str, Enum):
205
+ RandomHorizontalFlip = "RandomHorizontalFlip"
206
+ RandomVerticalFlip = "RandomVerticalFlip"
207
+ RandomRotation = "RandomRotation"
208
+ ColorJitter = "ColorJitter"
209
+ RandomAffine = "RandomAffine"
210
+ RandomPerspective = "RandomPerspective"
211
+
212
+
213
+ class Modality(str, Enum):
214
+ IMAGE = "Image"
215
+ RADAR = "Radar"
216
+ EKG = "EKG"
217
+
218
+
100
219
  class OptimizationDataset(BaseModel):
220
+ id: typing.Optional[int] = Field(default=None)
221
+ """
222
+ The ID of the dataset created on the server.
223
+ """
101
224
  name: str
102
225
  """
103
226
  The name of the dataset. Used to identify it amongst the list of datasets
104
227
  belonging to your organization in `hirundo`.
105
228
  """
106
- labelling_type: LabellingType
229
+ labeling_type: LabelingType
107
230
  """
108
- Indicates the labelling type of the dataset. The labelling type can be one of the following:
109
- - `LabellingType.SingleLabelClassification`: Indicates that the dataset is for classification tasks
110
- - `LabellingType.ObjectDetection`: Indicates that the dataset is for object detection tasks
231
+ Indicates the labeling type of the dataset. The labeling type can be one of the following:
232
+ - `LabelingType.SINGLE_LABEL_CLASSIFICATION`: Indicates that the dataset is for classification tasks
233
+ - `LabelingType.OBJECT_DETECTION`: Indicates that the dataset is for object detection tasks
234
+ - `LabelingType.SPEECH_TO_TEXT`: Indicates that the dataset is for speech-to-text tasks
111
235
  """
112
- dataset_storage: typing.Optional[StorageLink]
236
+ language: typing.Optional[str] = None
113
237
  """
114
- The storage link to the dataset. This can be a link to a file or a directory containing the dataset.
115
- If `None`, the `dataset_id` field must be set.
238
+ Language of the Speech-to-Text audio dataset. This is required for Speech-to-Text datasets.
116
239
  """
117
-
118
- classes: typing.Optional[list[str]] = None
240
+ storage_config_id: typing.Optional[int] = None
119
241
  """
120
- A full list of possible classes used in classification / object detection.
121
- It is currently required for clarity and performance.
242
+ The ID of the storage config used to store the dataset and metadata.
122
243
  """
123
- dataset_metadata_path: str = "metadata.csv"
244
+ storage_config: typing.Optional[
245
+ typing.Union[StorageConfig, ResponseStorageConfig]
246
+ ] = None
124
247
  """
125
- The path to the dataset metadata file within storage integration, e.g. S3 Bucket / GCP Bucket / Azure Blob storage / Git repo.
126
- Note: This path will be prefixed with the `StorageLink`'s `path`.
248
+ The `StorageConfig` instance to link to.
127
249
  """
128
- dataset_metadata_type: DatasetMetadataType = DatasetMetadataType.HirundoCSV
250
+ data_root_url: HirundoUrl
129
251
  """
130
- The type of dataset metadata file. The dataset metadata file can be one of the following:
131
- - `DatasetMetadataType.HirundoCSV`: Indicates that the dataset metadata file is a CSV file with the Hirundo format
252
+ URL for data (e.g. images) within the `StorageConfig` instance,
253
+ e.g. `s3://my-bucket-name/my-images-folder`, `gs://my-bucket-name/my-images-folder`,
254
+ or `ssh://my-username@my-repo-name/my-images-folder`
255
+ (or `file:///datasets/my-images-folder` if using LOCAL storage type with on-premises installation)
132
256
 
133
- Currently no other formats are supported. Future versions of `hirundo` may support additional formats.
257
+ Note: All CSV `image_path` entries in the metadata file should be relative to this folder.
134
258
  """
135
259
 
136
- storage_integration_id: typing.Optional[int] = Field(default=None, init=False)
260
+ classes: typing.Optional[list[str]] = None
137
261
  """
138
- The ID of the storage integration used to store the dataset and metadata.
262
+ A full list of possible classes used in classification / object detection.
263
+ It is currently required for clarity and performance.
139
264
  """
140
- dataset_id: typing.Optional[int] = Field(default=None, init=False)
265
+ labeling_info: LabelingInfo
266
+
267
+ augmentations: typing.Optional[list[AugmentationNames]] = None
141
268
  """
142
- The ID of the dataset created on the server.
269
+ Used to define which augmentations are apply to a vision dataset.
270
+ For audio datasets, this field is ignored.
271
+ If no value is provided, all augmentations are applied to vision datasets.
272
+ """
273
+ modality: Modality = Modality.IMAGE
143
274
  """
275
+ Used to define the modality of the dataset.
276
+ Defaults to Image.
277
+ """
278
+
144
279
  run_id: typing.Optional[str] = Field(default=None, init=False)
145
280
  """
146
281
  The ID of the Dataset Optimization run created on the server.
147
282
  """
148
283
 
284
+ status: typing.Optional[RunStatus] = None
285
+
149
286
  @model_validator(mode="after")
150
287
  def validate_dataset(self):
151
- if self.dataset_storage is None and self.storage_integration_id is None:
152
- raise ValueError("No dataset storage has been provided")
288
+ if self.storage_config is None and self.storage_config_id is None:
289
+ raise ValueError(
290
+ "No dataset storage has been provided. Provide one via `storage_config` or `storage_config_id`"
291
+ )
292
+ elif self.storage_config is not None and self.storage_config_id is not None:
293
+ raise ValueError(
294
+ "Both `storage_config` and `storage_config_id` have been provided. Pick one."
295
+ )
296
+ if self.labeling_type == LabelingType.SPEECH_TO_TEXT and self.language is None:
297
+ raise ValueError("Language is required for Speech-to-Text datasets.")
298
+ elif (
299
+ self.labeling_type != LabelingType.SPEECH_TO_TEXT
300
+ and self.language is not None
301
+ ):
302
+ raise ValueError("Language is only allowed for Speech-to-Text datasets.")
303
+ if (
304
+ self.labeling_info.type == DatasetMetadataType.YOLO
305
+ and isinstance(self.labeling_info, YOLO)
306
+ and (
307
+ self.labeling_info.data_yaml_url is not None
308
+ and self.classes is not None
309
+ )
310
+ ):
311
+ raise ValueError(
312
+ "Only one of `classes` or `labeling_info.data_yaml_url` should be provided for YOLO datasets"
313
+ )
153
314
  return self
154
315
 
155
316
  @staticmethod
156
- def list(organization_id: typing.Optional[int] = None) -> list[dict]:
317
+ def get_by_id(dataset_id: int) -> "OptimizationDataset":
318
+ """
319
+ Get a `OptimizationDataset` instance from the server by its ID
320
+
321
+ Args:
322
+ dataset_id: The ID of the `OptimizationDataset` instance to get
323
+ """
324
+ response = requests.get(
325
+ f"{API_HOST}/dataset-optimization/dataset/{dataset_id}",
326
+ headers=get_auth_headers(),
327
+ timeout=READ_TIMEOUT,
328
+ )
329
+ raise_for_status_with_reason(response)
330
+ dataset = response.json()
331
+ return OptimizationDataset(**dataset)
332
+
333
+ @staticmethod
334
+ def get_by_name(name: str) -> "OptimizationDataset":
335
+ """
336
+ Get a `OptimizationDataset` instance from the server by its name
337
+
338
+ Args:
339
+ name: The name of the `OptimizationDataset` instance to get
340
+ """
341
+ response = requests.get(
342
+ f"{API_HOST}/dataset-optimization/dataset/by-name/{name}",
343
+ headers=get_auth_headers(),
344
+ timeout=READ_TIMEOUT,
345
+ )
346
+ raise_for_status_with_reason(response)
347
+ dataset = response.json()
348
+ return OptimizationDataset(**dataset)
349
+
350
+ @staticmethod
351
+ def list_datasets(
352
+ organization_id: typing.Optional[int] = None,
353
+ ) -> list["DataOptimizationDatasetOut"]:
354
+ """
355
+ Lists all the optimization datasets created by user's default organization
356
+ or the `organization_id` passed
357
+
358
+ Args:
359
+ organization_id: The ID of the organization to list the datasets for.
360
+ """
361
+ response = requests.get(
362
+ f"{API_HOST}/dataset-optimization/dataset/",
363
+ params={"dataset_organization_id": organization_id},
364
+ headers=get_auth_headers(),
365
+ timeout=READ_TIMEOUT,
366
+ )
367
+ raise_for_status_with_reason(response)
368
+ datasets = response.json()
369
+ return [
370
+ DataOptimizationDatasetOut(
371
+ **ds,
372
+ )
373
+ for ds in datasets
374
+ ]
375
+
376
+ @staticmethod
377
+ def list_runs(
378
+ organization_id: typing.Optional[int] = None,
379
+ ) -> list["DataOptimizationRunOut"]:
157
380
  """
158
381
  Lists all the `OptimizationDataset` instances created by user's default organization
159
382
  or the `organization_id` passed
@@ -163,13 +386,19 @@ class OptimizationDataset(BaseModel):
163
386
  organization_id: The ID of the organization to list the datasets for.
164
387
  """
165
388
  response = requests.get(
166
- f"{API_HOST}/dataset-optimization/dataset/",
389
+ f"{API_HOST}/dataset-optimization/run/list",
167
390
  params={"dataset_organization_id": organization_id},
168
391
  headers=get_auth_headers(),
169
392
  timeout=READ_TIMEOUT,
170
393
  )
171
394
  raise_for_status_with_reason(response)
172
- return response.json()
395
+ runs = response.json()
396
+ return [
397
+ DataOptimizationRunOut(
398
+ **run,
399
+ )
400
+ for run in runs
401
+ ]
173
402
 
174
403
  @staticmethod
175
404
  def delete_by_id(dataset_id: int) -> None:
@@ -187,51 +416,71 @@ class OptimizationDataset(BaseModel):
187
416
  raise_for_status_with_reason(response)
188
417
  logger.info("Deleted dataset with ID: %s", dataset_id)
189
418
 
190
- def delete(self, storage_integration=True) -> None:
419
+ def delete(self, storage_config=True) -> None:
191
420
  """
192
421
  Deletes the active `OptimizationDataset` instance from the server.
193
422
  It can only be used on a `OptimizationDataset` instance that has been created.
194
423
 
195
424
  Args:
196
- storage_integration: If True, the `OptimizationDataset`'s `StorageIntegration` will also be deleted
425
+ storage_config: If True, the `OptimizationDataset`'s `StorageConfig` will also be deleted
197
426
 
198
- Note: If `storage_integration` is not set to `False` then the `storage_integration_id` must be set
199
- This can either be set manually or by creating the `StorageIntegration` instance via the `OptimizationDataset`'s
427
+ Note: If `storage_config` is not set to `False` then the `storage_config_id` must be set
428
+ This can either be set manually or by creating the `StorageConfig` instance via the `OptimizationDataset`'s
200
429
  `create` method
201
430
  """
202
- if storage_integration:
203
- if not self.storage_integration_id:
204
- raise ValueError("No storage integration has been created")
205
- StorageIntegration.delete_by_id(self.storage_integration_id)
206
- if not self.dataset_id:
431
+ if storage_config:
432
+ if not self.storage_config_id:
433
+ raise ValueError("No storage config has been created")
434
+ StorageConfig.delete_by_id(self.storage_config_id)
435
+ if not self.id:
207
436
  raise ValueError("No dataset has been created")
208
- self.delete_by_id(self.dataset_id)
437
+ self.delete_by_id(self.id)
209
438
 
210
- def create(self) -> int:
439
+ def create(
440
+ self,
441
+ organization_id: typing.Optional[int] = None,
442
+ replace_if_exists: bool = False,
443
+ ) -> int:
211
444
  """
212
445
  Create a `OptimizationDataset` instance on the server.
213
- If `storage_integration_id` is not set, it will be created.
446
+ If the `storage_config_id` field is not set, the storage config will also be created and the field will be set.
447
+
448
+ Args:
449
+ organization_id: The ID of the organization to create the dataset for.
450
+ replace_if_exists: If True, the dataset will be replaced if it already exists
451
+ (this is determined by a dataset of the same name in the same organization).
452
+
453
+ Returns:
454
+ The ID of the created `OptimizationDataset` instance
214
455
  """
215
- if not self.dataset_storage:
456
+ if self.storage_config is None and self.storage_config_id is None:
216
457
  raise ValueError("No dataset storage has been provided")
217
- if (
218
- self.dataset_storage
219
- and self.dataset_storage.storage_integration
220
- and not self.storage_integration_id
458
+ elif self.storage_config and self.storage_config_id is None:
459
+ if isinstance(self.storage_config, ResponseStorageConfig):
460
+ self.storage_config_id = self.storage_config.id
461
+ elif isinstance(self.storage_config, StorageConfig):
462
+ self.storage_config_id = self.storage_config.create(
463
+ replace_if_exists=replace_if_exists,
464
+ )
465
+ elif (
466
+ self.storage_config is not None
467
+ and self.storage_config_id is not None
468
+ and (
469
+ not isinstance(self.storage_config, ResponseStorageConfig)
470
+ or self.storage_config.id != self.storage_config_id
471
+ )
221
472
  ):
222
- self.storage_integration_id = (
223
- self.dataset_storage.storage_integration.create()
473
+ raise ValueError(
474
+ "Both `storage_config` and `storage_config_id` have been provided. Storage config IDs do not match."
224
475
  )
225
- model_dict = self.model_dump()
476
+ model_dict = self.model_dump(mode="json")
226
477
  # ⬆️ Get dict of model fields from Pydantic model instance
227
478
  dataset_response = requests.post(
228
479
  f"{API_HOST}/dataset-optimization/dataset/",
229
480
  json={
230
- "dataset_storage": {
231
- "storage_integration_id": self.storage_integration_id,
232
- "path": self.dataset_storage.path,
233
- },
234
- **{k: model_dict[k] for k in model_dict.keys() - {"dataset_storage"}},
481
+ **{k: model_dict[k] for k in model_dict.keys() - {"storage_config"}},
482
+ "organization_id": organization_id,
483
+ "replace_if_exists": replace_if_exists,
235
484
  },
236
485
  headers={
237
486
  **json_headers,
@@ -240,14 +489,18 @@ class OptimizationDataset(BaseModel):
240
489
  timeout=MODIFY_TIMEOUT,
241
490
  )
242
491
  raise_for_status_with_reason(dataset_response)
243
- self.dataset_id = dataset_response.json()["id"]
244
- if not self.dataset_id:
245
- raise HirundoError("Failed to create the dataset")
246
- logger.info("Created dataset with ID: %s", self.dataset_id)
247
- return self.dataset_id
492
+ self.id = dataset_response.json()["id"]
493
+ if not self.id:
494
+ raise HirundoError("An error ocurred while trying to create the dataset")
495
+ logger.info("Created dataset with ID: %s", self.id)
496
+ return self.id
248
497
 
249
498
  @staticmethod
250
- def launch_optimization_run(dataset_id: int) -> str:
499
+ def launch_optimization_run(
500
+ dataset_id: int,
501
+ organization_id: typing.Optional[int] = None,
502
+ run_args: typing.Optional[RunArgs] = None,
503
+ ) -> str:
251
504
  """
252
505
  Run the dataset optimization process on the server using the dataset with the given ID
253
506
  i.e. `dataset_id`.
@@ -258,26 +511,62 @@ class OptimizationDataset(BaseModel):
258
511
  Returns:
259
512
  ID of the run (`run_id`).
260
513
  """
514
+ run_info = {}
515
+ if organization_id:
516
+ run_info["organization_id"] = organization_id
517
+ if run_args:
518
+ run_info["run_args"] = run_args.model_dump(mode="json")
261
519
  run_response = requests.post(
262
520
  f"{API_HOST}/dataset-optimization/run/{dataset_id}",
521
+ json=run_info if len(run_info) > 0 else None,
263
522
  headers=get_auth_headers(),
264
523
  timeout=MODIFY_TIMEOUT,
265
524
  )
266
525
  raise_for_status_with_reason(run_response)
267
526
  return run_response.json()["run_id"]
268
527
 
269
- def run_optimization(self) -> str:
528
+ def _validate_run_args(self, run_args: RunArgs) -> None:
529
+ if self.labeling_type == LabelingType.SPEECH_TO_TEXT:
530
+ raise Exception("Speech to text cannot have `run_args` set")
531
+ if self.labeling_type != LabelingType.OBJECT_DETECTION and any(
532
+ (
533
+ run_args.min_abs_bbox_size != 0,
534
+ run_args.min_abs_bbox_area != 0,
535
+ run_args.min_rel_bbox_size != 0,
536
+ run_args.min_rel_bbox_area != 0,
537
+ )
538
+ ):
539
+ raise Exception(
540
+ "Cannot set `min_abs_bbox_size`, `min_abs_bbox_area`, "
541
+ + "`min_rel_bbox_size`, or `min_rel_bbox_area` for "
542
+ + f"labeling type {self.labeling_type}"
543
+ )
544
+
545
+ def run_optimization(
546
+ self,
547
+ organization_id: typing.Optional[int] = None,
548
+ replace_dataset_if_exists: bool = False,
549
+ run_args: typing.Optional[RunArgs] = None,
550
+ ) -> str:
270
551
  """
271
552
  If the dataset was not created on the server yet, it is created.
272
553
  Run the dataset optimization process on the server using the active `OptimizationDataset` instance
273
554
 
555
+ Args:
556
+ organization_id: The ID of the organization to run the optimization for.
557
+ replace_dataset_if_exists: If True, the dataset will be replaced if it already exists
558
+ (this is determined by a dataset of the same name in the same organization).
559
+ run_args: The run arguments to use for the optimization run
560
+
274
561
  Returns:
275
562
  An ID of the run (`run_id`) and stores that `run_id` on the instance
276
563
  """
277
564
  try:
278
- if not self.dataset_id:
279
- self.dataset_id = self.create()
280
- run_id = self.launch_optimization_run(self.dataset_id)
565
+ if not self.id:
566
+ self.id = self.create(replace_if_exists=replace_dataset_if_exists)
567
+ if run_args is not None:
568
+ self._validate_run_args(run_args)
569
+ run_id = self.launch_optimization_run(self.id, organization_id, run_args)
281
570
  self.run_id = run_id
282
571
  logger.info("Started the run with ID: %s", run_id)
283
572
  return run_id
@@ -293,17 +582,17 @@ class OptimizationDataset(BaseModel):
293
582
  except Exception:
294
583
  content = error.response.text
295
584
  raise HirundoError(
296
- f"Failed to start the run. Status code: {error.response.status_code} Content: {content}"
585
+ f"Unable to start the run. Status code: {error.response.status_code} Content: {content}"
297
586
  ) from error
298
587
  except Exception as error:
299
- raise HirundoError(f"Failed to start the run: {error}") from error
588
+ raise HirundoError(f"Unable to start the run: {error}") from error
300
589
 
301
590
  def clean_ids(self):
302
591
  """
303
- Reset `dataset_id`, `storage_integration_id`, and `run_id` values on the instance to default value of `None`
592
+ Reset `dataset_id`, `storage_config_id`, and `run_id` values on the instance to default value of `None`
304
593
  """
305
- self.storage_integration_id = None
306
- self.dataset_id = None
594
+ self.storage_config_id = None
595
+ self.id = None
307
596
  self.run_id = None
308
597
 
309
598
  @staticmethod
@@ -370,7 +659,15 @@ class OptimizationDataset(BaseModel):
370
659
  last_event = json.loads(sse.data)
371
660
  if not last_event:
372
661
  continue
373
- data = last_event["data"]
662
+ if "data" in last_event:
663
+ data = last_event["data"]
664
+ else:
665
+ if "detail" in last_event:
666
+ raise HirundoError(last_event["detail"])
667
+ elif "reason" in last_event:
668
+ raise HirundoError(last_event["reason"])
669
+ else:
670
+ raise HirundoError("Unknown error")
374
671
  OptimizationDataset._read_csvs_to_df(data)
375
672
  yield data
376
673
  if not last_event or last_event["data"]["state"] == RunStatus.PENDING.value:
@@ -420,7 +717,11 @@ class OptimizationDataset(BaseModel):
420
717
  t.n = STATUS_TO_PROGRESS_MAP[iteration["state"]]
421
718
  logger.debug("Setting progress to %s", t.n)
422
719
  t.refresh()
423
- if iteration["state"] == RunStatus.FAILURE.value:
720
+ if iteration["state"] in [
721
+ RunStatus.FAILURE.value,
722
+ RunStatus.REJECTED.value,
723
+ RunStatus.REVOKED.value,
724
+ ]:
424
725
  raise HirundoError(
425
726
  f"Optimization run failed with error: {iteration['result']}"
426
727
  )
@@ -445,13 +746,22 @@ class OptimizationDataset(BaseModel):
445
746
  and iteration["result"]["result"]
446
747
  and isinstance(iteration["result"]["result"], str)
447
748
  ):
448
- current_progress_percentage = float(
449
- iteration["result"]["result"].removesuffix("% done")
450
- )
749
+ result_info = iteration["result"]["result"].split(":")
750
+ if len(result_info) > 1:
751
+ stage = result_info[0]
752
+ current_progress_percentage = float(
753
+ result_info[1].removeprefix(" ").removesuffix("% done")
754
+ )
755
+ elif len(result_info) == 1:
756
+ stage = result_info[0]
757
+ current_progress_percentage = t.n # Keep the same progress
758
+ else:
759
+ stage = "Unknown progress state"
760
+ current_progress_percentage = t.n # Keep the same progress
451
761
  desc = (
452
762
  "Optimization run completed. Uploading results"
453
763
  if current_progress_percentage == 100.0
454
- else "Optimization run in progress"
764
+ else stage
455
765
  )
456
766
  t.set_description(desc)
457
767
  t.n = current_progress_percentage
@@ -574,3 +884,31 @@ class OptimizationDataset(BaseModel):
574
884
  if not self.run_id:
575
885
  raise ValueError("No run has been started")
576
886
  self.cancel_by_id(self.run_id)
887
+
888
+
889
+ class DataOptimizationDatasetOut(BaseModel):
890
+ id: int
891
+
892
+ name: str
893
+ labeling_type: LabelingType
894
+
895
+ storage_config: ResponseStorageConfig
896
+
897
+ data_root_url: HirundoUrl
898
+
899
+ classes: typing.Optional[list[str]] = None
900
+ labeling_info: LabelingInfo
901
+
902
+ organization_id: typing.Optional[int]
903
+ creator_id: typing.Optional[int]
904
+ created_at: datetime.datetime
905
+ updated_at: datetime.datetime
906
+
907
+
908
+ class DataOptimizationRunOut(BaseModel):
909
+ id: int
910
+ name: str
911
+ run_id: str
912
+ status: RunStatus
913
+ approved: bool
914
+ created_at: datetime.datetime