hirundo 0.1.18__py3-none-any.whl → 0.2.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hirundo/__init__.py +28 -8
- hirundo/_constraints.py +3 -4
- hirundo/_headers.py +1 -1
- hirundo/_http.py +53 -0
- hirundo/_iter_sse_retrying.py +8 -5
- hirundo/_llm_pipeline.py +153 -0
- hirundo/_run_checking.py +283 -0
- hirundo/_urls.py +1 -0
- hirundo/cli.py +8 -11
- hirundo/dataset_enum.py +2 -0
- hirundo/{dataset_optimization.py → dataset_qa.py} +213 -256
- hirundo/{dataset_optimization_results.py → dataset_qa_results.py} +7 -7
- hirundo/git.py +8 -10
- hirundo/labeling.py +22 -19
- hirundo/storage.py +26 -26
- hirundo/unlearning_llm.py +599 -0
- hirundo/unzip.py +12 -13
- {hirundo-0.1.18.dist-info → hirundo-0.2.3.post1.dist-info}/METADATA +59 -20
- hirundo-0.2.3.post1.dist-info/RECORD +28 -0
- {hirundo-0.1.18.dist-info → hirundo-0.2.3.post1.dist-info}/WHEEL +1 -1
- hirundo-0.1.18.dist-info/RECORD +0 -25
- {hirundo-0.1.18.dist-info → hirundo-0.2.3.post1.dist-info}/entry_points.txt +0 -0
- {hirundo-0.1.18.dist-info → hirundo-0.2.3.post1.dist-info}/licenses/LICENSE +0 -0
- {hirundo-0.1.18.dist-info → hirundo-0.2.3.post1.dist-info}/top_level.txt +0 -0
|
@@ -1,12 +1,9 @@
|
|
|
1
1
|
import datetime
|
|
2
|
-
import json
|
|
3
2
|
import typing
|
|
4
3
|
from collections.abc import AsyncGenerator, Generator
|
|
5
4
|
from enum import Enum
|
|
6
5
|
from typing import overload
|
|
7
6
|
|
|
8
|
-
import httpx
|
|
9
|
-
import requests
|
|
10
7
|
from pydantic import BaseModel, Field, model_validator
|
|
11
8
|
from tqdm import tqdm
|
|
12
9
|
from tqdm.contrib.logging import logging_redirect_tqdm
|
|
@@ -14,12 +11,21 @@ from tqdm.contrib.logging import logging_redirect_tqdm
|
|
|
14
11
|
from hirundo._constraints import validate_labeling_info, validate_url
|
|
15
12
|
from hirundo._env import API_HOST
|
|
16
13
|
from hirundo._headers import get_headers
|
|
17
|
-
from hirundo._http import raise_for_status_with_reason
|
|
18
|
-
from hirundo.
|
|
14
|
+
from hirundo._http import raise_for_status_with_reason, requests
|
|
15
|
+
from hirundo._run_checking import (
|
|
16
|
+
STATUS_TO_PROGRESS_MAP,
|
|
17
|
+
RunStatus,
|
|
18
|
+
aiter_run_events,
|
|
19
|
+
build_status_text_map,
|
|
20
|
+
get_state,
|
|
21
|
+
handle_run_failure,
|
|
22
|
+
iter_run_events,
|
|
23
|
+
update_progress_from_result,
|
|
24
|
+
)
|
|
19
25
|
from hirundo._timeouts import MODIFY_TIMEOUT, READ_TIMEOUT
|
|
20
26
|
from hirundo._urls import HirundoUrl
|
|
21
27
|
from hirundo.dataset_enum import DatasetMetadataType, LabelingType
|
|
22
|
-
from hirundo.
|
|
28
|
+
from hirundo.dataset_qa_results import DatasetQAResults
|
|
23
29
|
from hirundo.labeling import YOLO, LabelingInfo
|
|
24
30
|
from hirundo.logger import get_logger
|
|
25
31
|
from hirundo.storage import ResponseStorageConfig, StorageConfig
|
|
@@ -30,75 +36,63 @@ logger = get_logger(__name__)
|
|
|
30
36
|
|
|
31
37
|
class HirundoError(Exception):
|
|
32
38
|
"""
|
|
33
|
-
Custom exception used to indicate errors in `hirundo` dataset
|
|
39
|
+
Custom exception used to indicate errors in `hirundo` dataset QA runs
|
|
34
40
|
"""
|
|
35
41
|
|
|
36
42
|
pass
|
|
37
43
|
|
|
38
44
|
|
|
39
|
-
|
|
45
|
+
STATUS_TO_TEXT_MAP = build_status_text_map(
|
|
46
|
+
"Dataset QA",
|
|
47
|
+
started_detail="Dataset QA run in progress. Downloading dataset",
|
|
48
|
+
)
|
|
40
49
|
|
|
41
50
|
|
|
42
|
-
class
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
REJECTED = "REJECTED"
|
|
50
|
-
RETRY = "RETRY"
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
STATUS_TO_TEXT_MAP = {
|
|
54
|
-
RunStatus.STARTED.value: "Optimization run in progress. Downloading dataset",
|
|
55
|
-
RunStatus.PENDING.value: "Optimization run queued and not yet started",
|
|
56
|
-
RunStatus.SUCCESS.value: "Optimization run completed successfully",
|
|
57
|
-
RunStatus.FAILURE.value: "Optimization run failed",
|
|
58
|
-
RunStatus.AWAITING_MANUAL_APPROVAL.value: "Awaiting manual approval",
|
|
59
|
-
RunStatus.RETRY.value: "Optimization run failed. Retrying",
|
|
60
|
-
RunStatus.REVOKED.value: "Optimization run was cancelled",
|
|
61
|
-
RunStatus.REJECTED.value: "Optimization run was rejected",
|
|
62
|
-
}
|
|
63
|
-
STATUS_TO_PROGRESS_MAP = {
|
|
64
|
-
RunStatus.STARTED.value: 0.0,
|
|
65
|
-
RunStatus.PENDING.value: 0.0,
|
|
66
|
-
RunStatus.SUCCESS.value: 100.0,
|
|
67
|
-
RunStatus.FAILURE.value: 100.0,
|
|
68
|
-
RunStatus.AWAITING_MANUAL_APPROVAL.value: 100.0,
|
|
69
|
-
RunStatus.RETRY.value: 0.0,
|
|
70
|
-
RunStatus.REVOKED.value: 100.0,
|
|
71
|
-
RunStatus.REJECTED.value: 0.0,
|
|
72
|
-
}
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
class VisionRunArgs(BaseModel):
|
|
76
|
-
upsample: bool = False
|
|
51
|
+
class ClassificationRunArgs(BaseModel):
|
|
52
|
+
image_size: tuple[int, int] | None = (224, 224)
|
|
53
|
+
"""
|
|
54
|
+
Size (width, height) to which to resize classification images.
|
|
55
|
+
It is recommended to keep this value at (224, 224) unless your classes are differentiated by very small differences.
|
|
56
|
+
"""
|
|
57
|
+
upsample: bool | None = False
|
|
77
58
|
"""
|
|
78
59
|
Whether to upsample the dataset to attempt to balance the classes.
|
|
79
60
|
"""
|
|
80
|
-
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class ObjectDetectionRunArgs(ClassificationRunArgs):
|
|
64
|
+
min_abs_bbox_size: int | None = None
|
|
81
65
|
"""
|
|
82
|
-
Minimum valid size (in pixels) of a bounding box to keep it in the dataset for
|
|
66
|
+
Minimum valid size (in pixels) of a bounding box to keep it in the dataset for QA.
|
|
83
67
|
"""
|
|
84
|
-
min_abs_bbox_area: int =
|
|
68
|
+
min_abs_bbox_area: int | None = None
|
|
85
69
|
"""
|
|
86
|
-
Minimum valid absolute area (in pixels²) of a bounding box to keep it in the dataset for
|
|
70
|
+
Minimum valid absolute area (in pixels²) of a bounding box to keep it in the dataset for QA.
|
|
87
71
|
"""
|
|
88
|
-
min_rel_bbox_size: float =
|
|
72
|
+
min_rel_bbox_size: float | None = None
|
|
89
73
|
"""
|
|
90
74
|
Minimum valid size (as a fraction of both image height and width) for a bounding box
|
|
91
|
-
to keep it in the dataset for
|
|
75
|
+
to keep it in the dataset for QA, relative to the corresponding dimension size,
|
|
92
76
|
i.e. if the bounding box is 10% of the image width and 5% of the image height, it will be kept if this value is 0.05, but not if the
|
|
93
77
|
value is 0.06 (since both width and height are checked).
|
|
94
78
|
"""
|
|
95
|
-
min_rel_bbox_area: float =
|
|
79
|
+
min_rel_bbox_area: float | None = None
|
|
80
|
+
"""
|
|
81
|
+
Minimum valid relative area (as a fraction of the image area) of a bounding box to keep it in the dataset for QA.
|
|
82
|
+
"""
|
|
83
|
+
crop_ratio: float | None = None
|
|
96
84
|
"""
|
|
97
|
-
|
|
85
|
+
Ratio of the bounding box to crop.
|
|
86
|
+
Change this value at your own risk. It is recommended to keep it at 1.0 unless you know what you are doing.
|
|
87
|
+
"""
|
|
88
|
+
add_mask_channel: bool | None = None
|
|
89
|
+
"""
|
|
90
|
+
Whether to add a mask channel to the image.
|
|
91
|
+
Change at your own risk. It is recommended to keep it at False unless you know what you are doing.
|
|
98
92
|
"""
|
|
99
93
|
|
|
100
94
|
|
|
101
|
-
RunArgs =
|
|
95
|
+
RunArgs = ClassificationRunArgs | ObjectDetectionRunArgs
|
|
102
96
|
|
|
103
97
|
|
|
104
98
|
class AugmentationName(str, Enum):
|
|
@@ -111,14 +105,32 @@ class AugmentationName(str, Enum):
|
|
|
111
105
|
GAUSSIAN_BLUR = "GaussianBlur"
|
|
112
106
|
|
|
113
107
|
|
|
114
|
-
class
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
108
|
+
class ModalityType(str, Enum):
|
|
109
|
+
RADAR = "RADAR"
|
|
110
|
+
VISION = "VISION"
|
|
111
|
+
SPEECH = "SPEECH"
|
|
112
|
+
TABULAR = "TABULAR"
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
MODALITY_TO_SUPPORTED_LABELING_TYPES = {
|
|
116
|
+
ModalityType.RADAR: [
|
|
117
|
+
LabelingType.SINGLE_LABEL_CLASSIFICATION,
|
|
118
|
+
LabelingType.OBJECT_DETECTION,
|
|
119
|
+
],
|
|
120
|
+
ModalityType.VISION: [
|
|
121
|
+
LabelingType.SINGLE_LABEL_CLASSIFICATION,
|
|
122
|
+
LabelingType.OBJECT_DETECTION,
|
|
123
|
+
LabelingType.OBJECT_SEGMENTATION,
|
|
124
|
+
LabelingType.SEMANTIC_SEGMENTATION,
|
|
125
|
+
LabelingType.PANOPTIC_SEGMENTATION,
|
|
126
|
+
],
|
|
127
|
+
ModalityType.SPEECH: [LabelingType.SPEECH_TO_TEXT],
|
|
128
|
+
ModalityType.TABULAR: [LabelingType.SINGLE_LABEL_CLASSIFICATION],
|
|
129
|
+
}
|
|
118
130
|
|
|
119
131
|
|
|
120
|
-
class
|
|
121
|
-
id:
|
|
132
|
+
class QADataset(BaseModel):
|
|
133
|
+
id: int | None = Field(default=None)
|
|
122
134
|
"""
|
|
123
135
|
The ID of the dataset created on the server.
|
|
124
136
|
"""
|
|
@@ -134,17 +146,15 @@ class OptimizationDataset(BaseModel):
|
|
|
134
146
|
- `LabelingType.OBJECT_DETECTION`: Indicates that the dataset is for object detection tasks
|
|
135
147
|
- `LabelingType.SPEECH_TO_TEXT`: Indicates that the dataset is for speech-to-text tasks
|
|
136
148
|
"""
|
|
137
|
-
language:
|
|
149
|
+
language: str | None = None
|
|
138
150
|
"""
|
|
139
151
|
Language of the Speech-to-Text audio dataset. This is required for Speech-to-Text datasets.
|
|
140
152
|
"""
|
|
141
|
-
storage_config_id:
|
|
153
|
+
storage_config_id: int | None = None
|
|
142
154
|
"""
|
|
143
155
|
The ID of the storage config used to store the dataset and metadata.
|
|
144
156
|
"""
|
|
145
|
-
storage_config:
|
|
146
|
-
typing.Union[StorageConfig, ResponseStorageConfig]
|
|
147
|
-
] = None
|
|
157
|
+
storage_config: StorageConfig | ResponseStorageConfig | None = None
|
|
148
158
|
"""
|
|
149
159
|
The `StorageConfig` instance to link to.
|
|
150
160
|
"""
|
|
@@ -158,34 +168,45 @@ class OptimizationDataset(BaseModel):
|
|
|
158
168
|
Note: All CSV `image_path` entries in the metadata file should be relative to this folder.
|
|
159
169
|
"""
|
|
160
170
|
|
|
161
|
-
classes:
|
|
171
|
+
classes: list[str] | None = None
|
|
162
172
|
"""
|
|
163
173
|
A full list of possible classes used in classification / object detection.
|
|
164
174
|
It is currently required for clarity and performance.
|
|
165
175
|
"""
|
|
166
|
-
labeling_info:
|
|
176
|
+
labeling_info: LabelingInfo | list[LabelingInfo]
|
|
167
177
|
|
|
168
|
-
augmentations:
|
|
178
|
+
augmentations: list[AugmentationName] | None = None
|
|
169
179
|
"""
|
|
170
180
|
Used to define which augmentations are apply to a vision dataset.
|
|
171
181
|
For audio datasets, this field is ignored.
|
|
172
182
|
If no value is provided, all augmentations are applied to vision datasets.
|
|
173
183
|
"""
|
|
174
|
-
modality:
|
|
184
|
+
modality: ModalityType = ModalityType.VISION
|
|
175
185
|
"""
|
|
176
186
|
Used to define the modality of the dataset.
|
|
177
187
|
Defaults to Image.
|
|
178
188
|
"""
|
|
179
189
|
|
|
180
|
-
run_id:
|
|
190
|
+
run_id: str | None = Field(default=None, init=False)
|
|
181
191
|
"""
|
|
182
|
-
The ID of the Dataset
|
|
192
|
+
The ID of the Dataset QA run created on the server.
|
|
183
193
|
"""
|
|
184
194
|
|
|
185
|
-
status:
|
|
195
|
+
status: RunStatus | None = None
|
|
186
196
|
|
|
187
197
|
@model_validator(mode="after")
|
|
188
198
|
def validate_dataset(self):
|
|
199
|
+
if self.modality not in MODALITY_TO_SUPPORTED_LABELING_TYPES:
|
|
200
|
+
raise ValueError(
|
|
201
|
+
f"Modality {self.modality} is not supported. Supported modalities are: {list(MODALITY_TO_SUPPORTED_LABELING_TYPES.keys())}"
|
|
202
|
+
)
|
|
203
|
+
if (
|
|
204
|
+
self.labeling_type
|
|
205
|
+
not in MODALITY_TO_SUPPORTED_LABELING_TYPES[self.modality]
|
|
206
|
+
):
|
|
207
|
+
raise ValueError(
|
|
208
|
+
f"Labeling type {self.labeling_type} is not supported for modality {self.modality}. Supported labeling types are: {MODALITY_TO_SUPPORTED_LABELING_TYPES[self.modality]}"
|
|
209
|
+
)
|
|
189
210
|
if self.storage_config is None and self.storage_config_id is None:
|
|
190
211
|
raise ValueError(
|
|
191
212
|
"No dataset storage has been provided. Provide one via `storage_config` or `storage_config_id`"
|
|
@@ -229,52 +250,52 @@ class OptimizationDataset(BaseModel):
|
|
|
229
250
|
return self
|
|
230
251
|
|
|
231
252
|
@staticmethod
|
|
232
|
-
def get_by_id(dataset_id: int) -> "
|
|
253
|
+
def get_by_id(dataset_id: int) -> "QADataset":
|
|
233
254
|
"""
|
|
234
|
-
Get a `
|
|
255
|
+
Get a `QADataset` instance from the server by its ID
|
|
235
256
|
|
|
236
257
|
Args:
|
|
237
|
-
dataset_id: The ID of the `
|
|
258
|
+
dataset_id: The ID of the `QADataset` instance to get
|
|
238
259
|
"""
|
|
239
260
|
response = requests.get(
|
|
240
|
-
f"{API_HOST}/dataset-
|
|
261
|
+
f"{API_HOST}/dataset-qa/dataset/{dataset_id}",
|
|
241
262
|
headers=get_headers(),
|
|
242
263
|
timeout=READ_TIMEOUT,
|
|
243
264
|
)
|
|
244
265
|
raise_for_status_with_reason(response)
|
|
245
266
|
dataset = response.json()
|
|
246
|
-
return
|
|
267
|
+
return QADataset(**dataset)
|
|
247
268
|
|
|
248
269
|
@staticmethod
|
|
249
|
-
def get_by_name(name: str) -> "
|
|
270
|
+
def get_by_name(name: str) -> "QADataset":
|
|
250
271
|
"""
|
|
251
|
-
Get a `
|
|
272
|
+
Get a `QADataset` instance from the server by its name
|
|
252
273
|
|
|
253
274
|
Args:
|
|
254
|
-
name: The name of the `
|
|
275
|
+
name: The name of the `QADataset` instance to get
|
|
255
276
|
"""
|
|
256
277
|
response = requests.get(
|
|
257
|
-
f"{API_HOST}/dataset-
|
|
278
|
+
f"{API_HOST}/dataset-qa/dataset/by-name/{name}",
|
|
258
279
|
headers=get_headers(),
|
|
259
280
|
timeout=READ_TIMEOUT,
|
|
260
281
|
)
|
|
261
282
|
raise_for_status_with_reason(response)
|
|
262
283
|
dataset = response.json()
|
|
263
|
-
return
|
|
284
|
+
return QADataset(**dataset)
|
|
264
285
|
|
|
265
286
|
@staticmethod
|
|
266
287
|
def list_datasets(
|
|
267
|
-
organization_id:
|
|
268
|
-
) -> list["
|
|
288
|
+
organization_id: int | None = None,
|
|
289
|
+
) -> list["QADatasetOut"]:
|
|
269
290
|
"""
|
|
270
|
-
Lists all the
|
|
291
|
+
Lists all the datasets created by user's default organization
|
|
271
292
|
or the `organization_id` passed
|
|
272
293
|
|
|
273
294
|
Args:
|
|
274
295
|
organization_id: The ID of the organization to list the datasets for.
|
|
275
296
|
"""
|
|
276
297
|
response = requests.get(
|
|
277
|
-
f"{API_HOST}/dataset-
|
|
298
|
+
f"{API_HOST}/dataset-qa/dataset/",
|
|
278
299
|
params={"dataset_organization_id": organization_id},
|
|
279
300
|
headers=get_headers(),
|
|
280
301
|
timeout=READ_TIMEOUT,
|
|
@@ -282,7 +303,7 @@ class OptimizationDataset(BaseModel):
|
|
|
282
303
|
raise_for_status_with_reason(response)
|
|
283
304
|
datasets = response.json()
|
|
284
305
|
return [
|
|
285
|
-
|
|
306
|
+
QADatasetOut(
|
|
286
307
|
**ds,
|
|
287
308
|
)
|
|
288
309
|
for ds in datasets
|
|
@@ -290,26 +311,28 @@ class OptimizationDataset(BaseModel):
|
|
|
290
311
|
|
|
291
312
|
@staticmethod
|
|
292
313
|
def list_runs(
|
|
293
|
-
organization_id:
|
|
294
|
-
|
|
314
|
+
organization_id: int | None = None,
|
|
315
|
+
archived: bool | None = False,
|
|
316
|
+
) -> list["DataQARunOut"]:
|
|
295
317
|
"""
|
|
296
|
-
Lists all the `
|
|
318
|
+
Lists all the `QADataset` instances created by user's default organization
|
|
297
319
|
or the `organization_id` passed
|
|
298
|
-
Note: The return type is `list[dict]` and not `list[
|
|
320
|
+
Note: The return type is `list[dict]` and not `list[QADataset]`
|
|
299
321
|
|
|
300
322
|
Args:
|
|
301
323
|
organization_id: The ID of the organization to list the datasets for.
|
|
324
|
+
archived: Whether to list archived runs.
|
|
302
325
|
"""
|
|
303
326
|
response = requests.get(
|
|
304
|
-
f"{API_HOST}/dataset-
|
|
305
|
-
params={"dataset_organization_id": organization_id},
|
|
327
|
+
f"{API_HOST}/dataset-qa/run/list",
|
|
328
|
+
params={"dataset_organization_id": organization_id, "archived": archived},
|
|
306
329
|
headers=get_headers(),
|
|
307
330
|
timeout=READ_TIMEOUT,
|
|
308
331
|
)
|
|
309
332
|
raise_for_status_with_reason(response)
|
|
310
333
|
runs = response.json()
|
|
311
334
|
return [
|
|
312
|
-
|
|
335
|
+
DataQARunOut(
|
|
313
336
|
**run,
|
|
314
337
|
)
|
|
315
338
|
for run in runs
|
|
@@ -318,13 +341,13 @@ class OptimizationDataset(BaseModel):
|
|
|
318
341
|
@staticmethod
|
|
319
342
|
def delete_by_id(dataset_id: int) -> None:
|
|
320
343
|
"""
|
|
321
|
-
Deletes a `
|
|
344
|
+
Deletes a `QADataset` instance from the server by its ID
|
|
322
345
|
|
|
323
346
|
Args:
|
|
324
|
-
dataset_id: The ID of the `
|
|
347
|
+
dataset_id: The ID of the `QADataset` instance to delete
|
|
325
348
|
"""
|
|
326
349
|
response = requests.delete(
|
|
327
|
-
f"{API_HOST}/dataset-
|
|
350
|
+
f"{API_HOST}/dataset-qa/dataset/{dataset_id}",
|
|
328
351
|
headers=get_headers(),
|
|
329
352
|
timeout=MODIFY_TIMEOUT,
|
|
330
353
|
)
|
|
@@ -333,14 +356,14 @@ class OptimizationDataset(BaseModel):
|
|
|
333
356
|
|
|
334
357
|
def delete(self, storage_config=True) -> None:
|
|
335
358
|
"""
|
|
336
|
-
Deletes the active `
|
|
337
|
-
It can only be used on a `
|
|
359
|
+
Deletes the active `QADataset` instance from the server.
|
|
360
|
+
It can only be used on a `QADataset` instance that has been created.
|
|
338
361
|
|
|
339
362
|
Args:
|
|
340
|
-
storage_config: If True, the `
|
|
363
|
+
storage_config: If True, the `QADataset`'s `StorageConfig` will also be deleted
|
|
341
364
|
|
|
342
365
|
Note: If `storage_config` is not set to `False` then the `storage_config_id` must be set
|
|
343
|
-
This can either be set manually or by creating the `StorageConfig` instance via the `
|
|
366
|
+
This can either be set manually or by creating the `StorageConfig` instance via the `QADataset`'s
|
|
344
367
|
`create` method
|
|
345
368
|
"""
|
|
346
369
|
if storage_config:
|
|
@@ -353,11 +376,11 @@ class OptimizationDataset(BaseModel):
|
|
|
353
376
|
|
|
354
377
|
def create(
|
|
355
378
|
self,
|
|
356
|
-
organization_id:
|
|
379
|
+
organization_id: int | None = None,
|
|
357
380
|
replace_if_exists: bool = False,
|
|
358
381
|
) -> int:
|
|
359
382
|
"""
|
|
360
|
-
Create a `
|
|
383
|
+
Create a `QADataset` instance on the server.
|
|
361
384
|
If the `storage_config_id` field is not set, the storage config will also be created and the field will be set.
|
|
362
385
|
|
|
363
386
|
Args:
|
|
@@ -366,7 +389,7 @@ class OptimizationDataset(BaseModel):
|
|
|
366
389
|
(this is determined by a dataset of the same name in the same organization).
|
|
367
390
|
|
|
368
391
|
Returns:
|
|
369
|
-
The ID of the created `
|
|
392
|
+
The ID of the created `QADataset` instance
|
|
370
393
|
"""
|
|
371
394
|
if self.storage_config is None and self.storage_config_id is None:
|
|
372
395
|
raise ValueError("No dataset storage has been provided")
|
|
@@ -391,7 +414,7 @@ class OptimizationDataset(BaseModel):
|
|
|
391
414
|
model_dict = self.model_dump(mode="json")
|
|
392
415
|
# ⬆️ Get dict of model fields from Pydantic model instance
|
|
393
416
|
dataset_response = requests.post(
|
|
394
|
-
f"{API_HOST}/dataset-
|
|
417
|
+
f"{API_HOST}/dataset-qa/dataset/",
|
|
395
418
|
json={
|
|
396
419
|
**{k: model_dict[k] for k in model_dict.keys() - {"storage_config"}},
|
|
397
420
|
"organization_id": organization_id,
|
|
@@ -408,17 +431,17 @@ class OptimizationDataset(BaseModel):
|
|
|
408
431
|
return self.id
|
|
409
432
|
|
|
410
433
|
@staticmethod
|
|
411
|
-
def
|
|
434
|
+
def launch_qa_run(
|
|
412
435
|
dataset_id: int,
|
|
413
|
-
organization_id:
|
|
414
|
-
run_args:
|
|
436
|
+
organization_id: int | None = None,
|
|
437
|
+
run_args: RunArgs | None = None,
|
|
415
438
|
) -> str:
|
|
416
439
|
"""
|
|
417
|
-
Run the dataset
|
|
440
|
+
Run the dataset QA process on the server using the dataset with the given ID
|
|
418
441
|
i.e. `dataset_id`.
|
|
419
442
|
|
|
420
443
|
Args:
|
|
421
|
-
dataset_id: The ID of the dataset to run
|
|
444
|
+
dataset_id: The ID of the dataset to run QA on.
|
|
422
445
|
|
|
423
446
|
Returns:
|
|
424
447
|
ID of the run (`run_id`).
|
|
@@ -429,7 +452,7 @@ class OptimizationDataset(BaseModel):
|
|
|
429
452
|
if run_args:
|
|
430
453
|
run_info["run_args"] = run_args.model_dump(mode="json")
|
|
431
454
|
run_response = requests.post(
|
|
432
|
-
f"{API_HOST}/dataset-
|
|
455
|
+
f"{API_HOST}/dataset-qa/run/{dataset_id}",
|
|
433
456
|
json=run_info if len(run_info) > 0 else None,
|
|
434
457
|
headers=get_headers(),
|
|
435
458
|
timeout=MODIFY_TIMEOUT,
|
|
@@ -440,12 +463,16 @@ class OptimizationDataset(BaseModel):
|
|
|
440
463
|
def _validate_run_args(self, run_args: RunArgs) -> None:
|
|
441
464
|
if self.labeling_type == LabelingType.SPEECH_TO_TEXT:
|
|
442
465
|
raise Exception("Speech to text cannot have `run_args` set")
|
|
443
|
-
if
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
466
|
+
if (
|
|
467
|
+
self.labeling_type != LabelingType.OBJECT_DETECTION
|
|
468
|
+
and isinstance(run_args, ObjectDetectionRunArgs)
|
|
469
|
+
and any(
|
|
470
|
+
(
|
|
471
|
+
run_args.min_abs_bbox_size != 0,
|
|
472
|
+
run_args.min_abs_bbox_area != 0,
|
|
473
|
+
run_args.min_rel_bbox_size != 0,
|
|
474
|
+
run_args.min_rel_bbox_area != 0,
|
|
475
|
+
)
|
|
449
476
|
)
|
|
450
477
|
):
|
|
451
478
|
raise Exception(
|
|
@@ -454,21 +481,21 @@ class OptimizationDataset(BaseModel):
|
|
|
454
481
|
+ f"labeling type {self.labeling_type}"
|
|
455
482
|
)
|
|
456
483
|
|
|
457
|
-
def
|
|
484
|
+
def run_qa(
|
|
458
485
|
self,
|
|
459
|
-
organization_id:
|
|
486
|
+
organization_id: int | None = None,
|
|
460
487
|
replace_dataset_if_exists: bool = False,
|
|
461
|
-
run_args:
|
|
488
|
+
run_args: RunArgs | None = None,
|
|
462
489
|
) -> str:
|
|
463
490
|
"""
|
|
464
491
|
If the dataset was not created on the server yet, it is created.
|
|
465
|
-
Run the dataset
|
|
492
|
+
Run the dataset QA process on the server using the active `QADataset` instance
|
|
466
493
|
|
|
467
494
|
Args:
|
|
468
|
-
organization_id: The ID of the organization to run the
|
|
495
|
+
organization_id: The ID of the organization to run the QA for.
|
|
469
496
|
replace_dataset_if_exists: If True, the dataset will be replaced if it already exists
|
|
470
497
|
(this is determined by a dataset of the same name in the same organization).
|
|
471
|
-
run_args: The run arguments to use for the
|
|
498
|
+
run_args: The run arguments to use for the QA run
|
|
472
499
|
|
|
473
500
|
Returns:
|
|
474
501
|
An ID of the run (`run_id`) and stores that `run_id` on the instance
|
|
@@ -478,7 +505,7 @@ class OptimizationDataset(BaseModel):
|
|
|
478
505
|
self.id = self.create(replace_if_exists=replace_dataset_if_exists)
|
|
479
506
|
if run_args is not None:
|
|
480
507
|
self._validate_run_args(run_args)
|
|
481
|
-
run_id = self.
|
|
508
|
+
run_id = self.launch_qa_run(self.id, organization_id, run_args)
|
|
482
509
|
self.run_id = run_id
|
|
483
510
|
logger.info("Started the run with ID: %s", run_id)
|
|
484
511
|
return run_id
|
|
@@ -509,83 +536,46 @@ class OptimizationDataset(BaseModel):
|
|
|
509
536
|
|
|
510
537
|
@staticmethod
|
|
511
538
|
def _check_run_by_id(run_id: str, retry=0) -> Generator[dict, None, None]:
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
headers=get_headers(),
|
|
521
|
-
):
|
|
522
|
-
if sse.event == "ping":
|
|
523
|
-
continue
|
|
524
|
-
logger.debug(
|
|
525
|
-
"[SYNC] received event: %s with data: %s and ID: %s and retry: %s",
|
|
526
|
-
sse.event,
|
|
527
|
-
sse.data,
|
|
528
|
-
sse.id,
|
|
529
|
-
sse.retry,
|
|
530
|
-
)
|
|
531
|
-
last_event = json.loads(sse.data)
|
|
532
|
-
if not last_event:
|
|
533
|
-
continue
|
|
534
|
-
if "data" in last_event:
|
|
535
|
-
data = last_event["data"]
|
|
536
|
-
else:
|
|
537
|
-
if "detail" in last_event:
|
|
538
|
-
raise HirundoError(last_event["detail"])
|
|
539
|
-
elif "reason" in last_event:
|
|
540
|
-
raise HirundoError(last_event["reason"])
|
|
541
|
-
else:
|
|
542
|
-
raise HirundoError("Unknown error")
|
|
543
|
-
yield data
|
|
544
|
-
if not last_event or last_event["data"]["state"] == RunStatus.PENDING.value:
|
|
545
|
-
OptimizationDataset._check_run_by_id(run_id, retry + 1)
|
|
546
|
-
|
|
547
|
-
@staticmethod
|
|
548
|
-
def _handle_failure(iteration: dict):
|
|
549
|
-
if iteration["result"]:
|
|
550
|
-
raise HirundoError(
|
|
551
|
-
f"Optimization run failed with error: {iteration['result']}"
|
|
552
|
-
)
|
|
553
|
-
else:
|
|
554
|
-
raise HirundoError(
|
|
555
|
-
"Optimization run failed with an unknown error in _handle_failure"
|
|
556
|
-
)
|
|
539
|
+
yield from iter_run_events(
|
|
540
|
+
f"{API_HOST}/dataset-qa/run/{run_id}",
|
|
541
|
+
headers=get_headers(),
|
|
542
|
+
retry=retry,
|
|
543
|
+
status_keys=("state",),
|
|
544
|
+
error_cls=HirundoError,
|
|
545
|
+
log=logger,
|
|
546
|
+
)
|
|
557
547
|
|
|
558
548
|
@staticmethod
|
|
559
549
|
@overload
|
|
560
550
|
def check_run_by_id(
|
|
561
551
|
run_id: str, stop_on_manual_approval: typing.Literal[True]
|
|
562
|
-
) ->
|
|
552
|
+
) -> DatasetQAResults | None: ...
|
|
563
553
|
|
|
564
554
|
@staticmethod
|
|
565
555
|
@overload
|
|
566
556
|
def check_run_by_id(
|
|
567
557
|
run_id: str, stop_on_manual_approval: typing.Literal[False] = False
|
|
568
|
-
) ->
|
|
558
|
+
) -> DatasetQAResults: ...
|
|
569
559
|
|
|
570
560
|
@staticmethod
|
|
571
561
|
@overload
|
|
572
562
|
def check_run_by_id(
|
|
573
563
|
run_id: str, stop_on_manual_approval: bool
|
|
574
|
-
) ->
|
|
564
|
+
) -> DatasetQAResults | None: ...
|
|
575
565
|
|
|
576
566
|
@staticmethod
|
|
577
567
|
def check_run_by_id(
|
|
578
568
|
run_id: str, stop_on_manual_approval: bool = False
|
|
579
|
-
) ->
|
|
569
|
+
) -> DatasetQAResults | None:
|
|
580
570
|
"""
|
|
581
571
|
Check the status of a run given its ID
|
|
582
572
|
|
|
583
573
|
Args:
|
|
584
|
-
run_id: The `run_id` produced by a `
|
|
574
|
+
run_id: The `run_id` produced by a `run_qa` call
|
|
585
575
|
stop_on_manual_approval: If True, the function will return `None` if the run is awaiting manual approval
|
|
586
576
|
|
|
587
577
|
Returns:
|
|
588
|
-
A
|
|
578
|
+
A DatasetQAResults object with the results of the QA run
|
|
589
579
|
|
|
590
580
|
Raises:
|
|
591
581
|
HirundoError: If the maximum number of retries is reached or if the run fails
|
|
@@ -593,87 +583,67 @@ class OptimizationDataset(BaseModel):
|
|
|
593
583
|
logger.debug("Checking run with ID: %s", run_id)
|
|
594
584
|
with logging_redirect_tqdm():
|
|
595
585
|
t = tqdm(total=100.0)
|
|
596
|
-
for iteration in
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
t.
|
|
586
|
+
for iteration in QADataset._check_run_by_id(run_id):
|
|
587
|
+
state = get_state(iteration, ("state",))
|
|
588
|
+
if state in STATUS_TO_PROGRESS_MAP:
|
|
589
|
+
t.set_description(STATUS_TO_TEXT_MAP[state])
|
|
590
|
+
t.n = STATUS_TO_PROGRESS_MAP[state]
|
|
600
591
|
logger.debug("Setting progress to %s", t.n)
|
|
601
592
|
t.refresh()
|
|
602
|
-
if
|
|
593
|
+
if state in [
|
|
603
594
|
RunStatus.FAILURE.value,
|
|
604
595
|
RunStatus.REJECTED.value,
|
|
605
596
|
RunStatus.REVOKED.value,
|
|
606
597
|
]:
|
|
607
598
|
logger.error(
|
|
608
599
|
"State is failure, rejected, or revoked: %s",
|
|
609
|
-
|
|
600
|
+
state,
|
|
610
601
|
)
|
|
611
|
-
|
|
612
|
-
|
|
602
|
+
handle_run_failure(
|
|
603
|
+
iteration, error_cls=HirundoError, run_label="QA"
|
|
604
|
+
)
|
|
605
|
+
elif state == RunStatus.SUCCESS.value:
|
|
613
606
|
t.close()
|
|
614
607
|
zip_temporary_url = iteration["result"]
|
|
615
|
-
logger.debug("
|
|
608
|
+
logger.debug("QA run completed. Downloading results")
|
|
616
609
|
|
|
617
610
|
return download_and_extract_zip(
|
|
618
611
|
run_id,
|
|
619
612
|
zip_temporary_url,
|
|
620
613
|
)
|
|
621
614
|
elif (
|
|
622
|
-
|
|
615
|
+
state == RunStatus.AWAITING_MANUAL_APPROVAL.value
|
|
623
616
|
and stop_on_manual_approval
|
|
624
617
|
):
|
|
625
618
|
t.close()
|
|
626
619
|
return None
|
|
627
|
-
elif
|
|
628
|
-
|
|
629
|
-
iteration
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
)
|
|
634
|
-
|
|
635
|
-
if len(result_info) > 1:
|
|
636
|
-
stage = result_info[0]
|
|
637
|
-
current_progress_percentage = float(
|
|
638
|
-
result_info[1].removeprefix(" ").removesuffix("% done")
|
|
639
|
-
)
|
|
640
|
-
elif len(result_info) == 1:
|
|
641
|
-
stage = result_info[0]
|
|
642
|
-
current_progress_percentage = t.n # Keep the same progress
|
|
643
|
-
else:
|
|
644
|
-
stage = "Unknown progress state"
|
|
645
|
-
current_progress_percentage = t.n # Keep the same progress
|
|
646
|
-
desc = (
|
|
647
|
-
"Optimization run completed. Uploading results"
|
|
648
|
-
if current_progress_percentage == 100.0
|
|
649
|
-
else stage
|
|
650
|
-
)
|
|
651
|
-
t.set_description(desc)
|
|
652
|
-
t.n = current_progress_percentage
|
|
653
|
-
logger.debug("Setting progress to %s", t.n)
|
|
654
|
-
t.refresh()
|
|
655
|
-
raise HirundoError(
|
|
656
|
-
"Optimization run failed with an unknown error in check_run_by_id"
|
|
657
|
-
)
|
|
620
|
+
elif state is None:
|
|
621
|
+
update_progress_from_result(
|
|
622
|
+
iteration,
|
|
623
|
+
t,
|
|
624
|
+
uploading_text="QA run completed. Uploading results",
|
|
625
|
+
log=logger,
|
|
626
|
+
)
|
|
627
|
+
raise HirundoError("QA run failed with an unknown error in check_run_by_id")
|
|
658
628
|
|
|
659
629
|
@overload
|
|
660
630
|
def check_run(
|
|
661
631
|
self, stop_on_manual_approval: typing.Literal[True]
|
|
662
|
-
) ->
|
|
632
|
+
) -> DatasetQAResults | None: ...
|
|
663
633
|
|
|
664
634
|
@overload
|
|
665
635
|
def check_run(
|
|
666
636
|
self, stop_on_manual_approval: typing.Literal[False] = False
|
|
667
|
-
) ->
|
|
637
|
+
) -> DatasetQAResults: ...
|
|
668
638
|
|
|
669
639
|
def check_run(
|
|
670
640
|
self, stop_on_manual_approval: bool = False
|
|
671
|
-
) ->
|
|
641
|
+
) -> DatasetQAResults | None:
|
|
672
642
|
"""
|
|
673
643
|
Check the status of the current active instance's run.
|
|
674
644
|
|
|
675
645
|
Returns:
|
|
676
|
-
A pandas DataFrame with the results of the
|
|
646
|
+
A pandas DataFrame with the results of the QA run
|
|
677
647
|
|
|
678
648
|
"""
|
|
679
649
|
if not self.run_id:
|
|
@@ -690,7 +660,7 @@ class OptimizationDataset(BaseModel):
|
|
|
690
660
|
This generator will produce values to show progress of the run.
|
|
691
661
|
|
|
692
662
|
Args:
|
|
693
|
-
run_id: The `run_id` produced by a `
|
|
663
|
+
run_id: The `run_id` produced by a `run_qa` call
|
|
694
664
|
retry: A number used to track the number of retries to limit re-checks. *Do not* provide this value manually.
|
|
695
665
|
|
|
696
666
|
Yields:
|
|
@@ -700,32 +670,15 @@ class OptimizationDataset(BaseModel):
|
|
|
700
670
|
|
|
701
671
|
"""
|
|
702
672
|
logger.debug("Checking run with ID: %s", run_id)
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
f"{API_HOST}/dataset-optimization/run/{run_id}",
|
|
713
|
-
headers=get_headers(),
|
|
714
|
-
)
|
|
715
|
-
async for sse in async_iterator:
|
|
716
|
-
if sse.event == "ping":
|
|
717
|
-
continue
|
|
718
|
-
logger.debug(
|
|
719
|
-
"[ASYNC] Received event: %s with data: %s and ID: %s and retry: %s",
|
|
720
|
-
sse.event,
|
|
721
|
-
sse.data,
|
|
722
|
-
sse.id,
|
|
723
|
-
sse.retry,
|
|
724
|
-
)
|
|
725
|
-
last_event = json.loads(sse.data)
|
|
726
|
-
yield last_event["data"]
|
|
727
|
-
if not last_event or last_event["data"]["state"] == RunStatus.PENDING.value:
|
|
728
|
-
OptimizationDataset.acheck_run_by_id(run_id, retry + 1)
|
|
673
|
+
async for iteration in aiter_run_events(
|
|
674
|
+
f"{API_HOST}/dataset-qa/run/{run_id}",
|
|
675
|
+
headers=get_headers(),
|
|
676
|
+
retry=retry,
|
|
677
|
+
status_keys=("state",),
|
|
678
|
+
error_cls=HirundoError,
|
|
679
|
+
log=logger,
|
|
680
|
+
):
|
|
681
|
+
yield iteration
|
|
729
682
|
|
|
730
683
|
async def acheck_run(self) -> AsyncGenerator[dict, None]:
|
|
731
684
|
"""
|
|
@@ -735,6 +688,8 @@ class OptimizationDataset(BaseModel):
|
|
|
735
688
|
|
|
736
689
|
This generator will produce values to show progress of the run.
|
|
737
690
|
|
|
691
|
+
Note: This function does not handle errors nor show progress. It is expected that you do that.
|
|
692
|
+
|
|
738
693
|
Yields:
|
|
739
694
|
Each event will be a dict, where:
|
|
740
695
|
- `"state"` is PENDING, STARTED, RETRY, FAILURE or SUCCESS
|
|
@@ -749,14 +704,14 @@ class OptimizationDataset(BaseModel):
|
|
|
749
704
|
@staticmethod
|
|
750
705
|
def cancel_by_id(run_id: str) -> None:
|
|
751
706
|
"""
|
|
752
|
-
Cancel the dataset
|
|
707
|
+
Cancel the dataset QA run for the given `run_id`.
|
|
753
708
|
|
|
754
709
|
Args:
|
|
755
710
|
run_id: The ID of the run to cancel
|
|
756
711
|
"""
|
|
757
712
|
logger.info("Cancelling run with ID: %s", run_id)
|
|
758
713
|
response = requests.delete(
|
|
759
|
-
f"{API_HOST}/dataset-
|
|
714
|
+
f"{API_HOST}/dataset-qa/run/{run_id}",
|
|
760
715
|
headers=get_headers(),
|
|
761
716
|
timeout=MODIFY_TIMEOUT,
|
|
762
717
|
)
|
|
@@ -773,14 +728,14 @@ class OptimizationDataset(BaseModel):
|
|
|
773
728
|
@staticmethod
|
|
774
729
|
def archive_run_by_id(run_id: str) -> None:
|
|
775
730
|
"""
|
|
776
|
-
Archive the dataset
|
|
731
|
+
Archive the dataset QA run for the given `run_id`.
|
|
777
732
|
|
|
778
733
|
Args:
|
|
779
734
|
run_id: The ID of the run to archive
|
|
780
735
|
"""
|
|
781
736
|
logger.info("Archiving run with ID: %s", run_id)
|
|
782
737
|
response = requests.patch(
|
|
783
|
-
f"{API_HOST}/dataset-
|
|
738
|
+
f"{API_HOST}/dataset-qa/run/archive/{run_id}",
|
|
784
739
|
headers=get_headers(),
|
|
785
740
|
timeout=MODIFY_TIMEOUT,
|
|
786
741
|
)
|
|
@@ -795,7 +750,7 @@ class OptimizationDataset(BaseModel):
|
|
|
795
750
|
self.archive_run_by_id(self.run_id)
|
|
796
751
|
|
|
797
752
|
|
|
798
|
-
class
|
|
753
|
+
class QADatasetOut(BaseModel):
|
|
799
754
|
id: int
|
|
800
755
|
|
|
801
756
|
name: str
|
|
@@ -805,16 +760,16 @@ class DataOptimizationDatasetOut(BaseModel):
|
|
|
805
760
|
|
|
806
761
|
data_root_url: HirundoUrl
|
|
807
762
|
|
|
808
|
-
classes:
|
|
809
|
-
labeling_info:
|
|
763
|
+
classes: list[str] | None = None
|
|
764
|
+
labeling_info: LabelingInfo | list[LabelingInfo]
|
|
810
765
|
|
|
811
|
-
organization_id:
|
|
812
|
-
creator_id:
|
|
766
|
+
organization_id: int | None
|
|
767
|
+
creator_id: int | None
|
|
813
768
|
created_at: datetime.datetime
|
|
814
769
|
updated_at: datetime.datetime
|
|
815
770
|
|
|
816
771
|
|
|
817
|
-
class
|
|
772
|
+
class DataQARunOut(BaseModel):
|
|
818
773
|
id: int
|
|
819
774
|
name: str
|
|
820
775
|
dataset_id: int
|
|
@@ -822,4 +777,6 @@ class DataOptimizationRunOut(BaseModel):
|
|
|
822
777
|
status: RunStatus
|
|
823
778
|
approved: bool
|
|
824
779
|
created_at: datetime.datetime
|
|
825
|
-
run_args:
|
|
780
|
+
run_args: RunArgs | None
|
|
781
|
+
|
|
782
|
+
deleted_at: datetime.datetime | None = None
|