hirundo 0.1.16__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hirundo/__init__.py +30 -14
- hirundo/_constraints.py +164 -53
- hirundo/_headers.py +1 -1
- hirundo/_http.py +53 -0
- hirundo/_iter_sse_retrying.py +1 -1
- hirundo/_urls.py +59 -0
- hirundo/cli.py +7 -7
- hirundo/dataset_enum.py +23 -0
- hirundo/{dataset_optimization.py → dataset_qa.py} +195 -168
- hirundo/{dataset_optimization_results.py → dataset_qa_results.py} +4 -4
- hirundo/git.py +2 -3
- hirundo/labeling.py +140 -0
- hirundo/storage.py +43 -60
- hirundo/unzip.py +9 -10
- {hirundo-0.1.16.dist-info → hirundo-0.1.21.dist-info}/METADATA +67 -53
- hirundo-0.1.21.dist-info/RECORD +25 -0
- {hirundo-0.1.16.dist-info → hirundo-0.1.21.dist-info}/WHEEL +1 -1
- hirundo-0.1.16.dist-info/RECORD +0 -23
- {hirundo-0.1.16.dist-info → hirundo-0.1.21.dist-info}/entry_points.txt +0 -0
- {hirundo-0.1.16.dist-info → hirundo-0.1.21.dist-info}/licenses/LICENSE +0 -0
- {hirundo-0.1.16.dist-info → hirundo-0.1.21.dist-info}/top_level.txt +0 -0
|
@@ -1,25 +1,25 @@
|
|
|
1
1
|
import datetime
|
|
2
2
|
import json
|
|
3
3
|
import typing
|
|
4
|
-
from abc import ABC, abstractmethod
|
|
5
4
|
from collections.abc import AsyncGenerator, Generator
|
|
6
5
|
from enum import Enum
|
|
7
6
|
from typing import overload
|
|
8
7
|
|
|
9
8
|
import httpx
|
|
10
|
-
import requests
|
|
11
9
|
from pydantic import BaseModel, Field, model_validator
|
|
12
10
|
from tqdm import tqdm
|
|
13
11
|
from tqdm.contrib.logging import logging_redirect_tqdm
|
|
14
12
|
|
|
15
|
-
from hirundo._constraints import
|
|
13
|
+
from hirundo._constraints import validate_labeling_info, validate_url
|
|
16
14
|
from hirundo._env import API_HOST
|
|
17
15
|
from hirundo._headers import get_headers
|
|
18
|
-
from hirundo._http import raise_for_status_with_reason
|
|
16
|
+
from hirundo._http import raise_for_status_with_reason, requests
|
|
19
17
|
from hirundo._iter_sse_retrying import aiter_sse_retrying, iter_sse_retrying
|
|
20
18
|
from hirundo._timeouts import MODIFY_TIMEOUT, READ_TIMEOUT
|
|
19
|
+
from hirundo._urls import HirundoUrl
|
|
21
20
|
from hirundo.dataset_enum import DatasetMetadataType, LabelingType
|
|
22
|
-
from hirundo.
|
|
21
|
+
from hirundo.dataset_qa_results import DatasetQAResults
|
|
22
|
+
from hirundo.labeling import YOLO, LabelingInfo
|
|
23
23
|
from hirundo.logger import get_logger
|
|
24
24
|
from hirundo.storage import ResponseStorageConfig, StorageConfig
|
|
25
25
|
from hirundo.unzip import download_and_extract_zip
|
|
@@ -29,7 +29,7 @@ logger = get_logger(__name__)
|
|
|
29
29
|
|
|
30
30
|
class HirundoError(Exception):
|
|
31
31
|
"""
|
|
32
|
-
Custom exception used to indicate errors in `hirundo` dataset
|
|
32
|
+
Custom exception used to indicate errors in `hirundo` dataset QA runs
|
|
33
33
|
"""
|
|
34
34
|
|
|
35
35
|
pass
|
|
@@ -50,14 +50,14 @@ class RunStatus(Enum):
|
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
STATUS_TO_TEXT_MAP = {
|
|
53
|
-
RunStatus.STARTED.value: "
|
|
54
|
-
RunStatus.PENDING.value: "
|
|
55
|
-
RunStatus.SUCCESS.value: "
|
|
56
|
-
RunStatus.FAILURE.value: "
|
|
53
|
+
RunStatus.STARTED.value: "Dataset QA run in progress. Downloading dataset",
|
|
54
|
+
RunStatus.PENDING.value: "Dataset QA run queued and not yet started",
|
|
55
|
+
RunStatus.SUCCESS.value: "Dataset QA run completed successfully",
|
|
56
|
+
RunStatus.FAILURE.value: "Dataset QA run failed",
|
|
57
57
|
RunStatus.AWAITING_MANUAL_APPROVAL.value: "Awaiting manual approval",
|
|
58
|
-
RunStatus.RETRY.value: "
|
|
59
|
-
RunStatus.REVOKED.value: "
|
|
60
|
-
RunStatus.REJECTED.value: "
|
|
58
|
+
RunStatus.RETRY.value: "Dataset QA run failed. Retrying",
|
|
59
|
+
RunStatus.REVOKED.value: "Dataset QA run was cancelled",
|
|
60
|
+
RunStatus.REJECTED.value: "Dataset QA run was rejected",
|
|
61
61
|
}
|
|
62
62
|
STATUS_TO_PROGRESS_MAP = {
|
|
63
63
|
RunStatus.STARTED.value: 0.0,
|
|
@@ -71,99 +71,51 @@ STATUS_TO_PROGRESS_MAP = {
|
|
|
71
71
|
}
|
|
72
72
|
|
|
73
73
|
|
|
74
|
-
class
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
@property
|
|
78
|
-
@abstractmethod
|
|
79
|
-
def metadata_url(self) -> HirundoUrl:
|
|
80
|
-
raise NotImplementedError()
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
class HirundoCSV(Metadata):
|
|
74
|
+
class ClassificationRunArgs(BaseModel):
|
|
75
|
+
image_size: typing.Optional[tuple[int, int]] = (224, 224)
|
|
84
76
|
"""
|
|
85
|
-
|
|
77
|
+
Size (width, height) to which to resize classification images.
|
|
78
|
+
It is recommended to keep this value at (224, 224) unless your classes are differentiated by very small differences.
|
|
86
79
|
"""
|
|
87
|
-
|
|
88
|
-
type: DatasetMetadataType = DatasetMetadataType.HIRUNDO_CSV
|
|
89
|
-
csv_url: HirundoUrl
|
|
80
|
+
upsample: typing.Optional[bool] = False
|
|
90
81
|
"""
|
|
91
|
-
|
|
92
|
-
e.g. `s3://my-bucket-name/my-folder/my-metadata.csv`, `gs://my-bucket-name/my-folder/my-metadata.csv`,
|
|
93
|
-
or `ssh://my-username@my-repo-name/my-folder/my-metadata.csv`
|
|
94
|
-
(or `file:///datasets/my-folder/my-metadata.csv` if using LOCAL storage type with on-premises installation)
|
|
82
|
+
Whether to upsample the dataset to attempt to balance the classes.
|
|
95
83
|
"""
|
|
96
84
|
|
|
97
|
-
@property
|
|
98
|
-
def metadata_url(self) -> HirundoUrl:
|
|
99
|
-
return self.csv_url
|
|
100
|
-
|
|
101
85
|
|
|
102
|
-
class
|
|
86
|
+
class ObjectDetectionRunArgs(ClassificationRunArgs):
|
|
87
|
+
min_abs_bbox_size: typing.Optional[int] = None
|
|
103
88
|
"""
|
|
104
|
-
|
|
89
|
+
Minimum valid size (in pixels) of a bounding box to keep it in the dataset for QA.
|
|
105
90
|
"""
|
|
106
|
-
|
|
107
|
-
type: DatasetMetadataType = DatasetMetadataType.COCO
|
|
108
|
-
json_url: HirundoUrl
|
|
109
|
-
"""
|
|
110
|
-
The URL to access the dataset metadata JSON file.
|
|
111
|
-
e.g. `s3://my-bucket-name/my-folder/my-metadata.json`, `gs://my-bucket-name/my-folder/my-metadata.json`,
|
|
112
|
-
or `ssh://my-username@my-repo-name/my-folder/my-metadata.json`
|
|
113
|
-
(or `file:///datasets/my-folder/my-metadata.json` if using LOCAL storage type with on-premises installation)
|
|
114
|
-
"""
|
|
115
|
-
|
|
116
|
-
@property
|
|
117
|
-
def metadata_url(self) -> HirundoUrl:
|
|
118
|
-
return self.json_url
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
class YOLO(Metadata):
|
|
122
|
-
type: DatasetMetadataType = DatasetMetadataType.YOLO
|
|
123
|
-
data_yaml_url: typing.Optional[HirundoUrl] = None
|
|
124
|
-
labels_dir_url: HirundoUrl
|
|
125
|
-
|
|
126
|
-
@property
|
|
127
|
-
def metadata_url(self) -> HirundoUrl:
|
|
128
|
-
return self.labels_dir_url
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
LabelingInfo = typing.Union[HirundoCSV, COCO, YOLO]
|
|
132
|
-
"""
|
|
133
|
-
The dataset labeling info. The dataset labeling info can be one of the following:
|
|
134
|
-
- `DatasetMetadataType.HirundoCSV`: Indicates that the dataset metadata file is a CSV file with the Hirundo format
|
|
135
|
-
|
|
136
|
-
Currently no other formats are supported. Future versions of `hirundo` may support additional formats.
|
|
137
|
-
"""
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
class VisionRunArgs(BaseModel):
|
|
141
|
-
upsample: bool = False
|
|
91
|
+
min_abs_bbox_area: typing.Optional[int] = None
|
|
142
92
|
"""
|
|
143
|
-
|
|
93
|
+
Minimum valid absolute area (in pixels²) of a bounding box to keep it in the dataset for QA.
|
|
144
94
|
"""
|
|
145
|
-
|
|
95
|
+
min_rel_bbox_size: typing.Optional[float] = None
|
|
146
96
|
"""
|
|
147
|
-
Minimum valid size (
|
|
97
|
+
Minimum valid size (as a fraction of both image height and width) for a bounding box
|
|
98
|
+
to keep it in the dataset for QA, relative to the corresponding dimension size,
|
|
99
|
+
i.e. if the bounding box is 10% of the image width and 5% of the image height, it will be kept if this value is 0.05, but not if the
|
|
100
|
+
value is 0.06 (since both width and height are checked).
|
|
148
101
|
"""
|
|
149
|
-
|
|
102
|
+
min_rel_bbox_area: typing.Optional[float] = None
|
|
150
103
|
"""
|
|
151
|
-
Minimum valid
|
|
104
|
+
Minimum valid relative area (as a fraction of the image area) of a bounding box to keep it in the dataset for QA.
|
|
152
105
|
"""
|
|
153
|
-
|
|
106
|
+
crop_ratio: typing.Optional[float] = None
|
|
154
107
|
"""
|
|
155
|
-
|
|
156
|
-
to keep it
|
|
157
|
-
i.e. if the bounding box is 10% of the image width and 5% of the image height, it will be kept if this value is 0.05, but not if the
|
|
158
|
-
value is 0.06 (since both width and height are checked).
|
|
108
|
+
Ratio of the bounding box to crop.
|
|
109
|
+
Change this value at your own risk. It is recommended to keep it at 1.0 unless you know what you are doing.
|
|
159
110
|
"""
|
|
160
|
-
|
|
111
|
+
add_mask_channel: typing.Optional[bool] = None
|
|
161
112
|
"""
|
|
162
|
-
|
|
113
|
+
Whether to add a mask channel to the image.
|
|
114
|
+
Change at your own risk. It is recommended to keep it at False unless you know what you are doing.
|
|
163
115
|
"""
|
|
164
116
|
|
|
165
117
|
|
|
166
|
-
RunArgs = typing.Union[
|
|
118
|
+
RunArgs = typing.Union[ClassificationRunArgs, ObjectDetectionRunArgs]
|
|
167
119
|
|
|
168
120
|
|
|
169
121
|
class AugmentationName(str, Enum):
|
|
@@ -176,13 +128,31 @@ class AugmentationName(str, Enum):
|
|
|
176
128
|
GAUSSIAN_BLUR = "GaussianBlur"
|
|
177
129
|
|
|
178
130
|
|
|
179
|
-
class
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
131
|
+
class Domain(str, Enum):
|
|
132
|
+
RADAR = "RADAR"
|
|
133
|
+
VISION = "VISION"
|
|
134
|
+
SPEECH = "SPEECH"
|
|
135
|
+
TABULAR = "TABULAR"
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
DOMAIN_TO_SUPPORTED_LABELING_TYPES = {
|
|
139
|
+
Domain.RADAR: [
|
|
140
|
+
LabelingType.SINGLE_LABEL_CLASSIFICATION,
|
|
141
|
+
LabelingType.OBJECT_DETECTION,
|
|
142
|
+
],
|
|
143
|
+
Domain.VISION: [
|
|
144
|
+
LabelingType.SINGLE_LABEL_CLASSIFICATION,
|
|
145
|
+
LabelingType.OBJECT_DETECTION,
|
|
146
|
+
LabelingType.OBJECT_SEGMENTATION,
|
|
147
|
+
LabelingType.SEMANTIC_SEGMENTATION,
|
|
148
|
+
LabelingType.PANOPTIC_SEGMENTATION,
|
|
149
|
+
],
|
|
150
|
+
Domain.SPEECH: [LabelingType.SPEECH_TO_TEXT],
|
|
151
|
+
Domain.TABULAR: [LabelingType.SINGLE_LABEL_CLASSIFICATION],
|
|
152
|
+
}
|
|
183
153
|
|
|
184
154
|
|
|
185
|
-
class
|
|
155
|
+
class QADataset(BaseModel):
|
|
186
156
|
id: typing.Optional[int] = Field(default=None)
|
|
187
157
|
"""
|
|
188
158
|
The ID of the dataset created on the server.
|
|
@@ -228,7 +198,7 @@ class OptimizationDataset(BaseModel):
|
|
|
228
198
|
A full list of possible classes used in classification / object detection.
|
|
229
199
|
It is currently required for clarity and performance.
|
|
230
200
|
"""
|
|
231
|
-
labeling_info: LabelingInfo
|
|
201
|
+
labeling_info: typing.Union[LabelingInfo, list[LabelingInfo]]
|
|
232
202
|
|
|
233
203
|
augmentations: typing.Optional[list[AugmentationName]] = None
|
|
234
204
|
"""
|
|
@@ -236,21 +206,29 @@ class OptimizationDataset(BaseModel):
|
|
|
236
206
|
For audio datasets, this field is ignored.
|
|
237
207
|
If no value is provided, all augmentations are applied to vision datasets.
|
|
238
208
|
"""
|
|
239
|
-
|
|
209
|
+
domain: Domain = Domain.VISION
|
|
240
210
|
"""
|
|
241
|
-
Used to define the
|
|
211
|
+
Used to define the domain of the dataset.
|
|
242
212
|
Defaults to Image.
|
|
243
213
|
"""
|
|
244
214
|
|
|
245
215
|
run_id: typing.Optional[str] = Field(default=None, init=False)
|
|
246
216
|
"""
|
|
247
|
-
The ID of the Dataset
|
|
217
|
+
The ID of the Dataset QA run created on the server.
|
|
248
218
|
"""
|
|
249
219
|
|
|
250
220
|
status: typing.Optional[RunStatus] = None
|
|
251
221
|
|
|
252
222
|
@model_validator(mode="after")
|
|
253
223
|
def validate_dataset(self):
|
|
224
|
+
if self.domain not in DOMAIN_TO_SUPPORTED_LABELING_TYPES:
|
|
225
|
+
raise ValueError(
|
|
226
|
+
f"Domain {self.domain} is not supported. Supported domains are: {list(DOMAIN_TO_SUPPORTED_LABELING_TYPES.keys())}"
|
|
227
|
+
)
|
|
228
|
+
if self.labeling_type not in DOMAIN_TO_SUPPORTED_LABELING_TYPES[self.domain]:
|
|
229
|
+
raise ValueError(
|
|
230
|
+
f"Labeling type {self.labeling_type} is not supported for domain {self.domain}. Supported labeling types are: {DOMAIN_TO_SUPPORTED_LABELING_TYPES[self.domain]}"
|
|
231
|
+
)
|
|
254
232
|
if self.storage_config is None and self.storage_config_id is None:
|
|
255
233
|
raise ValueError(
|
|
256
234
|
"No dataset storage has been provided. Provide one via `storage_config` or `storage_config_id`"
|
|
@@ -267,65 +245,79 @@ class OptimizationDataset(BaseModel):
|
|
|
267
245
|
):
|
|
268
246
|
raise ValueError("Language is only allowed for Speech-to-Text datasets.")
|
|
269
247
|
if (
|
|
270
|
-
self.labeling_info
|
|
248
|
+
not isinstance(self.labeling_info, list)
|
|
249
|
+
and self.labeling_info.type == DatasetMetadataType.YOLO
|
|
271
250
|
and isinstance(self.labeling_info, YOLO)
|
|
272
251
|
and (
|
|
273
252
|
self.labeling_info.data_yaml_url is not None
|
|
274
253
|
and self.classes is not None
|
|
275
254
|
)
|
|
255
|
+
) or (
|
|
256
|
+
isinstance(self.labeling_info, list)
|
|
257
|
+
and self.classes is not None
|
|
258
|
+
and any(
|
|
259
|
+
isinstance(info, YOLO) and info.data_yaml_url is not None
|
|
260
|
+
for info in self.labeling_info
|
|
261
|
+
)
|
|
276
262
|
):
|
|
277
263
|
raise ValueError(
|
|
278
264
|
"Only one of `classes` or `labeling_info.data_yaml_url` should be provided for YOLO datasets"
|
|
279
265
|
)
|
|
266
|
+
if self.storage_config:
|
|
267
|
+
validate_labeling_info(
|
|
268
|
+
self.labeling_type, self.labeling_info, self.storage_config
|
|
269
|
+
)
|
|
270
|
+
if self.data_root_url and self.storage_config:
|
|
271
|
+
validate_url(self.data_root_url, self.storage_config)
|
|
280
272
|
return self
|
|
281
273
|
|
|
282
274
|
@staticmethod
|
|
283
|
-
def get_by_id(dataset_id: int) -> "
|
|
275
|
+
def get_by_id(dataset_id: int) -> "QADataset":
|
|
284
276
|
"""
|
|
285
|
-
Get a `
|
|
277
|
+
Get a `QADataset` instance from the server by its ID
|
|
286
278
|
|
|
287
279
|
Args:
|
|
288
|
-
dataset_id: The ID of the `
|
|
280
|
+
dataset_id: The ID of the `QADataset` instance to get
|
|
289
281
|
"""
|
|
290
282
|
response = requests.get(
|
|
291
|
-
f"{API_HOST}/dataset-
|
|
283
|
+
f"{API_HOST}/dataset-qa/dataset/{dataset_id}",
|
|
292
284
|
headers=get_headers(),
|
|
293
285
|
timeout=READ_TIMEOUT,
|
|
294
286
|
)
|
|
295
287
|
raise_for_status_with_reason(response)
|
|
296
288
|
dataset = response.json()
|
|
297
|
-
return
|
|
289
|
+
return QADataset(**dataset)
|
|
298
290
|
|
|
299
291
|
@staticmethod
|
|
300
|
-
def get_by_name(name: str) -> "
|
|
292
|
+
def get_by_name(name: str) -> "QADataset":
|
|
301
293
|
"""
|
|
302
|
-
Get a `
|
|
294
|
+
Get a `QADataset` instance from the server by its name
|
|
303
295
|
|
|
304
296
|
Args:
|
|
305
|
-
name: The name of the `
|
|
297
|
+
name: The name of the `QADataset` instance to get
|
|
306
298
|
"""
|
|
307
299
|
response = requests.get(
|
|
308
|
-
f"{API_HOST}/dataset-
|
|
300
|
+
f"{API_HOST}/dataset-qa/dataset/by-name/{name}",
|
|
309
301
|
headers=get_headers(),
|
|
310
302
|
timeout=READ_TIMEOUT,
|
|
311
303
|
)
|
|
312
304
|
raise_for_status_with_reason(response)
|
|
313
305
|
dataset = response.json()
|
|
314
|
-
return
|
|
306
|
+
return QADataset(**dataset)
|
|
315
307
|
|
|
316
308
|
@staticmethod
|
|
317
309
|
def list_datasets(
|
|
318
310
|
organization_id: typing.Optional[int] = None,
|
|
319
|
-
) -> list["
|
|
311
|
+
) -> list["QADatasetOut"]:
|
|
320
312
|
"""
|
|
321
|
-
Lists all the
|
|
313
|
+
Lists all the datasets created by user's default organization
|
|
322
314
|
or the `organization_id` passed
|
|
323
315
|
|
|
324
316
|
Args:
|
|
325
317
|
organization_id: The ID of the organization to list the datasets for.
|
|
326
318
|
"""
|
|
327
319
|
response = requests.get(
|
|
328
|
-
f"{API_HOST}/dataset-
|
|
320
|
+
f"{API_HOST}/dataset-qa/dataset/",
|
|
329
321
|
params={"dataset_organization_id": organization_id},
|
|
330
322
|
headers=get_headers(),
|
|
331
323
|
timeout=READ_TIMEOUT,
|
|
@@ -333,7 +325,7 @@ class OptimizationDataset(BaseModel):
|
|
|
333
325
|
raise_for_status_with_reason(response)
|
|
334
326
|
datasets = response.json()
|
|
335
327
|
return [
|
|
336
|
-
|
|
328
|
+
QADatasetOut(
|
|
337
329
|
**ds,
|
|
338
330
|
)
|
|
339
331
|
for ds in datasets
|
|
@@ -342,17 +334,17 @@ class OptimizationDataset(BaseModel):
|
|
|
342
334
|
@staticmethod
|
|
343
335
|
def list_runs(
|
|
344
336
|
organization_id: typing.Optional[int] = None,
|
|
345
|
-
) -> list["
|
|
337
|
+
) -> list["DataQARunOut"]:
|
|
346
338
|
"""
|
|
347
|
-
Lists all the `
|
|
339
|
+
Lists all the `QADataset` instances created by user's default organization
|
|
348
340
|
or the `organization_id` passed
|
|
349
|
-
Note: The return type is `list[dict]` and not `list[
|
|
341
|
+
Note: The return type is `list[dict]` and not `list[QADataset]`
|
|
350
342
|
|
|
351
343
|
Args:
|
|
352
344
|
organization_id: The ID of the organization to list the datasets for.
|
|
353
345
|
"""
|
|
354
346
|
response = requests.get(
|
|
355
|
-
f"{API_HOST}/dataset-
|
|
347
|
+
f"{API_HOST}/dataset-qa/run/list",
|
|
356
348
|
params={"dataset_organization_id": organization_id},
|
|
357
349
|
headers=get_headers(),
|
|
358
350
|
timeout=READ_TIMEOUT,
|
|
@@ -360,7 +352,7 @@ class OptimizationDataset(BaseModel):
|
|
|
360
352
|
raise_for_status_with_reason(response)
|
|
361
353
|
runs = response.json()
|
|
362
354
|
return [
|
|
363
|
-
|
|
355
|
+
DataQARunOut(
|
|
364
356
|
**run,
|
|
365
357
|
)
|
|
366
358
|
for run in runs
|
|
@@ -369,13 +361,13 @@ class OptimizationDataset(BaseModel):
|
|
|
369
361
|
@staticmethod
|
|
370
362
|
def delete_by_id(dataset_id: int) -> None:
|
|
371
363
|
"""
|
|
372
|
-
Deletes a `
|
|
364
|
+
Deletes a `QADataset` instance from the server by its ID
|
|
373
365
|
|
|
374
366
|
Args:
|
|
375
|
-
dataset_id: The ID of the `
|
|
367
|
+
dataset_id: The ID of the `QADataset` instance to delete
|
|
376
368
|
"""
|
|
377
369
|
response = requests.delete(
|
|
378
|
-
f"{API_HOST}/dataset-
|
|
370
|
+
f"{API_HOST}/dataset-qa/dataset/{dataset_id}",
|
|
379
371
|
headers=get_headers(),
|
|
380
372
|
timeout=MODIFY_TIMEOUT,
|
|
381
373
|
)
|
|
@@ -384,14 +376,14 @@ class OptimizationDataset(BaseModel):
|
|
|
384
376
|
|
|
385
377
|
def delete(self, storage_config=True) -> None:
|
|
386
378
|
"""
|
|
387
|
-
Deletes the active `
|
|
388
|
-
It can only be used on a `
|
|
379
|
+
Deletes the active `QADataset` instance from the server.
|
|
380
|
+
It can only be used on a `QADataset` instance that has been created.
|
|
389
381
|
|
|
390
382
|
Args:
|
|
391
|
-
storage_config: If True, the `
|
|
383
|
+
storage_config: If True, the `QADataset`'s `StorageConfig` will also be deleted
|
|
392
384
|
|
|
393
385
|
Note: If `storage_config` is not set to `False` then the `storage_config_id` must be set
|
|
394
|
-
This can either be set manually or by creating the `StorageConfig` instance via the `
|
|
386
|
+
This can either be set manually or by creating the `StorageConfig` instance via the `QADataset`'s
|
|
395
387
|
`create` method
|
|
396
388
|
"""
|
|
397
389
|
if storage_config:
|
|
@@ -408,7 +400,7 @@ class OptimizationDataset(BaseModel):
|
|
|
408
400
|
replace_if_exists: bool = False,
|
|
409
401
|
) -> int:
|
|
410
402
|
"""
|
|
411
|
-
Create a `
|
|
403
|
+
Create a `QADataset` instance on the server.
|
|
412
404
|
If the `storage_config_id` field is not set, the storage config will also be created and the field will be set.
|
|
413
405
|
|
|
414
406
|
Args:
|
|
@@ -417,7 +409,7 @@ class OptimizationDataset(BaseModel):
|
|
|
417
409
|
(this is determined by a dataset of the same name in the same organization).
|
|
418
410
|
|
|
419
411
|
Returns:
|
|
420
|
-
The ID of the created `
|
|
412
|
+
The ID of the created `QADataset` instance
|
|
421
413
|
"""
|
|
422
414
|
if self.storage_config is None and self.storage_config_id is None:
|
|
423
415
|
raise ValueError("No dataset storage has been provided")
|
|
@@ -442,7 +434,7 @@ class OptimizationDataset(BaseModel):
|
|
|
442
434
|
model_dict = self.model_dump(mode="json")
|
|
443
435
|
# ⬆️ Get dict of model fields from Pydantic model instance
|
|
444
436
|
dataset_response = requests.post(
|
|
445
|
-
f"{API_HOST}/dataset-
|
|
437
|
+
f"{API_HOST}/dataset-qa/dataset/",
|
|
446
438
|
json={
|
|
447
439
|
**{k: model_dict[k] for k in model_dict.keys() - {"storage_config"}},
|
|
448
440
|
"organization_id": organization_id,
|
|
@@ -459,17 +451,17 @@ class OptimizationDataset(BaseModel):
|
|
|
459
451
|
return self.id
|
|
460
452
|
|
|
461
453
|
@staticmethod
|
|
462
|
-
def
|
|
454
|
+
def launch_qa_run(
|
|
463
455
|
dataset_id: int,
|
|
464
456
|
organization_id: typing.Optional[int] = None,
|
|
465
457
|
run_args: typing.Optional[RunArgs] = None,
|
|
466
458
|
) -> str:
|
|
467
459
|
"""
|
|
468
|
-
Run the dataset
|
|
460
|
+
Run the dataset QA process on the server using the dataset with the given ID
|
|
469
461
|
i.e. `dataset_id`.
|
|
470
462
|
|
|
471
463
|
Args:
|
|
472
|
-
dataset_id: The ID of the dataset to run
|
|
464
|
+
dataset_id: The ID of the dataset to run QA on.
|
|
473
465
|
|
|
474
466
|
Returns:
|
|
475
467
|
ID of the run (`run_id`).
|
|
@@ -480,7 +472,7 @@ class OptimizationDataset(BaseModel):
|
|
|
480
472
|
if run_args:
|
|
481
473
|
run_info["run_args"] = run_args.model_dump(mode="json")
|
|
482
474
|
run_response = requests.post(
|
|
483
|
-
f"{API_HOST}/dataset-
|
|
475
|
+
f"{API_HOST}/dataset-qa/run/{dataset_id}",
|
|
484
476
|
json=run_info if len(run_info) > 0 else None,
|
|
485
477
|
headers=get_headers(),
|
|
486
478
|
timeout=MODIFY_TIMEOUT,
|
|
@@ -491,12 +483,16 @@ class OptimizationDataset(BaseModel):
|
|
|
491
483
|
def _validate_run_args(self, run_args: RunArgs) -> None:
|
|
492
484
|
if self.labeling_type == LabelingType.SPEECH_TO_TEXT:
|
|
493
485
|
raise Exception("Speech to text cannot have `run_args` set")
|
|
494
|
-
if
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
486
|
+
if (
|
|
487
|
+
self.labeling_type != LabelingType.OBJECT_DETECTION
|
|
488
|
+
and isinstance(run_args, ObjectDetectionRunArgs)
|
|
489
|
+
and any(
|
|
490
|
+
(
|
|
491
|
+
run_args.min_abs_bbox_size != 0,
|
|
492
|
+
run_args.min_abs_bbox_area != 0,
|
|
493
|
+
run_args.min_rel_bbox_size != 0,
|
|
494
|
+
run_args.min_rel_bbox_area != 0,
|
|
495
|
+
)
|
|
500
496
|
)
|
|
501
497
|
):
|
|
502
498
|
raise Exception(
|
|
@@ -505,7 +501,7 @@ class OptimizationDataset(BaseModel):
|
|
|
505
501
|
+ f"labeling type {self.labeling_type}"
|
|
506
502
|
)
|
|
507
503
|
|
|
508
|
-
def
|
|
504
|
+
def run_qa(
|
|
509
505
|
self,
|
|
510
506
|
organization_id: typing.Optional[int] = None,
|
|
511
507
|
replace_dataset_if_exists: bool = False,
|
|
@@ -513,13 +509,13 @@ class OptimizationDataset(BaseModel):
|
|
|
513
509
|
) -> str:
|
|
514
510
|
"""
|
|
515
511
|
If the dataset was not created on the server yet, it is created.
|
|
516
|
-
Run the dataset
|
|
512
|
+
Run the dataset QA process on the server using the active `QADataset` instance
|
|
517
513
|
|
|
518
514
|
Args:
|
|
519
|
-
organization_id: The ID of the organization to run the
|
|
515
|
+
organization_id: The ID of the organization to run the QA for.
|
|
520
516
|
replace_dataset_if_exists: If True, the dataset will be replaced if it already exists
|
|
521
517
|
(this is determined by a dataset of the same name in the same organization).
|
|
522
|
-
run_args: The run arguments to use for the
|
|
518
|
+
run_args: The run arguments to use for the QA run
|
|
523
519
|
|
|
524
520
|
Returns:
|
|
525
521
|
An ID of the run (`run_id`) and stores that `run_id` on the instance
|
|
@@ -529,7 +525,7 @@ class OptimizationDataset(BaseModel):
|
|
|
529
525
|
self.id = self.create(replace_if_exists=replace_dataset_if_exists)
|
|
530
526
|
if run_args is not None:
|
|
531
527
|
self._validate_run_args(run_args)
|
|
532
|
-
run_id = self.
|
|
528
|
+
run_id = self.launch_qa_run(self.id, organization_id, run_args)
|
|
533
529
|
self.run_id = run_id
|
|
534
530
|
logger.info("Started the run with ID: %s", run_id)
|
|
535
531
|
return run_id
|
|
@@ -567,7 +563,7 @@ class OptimizationDataset(BaseModel):
|
|
|
567
563
|
for sse in iter_sse_retrying(
|
|
568
564
|
client,
|
|
569
565
|
"GET",
|
|
570
|
-
f"{API_HOST}/dataset-
|
|
566
|
+
f"{API_HOST}/dataset-qa/run/{run_id}",
|
|
571
567
|
headers=get_headers(),
|
|
572
568
|
):
|
|
573
569
|
if sse.event == "ping":
|
|
@@ -593,39 +589,46 @@ class OptimizationDataset(BaseModel):
|
|
|
593
589
|
raise HirundoError("Unknown error")
|
|
594
590
|
yield data
|
|
595
591
|
if not last_event or last_event["data"]["state"] == RunStatus.PENDING.value:
|
|
596
|
-
|
|
592
|
+
QADataset._check_run_by_id(run_id, retry + 1)
|
|
593
|
+
|
|
594
|
+
@staticmethod
|
|
595
|
+
def _handle_failure(iteration: dict):
|
|
596
|
+
if iteration["result"]:
|
|
597
|
+
raise HirundoError(f"QA run failed with error: {iteration['result']}")
|
|
598
|
+
else:
|
|
599
|
+
raise HirundoError("QA run failed with an unknown error in _handle_failure")
|
|
597
600
|
|
|
598
601
|
@staticmethod
|
|
599
602
|
@overload
|
|
600
603
|
def check_run_by_id(
|
|
601
604
|
run_id: str, stop_on_manual_approval: typing.Literal[True]
|
|
602
|
-
) -> typing.Optional[
|
|
605
|
+
) -> typing.Optional[DatasetQAResults]: ...
|
|
603
606
|
|
|
604
607
|
@staticmethod
|
|
605
608
|
@overload
|
|
606
609
|
def check_run_by_id(
|
|
607
610
|
run_id: str, stop_on_manual_approval: typing.Literal[False] = False
|
|
608
|
-
) ->
|
|
611
|
+
) -> DatasetQAResults: ...
|
|
609
612
|
|
|
610
613
|
@staticmethod
|
|
611
614
|
@overload
|
|
612
615
|
def check_run_by_id(
|
|
613
616
|
run_id: str, stop_on_manual_approval: bool
|
|
614
|
-
) -> typing.Optional[
|
|
617
|
+
) -> typing.Optional[DatasetQAResults]: ...
|
|
615
618
|
|
|
616
619
|
@staticmethod
|
|
617
620
|
def check_run_by_id(
|
|
618
621
|
run_id: str, stop_on_manual_approval: bool = False
|
|
619
|
-
) -> typing.Optional[
|
|
622
|
+
) -> typing.Optional[DatasetQAResults]:
|
|
620
623
|
"""
|
|
621
624
|
Check the status of a run given its ID
|
|
622
625
|
|
|
623
626
|
Args:
|
|
624
|
-
run_id: The `run_id` produced by a `
|
|
627
|
+
run_id: The `run_id` produced by a `run_qa` call
|
|
625
628
|
stop_on_manual_approval: If True, the function will return `None` if the run is awaiting manual approval
|
|
626
629
|
|
|
627
630
|
Returns:
|
|
628
|
-
A
|
|
631
|
+
A DatasetQAResults object with the results of the QA run
|
|
629
632
|
|
|
630
633
|
Raises:
|
|
631
634
|
HirundoError: If the maximum number of retries is reached or if the run fails
|
|
@@ -633,7 +636,7 @@ class OptimizationDataset(BaseModel):
|
|
|
633
636
|
logger.debug("Checking run with ID: %s", run_id)
|
|
634
637
|
with logging_redirect_tqdm():
|
|
635
638
|
t = tqdm(total=100.0)
|
|
636
|
-
for iteration in
|
|
639
|
+
for iteration in QADataset._check_run_by_id(run_id):
|
|
637
640
|
if iteration["state"] in STATUS_TO_PROGRESS_MAP:
|
|
638
641
|
t.set_description(STATUS_TO_TEXT_MAP[iteration["state"]])
|
|
639
642
|
t.n = STATUS_TO_PROGRESS_MAP[iteration["state"]]
|
|
@@ -644,13 +647,15 @@ class OptimizationDataset(BaseModel):
|
|
|
644
647
|
RunStatus.REJECTED.value,
|
|
645
648
|
RunStatus.REVOKED.value,
|
|
646
649
|
]:
|
|
647
|
-
|
|
648
|
-
|
|
650
|
+
logger.error(
|
|
651
|
+
"State is failure, rejected, or revoked: %s",
|
|
652
|
+
iteration["state"],
|
|
649
653
|
)
|
|
654
|
+
QADataset._handle_failure(iteration)
|
|
650
655
|
elif iteration["state"] == RunStatus.SUCCESS.value:
|
|
651
656
|
t.close()
|
|
652
657
|
zip_temporary_url = iteration["result"]
|
|
653
|
-
logger.debug("
|
|
658
|
+
logger.debug("QA run completed. Downloading results")
|
|
654
659
|
|
|
655
660
|
return download_and_extract_zip(
|
|
656
661
|
run_id,
|
|
@@ -682,7 +687,7 @@ class OptimizationDataset(BaseModel):
|
|
|
682
687
|
stage = "Unknown progress state"
|
|
683
688
|
current_progress_percentage = t.n # Keep the same progress
|
|
684
689
|
desc = (
|
|
685
|
-
"
|
|
690
|
+
"QA run completed. Uploading results"
|
|
686
691
|
if current_progress_percentage == 100.0
|
|
687
692
|
else stage
|
|
688
693
|
)
|
|
@@ -690,26 +695,26 @@ class OptimizationDataset(BaseModel):
|
|
|
690
695
|
t.n = current_progress_percentage
|
|
691
696
|
logger.debug("Setting progress to %s", t.n)
|
|
692
697
|
t.refresh()
|
|
693
|
-
raise HirundoError("
|
|
698
|
+
raise HirundoError("QA run failed with an unknown error in check_run_by_id")
|
|
694
699
|
|
|
695
700
|
@overload
|
|
696
701
|
def check_run(
|
|
697
702
|
self, stop_on_manual_approval: typing.Literal[True]
|
|
698
|
-
) -> typing.Optional[
|
|
703
|
+
) -> typing.Optional[DatasetQAResults]: ...
|
|
699
704
|
|
|
700
705
|
@overload
|
|
701
706
|
def check_run(
|
|
702
707
|
self, stop_on_manual_approval: typing.Literal[False] = False
|
|
703
|
-
) ->
|
|
708
|
+
) -> DatasetQAResults: ...
|
|
704
709
|
|
|
705
710
|
def check_run(
|
|
706
711
|
self, stop_on_manual_approval: bool = False
|
|
707
|
-
) -> typing.Optional[
|
|
712
|
+
) -> typing.Optional[DatasetQAResults]:
|
|
708
713
|
"""
|
|
709
714
|
Check the status of the current active instance's run.
|
|
710
715
|
|
|
711
716
|
Returns:
|
|
712
|
-
A pandas DataFrame with the results of the
|
|
717
|
+
A pandas DataFrame with the results of the QA run
|
|
713
718
|
|
|
714
719
|
"""
|
|
715
720
|
if not self.run_id:
|
|
@@ -726,7 +731,7 @@ class OptimizationDataset(BaseModel):
|
|
|
726
731
|
This generator will produce values to show progress of the run.
|
|
727
732
|
|
|
728
733
|
Args:
|
|
729
|
-
run_id: The `run_id` produced by a `
|
|
734
|
+
run_id: The `run_id` produced by a `run_qa` call
|
|
730
735
|
retry: A number used to track the number of retries to limit re-checks. *Do not* provide this value manually.
|
|
731
736
|
|
|
732
737
|
Yields:
|
|
@@ -745,7 +750,7 @@ class OptimizationDataset(BaseModel):
|
|
|
745
750
|
async_iterator = await aiter_sse_retrying(
|
|
746
751
|
client,
|
|
747
752
|
"GET",
|
|
748
|
-
f"{API_HOST}/dataset-
|
|
753
|
+
f"{API_HOST}/dataset-qa/run/{run_id}",
|
|
749
754
|
headers=get_headers(),
|
|
750
755
|
)
|
|
751
756
|
async for sse in async_iterator:
|
|
@@ -761,7 +766,7 @@ class OptimizationDataset(BaseModel):
|
|
|
761
766
|
last_event = json.loads(sse.data)
|
|
762
767
|
yield last_event["data"]
|
|
763
768
|
if not last_event or last_event["data"]["state"] == RunStatus.PENDING.value:
|
|
764
|
-
|
|
769
|
+
QADataset.acheck_run_by_id(run_id, retry + 1)
|
|
765
770
|
|
|
766
771
|
async def acheck_run(self) -> AsyncGenerator[dict, None]:
|
|
767
772
|
"""
|
|
@@ -785,16 +790,14 @@ class OptimizationDataset(BaseModel):
|
|
|
785
790
|
@staticmethod
|
|
786
791
|
def cancel_by_id(run_id: str) -> None:
|
|
787
792
|
"""
|
|
788
|
-
Cancel the dataset
|
|
793
|
+
Cancel the dataset QA run for the given `run_id`.
|
|
789
794
|
|
|
790
795
|
Args:
|
|
791
796
|
run_id: The ID of the run to cancel
|
|
792
797
|
"""
|
|
793
|
-
if not run_id:
|
|
794
|
-
raise ValueError("No run has been started")
|
|
795
798
|
logger.info("Cancelling run with ID: %s", run_id)
|
|
796
799
|
response = requests.delete(
|
|
797
|
-
f"{API_HOST}/dataset-
|
|
800
|
+
f"{API_HOST}/dataset-qa/run/{run_id}",
|
|
798
801
|
headers=get_headers(),
|
|
799
802
|
timeout=MODIFY_TIMEOUT,
|
|
800
803
|
)
|
|
@@ -808,8 +811,32 @@ class OptimizationDataset(BaseModel):
|
|
|
808
811
|
raise ValueError("No run has been started")
|
|
809
812
|
self.cancel_by_id(self.run_id)
|
|
810
813
|
|
|
814
|
+
@staticmethod
|
|
815
|
+
def archive_run_by_id(run_id: str) -> None:
|
|
816
|
+
"""
|
|
817
|
+
Archive the dataset QA run for the given `run_id`.
|
|
818
|
+
|
|
819
|
+
Args:
|
|
820
|
+
run_id: The ID of the run to archive
|
|
821
|
+
"""
|
|
822
|
+
logger.info("Archiving run with ID: %s", run_id)
|
|
823
|
+
response = requests.patch(
|
|
824
|
+
f"{API_HOST}/dataset-qa/run/archive/{run_id}",
|
|
825
|
+
headers=get_headers(),
|
|
826
|
+
timeout=MODIFY_TIMEOUT,
|
|
827
|
+
)
|
|
828
|
+
raise_for_status_with_reason(response)
|
|
829
|
+
|
|
830
|
+
def archive(self) -> None:
|
|
831
|
+
"""
|
|
832
|
+
Archive the current active instance's run.
|
|
833
|
+
"""
|
|
834
|
+
if not self.run_id:
|
|
835
|
+
raise ValueError("No run has been started")
|
|
836
|
+
self.archive_run_by_id(self.run_id)
|
|
837
|
+
|
|
811
838
|
|
|
812
|
-
class
|
|
839
|
+
class QADatasetOut(BaseModel):
|
|
813
840
|
id: int
|
|
814
841
|
|
|
815
842
|
name: str
|
|
@@ -820,7 +847,7 @@ class DataOptimizationDatasetOut(BaseModel):
|
|
|
820
847
|
data_root_url: HirundoUrl
|
|
821
848
|
|
|
822
849
|
classes: typing.Optional[list[str]] = None
|
|
823
|
-
labeling_info: LabelingInfo
|
|
850
|
+
labeling_info: typing.Union[LabelingInfo, list[LabelingInfo]]
|
|
824
851
|
|
|
825
852
|
organization_id: typing.Optional[int]
|
|
826
853
|
creator_id: typing.Optional[int]
|
|
@@ -828,7 +855,7 @@ class DataOptimizationDatasetOut(BaseModel):
|
|
|
828
855
|
updated_at: datetime.datetime
|
|
829
856
|
|
|
830
857
|
|
|
831
|
-
class
|
|
858
|
+
class DataQARunOut(BaseModel):
|
|
832
859
|
id: int
|
|
833
860
|
name: str
|
|
834
861
|
dataset_id: int
|