hirundo 0.1.18__py3-none-any.whl → 0.2.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,12 +1,9 @@
1
1
  import datetime
2
- import json
3
2
  import typing
4
3
  from collections.abc import AsyncGenerator, Generator
5
4
  from enum import Enum
6
5
  from typing import overload
7
6
 
8
- import httpx
9
- import requests
10
7
  from pydantic import BaseModel, Field, model_validator
11
8
  from tqdm import tqdm
12
9
  from tqdm.contrib.logging import logging_redirect_tqdm
@@ -14,12 +11,21 @@ from tqdm.contrib.logging import logging_redirect_tqdm
14
11
  from hirundo._constraints import validate_labeling_info, validate_url
15
12
  from hirundo._env import API_HOST
16
13
  from hirundo._headers import get_headers
17
- from hirundo._http import raise_for_status_with_reason
18
- from hirundo._iter_sse_retrying import aiter_sse_retrying, iter_sse_retrying
14
+ from hirundo._http import raise_for_status_with_reason, requests
15
+ from hirundo._run_checking import (
16
+ STATUS_TO_PROGRESS_MAP,
17
+ RunStatus,
18
+ aiter_run_events,
19
+ build_status_text_map,
20
+ get_state,
21
+ handle_run_failure,
22
+ iter_run_events,
23
+ update_progress_from_result,
24
+ )
19
25
  from hirundo._timeouts import MODIFY_TIMEOUT, READ_TIMEOUT
20
26
  from hirundo._urls import HirundoUrl
21
27
  from hirundo.dataset_enum import DatasetMetadataType, LabelingType
22
- from hirundo.dataset_optimization_results import DatasetOptimizationResults
28
+ from hirundo.dataset_qa_results import DatasetQAResults
23
29
  from hirundo.labeling import YOLO, LabelingInfo
24
30
  from hirundo.logger import get_logger
25
31
  from hirundo.storage import ResponseStorageConfig, StorageConfig
@@ -30,75 +36,63 @@ logger = get_logger(__name__)
30
36
 
31
37
  class HirundoError(Exception):
32
38
  """
33
- Custom exception used to indicate errors in `hirundo` dataset optimization runs
39
+ Custom exception used to indicate errors in `hirundo` dataset QA runs
34
40
  """
35
41
 
36
42
  pass
37
43
 
38
44
 
39
- MAX_RETRIES = 200 # Max 200 retries for HTTP SSE connection
45
+ STATUS_TO_TEXT_MAP = build_status_text_map(
46
+ "Dataset QA",
47
+ started_detail="Dataset QA run in progress. Downloading dataset",
48
+ )
40
49
 
41
50
 
42
- class RunStatus(Enum):
43
- PENDING = "PENDING"
44
- STARTED = "STARTED"
45
- SUCCESS = "SUCCESS"
46
- FAILURE = "FAILURE"
47
- AWAITING_MANUAL_APPROVAL = "AWAITING MANUAL APPROVAL"
48
- REVOKED = "REVOKED"
49
- REJECTED = "REJECTED"
50
- RETRY = "RETRY"
51
-
52
-
53
- STATUS_TO_TEXT_MAP = {
54
- RunStatus.STARTED.value: "Optimization run in progress. Downloading dataset",
55
- RunStatus.PENDING.value: "Optimization run queued and not yet started",
56
- RunStatus.SUCCESS.value: "Optimization run completed successfully",
57
- RunStatus.FAILURE.value: "Optimization run failed",
58
- RunStatus.AWAITING_MANUAL_APPROVAL.value: "Awaiting manual approval",
59
- RunStatus.RETRY.value: "Optimization run failed. Retrying",
60
- RunStatus.REVOKED.value: "Optimization run was cancelled",
61
- RunStatus.REJECTED.value: "Optimization run was rejected",
62
- }
63
- STATUS_TO_PROGRESS_MAP = {
64
- RunStatus.STARTED.value: 0.0,
65
- RunStatus.PENDING.value: 0.0,
66
- RunStatus.SUCCESS.value: 100.0,
67
- RunStatus.FAILURE.value: 100.0,
68
- RunStatus.AWAITING_MANUAL_APPROVAL.value: 100.0,
69
- RunStatus.RETRY.value: 0.0,
70
- RunStatus.REVOKED.value: 100.0,
71
- RunStatus.REJECTED.value: 0.0,
72
- }
73
-
74
-
75
- class VisionRunArgs(BaseModel):
76
- upsample: bool = False
51
+ class ClassificationRunArgs(BaseModel):
52
+ image_size: tuple[int, int] | None = (224, 224)
53
+ """
54
+ Size (width, height) to which to resize classification images.
55
+ It is recommended to keep this value at (224, 224) unless your classes are differentiated by very small differences.
56
+ """
57
+ upsample: bool | None = False
77
58
  """
78
59
  Whether to upsample the dataset to attempt to balance the classes.
79
60
  """
80
- min_abs_bbox_size: int = 0
61
+
62
+
63
+ class ObjectDetectionRunArgs(ClassificationRunArgs):
64
+ min_abs_bbox_size: int | None = None
81
65
  """
82
- Minimum valid size (in pixels) of a bounding box to keep it in the dataset for optimization.
66
+ Minimum valid size (in pixels) of a bounding box to keep it in the dataset for QA.
83
67
  """
84
- min_abs_bbox_area: int = 0
68
+ min_abs_bbox_area: int | None = None
85
69
  """
86
- Minimum valid absolute area (in pixels²) of a bounding box to keep it in the dataset for optimization.
70
+ Minimum valid absolute area (in pixels²) of a bounding box to keep it in the dataset for QA.
87
71
  """
88
- min_rel_bbox_size: float = 0.0
72
+ min_rel_bbox_size: float | None = None
89
73
  """
90
74
  Minimum valid size (as a fraction of both image height and width) for a bounding box
91
- to keep it in the dataset for optimization, relative to the corresponding dimension size,
75
+ to keep it in the dataset for QA, relative to the corresponding dimension size,
92
76
  i.e. if the bounding box is 10% of the image width and 5% of the image height, it will be kept if this value is 0.05, but not if the
93
77
  value is 0.06 (since both width and height are checked).
94
78
  """
95
- min_rel_bbox_area: float = 0.0
79
+ min_rel_bbox_area: float | None = None
80
+ """
81
+ Minimum valid relative area (as a fraction of the image area) of a bounding box to keep it in the dataset for QA.
82
+ """
83
+ crop_ratio: float | None = None
96
84
  """
97
- Minimum valid relative area (as a fraction of the image area) of a bounding box to keep it in the dataset for optimization.
85
+ Ratio of the bounding box to crop.
86
+ Change this value at your own risk. It is recommended to keep it at 1.0 unless you know what you are doing.
87
+ """
88
+ add_mask_channel: bool | None = None
89
+ """
90
+ Whether to add a mask channel to the image.
91
+ Change at your own risk. It is recommended to keep it at False unless you know what you are doing.
98
92
  """
99
93
 
100
94
 
101
- RunArgs = typing.Union[VisionRunArgs]
95
+ RunArgs = ClassificationRunArgs | ObjectDetectionRunArgs
102
96
 
103
97
 
104
98
  class AugmentationName(str, Enum):
@@ -111,14 +105,32 @@ class AugmentationName(str, Enum):
111
105
  GAUSSIAN_BLUR = "GaussianBlur"
112
106
 
113
107
 
114
- class Modality(str, Enum):
115
- IMAGE = "Image"
116
- RADAR = "Radar"
117
- EKG = "EKG"
108
+ class ModalityType(str, Enum):
109
+ RADAR = "RADAR"
110
+ VISION = "VISION"
111
+ SPEECH = "SPEECH"
112
+ TABULAR = "TABULAR"
113
+
114
+
115
+ MODALITY_TO_SUPPORTED_LABELING_TYPES = {
116
+ ModalityType.RADAR: [
117
+ LabelingType.SINGLE_LABEL_CLASSIFICATION,
118
+ LabelingType.OBJECT_DETECTION,
119
+ ],
120
+ ModalityType.VISION: [
121
+ LabelingType.SINGLE_LABEL_CLASSIFICATION,
122
+ LabelingType.OBJECT_DETECTION,
123
+ LabelingType.OBJECT_SEGMENTATION,
124
+ LabelingType.SEMANTIC_SEGMENTATION,
125
+ LabelingType.PANOPTIC_SEGMENTATION,
126
+ ],
127
+ ModalityType.SPEECH: [LabelingType.SPEECH_TO_TEXT],
128
+ ModalityType.TABULAR: [LabelingType.SINGLE_LABEL_CLASSIFICATION],
129
+ }
118
130
 
119
131
 
120
- class OptimizationDataset(BaseModel):
121
- id: typing.Optional[int] = Field(default=None)
132
+ class QADataset(BaseModel):
133
+ id: int | None = Field(default=None)
122
134
  """
123
135
  The ID of the dataset created on the server.
124
136
  """
@@ -134,17 +146,15 @@ class OptimizationDataset(BaseModel):
134
146
  - `LabelingType.OBJECT_DETECTION`: Indicates that the dataset is for object detection tasks
135
147
  - `LabelingType.SPEECH_TO_TEXT`: Indicates that the dataset is for speech-to-text tasks
136
148
  """
137
- language: typing.Optional[str] = None
149
+ language: str | None = None
138
150
  """
139
151
  Language of the Speech-to-Text audio dataset. This is required for Speech-to-Text datasets.
140
152
  """
141
- storage_config_id: typing.Optional[int] = None
153
+ storage_config_id: int | None = None
142
154
  """
143
155
  The ID of the storage config used to store the dataset and metadata.
144
156
  """
145
- storage_config: typing.Optional[
146
- typing.Union[StorageConfig, ResponseStorageConfig]
147
- ] = None
157
+ storage_config: StorageConfig | ResponseStorageConfig | None = None
148
158
  """
149
159
  The `StorageConfig` instance to link to.
150
160
  """
@@ -158,34 +168,45 @@ class OptimizationDataset(BaseModel):
158
168
  Note: All CSV `image_path` entries in the metadata file should be relative to this folder.
159
169
  """
160
170
 
161
- classes: typing.Optional[list[str]] = None
171
+ classes: list[str] | None = None
162
172
  """
163
173
  A full list of possible classes used in classification / object detection.
164
174
  It is currently required for clarity and performance.
165
175
  """
166
- labeling_info: typing.Union[LabelingInfo, list[LabelingInfo]]
176
+ labeling_info: LabelingInfo | list[LabelingInfo]
167
177
 
168
- augmentations: typing.Optional[list[AugmentationName]] = None
178
+ augmentations: list[AugmentationName] | None = None
169
179
  """
170
180
  Used to define which augmentations are apply to a vision dataset.
171
181
  For audio datasets, this field is ignored.
172
182
  If no value is provided, all augmentations are applied to vision datasets.
173
183
  """
174
- modality: Modality = Modality.IMAGE
184
+ modality: ModalityType = ModalityType.VISION
175
185
  """
176
186
  Used to define the modality of the dataset.
177
187
  Defaults to Image.
178
188
  """
179
189
 
180
- run_id: typing.Optional[str] = Field(default=None, init=False)
190
+ run_id: str | None = Field(default=None, init=False)
181
191
  """
182
- The ID of the Dataset Optimization run created on the server.
192
+ The ID of the Dataset QA run created on the server.
183
193
  """
184
194
 
185
- status: typing.Optional[RunStatus] = None
195
+ status: RunStatus | None = None
186
196
 
187
197
  @model_validator(mode="after")
188
198
  def validate_dataset(self):
199
+ if self.modality not in MODALITY_TO_SUPPORTED_LABELING_TYPES:
200
+ raise ValueError(
201
+ f"Modality {self.modality} is not supported. Supported modalities are: {list(MODALITY_TO_SUPPORTED_LABELING_TYPES.keys())}"
202
+ )
203
+ if (
204
+ self.labeling_type
205
+ not in MODALITY_TO_SUPPORTED_LABELING_TYPES[self.modality]
206
+ ):
207
+ raise ValueError(
208
+ f"Labeling type {self.labeling_type} is not supported for modality {self.modality}. Supported labeling types are: {MODALITY_TO_SUPPORTED_LABELING_TYPES[self.modality]}"
209
+ )
189
210
  if self.storage_config is None and self.storage_config_id is None:
190
211
  raise ValueError(
191
212
  "No dataset storage has been provided. Provide one via `storage_config` or `storage_config_id`"
@@ -229,52 +250,52 @@ class OptimizationDataset(BaseModel):
229
250
  return self
230
251
 
231
252
  @staticmethod
232
- def get_by_id(dataset_id: int) -> "OptimizationDataset":
253
+ def get_by_id(dataset_id: int) -> "QADataset":
233
254
  """
234
- Get a `OptimizationDataset` instance from the server by its ID
255
+ Get a `QADataset` instance from the server by its ID
235
256
 
236
257
  Args:
237
- dataset_id: The ID of the `OptimizationDataset` instance to get
258
+ dataset_id: The ID of the `QADataset` instance to get
238
259
  """
239
260
  response = requests.get(
240
- f"{API_HOST}/dataset-optimization/dataset/{dataset_id}",
261
+ f"{API_HOST}/dataset-qa/dataset/{dataset_id}",
241
262
  headers=get_headers(),
242
263
  timeout=READ_TIMEOUT,
243
264
  )
244
265
  raise_for_status_with_reason(response)
245
266
  dataset = response.json()
246
- return OptimizationDataset(**dataset)
267
+ return QADataset(**dataset)
247
268
 
248
269
  @staticmethod
249
- def get_by_name(name: str) -> "OptimizationDataset":
270
+ def get_by_name(name: str) -> "QADataset":
250
271
  """
251
- Get a `OptimizationDataset` instance from the server by its name
272
+ Get a `QADataset` instance from the server by its name
252
273
 
253
274
  Args:
254
- name: The name of the `OptimizationDataset` instance to get
275
+ name: The name of the `QADataset` instance to get
255
276
  """
256
277
  response = requests.get(
257
- f"{API_HOST}/dataset-optimization/dataset/by-name/{name}",
278
+ f"{API_HOST}/dataset-qa/dataset/by-name/{name}",
258
279
  headers=get_headers(),
259
280
  timeout=READ_TIMEOUT,
260
281
  )
261
282
  raise_for_status_with_reason(response)
262
283
  dataset = response.json()
263
- return OptimizationDataset(**dataset)
284
+ return QADataset(**dataset)
264
285
 
265
286
  @staticmethod
266
287
  def list_datasets(
267
- organization_id: typing.Optional[int] = None,
268
- ) -> list["DataOptimizationDatasetOut"]:
288
+ organization_id: int | None = None,
289
+ ) -> list["QADatasetOut"]:
269
290
  """
270
- Lists all the optimization datasets created by user's default organization
291
+ Lists all the datasets created by user's default organization
271
292
  or the `organization_id` passed
272
293
 
273
294
  Args:
274
295
  organization_id: The ID of the organization to list the datasets for.
275
296
  """
276
297
  response = requests.get(
277
- f"{API_HOST}/dataset-optimization/dataset/",
298
+ f"{API_HOST}/dataset-qa/dataset/",
278
299
  params={"dataset_organization_id": organization_id},
279
300
  headers=get_headers(),
280
301
  timeout=READ_TIMEOUT,
@@ -282,7 +303,7 @@ class OptimizationDataset(BaseModel):
282
303
  raise_for_status_with_reason(response)
283
304
  datasets = response.json()
284
305
  return [
285
- DataOptimizationDatasetOut(
306
+ QADatasetOut(
286
307
  **ds,
287
308
  )
288
309
  for ds in datasets
@@ -290,26 +311,28 @@ class OptimizationDataset(BaseModel):
290
311
 
291
312
  @staticmethod
292
313
  def list_runs(
293
- organization_id: typing.Optional[int] = None,
294
- ) -> list["DataOptimizationRunOut"]:
314
+ organization_id: int | None = None,
315
+ archived: bool | None = False,
316
+ ) -> list["DataQARunOut"]:
295
317
  """
296
- Lists all the `OptimizationDataset` instances created by user's default organization
318
+ Lists all the `QADataset` instances created by user's default organization
297
319
  or the `organization_id` passed
298
- Note: The return type is `list[dict]` and not `list[OptimizationDataset]`
320
+ Note: The return type is `list[dict]` and not `list[QADataset]`
299
321
 
300
322
  Args:
301
323
  organization_id: The ID of the organization to list the datasets for.
324
+ archived: Whether to list archived runs.
302
325
  """
303
326
  response = requests.get(
304
- f"{API_HOST}/dataset-optimization/run/list",
305
- params={"dataset_organization_id": organization_id},
327
+ f"{API_HOST}/dataset-qa/run/list",
328
+ params={"dataset_organization_id": organization_id, "archived": archived},
306
329
  headers=get_headers(),
307
330
  timeout=READ_TIMEOUT,
308
331
  )
309
332
  raise_for_status_with_reason(response)
310
333
  runs = response.json()
311
334
  return [
312
- DataOptimizationRunOut(
335
+ DataQARunOut(
313
336
  **run,
314
337
  )
315
338
  for run in runs
@@ -318,13 +341,13 @@ class OptimizationDataset(BaseModel):
318
341
  @staticmethod
319
342
  def delete_by_id(dataset_id: int) -> None:
320
343
  """
321
- Deletes a `OptimizationDataset` instance from the server by its ID
344
+ Deletes a `QADataset` instance from the server by its ID
322
345
 
323
346
  Args:
324
- dataset_id: The ID of the `OptimizationDataset` instance to delete
347
+ dataset_id: The ID of the `QADataset` instance to delete
325
348
  """
326
349
  response = requests.delete(
327
- f"{API_HOST}/dataset-optimization/dataset/{dataset_id}",
350
+ f"{API_HOST}/dataset-qa/dataset/{dataset_id}",
328
351
  headers=get_headers(),
329
352
  timeout=MODIFY_TIMEOUT,
330
353
  )
@@ -333,14 +356,14 @@ class OptimizationDataset(BaseModel):
333
356
 
334
357
  def delete(self, storage_config=True) -> None:
335
358
  """
336
- Deletes the active `OptimizationDataset` instance from the server.
337
- It can only be used on a `OptimizationDataset` instance that has been created.
359
+ Deletes the active `QADataset` instance from the server.
360
+ It can only be used on a `QADataset` instance that has been created.
338
361
 
339
362
  Args:
340
- storage_config: If True, the `OptimizationDataset`'s `StorageConfig` will also be deleted
363
+ storage_config: If True, the `QADataset`'s `StorageConfig` will also be deleted
341
364
 
342
365
  Note: If `storage_config` is not set to `False` then the `storage_config_id` must be set
343
- This can either be set manually or by creating the `StorageConfig` instance via the `OptimizationDataset`'s
366
+ This can either be set manually or by creating the `StorageConfig` instance via the `QADataset`'s
344
367
  `create` method
345
368
  """
346
369
  if storage_config:
@@ -353,11 +376,11 @@ class OptimizationDataset(BaseModel):
353
376
 
354
377
  def create(
355
378
  self,
356
- organization_id: typing.Optional[int] = None,
379
+ organization_id: int | None = None,
357
380
  replace_if_exists: bool = False,
358
381
  ) -> int:
359
382
  """
360
- Create a `OptimizationDataset` instance on the server.
383
+ Create a `QADataset` instance on the server.
361
384
  If the `storage_config_id` field is not set, the storage config will also be created and the field will be set.
362
385
 
363
386
  Args:
@@ -366,7 +389,7 @@ class OptimizationDataset(BaseModel):
366
389
  (this is determined by a dataset of the same name in the same organization).
367
390
 
368
391
  Returns:
369
- The ID of the created `OptimizationDataset` instance
392
+ The ID of the created `QADataset` instance
370
393
  """
371
394
  if self.storage_config is None and self.storage_config_id is None:
372
395
  raise ValueError("No dataset storage has been provided")
@@ -391,7 +414,7 @@ class OptimizationDataset(BaseModel):
391
414
  model_dict = self.model_dump(mode="json")
392
415
  # ⬆️ Get dict of model fields from Pydantic model instance
393
416
  dataset_response = requests.post(
394
- f"{API_HOST}/dataset-optimization/dataset/",
417
+ f"{API_HOST}/dataset-qa/dataset/",
395
418
  json={
396
419
  **{k: model_dict[k] for k in model_dict.keys() - {"storage_config"}},
397
420
  "organization_id": organization_id,
@@ -408,17 +431,17 @@ class OptimizationDataset(BaseModel):
408
431
  return self.id
409
432
 
410
433
  @staticmethod
411
- def launch_optimization_run(
434
+ def launch_qa_run(
412
435
  dataset_id: int,
413
- organization_id: typing.Optional[int] = None,
414
- run_args: typing.Optional[RunArgs] = None,
436
+ organization_id: int | None = None,
437
+ run_args: RunArgs | None = None,
415
438
  ) -> str:
416
439
  """
417
- Run the dataset optimization process on the server using the dataset with the given ID
440
+ Run the dataset QA process on the server using the dataset with the given ID
418
441
  i.e. `dataset_id`.
419
442
 
420
443
  Args:
421
- dataset_id: The ID of the dataset to run optimization on.
444
+ dataset_id: The ID of the dataset to run QA on.
422
445
 
423
446
  Returns:
424
447
  ID of the run (`run_id`).
@@ -429,7 +452,7 @@ class OptimizationDataset(BaseModel):
429
452
  if run_args:
430
453
  run_info["run_args"] = run_args.model_dump(mode="json")
431
454
  run_response = requests.post(
432
- f"{API_HOST}/dataset-optimization/run/{dataset_id}",
455
+ f"{API_HOST}/dataset-qa/run/{dataset_id}",
433
456
  json=run_info if len(run_info) > 0 else None,
434
457
  headers=get_headers(),
435
458
  timeout=MODIFY_TIMEOUT,
@@ -440,12 +463,16 @@ class OptimizationDataset(BaseModel):
440
463
  def _validate_run_args(self, run_args: RunArgs) -> None:
441
464
  if self.labeling_type == LabelingType.SPEECH_TO_TEXT:
442
465
  raise Exception("Speech to text cannot have `run_args` set")
443
- if self.labeling_type != LabelingType.OBJECT_DETECTION and any(
444
- (
445
- run_args.min_abs_bbox_size != 0,
446
- run_args.min_abs_bbox_area != 0,
447
- run_args.min_rel_bbox_size != 0,
448
- run_args.min_rel_bbox_area != 0,
466
+ if (
467
+ self.labeling_type != LabelingType.OBJECT_DETECTION
468
+ and isinstance(run_args, ObjectDetectionRunArgs)
469
+ and any(
470
+ (
471
+ run_args.min_abs_bbox_size != 0,
472
+ run_args.min_abs_bbox_area != 0,
473
+ run_args.min_rel_bbox_size != 0,
474
+ run_args.min_rel_bbox_area != 0,
475
+ )
449
476
  )
450
477
  ):
451
478
  raise Exception(
@@ -454,21 +481,21 @@ class OptimizationDataset(BaseModel):
454
481
  + f"labeling type {self.labeling_type}"
455
482
  )
456
483
 
457
- def run_optimization(
484
+ def run_qa(
458
485
  self,
459
- organization_id: typing.Optional[int] = None,
486
+ organization_id: int | None = None,
460
487
  replace_dataset_if_exists: bool = False,
461
- run_args: typing.Optional[RunArgs] = None,
488
+ run_args: RunArgs | None = None,
462
489
  ) -> str:
463
490
  """
464
491
  If the dataset was not created on the server yet, it is created.
465
- Run the dataset optimization process on the server using the active `OptimizationDataset` instance
492
+ Run the dataset QA process on the server using the active `QADataset` instance
466
493
 
467
494
  Args:
468
- organization_id: The ID of the organization to run the optimization for.
495
+ organization_id: The ID of the organization to run the QA for.
469
496
  replace_dataset_if_exists: If True, the dataset will be replaced if it already exists
470
497
  (this is determined by a dataset of the same name in the same organization).
471
- run_args: The run arguments to use for the optimization run
498
+ run_args: The run arguments to use for the QA run
472
499
 
473
500
  Returns:
474
501
  An ID of the run (`run_id`) and stores that `run_id` on the instance
@@ -478,7 +505,7 @@ class OptimizationDataset(BaseModel):
478
505
  self.id = self.create(replace_if_exists=replace_dataset_if_exists)
479
506
  if run_args is not None:
480
507
  self._validate_run_args(run_args)
481
- run_id = self.launch_optimization_run(self.id, organization_id, run_args)
508
+ run_id = self.launch_qa_run(self.id, organization_id, run_args)
482
509
  self.run_id = run_id
483
510
  logger.info("Started the run with ID: %s", run_id)
484
511
  return run_id
@@ -509,83 +536,46 @@ class OptimizationDataset(BaseModel):
509
536
 
510
537
  @staticmethod
511
538
  def _check_run_by_id(run_id: str, retry=0) -> Generator[dict, None, None]:
512
- if retry > MAX_RETRIES:
513
- raise HirundoError("Max retries reached")
514
- last_event = None
515
- with httpx.Client(timeout=httpx.Timeout(None, connect=5.0)) as client:
516
- for sse in iter_sse_retrying(
517
- client,
518
- "GET",
519
- f"{API_HOST}/dataset-optimization/run/{run_id}",
520
- headers=get_headers(),
521
- ):
522
- if sse.event == "ping":
523
- continue
524
- logger.debug(
525
- "[SYNC] received event: %s with data: %s and ID: %s and retry: %s",
526
- sse.event,
527
- sse.data,
528
- sse.id,
529
- sse.retry,
530
- )
531
- last_event = json.loads(sse.data)
532
- if not last_event:
533
- continue
534
- if "data" in last_event:
535
- data = last_event["data"]
536
- else:
537
- if "detail" in last_event:
538
- raise HirundoError(last_event["detail"])
539
- elif "reason" in last_event:
540
- raise HirundoError(last_event["reason"])
541
- else:
542
- raise HirundoError("Unknown error")
543
- yield data
544
- if not last_event or last_event["data"]["state"] == RunStatus.PENDING.value:
545
- OptimizationDataset._check_run_by_id(run_id, retry + 1)
546
-
547
- @staticmethod
548
- def _handle_failure(iteration: dict):
549
- if iteration["result"]:
550
- raise HirundoError(
551
- f"Optimization run failed with error: {iteration['result']}"
552
- )
553
- else:
554
- raise HirundoError(
555
- "Optimization run failed with an unknown error in _handle_failure"
556
- )
539
+ yield from iter_run_events(
540
+ f"{API_HOST}/dataset-qa/run/{run_id}",
541
+ headers=get_headers(),
542
+ retry=retry,
543
+ status_keys=("state",),
544
+ error_cls=HirundoError,
545
+ log=logger,
546
+ )
557
547
 
558
548
  @staticmethod
559
549
  @overload
560
550
  def check_run_by_id(
561
551
  run_id: str, stop_on_manual_approval: typing.Literal[True]
562
- ) -> typing.Optional[DatasetOptimizationResults]: ...
552
+ ) -> DatasetQAResults | None: ...
563
553
 
564
554
  @staticmethod
565
555
  @overload
566
556
  def check_run_by_id(
567
557
  run_id: str, stop_on_manual_approval: typing.Literal[False] = False
568
- ) -> DatasetOptimizationResults: ...
558
+ ) -> DatasetQAResults: ...
569
559
 
570
560
  @staticmethod
571
561
  @overload
572
562
  def check_run_by_id(
573
563
  run_id: str, stop_on_manual_approval: bool
574
- ) -> typing.Optional[DatasetOptimizationResults]: ...
564
+ ) -> DatasetQAResults | None: ...
575
565
 
576
566
  @staticmethod
577
567
  def check_run_by_id(
578
568
  run_id: str, stop_on_manual_approval: bool = False
579
- ) -> typing.Optional[DatasetOptimizationResults]:
569
+ ) -> DatasetQAResults | None:
580
570
  """
581
571
  Check the status of a run given its ID
582
572
 
583
573
  Args:
584
- run_id: The `run_id` produced by a `run_optimization` call
574
+ run_id: The `run_id` produced by a `run_qa` call
585
575
  stop_on_manual_approval: If True, the function will return `None` if the run is awaiting manual approval
586
576
 
587
577
  Returns:
588
- A DatasetOptimizationResults object with the results of the optimization run
578
+ A DatasetQAResults object with the results of the QA run
589
579
 
590
580
  Raises:
591
581
  HirundoError: If the maximum number of retries is reached or if the run fails
@@ -593,87 +583,67 @@ class OptimizationDataset(BaseModel):
593
583
  logger.debug("Checking run with ID: %s", run_id)
594
584
  with logging_redirect_tqdm():
595
585
  t = tqdm(total=100.0)
596
- for iteration in OptimizationDataset._check_run_by_id(run_id):
597
- if iteration["state"] in STATUS_TO_PROGRESS_MAP:
598
- t.set_description(STATUS_TO_TEXT_MAP[iteration["state"]])
599
- t.n = STATUS_TO_PROGRESS_MAP[iteration["state"]]
586
+ for iteration in QADataset._check_run_by_id(run_id):
587
+ state = get_state(iteration, ("state",))
588
+ if state in STATUS_TO_PROGRESS_MAP:
589
+ t.set_description(STATUS_TO_TEXT_MAP[state])
590
+ t.n = STATUS_TO_PROGRESS_MAP[state]
600
591
  logger.debug("Setting progress to %s", t.n)
601
592
  t.refresh()
602
- if iteration["state"] in [
593
+ if state in [
603
594
  RunStatus.FAILURE.value,
604
595
  RunStatus.REJECTED.value,
605
596
  RunStatus.REVOKED.value,
606
597
  ]:
607
598
  logger.error(
608
599
  "State is failure, rejected, or revoked: %s",
609
- iteration["state"],
600
+ state,
610
601
  )
611
- OptimizationDataset._handle_failure(iteration)
612
- elif iteration["state"] == RunStatus.SUCCESS.value:
602
+ handle_run_failure(
603
+ iteration, error_cls=HirundoError, run_label="QA"
604
+ )
605
+ elif state == RunStatus.SUCCESS.value:
613
606
  t.close()
614
607
  zip_temporary_url = iteration["result"]
615
- logger.debug("Optimization run completed. Downloading results")
608
+ logger.debug("QA run completed. Downloading results")
616
609
 
617
610
  return download_and_extract_zip(
618
611
  run_id,
619
612
  zip_temporary_url,
620
613
  )
621
614
  elif (
622
- iteration["state"] == RunStatus.AWAITING_MANUAL_APPROVAL.value
615
+ state == RunStatus.AWAITING_MANUAL_APPROVAL.value
623
616
  and stop_on_manual_approval
624
617
  ):
625
618
  t.close()
626
619
  return None
627
- elif iteration["state"] is None:
628
- if (
629
- iteration["result"]
630
- and isinstance(iteration["result"], dict)
631
- and iteration["result"]["result"]
632
- and isinstance(iteration["result"]["result"], str)
633
- ):
634
- result_info = iteration["result"]["result"].split(":")
635
- if len(result_info) > 1:
636
- stage = result_info[0]
637
- current_progress_percentage = float(
638
- result_info[1].removeprefix(" ").removesuffix("% done")
639
- )
640
- elif len(result_info) == 1:
641
- stage = result_info[0]
642
- current_progress_percentage = t.n # Keep the same progress
643
- else:
644
- stage = "Unknown progress state"
645
- current_progress_percentage = t.n # Keep the same progress
646
- desc = (
647
- "Optimization run completed. Uploading results"
648
- if current_progress_percentage == 100.0
649
- else stage
650
- )
651
- t.set_description(desc)
652
- t.n = current_progress_percentage
653
- logger.debug("Setting progress to %s", t.n)
654
- t.refresh()
655
- raise HirundoError(
656
- "Optimization run failed with an unknown error in check_run_by_id"
657
- )
620
+ elif state is None:
621
+ update_progress_from_result(
622
+ iteration,
623
+ t,
624
+ uploading_text="QA run completed. Uploading results",
625
+ log=logger,
626
+ )
627
+ raise HirundoError("QA run failed with an unknown error in check_run_by_id")
658
628
 
659
629
  @overload
660
630
  def check_run(
661
631
  self, stop_on_manual_approval: typing.Literal[True]
662
- ) -> typing.Optional[DatasetOptimizationResults]: ...
632
+ ) -> DatasetQAResults | None: ...
663
633
 
664
634
  @overload
665
635
  def check_run(
666
636
  self, stop_on_manual_approval: typing.Literal[False] = False
667
- ) -> DatasetOptimizationResults: ...
637
+ ) -> DatasetQAResults: ...
668
638
 
669
639
  def check_run(
670
640
  self, stop_on_manual_approval: bool = False
671
- ) -> typing.Optional[DatasetOptimizationResults]:
641
+ ) -> DatasetQAResults | None:
672
642
  """
673
643
  Check the status of the current active instance's run.
674
644
 
675
645
  Returns:
676
- A pandas DataFrame with the results of the optimization run
646
+ A pandas DataFrame with the results of the QA run
677
647
 
678
648
  """
679
649
  if not self.run_id:
@@ -690,7 +660,7 @@ class OptimizationDataset(BaseModel):
690
660
  This generator will produce values to show progress of the run.
691
661
 
692
662
  Args:
693
- run_id: The `run_id` produced by a `run_optimization` call
663
+ run_id: The `run_id` produced by a `run_qa` call
694
664
  retry: A number used to track the number of retries to limit re-checks. *Do not* provide this value manually.
695
665
 
696
666
  Yields:
@@ -700,32 +670,15 @@ class OptimizationDataset(BaseModel):
700
670
 
701
671
  """
702
672
  logger.debug("Checking run with ID: %s", run_id)
703
- if retry > MAX_RETRIES:
704
- raise HirundoError("Max retries reached")
705
- last_event = None
706
- async with httpx.AsyncClient(
707
- timeout=httpx.Timeout(None, connect=5.0)
708
- ) as client:
709
- async_iterator = await aiter_sse_retrying(
710
- client,
711
- "GET",
712
- f"{API_HOST}/dataset-optimization/run/{run_id}",
713
- headers=get_headers(),
714
- )
715
- async for sse in async_iterator:
716
- if sse.event == "ping":
717
- continue
718
- logger.debug(
719
- "[ASYNC] Received event: %s with data: %s and ID: %s and retry: %s",
720
- sse.event,
721
- sse.data,
722
- sse.id,
723
- sse.retry,
724
- )
725
- last_event = json.loads(sse.data)
726
- yield last_event["data"]
727
- if not last_event or last_event["data"]["state"] == RunStatus.PENDING.value:
728
- OptimizationDataset.acheck_run_by_id(run_id, retry + 1)
673
+ async for iteration in aiter_run_events(
674
+ f"{API_HOST}/dataset-qa/run/{run_id}",
675
+ headers=get_headers(),
676
+ retry=retry,
677
+ status_keys=("state",),
678
+ error_cls=HirundoError,
679
+ log=logger,
680
+ ):
681
+ yield iteration
729
682
 
730
683
  async def acheck_run(self) -> AsyncGenerator[dict, None]:
731
684
  """
@@ -735,6 +688,8 @@ class OptimizationDataset(BaseModel):
735
688
 
736
689
  This generator will produce values to show progress of the run.
737
690
 
691
+ Note: This function does not handle errors nor show progress. It is expected that you do that.
692
+
738
693
  Yields:
739
694
  Each event will be a dict, where:
740
695
  - `"state"` is PENDING, STARTED, RETRY, FAILURE or SUCCESS
@@ -749,14 +704,14 @@ class OptimizationDataset(BaseModel):
749
704
  @staticmethod
750
705
  def cancel_by_id(run_id: str) -> None:
751
706
  """
752
- Cancel the dataset optimization run for the given `run_id`.
707
+ Cancel the dataset QA run for the given `run_id`.
753
708
 
754
709
  Args:
755
710
  run_id: The ID of the run to cancel
756
711
  """
757
712
  logger.info("Cancelling run with ID: %s", run_id)
758
713
  response = requests.delete(
759
- f"{API_HOST}/dataset-optimization/run/{run_id}",
714
+ f"{API_HOST}/dataset-qa/run/{run_id}",
760
715
  headers=get_headers(),
761
716
  timeout=MODIFY_TIMEOUT,
762
717
  )
@@ -773,14 +728,14 @@ class OptimizationDataset(BaseModel):
773
728
  @staticmethod
774
729
  def archive_run_by_id(run_id: str) -> None:
775
730
  """
776
- Archive the dataset optimization run for the given `run_id`.
731
+ Archive the dataset QA run for the given `run_id`.
777
732
 
778
733
  Args:
779
734
  run_id: The ID of the run to archive
780
735
  """
781
736
  logger.info("Archiving run with ID: %s", run_id)
782
737
  response = requests.patch(
783
- f"{API_HOST}/dataset-optimization/run/archive/{run_id}",
738
+ f"{API_HOST}/dataset-qa/run/archive/{run_id}",
784
739
  headers=get_headers(),
785
740
  timeout=MODIFY_TIMEOUT,
786
741
  )
@@ -795,7 +750,7 @@ class OptimizationDataset(BaseModel):
795
750
  self.archive_run_by_id(self.run_id)
796
751
 
797
752
 
798
- class DataOptimizationDatasetOut(BaseModel):
753
+ class QADatasetOut(BaseModel):
799
754
  id: int
800
755
 
801
756
  name: str
@@ -805,16 +760,16 @@ class DataOptimizationDatasetOut(BaseModel):
805
760
 
806
761
  data_root_url: HirundoUrl
807
762
 
808
- classes: typing.Optional[list[str]] = None
809
- labeling_info: typing.Union[LabelingInfo, list[LabelingInfo]]
763
+ classes: list[str] | None = None
764
+ labeling_info: LabelingInfo | list[LabelingInfo]
810
765
 
811
- organization_id: typing.Optional[int]
812
- creator_id: typing.Optional[int]
766
+ organization_id: int | None
767
+ creator_id: int | None
813
768
  created_at: datetime.datetime
814
769
  updated_at: datetime.datetime
815
770
 
816
771
 
817
- class DataOptimizationRunOut(BaseModel):
772
+ class DataQARunOut(BaseModel):
818
773
  id: int
819
774
  name: str
820
775
  dataset_id: int
@@ -822,4 +777,6 @@ class DataOptimizationRunOut(BaseModel):
822
777
  status: RunStatus
823
778
  approved: bool
824
779
  created_at: datetime.datetime
825
- run_args: typing.Optional[RunArgs]
780
+ run_args: RunArgs | None
781
+
782
+ deleted_at: datetime.datetime | None = None