hirundo 0.1.8__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hirundo/__init__.py +17 -9
- hirundo/_constraints.py +34 -2
- hirundo/_http.py +7 -2
- hirundo/_iter_sse_retrying.py +61 -17
- hirundo/dataset_optimization.py +421 -83
- hirundo/enum.py +8 -5
- hirundo/git.py +85 -20
- hirundo/storage.py +233 -62
- {hirundo-0.1.8.dist-info → hirundo-0.1.9.dist-info}/METADATA +78 -42
- hirundo-0.1.9.dist-info/RECORD +20 -0
- {hirundo-0.1.8.dist-info → hirundo-0.1.9.dist-info}/WHEEL +1 -1
- hirundo-0.1.8.dist-info/RECORD +0 -20
- {hirundo-0.1.8.dist-info → hirundo-0.1.9.dist-info}/LICENSE +0 -0
- {hirundo-0.1.8.dist-info → hirundo-0.1.9.dist-info}/entry_points.txt +0 -0
- {hirundo-0.1.8.dist-info → hirundo-0.1.9.dist-info}/top_level.txt +0 -0
hirundo/dataset_optimization.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
|
+
import datetime
|
|
1
2
|
import json
|
|
2
3
|
import typing
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
3
5
|
from collections.abc import AsyncGenerator, Generator
|
|
4
6
|
from enum import Enum
|
|
5
7
|
from io import StringIO
|
|
@@ -14,14 +16,15 @@ from pydantic import BaseModel, Field, model_validator
|
|
|
14
16
|
from tqdm import tqdm
|
|
15
17
|
from tqdm.contrib.logging import logging_redirect_tqdm
|
|
16
18
|
|
|
19
|
+
from hirundo._constraints import HirundoUrl
|
|
17
20
|
from hirundo._env import API_HOST
|
|
18
21
|
from hirundo._headers import get_auth_headers, json_headers
|
|
19
22
|
from hirundo._http import raise_for_status_with_reason
|
|
20
23
|
from hirundo._iter_sse_retrying import aiter_sse_retrying, iter_sse_retrying
|
|
21
24
|
from hirundo._timeouts import MODIFY_TIMEOUT, READ_TIMEOUT
|
|
22
|
-
from hirundo.enum import DatasetMetadataType,
|
|
25
|
+
from hirundo.enum import DatasetMetadataType, LabelingType
|
|
23
26
|
from hirundo.logger import get_logger
|
|
24
|
-
from hirundo.storage import
|
|
27
|
+
from hirundo.storage import ResponseStorageConfig, StorageConfig
|
|
25
28
|
|
|
26
29
|
logger = get_logger(__name__)
|
|
27
30
|
|
|
@@ -38,12 +41,14 @@ MAX_RETRIES = 200 # Max 200 retries for HTTP SSE connection
|
|
|
38
41
|
|
|
39
42
|
|
|
40
43
|
class RunStatus(Enum):
|
|
41
|
-
STARTED = "STARTED"
|
|
42
44
|
PENDING = "PENDING"
|
|
45
|
+
STARTED = "STARTED"
|
|
43
46
|
SUCCESS = "SUCCESS"
|
|
44
47
|
FAILURE = "FAILURE"
|
|
45
48
|
AWAITING_MANUAL_APPROVAL = "AWAITING MANUAL APPROVAL"
|
|
46
|
-
|
|
49
|
+
REVOKED = "REVOKED"
|
|
50
|
+
REJECTED = "REJECTED"
|
|
51
|
+
RETRY = "RETRY"
|
|
47
52
|
|
|
48
53
|
|
|
49
54
|
STATUS_TO_TEXT_MAP = {
|
|
@@ -52,7 +57,9 @@ STATUS_TO_TEXT_MAP = {
|
|
|
52
57
|
RunStatus.SUCCESS.value: "Optimization run completed successfully",
|
|
53
58
|
RunStatus.FAILURE.value: "Optimization run failed",
|
|
54
59
|
RunStatus.AWAITING_MANUAL_APPROVAL.value: "Awaiting manual approval",
|
|
55
|
-
RunStatus.
|
|
60
|
+
RunStatus.RETRY.value: "Optimization run failed. Retrying",
|
|
61
|
+
RunStatus.REVOKED.value: "Optimization run was cancelled",
|
|
62
|
+
RunStatus.REJECTED.value: "Optimization run was rejected",
|
|
56
63
|
}
|
|
57
64
|
STATUS_TO_PROGRESS_MAP = {
|
|
58
65
|
RunStatus.STARTED.value: 0.0,
|
|
@@ -60,7 +67,9 @@ STATUS_TO_PROGRESS_MAP = {
|
|
|
60
67
|
RunStatus.SUCCESS.value: 100.0,
|
|
61
68
|
RunStatus.FAILURE.value: 100.0,
|
|
62
69
|
RunStatus.AWAITING_MANUAL_APPROVAL.value: 100.0,
|
|
63
|
-
RunStatus.
|
|
70
|
+
RunStatus.RETRY.value: 0.0,
|
|
71
|
+
RunStatus.REVOKED.value: 100.0,
|
|
72
|
+
RunStatus.REJECTED.value: 0.0,
|
|
64
73
|
}
|
|
65
74
|
|
|
66
75
|
|
|
@@ -84,10 +93,10 @@ CUSTOMER_INTERCHANGE_DTYPES: DtypeArg = {
|
|
|
84
93
|
"segment_id": np.int32,
|
|
85
94
|
"label": str,
|
|
86
95
|
"bbox_id": str,
|
|
87
|
-
"xmin": np.
|
|
88
|
-
"ymin": np.
|
|
89
|
-
"xmax": np.
|
|
90
|
-
"ymax": np.
|
|
96
|
+
"xmin": np.float32,
|
|
97
|
+
"ymin": np.float32,
|
|
98
|
+
"xmax": np.float32,
|
|
99
|
+
"ymax": np.float32,
|
|
91
100
|
"suspect_level": np.float32, # If exists, must be one of the values in the enum below
|
|
92
101
|
"suggested_label": str,
|
|
93
102
|
"suggested_label_conf": np.float32,
|
|
@@ -97,63 +106,277 @@ CUSTOMER_INTERCHANGE_DTYPES: DtypeArg = {
|
|
|
97
106
|
}
|
|
98
107
|
|
|
99
108
|
|
|
109
|
+
class Metadata(BaseModel, ABC):
|
|
110
|
+
type: DatasetMetadataType
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
@abstractmethod
|
|
114
|
+
def metadata_url(self) -> HirundoUrl:
|
|
115
|
+
raise NotImplementedError()
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class HirundoCSV(Metadata):
|
|
119
|
+
"""
|
|
120
|
+
A dataset metadata file in the Hirundo CSV format
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
type: DatasetMetadataType = DatasetMetadataType.HIRUNDO_CSV
|
|
124
|
+
csv_url: HirundoUrl
|
|
125
|
+
"""
|
|
126
|
+
The URL to access the dataset metadata CSV file.
|
|
127
|
+
e.g. `s3://my-bucket-name/my-folder/my-metadata.csv`, `gs://my-bucket-name/my-folder/my-metadata.csv`,
|
|
128
|
+
or `ssh://my-username@my-repo-name/my-folder/my-metadata.csv`
|
|
129
|
+
(or `file:///datasets/my-folder/my-metadata.csv` if using LOCAL storage type with on-premises installation)
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def metadata_url(self) -> HirundoUrl:
|
|
134
|
+
return self.csv_url
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class COCO(Metadata):
|
|
138
|
+
"""
|
|
139
|
+
A dataset metadata file in the COCO format
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
type: DatasetMetadataType = DatasetMetadataType.COCO
|
|
143
|
+
json_url: HirundoUrl
|
|
144
|
+
"""
|
|
145
|
+
The URL to access the dataset metadata JSON file.
|
|
146
|
+
e.g. `s3://my-bucket-name/my-folder/my-metadata.json`, `gs://my-bucket-name/my-folder/my-metadata.json`,
|
|
147
|
+
or `ssh://my-username@my-repo-name/my-folder/my-metadata.json`
|
|
148
|
+
(or `file:///datasets/my-folder/my-metadata.json` if using LOCAL storage type with on-premises installation)
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def metadata_url(self) -> HirundoUrl:
|
|
153
|
+
return self.json_url
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class YOLO(Metadata):
|
|
157
|
+
type: DatasetMetadataType = DatasetMetadataType.YOLO
|
|
158
|
+
data_yaml_url: typing.Optional[HirundoUrl] = None
|
|
159
|
+
labels_dir_url: HirundoUrl
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def metadata_url(self) -> HirundoUrl:
|
|
163
|
+
return self.labels_dir_url
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
LabelingInfo = typing.Union[HirundoCSV, COCO, YOLO]
|
|
167
|
+
"""
|
|
168
|
+
The dataset labeling info. The dataset labeling info can be one of the following:
|
|
169
|
+
- `DatasetMetadataType.HirundoCSV`: Indicates that the dataset metadata file is a CSV file with the Hirundo format
|
|
170
|
+
|
|
171
|
+
Currently no other formats are supported. Future versions of `hirundo` may support additional formats.
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class VisionRunArgs(BaseModel):
|
|
176
|
+
upsample: bool = False
|
|
177
|
+
"""
|
|
178
|
+
Whether to upsample the dataset to attempt to balance the classes.
|
|
179
|
+
"""
|
|
180
|
+
min_abs_bbox_size: int = 0
|
|
181
|
+
"""
|
|
182
|
+
Minimum valid size (in pixels) of a bounding box to keep it in the dataset for optimization.
|
|
183
|
+
"""
|
|
184
|
+
min_abs_bbox_area: int = 0
|
|
185
|
+
"""
|
|
186
|
+
Minimum valid absolute area (in pixels²) of a bounding box to keep it in the dataset for optimization.
|
|
187
|
+
"""
|
|
188
|
+
min_rel_bbox_size: float = 0.0
|
|
189
|
+
"""
|
|
190
|
+
Minimum valid size (as a fraction of both image height and width) for a bounding box
|
|
191
|
+
to keep it in the dataset for optimization, relative to the corresponding dimension size,
|
|
192
|
+
i.e. if the bounding box is 10% of the image width and 5% of the image height, it will be kept if this value is 0.05, but not if the
|
|
193
|
+
value is 0.06 (since both width and height are checked).
|
|
194
|
+
"""
|
|
195
|
+
min_rel_bbox_area: float = 0.0
|
|
196
|
+
"""
|
|
197
|
+
Minimum valid relative area (as a fraction of the image area) of a bounding box to keep it in the dataset for optimization.
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
RunArgs = typing.Union[VisionRunArgs]
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class AugmentationNames(str, Enum):
|
|
205
|
+
RandomHorizontalFlip = "RandomHorizontalFlip"
|
|
206
|
+
RandomVerticalFlip = "RandomVerticalFlip"
|
|
207
|
+
RandomRotation = "RandomRotation"
|
|
208
|
+
ColorJitter = "ColorJitter"
|
|
209
|
+
RandomAffine = "RandomAffine"
|
|
210
|
+
RandomPerspective = "RandomPerspective"
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class Modality(str, Enum):
|
|
214
|
+
IMAGE = "Image"
|
|
215
|
+
RADAR = "Radar"
|
|
216
|
+
EKG = "EKG"
|
|
217
|
+
|
|
218
|
+
|
|
100
219
|
class OptimizationDataset(BaseModel):
|
|
220
|
+
id: typing.Optional[int] = Field(default=None)
|
|
221
|
+
"""
|
|
222
|
+
The ID of the dataset created on the server.
|
|
223
|
+
"""
|
|
101
224
|
name: str
|
|
102
225
|
"""
|
|
103
226
|
The name of the dataset. Used to identify it amongst the list of datasets
|
|
104
227
|
belonging to your organization in `hirundo`.
|
|
105
228
|
"""
|
|
106
|
-
|
|
229
|
+
labeling_type: LabelingType
|
|
107
230
|
"""
|
|
108
|
-
Indicates the
|
|
109
|
-
- `
|
|
110
|
-
- `
|
|
231
|
+
Indicates the labeling type of the dataset. The labeling type can be one of the following:
|
|
232
|
+
- `LabelingType.SINGLE_LABEL_CLASSIFICATION`: Indicates that the dataset is for classification tasks
|
|
233
|
+
- `LabelingType.OBJECT_DETECTION`: Indicates that the dataset is for object detection tasks
|
|
234
|
+
- `LabelingType.SPEECH_TO_TEXT`: Indicates that the dataset is for speech-to-text tasks
|
|
111
235
|
"""
|
|
112
|
-
|
|
236
|
+
language: typing.Optional[str] = None
|
|
113
237
|
"""
|
|
114
|
-
|
|
115
|
-
If `None`, the `dataset_id` field must be set.
|
|
238
|
+
Language of the Speech-to-Text audio dataset. This is required for Speech-to-Text datasets.
|
|
116
239
|
"""
|
|
117
|
-
|
|
118
|
-
classes: typing.Optional[list[str]] = None
|
|
240
|
+
storage_config_id: typing.Optional[int] = None
|
|
119
241
|
"""
|
|
120
|
-
|
|
121
|
-
It is currently required for clarity and performance.
|
|
242
|
+
The ID of the storage config used to store the dataset and metadata.
|
|
122
243
|
"""
|
|
123
|
-
|
|
244
|
+
storage_config: typing.Optional[
|
|
245
|
+
typing.Union[StorageConfig, ResponseStorageConfig]
|
|
246
|
+
] = None
|
|
124
247
|
"""
|
|
125
|
-
The
|
|
126
|
-
Note: This path will be prefixed with the `StorageLink`'s `path`.
|
|
248
|
+
The `StorageConfig` instance to link to.
|
|
127
249
|
"""
|
|
128
|
-
|
|
250
|
+
data_root_url: HirundoUrl
|
|
129
251
|
"""
|
|
130
|
-
|
|
131
|
-
- `
|
|
252
|
+
URL for data (e.g. images) within the `StorageConfig` instance,
|
|
253
|
+
e.g. `s3://my-bucket-name/my-images-folder`, `gs://my-bucket-name/my-images-folder`,
|
|
254
|
+
or `ssh://my-username@my-repo-name/my-images-folder`
|
|
255
|
+
(or `file:///datasets/my-images-folder` if using LOCAL storage type with on-premises installation)
|
|
132
256
|
|
|
133
|
-
|
|
257
|
+
Note: All CSV `image_path` entries in the metadata file should be relative to this folder.
|
|
134
258
|
"""
|
|
135
259
|
|
|
136
|
-
|
|
260
|
+
classes: typing.Optional[list[str]] = None
|
|
137
261
|
"""
|
|
138
|
-
|
|
262
|
+
A full list of possible classes used in classification / object detection.
|
|
263
|
+
It is currently required for clarity and performance.
|
|
139
264
|
"""
|
|
140
|
-
|
|
265
|
+
labeling_info: LabelingInfo
|
|
266
|
+
|
|
267
|
+
augmentations: typing.Optional[list[AugmentationNames]] = None
|
|
141
268
|
"""
|
|
142
|
-
|
|
269
|
+
Used to define which augmentations are apply to a vision dataset.
|
|
270
|
+
For audio datasets, this field is ignored.
|
|
271
|
+
If no value is provided, all augmentations are applied to vision datasets.
|
|
272
|
+
"""
|
|
273
|
+
modality: Modality = Modality.IMAGE
|
|
143
274
|
"""
|
|
275
|
+
Used to define the modality of the dataset.
|
|
276
|
+
Defaults to Image.
|
|
277
|
+
"""
|
|
278
|
+
|
|
144
279
|
run_id: typing.Optional[str] = Field(default=None, init=False)
|
|
145
280
|
"""
|
|
146
281
|
The ID of the Dataset Optimization run created on the server.
|
|
147
282
|
"""
|
|
148
283
|
|
|
284
|
+
status: typing.Optional[RunStatus] = None
|
|
285
|
+
|
|
149
286
|
@model_validator(mode="after")
|
|
150
287
|
def validate_dataset(self):
|
|
151
|
-
if self.
|
|
152
|
-
raise ValueError(
|
|
288
|
+
if self.storage_config is None and self.storage_config_id is None:
|
|
289
|
+
raise ValueError(
|
|
290
|
+
"No dataset storage has been provided. Provide one via `storage_config` or `storage_config_id`"
|
|
291
|
+
)
|
|
292
|
+
elif self.storage_config is not None and self.storage_config_id is not None:
|
|
293
|
+
raise ValueError(
|
|
294
|
+
"Both `storage_config` and `storage_config_id` have been provided. Pick one."
|
|
295
|
+
)
|
|
296
|
+
if self.labeling_type == LabelingType.SPEECH_TO_TEXT and self.language is None:
|
|
297
|
+
raise ValueError("Language is required for Speech-to-Text datasets.")
|
|
298
|
+
elif (
|
|
299
|
+
self.labeling_type != LabelingType.SPEECH_TO_TEXT
|
|
300
|
+
and self.language is not None
|
|
301
|
+
):
|
|
302
|
+
raise ValueError("Language is only allowed for Speech-to-Text datasets.")
|
|
303
|
+
if (
|
|
304
|
+
self.labeling_info.type == DatasetMetadataType.YOLO
|
|
305
|
+
and isinstance(self.labeling_info, YOLO)
|
|
306
|
+
and (
|
|
307
|
+
self.labeling_info.data_yaml_url is not None
|
|
308
|
+
and self.classes is not None
|
|
309
|
+
)
|
|
310
|
+
):
|
|
311
|
+
raise ValueError(
|
|
312
|
+
"Only one of `classes` or `labeling_info.data_yaml_url` should be provided for YOLO datasets"
|
|
313
|
+
)
|
|
153
314
|
return self
|
|
154
315
|
|
|
155
316
|
@staticmethod
|
|
156
|
-
def
|
|
317
|
+
def get_by_id(dataset_id: int) -> "OptimizationDataset":
|
|
318
|
+
"""
|
|
319
|
+
Get a `OptimizationDataset` instance from the server by its ID
|
|
320
|
+
|
|
321
|
+
Args:
|
|
322
|
+
dataset_id: The ID of the `OptimizationDataset` instance to get
|
|
323
|
+
"""
|
|
324
|
+
response = requests.get(
|
|
325
|
+
f"{API_HOST}/dataset-optimization/dataset/{dataset_id}",
|
|
326
|
+
headers=get_auth_headers(),
|
|
327
|
+
timeout=READ_TIMEOUT,
|
|
328
|
+
)
|
|
329
|
+
raise_for_status_with_reason(response)
|
|
330
|
+
dataset = response.json()
|
|
331
|
+
return OptimizationDataset(**dataset)
|
|
332
|
+
|
|
333
|
+
@staticmethod
|
|
334
|
+
def get_by_name(name: str) -> "OptimizationDataset":
|
|
335
|
+
"""
|
|
336
|
+
Get a `OptimizationDataset` instance from the server by its name
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
name: The name of the `OptimizationDataset` instance to get
|
|
340
|
+
"""
|
|
341
|
+
response = requests.get(
|
|
342
|
+
f"{API_HOST}/dataset-optimization/dataset/by-name/{name}",
|
|
343
|
+
headers=get_auth_headers(),
|
|
344
|
+
timeout=READ_TIMEOUT,
|
|
345
|
+
)
|
|
346
|
+
raise_for_status_with_reason(response)
|
|
347
|
+
dataset = response.json()
|
|
348
|
+
return OptimizationDataset(**dataset)
|
|
349
|
+
|
|
350
|
+
@staticmethod
|
|
351
|
+
def list_datasets(
|
|
352
|
+
organization_id: typing.Optional[int] = None,
|
|
353
|
+
) -> list["DataOptimizationDatasetOut"]:
|
|
354
|
+
"""
|
|
355
|
+
Lists all the optimization datasets created by user's default organization
|
|
356
|
+
or the `organization_id` passed
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
organization_id: The ID of the organization to list the datasets for.
|
|
360
|
+
"""
|
|
361
|
+
response = requests.get(
|
|
362
|
+
f"{API_HOST}/dataset-optimization/dataset/",
|
|
363
|
+
params={"dataset_organization_id": organization_id},
|
|
364
|
+
headers=get_auth_headers(),
|
|
365
|
+
timeout=READ_TIMEOUT,
|
|
366
|
+
)
|
|
367
|
+
raise_for_status_with_reason(response)
|
|
368
|
+
datasets = response.json()
|
|
369
|
+
return [
|
|
370
|
+
DataOptimizationDatasetOut(
|
|
371
|
+
**ds,
|
|
372
|
+
)
|
|
373
|
+
for ds in datasets
|
|
374
|
+
]
|
|
375
|
+
|
|
376
|
+
@staticmethod
|
|
377
|
+
def list_runs(
|
|
378
|
+
organization_id: typing.Optional[int] = None,
|
|
379
|
+
) -> list["DataOptimizationRunOut"]:
|
|
157
380
|
"""
|
|
158
381
|
Lists all the `OptimizationDataset` instances created by user's default organization
|
|
159
382
|
or the `organization_id` passed
|
|
@@ -163,13 +386,19 @@ class OptimizationDataset(BaseModel):
|
|
|
163
386
|
organization_id: The ID of the organization to list the datasets for.
|
|
164
387
|
"""
|
|
165
388
|
response = requests.get(
|
|
166
|
-
f"{API_HOST}/dataset-optimization/
|
|
389
|
+
f"{API_HOST}/dataset-optimization/run/list",
|
|
167
390
|
params={"dataset_organization_id": organization_id},
|
|
168
391
|
headers=get_auth_headers(),
|
|
169
392
|
timeout=READ_TIMEOUT,
|
|
170
393
|
)
|
|
171
394
|
raise_for_status_with_reason(response)
|
|
172
|
-
|
|
395
|
+
runs = response.json()
|
|
396
|
+
return [
|
|
397
|
+
DataOptimizationRunOut(
|
|
398
|
+
**run,
|
|
399
|
+
)
|
|
400
|
+
for run in runs
|
|
401
|
+
]
|
|
173
402
|
|
|
174
403
|
@staticmethod
|
|
175
404
|
def delete_by_id(dataset_id: int) -> None:
|
|
@@ -187,51 +416,71 @@ class OptimizationDataset(BaseModel):
|
|
|
187
416
|
raise_for_status_with_reason(response)
|
|
188
417
|
logger.info("Deleted dataset with ID: %s", dataset_id)
|
|
189
418
|
|
|
190
|
-
def delete(self,
|
|
419
|
+
def delete(self, storage_config=True) -> None:
|
|
191
420
|
"""
|
|
192
421
|
Deletes the active `OptimizationDataset` instance from the server.
|
|
193
422
|
It can only be used on a `OptimizationDataset` instance that has been created.
|
|
194
423
|
|
|
195
424
|
Args:
|
|
196
|
-
|
|
425
|
+
storage_config: If True, the `OptimizationDataset`'s `StorageConfig` will also be deleted
|
|
197
426
|
|
|
198
|
-
Note: If `
|
|
199
|
-
This can either be set manually or by creating the `
|
|
427
|
+
Note: If `storage_config` is not set to `False` then the `storage_config_id` must be set
|
|
428
|
+
This can either be set manually or by creating the `StorageConfig` instance via the `OptimizationDataset`'s
|
|
200
429
|
`create` method
|
|
201
430
|
"""
|
|
202
|
-
if
|
|
203
|
-
if not self.
|
|
204
|
-
raise ValueError("No storage
|
|
205
|
-
|
|
206
|
-
if not self.
|
|
431
|
+
if storage_config:
|
|
432
|
+
if not self.storage_config_id:
|
|
433
|
+
raise ValueError("No storage config has been created")
|
|
434
|
+
StorageConfig.delete_by_id(self.storage_config_id)
|
|
435
|
+
if not self.id:
|
|
207
436
|
raise ValueError("No dataset has been created")
|
|
208
|
-
self.delete_by_id(self.
|
|
437
|
+
self.delete_by_id(self.id)
|
|
209
438
|
|
|
210
|
-
def create(
|
|
439
|
+
def create(
|
|
440
|
+
self,
|
|
441
|
+
organization_id: typing.Optional[int] = None,
|
|
442
|
+
replace_if_exists: bool = False,
|
|
443
|
+
) -> int:
|
|
211
444
|
"""
|
|
212
445
|
Create a `OptimizationDataset` instance on the server.
|
|
213
|
-
If `
|
|
446
|
+
If the `storage_config_id` field is not set, the storage config will also be created and the field will be set.
|
|
447
|
+
|
|
448
|
+
Args:
|
|
449
|
+
organization_id: The ID of the organization to create the dataset for.
|
|
450
|
+
replace_if_exists: If True, the dataset will be replaced if it already exists
|
|
451
|
+
(this is determined by a dataset of the same name in the same organization).
|
|
452
|
+
|
|
453
|
+
Returns:
|
|
454
|
+
The ID of the created `OptimizationDataset` instance
|
|
214
455
|
"""
|
|
215
|
-
if
|
|
456
|
+
if self.storage_config is None and self.storage_config_id is None:
|
|
216
457
|
raise ValueError("No dataset storage has been provided")
|
|
217
|
-
|
|
218
|
-
self.
|
|
219
|
-
|
|
220
|
-
|
|
458
|
+
elif self.storage_config and self.storage_config_id is None:
|
|
459
|
+
if isinstance(self.storage_config, ResponseStorageConfig):
|
|
460
|
+
self.storage_config_id = self.storage_config.id
|
|
461
|
+
elif isinstance(self.storage_config, StorageConfig):
|
|
462
|
+
self.storage_config_id = self.storage_config.create(
|
|
463
|
+
replace_if_exists=replace_if_exists,
|
|
464
|
+
)
|
|
465
|
+
elif (
|
|
466
|
+
self.storage_config is not None
|
|
467
|
+
and self.storage_config_id is not None
|
|
468
|
+
and (
|
|
469
|
+
not isinstance(self.storage_config, ResponseStorageConfig)
|
|
470
|
+
or self.storage_config.id != self.storage_config_id
|
|
471
|
+
)
|
|
221
472
|
):
|
|
222
|
-
|
|
223
|
-
|
|
473
|
+
raise ValueError(
|
|
474
|
+
"Both `storage_config` and `storage_config_id` have been provided. Storage config IDs do not match."
|
|
224
475
|
)
|
|
225
|
-
model_dict = self.model_dump()
|
|
476
|
+
model_dict = self.model_dump(mode="json")
|
|
226
477
|
# ⬆️ Get dict of model fields from Pydantic model instance
|
|
227
478
|
dataset_response = requests.post(
|
|
228
479
|
f"{API_HOST}/dataset-optimization/dataset/",
|
|
229
480
|
json={
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
},
|
|
234
|
-
**{k: model_dict[k] for k in model_dict.keys() - {"dataset_storage"}},
|
|
481
|
+
**{k: model_dict[k] for k in model_dict.keys() - {"storage_config"}},
|
|
482
|
+
"organization_id": organization_id,
|
|
483
|
+
"replace_if_exists": replace_if_exists,
|
|
235
484
|
},
|
|
236
485
|
headers={
|
|
237
486
|
**json_headers,
|
|
@@ -240,14 +489,18 @@ class OptimizationDataset(BaseModel):
|
|
|
240
489
|
timeout=MODIFY_TIMEOUT,
|
|
241
490
|
)
|
|
242
491
|
raise_for_status_with_reason(dataset_response)
|
|
243
|
-
self.
|
|
244
|
-
if not self.
|
|
245
|
-
raise HirundoError("
|
|
246
|
-
logger.info("Created dataset with ID: %s", self.
|
|
247
|
-
return self.
|
|
492
|
+
self.id = dataset_response.json()["id"]
|
|
493
|
+
if not self.id:
|
|
494
|
+
raise HirundoError("An error ocurred while trying to create the dataset")
|
|
495
|
+
logger.info("Created dataset with ID: %s", self.id)
|
|
496
|
+
return self.id
|
|
248
497
|
|
|
249
498
|
@staticmethod
|
|
250
|
-
def launch_optimization_run(
|
|
499
|
+
def launch_optimization_run(
|
|
500
|
+
dataset_id: int,
|
|
501
|
+
organization_id: typing.Optional[int] = None,
|
|
502
|
+
run_args: typing.Optional[RunArgs] = None,
|
|
503
|
+
) -> str:
|
|
251
504
|
"""
|
|
252
505
|
Run the dataset optimization process on the server using the dataset with the given ID
|
|
253
506
|
i.e. `dataset_id`.
|
|
@@ -258,26 +511,62 @@ class OptimizationDataset(BaseModel):
|
|
|
258
511
|
Returns:
|
|
259
512
|
ID of the run (`run_id`).
|
|
260
513
|
"""
|
|
514
|
+
run_info = {}
|
|
515
|
+
if organization_id:
|
|
516
|
+
run_info["organization_id"] = organization_id
|
|
517
|
+
if run_args:
|
|
518
|
+
run_info["run_args"] = run_args.model_dump(mode="json")
|
|
261
519
|
run_response = requests.post(
|
|
262
520
|
f"{API_HOST}/dataset-optimization/run/{dataset_id}",
|
|
521
|
+
json=run_info if len(run_info) > 0 else None,
|
|
263
522
|
headers=get_auth_headers(),
|
|
264
523
|
timeout=MODIFY_TIMEOUT,
|
|
265
524
|
)
|
|
266
525
|
raise_for_status_with_reason(run_response)
|
|
267
526
|
return run_response.json()["run_id"]
|
|
268
527
|
|
|
269
|
-
def
|
|
528
|
+
def _validate_run_args(self, run_args: RunArgs) -> None:
|
|
529
|
+
if self.labeling_type == LabelingType.SPEECH_TO_TEXT:
|
|
530
|
+
raise Exception("Speech to text cannot have `run_args` set")
|
|
531
|
+
if self.labeling_type != LabelingType.OBJECT_DETECTION and any(
|
|
532
|
+
(
|
|
533
|
+
run_args.min_abs_bbox_size != 0,
|
|
534
|
+
run_args.min_abs_bbox_area != 0,
|
|
535
|
+
run_args.min_rel_bbox_size != 0,
|
|
536
|
+
run_args.min_rel_bbox_area != 0,
|
|
537
|
+
)
|
|
538
|
+
):
|
|
539
|
+
raise Exception(
|
|
540
|
+
"Cannot set `min_abs_bbox_size`, `min_abs_bbox_area`, "
|
|
541
|
+
+ "`min_rel_bbox_size`, or `min_rel_bbox_area` for "
|
|
542
|
+
+ f"labeling type {self.labeling_type}"
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
def run_optimization(
|
|
546
|
+
self,
|
|
547
|
+
organization_id: typing.Optional[int] = None,
|
|
548
|
+
replace_dataset_if_exists: bool = False,
|
|
549
|
+
run_args: typing.Optional[RunArgs] = None,
|
|
550
|
+
) -> str:
|
|
270
551
|
"""
|
|
271
552
|
If the dataset was not created on the server yet, it is created.
|
|
272
553
|
Run the dataset optimization process on the server using the active `OptimizationDataset` instance
|
|
273
554
|
|
|
555
|
+
Args:
|
|
556
|
+
organization_id: The ID of the organization to run the optimization for.
|
|
557
|
+
replace_dataset_if_exists: If True, the dataset will be replaced if it already exists
|
|
558
|
+
(this is determined by a dataset of the same name in the same organization).
|
|
559
|
+
run_args: The run arguments to use for the optimization run
|
|
560
|
+
|
|
274
561
|
Returns:
|
|
275
562
|
An ID of the run (`run_id`) and stores that `run_id` on the instance
|
|
276
563
|
"""
|
|
277
564
|
try:
|
|
278
|
-
if not self.
|
|
279
|
-
self.
|
|
280
|
-
|
|
565
|
+
if not self.id:
|
|
566
|
+
self.id = self.create(replace_if_exists=replace_dataset_if_exists)
|
|
567
|
+
if run_args is not None:
|
|
568
|
+
self._validate_run_args(run_args)
|
|
569
|
+
run_id = self.launch_optimization_run(self.id, organization_id, run_args)
|
|
281
570
|
self.run_id = run_id
|
|
282
571
|
logger.info("Started the run with ID: %s", run_id)
|
|
283
572
|
return run_id
|
|
@@ -293,17 +582,17 @@ class OptimizationDataset(BaseModel):
|
|
|
293
582
|
except Exception:
|
|
294
583
|
content = error.response.text
|
|
295
584
|
raise HirundoError(
|
|
296
|
-
f"
|
|
585
|
+
f"Unable to start the run. Status code: {error.response.status_code} Content: {content}"
|
|
297
586
|
) from error
|
|
298
587
|
except Exception as error:
|
|
299
|
-
raise HirundoError(f"
|
|
588
|
+
raise HirundoError(f"Unable to start the run: {error}") from error
|
|
300
589
|
|
|
301
590
|
def clean_ids(self):
|
|
302
591
|
"""
|
|
303
|
-
Reset `dataset_id`, `
|
|
592
|
+
Reset `dataset_id`, `storage_config_id`, and `run_id` values on the instance to default value of `None`
|
|
304
593
|
"""
|
|
305
|
-
self.
|
|
306
|
-
self.
|
|
594
|
+
self.storage_config_id = None
|
|
595
|
+
self.id = None
|
|
307
596
|
self.run_id = None
|
|
308
597
|
|
|
309
598
|
@staticmethod
|
|
@@ -370,7 +659,15 @@ class OptimizationDataset(BaseModel):
|
|
|
370
659
|
last_event = json.loads(sse.data)
|
|
371
660
|
if not last_event:
|
|
372
661
|
continue
|
|
373
|
-
data
|
|
662
|
+
if "data" in last_event:
|
|
663
|
+
data = last_event["data"]
|
|
664
|
+
else:
|
|
665
|
+
if "detail" in last_event:
|
|
666
|
+
raise HirundoError(last_event["detail"])
|
|
667
|
+
elif "reason" in last_event:
|
|
668
|
+
raise HirundoError(last_event["reason"])
|
|
669
|
+
else:
|
|
670
|
+
raise HirundoError("Unknown error")
|
|
374
671
|
OptimizationDataset._read_csvs_to_df(data)
|
|
375
672
|
yield data
|
|
376
673
|
if not last_event or last_event["data"]["state"] == RunStatus.PENDING.value:
|
|
@@ -420,7 +717,11 @@ class OptimizationDataset(BaseModel):
|
|
|
420
717
|
t.n = STATUS_TO_PROGRESS_MAP[iteration["state"]]
|
|
421
718
|
logger.debug("Setting progress to %s", t.n)
|
|
422
719
|
t.refresh()
|
|
423
|
-
if iteration["state"]
|
|
720
|
+
if iteration["state"] in [
|
|
721
|
+
RunStatus.FAILURE.value,
|
|
722
|
+
RunStatus.REJECTED.value,
|
|
723
|
+
RunStatus.REVOKED.value,
|
|
724
|
+
]:
|
|
424
725
|
raise HirundoError(
|
|
425
726
|
f"Optimization run failed with error: {iteration['result']}"
|
|
426
727
|
)
|
|
@@ -445,13 +746,22 @@ class OptimizationDataset(BaseModel):
|
|
|
445
746
|
and iteration["result"]["result"]
|
|
446
747
|
and isinstance(iteration["result"]["result"], str)
|
|
447
748
|
):
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
749
|
+
result_info = iteration["result"]["result"].split(":")
|
|
750
|
+
if len(result_info) > 1:
|
|
751
|
+
stage = result_info[0]
|
|
752
|
+
current_progress_percentage = float(
|
|
753
|
+
result_info[1].removeprefix(" ").removesuffix("% done")
|
|
754
|
+
)
|
|
755
|
+
elif len(result_info) == 1:
|
|
756
|
+
stage = result_info[0]
|
|
757
|
+
current_progress_percentage = t.n # Keep the same progress
|
|
758
|
+
else:
|
|
759
|
+
stage = "Unknown progress state"
|
|
760
|
+
current_progress_percentage = t.n # Keep the same progress
|
|
451
761
|
desc = (
|
|
452
762
|
"Optimization run completed. Uploading results"
|
|
453
763
|
if current_progress_percentage == 100.0
|
|
454
|
-
else
|
|
764
|
+
else stage
|
|
455
765
|
)
|
|
456
766
|
t.set_description(desc)
|
|
457
767
|
t.n = current_progress_percentage
|
|
@@ -574,3 +884,31 @@ class OptimizationDataset(BaseModel):
|
|
|
574
884
|
if not self.run_id:
|
|
575
885
|
raise ValueError("No run has been started")
|
|
576
886
|
self.cancel_by_id(self.run_id)
|
|
887
|
+
|
|
888
|
+
|
|
889
|
+
class DataOptimizationDatasetOut(BaseModel):
|
|
890
|
+
id: int
|
|
891
|
+
|
|
892
|
+
name: str
|
|
893
|
+
labeling_type: LabelingType
|
|
894
|
+
|
|
895
|
+
storage_config: ResponseStorageConfig
|
|
896
|
+
|
|
897
|
+
data_root_url: HirundoUrl
|
|
898
|
+
|
|
899
|
+
classes: typing.Optional[list[str]] = None
|
|
900
|
+
labeling_info: LabelingInfo
|
|
901
|
+
|
|
902
|
+
organization_id: typing.Optional[int]
|
|
903
|
+
creator_id: typing.Optional[int]
|
|
904
|
+
created_at: datetime.datetime
|
|
905
|
+
updated_at: datetime.datetime
|
|
906
|
+
|
|
907
|
+
|
|
908
|
+
class DataOptimizationRunOut(BaseModel):
|
|
909
|
+
id: int
|
|
910
|
+
name: str
|
|
911
|
+
run_id: str
|
|
912
|
+
status: RunStatus
|
|
913
|
+
approved: bool
|
|
914
|
+
created_at: datetime.datetime
|