hirundo 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hirundo/__init__.py +17 -9
- hirundo/_constraints.py +34 -2
- hirundo/_env.py +12 -1
- hirundo/_http.py +19 -0
- hirundo/_iter_sse_retrying.py +63 -19
- hirundo/cli.py +75 -16
- hirundo/dataset_optimization.py +519 -127
- hirundo/enum.py +8 -5
- hirundo/git.py +95 -28
- hirundo/logger.py +3 -1
- hirundo/storage.py +246 -75
- hirundo-0.1.9.dist-info/METADATA +212 -0
- hirundo-0.1.9.dist-info/RECORD +20 -0
- {hirundo-0.1.7.dist-info → hirundo-0.1.9.dist-info}/WHEEL +1 -1
- hirundo-0.1.7.dist-info/METADATA +0 -118
- hirundo-0.1.7.dist-info/RECORD +0 -19
- {hirundo-0.1.7.dist-info → hirundo-0.1.9.dist-info}/LICENSE +0 -0
- {hirundo-0.1.7.dist-info → hirundo-0.1.9.dist-info}/entry_points.txt +0 -0
- {hirundo-0.1.7.dist-info → hirundo-0.1.9.dist-info}/top_level.txt +0 -0
hirundo/dataset_optimization.py
CHANGED
|
@@ -1,24 +1,30 @@
|
|
|
1
|
+
import datetime
|
|
1
2
|
import json
|
|
2
3
|
import typing
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
3
5
|
from collections.abc import AsyncGenerator, Generator
|
|
4
6
|
from enum import Enum
|
|
5
7
|
from io import StringIO
|
|
6
|
-
from typing import
|
|
8
|
+
from typing import overload
|
|
7
9
|
|
|
8
10
|
import httpx
|
|
11
|
+
import numpy as np
|
|
9
12
|
import pandas as pd
|
|
10
13
|
import requests
|
|
14
|
+
from pandas._typing import DtypeArg
|
|
11
15
|
from pydantic import BaseModel, Field, model_validator
|
|
12
16
|
from tqdm import tqdm
|
|
13
17
|
from tqdm.contrib.logging import logging_redirect_tqdm
|
|
14
18
|
|
|
19
|
+
from hirundo._constraints import HirundoUrl
|
|
15
20
|
from hirundo._env import API_HOST
|
|
16
21
|
from hirundo._headers import get_auth_headers, json_headers
|
|
22
|
+
from hirundo._http import raise_for_status_with_reason
|
|
17
23
|
from hirundo._iter_sse_retrying import aiter_sse_retrying, iter_sse_retrying
|
|
18
24
|
from hirundo._timeouts import MODIFY_TIMEOUT, READ_TIMEOUT
|
|
19
|
-
from hirundo.enum import DatasetMetadataType,
|
|
25
|
+
from hirundo.enum import DatasetMetadataType, LabelingType
|
|
20
26
|
from hirundo.logger import get_logger
|
|
21
|
-
from hirundo.storage import
|
|
27
|
+
from hirundo.storage import ResponseStorageConfig, StorageConfig
|
|
22
28
|
|
|
23
29
|
logger = get_logger(__name__)
|
|
24
30
|
|
|
@@ -35,70 +41,342 @@ MAX_RETRIES = 200 # Max 200 retries for HTTP SSE connection
|
|
|
35
41
|
|
|
36
42
|
|
|
37
43
|
class RunStatus(Enum):
|
|
38
|
-
STARTED = "STARTED"
|
|
39
44
|
PENDING = "PENDING"
|
|
45
|
+
STARTED = "STARTED"
|
|
40
46
|
SUCCESS = "SUCCESS"
|
|
41
47
|
FAILURE = "FAILURE"
|
|
42
48
|
AWAITING_MANUAL_APPROVAL = "AWAITING MANUAL APPROVAL"
|
|
49
|
+
REVOKED = "REVOKED"
|
|
50
|
+
REJECTED = "REJECTED"
|
|
51
|
+
RETRY = "RETRY"
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
STATUS_TO_TEXT_MAP = {
|
|
55
|
+
RunStatus.STARTED.value: "Optimization run in progress. Downloading dataset",
|
|
56
|
+
RunStatus.PENDING.value: "Optimization run queued and not yet started",
|
|
57
|
+
RunStatus.SUCCESS.value: "Optimization run completed successfully",
|
|
58
|
+
RunStatus.FAILURE.value: "Optimization run failed",
|
|
59
|
+
RunStatus.AWAITING_MANUAL_APPROVAL.value: "Awaiting manual approval",
|
|
60
|
+
RunStatus.RETRY.value: "Optimization run failed. Retrying",
|
|
61
|
+
RunStatus.REVOKED.value: "Optimization run was cancelled",
|
|
62
|
+
RunStatus.REJECTED.value: "Optimization run was rejected",
|
|
63
|
+
}
|
|
64
|
+
STATUS_TO_PROGRESS_MAP = {
|
|
65
|
+
RunStatus.STARTED.value: 0.0,
|
|
66
|
+
RunStatus.PENDING.value: 0.0,
|
|
67
|
+
RunStatus.SUCCESS.value: 100.0,
|
|
68
|
+
RunStatus.FAILURE.value: 100.0,
|
|
69
|
+
RunStatus.AWAITING_MANUAL_APPROVAL.value: 100.0,
|
|
70
|
+
RunStatus.RETRY.value: 0.0,
|
|
71
|
+
RunStatus.REVOKED.value: 100.0,
|
|
72
|
+
RunStatus.REJECTED.value: 0.0,
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class DatasetOptimizationResults(BaseModel):
|
|
77
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
78
|
+
|
|
79
|
+
suspects: pd.DataFrame
|
|
80
|
+
"""
|
|
81
|
+
A pandas DataFrame containing the results of the optimization run
|
|
82
|
+
"""
|
|
83
|
+
warnings_and_errors: pd.DataFrame
|
|
84
|
+
"""
|
|
85
|
+
A pandas DataFrame containing the warnings and errors of the optimization run
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
CUSTOMER_INTERCHANGE_DTYPES: DtypeArg = {
|
|
90
|
+
"image_path": str,
|
|
91
|
+
"label_path": str,
|
|
92
|
+
"segments_mask_path": str,
|
|
93
|
+
"segment_id": np.int32,
|
|
94
|
+
"label": str,
|
|
95
|
+
"bbox_id": str,
|
|
96
|
+
"xmin": np.float32,
|
|
97
|
+
"ymin": np.float32,
|
|
98
|
+
"xmax": np.float32,
|
|
99
|
+
"ymax": np.float32,
|
|
100
|
+
"suspect_level": np.float32, # If exists, must be one of the values in the enum below
|
|
101
|
+
"suggested_label": str,
|
|
102
|
+
"suggested_label_conf": np.float32,
|
|
103
|
+
"status": str,
|
|
104
|
+
# ⬆️ If exists, must be one of the following:
|
|
105
|
+
# NO_LABELS/MISSING_IMAGE/INVALID_IMAGE/INVALID_BBOX/INVALID_BBOX_SIZE/INVALID_SEG/INVALID_SEG_SIZE
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class Metadata(BaseModel, ABC):
|
|
110
|
+
type: DatasetMetadataType
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
@abstractmethod
|
|
114
|
+
def metadata_url(self) -> HirundoUrl:
|
|
115
|
+
raise NotImplementedError()
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class HirundoCSV(Metadata):
|
|
119
|
+
"""
|
|
120
|
+
A dataset metadata file in the Hirundo CSV format
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
type: DatasetMetadataType = DatasetMetadataType.HIRUNDO_CSV
|
|
124
|
+
csv_url: HirundoUrl
|
|
125
|
+
"""
|
|
126
|
+
The URL to access the dataset metadata CSV file.
|
|
127
|
+
e.g. `s3://my-bucket-name/my-folder/my-metadata.csv`, `gs://my-bucket-name/my-folder/my-metadata.csv`,
|
|
128
|
+
or `ssh://my-username@my-repo-name/my-folder/my-metadata.csv`
|
|
129
|
+
(or `file:///datasets/my-folder/my-metadata.csv` if using LOCAL storage type with on-premises installation)
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def metadata_url(self) -> HirundoUrl:
|
|
134
|
+
return self.csv_url
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class COCO(Metadata):
|
|
138
|
+
"""
|
|
139
|
+
A dataset metadata file in the COCO format
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
type: DatasetMetadataType = DatasetMetadataType.COCO
|
|
143
|
+
json_url: HirundoUrl
|
|
144
|
+
"""
|
|
145
|
+
The URL to access the dataset metadata JSON file.
|
|
146
|
+
e.g. `s3://my-bucket-name/my-folder/my-metadata.json`, `gs://my-bucket-name/my-folder/my-metadata.json`,
|
|
147
|
+
or `ssh://my-username@my-repo-name/my-folder/my-metadata.json`
|
|
148
|
+
(or `file:///datasets/my-folder/my-metadata.json` if using LOCAL storage type with on-premises installation)
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def metadata_url(self) -> HirundoUrl:
|
|
153
|
+
return self.json_url
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class YOLO(Metadata):
|
|
157
|
+
type: DatasetMetadataType = DatasetMetadataType.YOLO
|
|
158
|
+
data_yaml_url: typing.Optional[HirundoUrl] = None
|
|
159
|
+
labels_dir_url: HirundoUrl
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def metadata_url(self) -> HirundoUrl:
|
|
163
|
+
return self.labels_dir_url
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
LabelingInfo = typing.Union[HirundoCSV, COCO, YOLO]
|
|
167
|
+
"""
|
|
168
|
+
The dataset labeling info. The dataset labeling info can be one of the following:
|
|
169
|
+
- `DatasetMetadataType.HirundoCSV`: Indicates that the dataset metadata file is a CSV file with the Hirundo format
|
|
170
|
+
|
|
171
|
+
Currently no other formats are supported. Future versions of `hirundo` may support additional formats.
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class VisionRunArgs(BaseModel):
|
|
176
|
+
upsample: bool = False
|
|
177
|
+
"""
|
|
178
|
+
Whether to upsample the dataset to attempt to balance the classes.
|
|
179
|
+
"""
|
|
180
|
+
min_abs_bbox_size: int = 0
|
|
181
|
+
"""
|
|
182
|
+
Minimum valid size (in pixels) of a bounding box to keep it in the dataset for optimization.
|
|
183
|
+
"""
|
|
184
|
+
min_abs_bbox_area: int = 0
|
|
185
|
+
"""
|
|
186
|
+
Minimum valid absolute area (in pixels²) of a bounding box to keep it in the dataset for optimization.
|
|
187
|
+
"""
|
|
188
|
+
min_rel_bbox_size: float = 0.0
|
|
189
|
+
"""
|
|
190
|
+
Minimum valid size (as a fraction of both image height and width) for a bounding box
|
|
191
|
+
to keep it in the dataset for optimization, relative to the corresponding dimension size,
|
|
192
|
+
i.e. if the bounding box is 10% of the image width and 5% of the image height, it will be kept if this value is 0.05, but not if the
|
|
193
|
+
value is 0.06 (since both width and height are checked).
|
|
194
|
+
"""
|
|
195
|
+
min_rel_bbox_area: float = 0.0
|
|
196
|
+
"""
|
|
197
|
+
Minimum valid relative area (as a fraction of the image area) of a bounding box to keep it in the dataset for optimization.
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
RunArgs = typing.Union[VisionRunArgs]
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class AugmentationNames(str, Enum):
|
|
205
|
+
RandomHorizontalFlip = "RandomHorizontalFlip"
|
|
206
|
+
RandomVerticalFlip = "RandomVerticalFlip"
|
|
207
|
+
RandomRotation = "RandomRotation"
|
|
208
|
+
ColorJitter = "ColorJitter"
|
|
209
|
+
RandomAffine = "RandomAffine"
|
|
210
|
+
RandomPerspective = "RandomPerspective"
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class Modality(str, Enum):
|
|
214
|
+
IMAGE = "Image"
|
|
215
|
+
RADAR = "Radar"
|
|
216
|
+
EKG = "EKG"
|
|
43
217
|
|
|
44
218
|
|
|
45
219
|
class OptimizationDataset(BaseModel):
|
|
220
|
+
id: typing.Optional[int] = Field(default=None)
|
|
221
|
+
"""
|
|
222
|
+
The ID of the dataset created on the server.
|
|
223
|
+
"""
|
|
46
224
|
name: str
|
|
47
225
|
"""
|
|
48
226
|
The name of the dataset. Used to identify it amongst the list of datasets
|
|
49
227
|
belonging to your organization in `hirundo`.
|
|
50
228
|
"""
|
|
51
|
-
|
|
229
|
+
labeling_type: LabelingType
|
|
52
230
|
"""
|
|
53
|
-
Indicates the
|
|
54
|
-
- `
|
|
55
|
-
- `
|
|
231
|
+
Indicates the labeling type of the dataset. The labeling type can be one of the following:
|
|
232
|
+
- `LabelingType.SINGLE_LABEL_CLASSIFICATION`: Indicates that the dataset is for classification tasks
|
|
233
|
+
- `LabelingType.OBJECT_DETECTION`: Indicates that the dataset is for object detection tasks
|
|
234
|
+
- `LabelingType.SPEECH_TO_TEXT`: Indicates that the dataset is for speech-to-text tasks
|
|
56
235
|
"""
|
|
57
|
-
|
|
236
|
+
language: typing.Optional[str] = None
|
|
58
237
|
"""
|
|
59
|
-
|
|
60
|
-
If `None`, the `dataset_id` field must be set.
|
|
238
|
+
Language of the Speech-to-Text audio dataset. This is required for Speech-to-Text datasets.
|
|
61
239
|
"""
|
|
62
|
-
|
|
63
|
-
classes: list[str]
|
|
240
|
+
storage_config_id: typing.Optional[int] = None
|
|
64
241
|
"""
|
|
65
|
-
|
|
66
|
-
|
|
242
|
+
The ID of the storage config used to store the dataset and metadata.
|
|
243
|
+
"""
|
|
244
|
+
storage_config: typing.Optional[
|
|
245
|
+
typing.Union[StorageConfig, ResponseStorageConfig]
|
|
246
|
+
] = None
|
|
67
247
|
"""
|
|
68
|
-
|
|
248
|
+
The `StorageConfig` instance to link to.
|
|
69
249
|
"""
|
|
70
|
-
|
|
71
|
-
Note: This path will be prefixed with the `StorageLink`'s `path`.
|
|
250
|
+
data_root_url: HirundoUrl
|
|
72
251
|
"""
|
|
73
|
-
|
|
252
|
+
URL for data (e.g. images) within the `StorageConfig` instance,
|
|
253
|
+
e.g. `s3://my-bucket-name/my-images-folder`, `gs://my-bucket-name/my-images-folder`,
|
|
254
|
+
or `ssh://my-username@my-repo-name/my-images-folder`
|
|
255
|
+
(or `file:///datasets/my-images-folder` if using LOCAL storage type with on-premises installation)
|
|
256
|
+
|
|
257
|
+
Note: All CSV `image_path` entries in the metadata file should be relative to this folder.
|
|
74
258
|
"""
|
|
75
|
-
The type of dataset metadata file. The dataset metadata file can be one of the following:
|
|
76
|
-
- `DatasetMetadataType.HirundoCSV`: Indicates that the dataset metadata file is a CSV file with the Hirundo format
|
|
77
259
|
|
|
78
|
-
|
|
260
|
+
classes: typing.Optional[list[str]] = None
|
|
261
|
+
"""
|
|
262
|
+
A full list of possible classes used in classification / object detection.
|
|
263
|
+
It is currently required for clarity and performance.
|
|
79
264
|
"""
|
|
265
|
+
labeling_info: LabelingInfo
|
|
80
266
|
|
|
81
|
-
|
|
267
|
+
augmentations: typing.Optional[list[AugmentationNames]] = None
|
|
82
268
|
"""
|
|
83
|
-
|
|
269
|
+
Used to define which augmentations are apply to a vision dataset.
|
|
270
|
+
For audio datasets, this field is ignored.
|
|
271
|
+
If no value is provided, all augmentations are applied to vision datasets.
|
|
84
272
|
"""
|
|
85
|
-
|
|
273
|
+
modality: Modality = Modality.IMAGE
|
|
86
274
|
"""
|
|
87
|
-
|
|
275
|
+
Used to define the modality of the dataset.
|
|
276
|
+
Defaults to Image.
|
|
88
277
|
"""
|
|
89
|
-
|
|
278
|
+
|
|
279
|
+
run_id: typing.Optional[str] = Field(default=None, init=False)
|
|
90
280
|
"""
|
|
91
281
|
The ID of the Dataset Optimization run created on the server.
|
|
92
282
|
"""
|
|
93
283
|
|
|
284
|
+
status: typing.Optional[RunStatus] = None
|
|
285
|
+
|
|
94
286
|
@model_validator(mode="after")
|
|
95
287
|
def validate_dataset(self):
|
|
96
|
-
if self.
|
|
97
|
-
raise ValueError(
|
|
288
|
+
if self.storage_config is None and self.storage_config_id is None:
|
|
289
|
+
raise ValueError(
|
|
290
|
+
"No dataset storage has been provided. Provide one via `storage_config` or `storage_config_id`"
|
|
291
|
+
)
|
|
292
|
+
elif self.storage_config is not None and self.storage_config_id is not None:
|
|
293
|
+
raise ValueError(
|
|
294
|
+
"Both `storage_config` and `storage_config_id` have been provided. Pick one."
|
|
295
|
+
)
|
|
296
|
+
if self.labeling_type == LabelingType.SPEECH_TO_TEXT and self.language is None:
|
|
297
|
+
raise ValueError("Language is required for Speech-to-Text datasets.")
|
|
298
|
+
elif (
|
|
299
|
+
self.labeling_type != LabelingType.SPEECH_TO_TEXT
|
|
300
|
+
and self.language is not None
|
|
301
|
+
):
|
|
302
|
+
raise ValueError("Language is only allowed for Speech-to-Text datasets.")
|
|
303
|
+
if (
|
|
304
|
+
self.labeling_info.type == DatasetMetadataType.YOLO
|
|
305
|
+
and isinstance(self.labeling_info, YOLO)
|
|
306
|
+
and (
|
|
307
|
+
self.labeling_info.data_yaml_url is not None
|
|
308
|
+
and self.classes is not None
|
|
309
|
+
)
|
|
310
|
+
):
|
|
311
|
+
raise ValueError(
|
|
312
|
+
"Only one of `classes` or `labeling_info.data_yaml_url` should be provided for YOLO datasets"
|
|
313
|
+
)
|
|
98
314
|
return self
|
|
99
315
|
|
|
100
316
|
@staticmethod
|
|
101
|
-
def
|
|
317
|
+
def get_by_id(dataset_id: int) -> "OptimizationDataset":
|
|
318
|
+
"""
|
|
319
|
+
Get a `OptimizationDataset` instance from the server by its ID
|
|
320
|
+
|
|
321
|
+
Args:
|
|
322
|
+
dataset_id: The ID of the `OptimizationDataset` instance to get
|
|
323
|
+
"""
|
|
324
|
+
response = requests.get(
|
|
325
|
+
f"{API_HOST}/dataset-optimization/dataset/{dataset_id}",
|
|
326
|
+
headers=get_auth_headers(),
|
|
327
|
+
timeout=READ_TIMEOUT,
|
|
328
|
+
)
|
|
329
|
+
raise_for_status_with_reason(response)
|
|
330
|
+
dataset = response.json()
|
|
331
|
+
return OptimizationDataset(**dataset)
|
|
332
|
+
|
|
333
|
+
@staticmethod
|
|
334
|
+
def get_by_name(name: str) -> "OptimizationDataset":
|
|
335
|
+
"""
|
|
336
|
+
Get a `OptimizationDataset` instance from the server by its name
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
name: The name of the `OptimizationDataset` instance to get
|
|
340
|
+
"""
|
|
341
|
+
response = requests.get(
|
|
342
|
+
f"{API_HOST}/dataset-optimization/dataset/by-name/{name}",
|
|
343
|
+
headers=get_auth_headers(),
|
|
344
|
+
timeout=READ_TIMEOUT,
|
|
345
|
+
)
|
|
346
|
+
raise_for_status_with_reason(response)
|
|
347
|
+
dataset = response.json()
|
|
348
|
+
return OptimizationDataset(**dataset)
|
|
349
|
+
|
|
350
|
+
@staticmethod
|
|
351
|
+
def list_datasets(
|
|
352
|
+
organization_id: typing.Optional[int] = None,
|
|
353
|
+
) -> list["DataOptimizationDatasetOut"]:
|
|
354
|
+
"""
|
|
355
|
+
Lists all the optimization datasets created by user's default organization
|
|
356
|
+
or the `organization_id` passed
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
organization_id: The ID of the organization to list the datasets for.
|
|
360
|
+
"""
|
|
361
|
+
response = requests.get(
|
|
362
|
+
f"{API_HOST}/dataset-optimization/dataset/",
|
|
363
|
+
params={"dataset_organization_id": organization_id},
|
|
364
|
+
headers=get_auth_headers(),
|
|
365
|
+
timeout=READ_TIMEOUT,
|
|
366
|
+
)
|
|
367
|
+
raise_for_status_with_reason(response)
|
|
368
|
+
datasets = response.json()
|
|
369
|
+
return [
|
|
370
|
+
DataOptimizationDatasetOut(
|
|
371
|
+
**ds,
|
|
372
|
+
)
|
|
373
|
+
for ds in datasets
|
|
374
|
+
]
|
|
375
|
+
|
|
376
|
+
@staticmethod
|
|
377
|
+
def list_runs(
|
|
378
|
+
organization_id: typing.Optional[int] = None,
|
|
379
|
+
) -> list["DataOptimizationRunOut"]:
|
|
102
380
|
"""
|
|
103
381
|
Lists all the `OptimizationDataset` instances created by user's default organization
|
|
104
382
|
or the `organization_id` passed
|
|
@@ -108,13 +386,19 @@ class OptimizationDataset(BaseModel):
|
|
|
108
386
|
organization_id: The ID of the organization to list the datasets for.
|
|
109
387
|
"""
|
|
110
388
|
response = requests.get(
|
|
111
|
-
f"{API_HOST}/dataset-optimization/
|
|
389
|
+
f"{API_HOST}/dataset-optimization/run/list",
|
|
112
390
|
params={"dataset_organization_id": organization_id},
|
|
113
391
|
headers=get_auth_headers(),
|
|
114
392
|
timeout=READ_TIMEOUT,
|
|
115
393
|
)
|
|
116
|
-
response
|
|
117
|
-
|
|
394
|
+
raise_for_status_with_reason(response)
|
|
395
|
+
runs = response.json()
|
|
396
|
+
return [
|
|
397
|
+
DataOptimizationRunOut(
|
|
398
|
+
**run,
|
|
399
|
+
)
|
|
400
|
+
for run in runs
|
|
401
|
+
]
|
|
118
402
|
|
|
119
403
|
@staticmethod
|
|
120
404
|
def delete_by_id(dataset_id: int) -> None:
|
|
@@ -129,54 +413,74 @@ class OptimizationDataset(BaseModel):
|
|
|
129
413
|
headers=get_auth_headers(),
|
|
130
414
|
timeout=MODIFY_TIMEOUT,
|
|
131
415
|
)
|
|
132
|
-
response
|
|
416
|
+
raise_for_status_with_reason(response)
|
|
133
417
|
logger.info("Deleted dataset with ID: %s", dataset_id)
|
|
134
418
|
|
|
135
|
-
def delete(self,
|
|
419
|
+
def delete(self, storage_config=True) -> None:
|
|
136
420
|
"""
|
|
137
421
|
Deletes the active `OptimizationDataset` instance from the server.
|
|
138
422
|
It can only be used on a `OptimizationDataset` instance that has been created.
|
|
139
423
|
|
|
140
424
|
Args:
|
|
141
|
-
|
|
425
|
+
storage_config: If True, the `OptimizationDataset`'s `StorageConfig` will also be deleted
|
|
142
426
|
|
|
143
|
-
Note: If `
|
|
144
|
-
This can either be set manually or by creating the `
|
|
427
|
+
Note: If `storage_config` is not set to `False` then the `storage_config_id` must be set
|
|
428
|
+
This can either be set manually or by creating the `StorageConfig` instance via the `OptimizationDataset`'s
|
|
145
429
|
`create` method
|
|
146
430
|
"""
|
|
147
|
-
if
|
|
148
|
-
if not self.
|
|
149
|
-
raise ValueError("No storage
|
|
150
|
-
|
|
151
|
-
if not self.
|
|
431
|
+
if storage_config:
|
|
432
|
+
if not self.storage_config_id:
|
|
433
|
+
raise ValueError("No storage config has been created")
|
|
434
|
+
StorageConfig.delete_by_id(self.storage_config_id)
|
|
435
|
+
if not self.id:
|
|
152
436
|
raise ValueError("No dataset has been created")
|
|
153
|
-
self.delete_by_id(self.
|
|
437
|
+
self.delete_by_id(self.id)
|
|
154
438
|
|
|
155
|
-
def create(
|
|
439
|
+
def create(
|
|
440
|
+
self,
|
|
441
|
+
organization_id: typing.Optional[int] = None,
|
|
442
|
+
replace_if_exists: bool = False,
|
|
443
|
+
) -> int:
|
|
156
444
|
"""
|
|
157
445
|
Create a `OptimizationDataset` instance on the server.
|
|
158
|
-
If `
|
|
446
|
+
If the `storage_config_id` field is not set, the storage config will also be created and the field will be set.
|
|
447
|
+
|
|
448
|
+
Args:
|
|
449
|
+
organization_id: The ID of the organization to create the dataset for.
|
|
450
|
+
replace_if_exists: If True, the dataset will be replaced if it already exists
|
|
451
|
+
(this is determined by a dataset of the same name in the same organization).
|
|
452
|
+
|
|
453
|
+
Returns:
|
|
454
|
+
The ID of the created `OptimizationDataset` instance
|
|
159
455
|
"""
|
|
160
|
-
if
|
|
456
|
+
if self.storage_config is None and self.storage_config_id is None:
|
|
161
457
|
raise ValueError("No dataset storage has been provided")
|
|
162
|
-
|
|
163
|
-
self.
|
|
164
|
-
|
|
165
|
-
|
|
458
|
+
elif self.storage_config and self.storage_config_id is None:
|
|
459
|
+
if isinstance(self.storage_config, ResponseStorageConfig):
|
|
460
|
+
self.storage_config_id = self.storage_config.id
|
|
461
|
+
elif isinstance(self.storage_config, StorageConfig):
|
|
462
|
+
self.storage_config_id = self.storage_config.create(
|
|
463
|
+
replace_if_exists=replace_if_exists,
|
|
464
|
+
)
|
|
465
|
+
elif (
|
|
466
|
+
self.storage_config is not None
|
|
467
|
+
and self.storage_config_id is not None
|
|
468
|
+
and (
|
|
469
|
+
not isinstance(self.storage_config, ResponseStorageConfig)
|
|
470
|
+
or self.storage_config.id != self.storage_config_id
|
|
471
|
+
)
|
|
166
472
|
):
|
|
167
|
-
|
|
168
|
-
|
|
473
|
+
raise ValueError(
|
|
474
|
+
"Both `storage_config` and `storage_config_id` have been provided. Storage config IDs do not match."
|
|
169
475
|
)
|
|
170
|
-
model_dict = self.model_dump()
|
|
476
|
+
model_dict = self.model_dump(mode="json")
|
|
171
477
|
# ⬆️ Get dict of model fields from Pydantic model instance
|
|
172
478
|
dataset_response = requests.post(
|
|
173
479
|
f"{API_HOST}/dataset-optimization/dataset/",
|
|
174
480
|
json={
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
},
|
|
179
|
-
**{k: model_dict[k] for k in model_dict.keys() - {"dataset_storage"}},
|
|
481
|
+
**{k: model_dict[k] for k in model_dict.keys() - {"storage_config"}},
|
|
482
|
+
"organization_id": organization_id,
|
|
483
|
+
"replace_if_exists": replace_if_exists,
|
|
180
484
|
},
|
|
181
485
|
headers={
|
|
182
486
|
**json_headers,
|
|
@@ -184,15 +488,19 @@ class OptimizationDataset(BaseModel):
|
|
|
184
488
|
},
|
|
185
489
|
timeout=MODIFY_TIMEOUT,
|
|
186
490
|
)
|
|
187
|
-
dataset_response
|
|
188
|
-
self.
|
|
189
|
-
if not self.
|
|
190
|
-
raise HirundoError("
|
|
191
|
-
logger.info("Created dataset with ID: %s", self.
|
|
192
|
-
return self.
|
|
491
|
+
raise_for_status_with_reason(dataset_response)
|
|
492
|
+
self.id = dataset_response.json()["id"]
|
|
493
|
+
if not self.id:
|
|
494
|
+
raise HirundoError("An error ocurred while trying to create the dataset")
|
|
495
|
+
logger.info("Created dataset with ID: %s", self.id)
|
|
496
|
+
return self.id
|
|
193
497
|
|
|
194
498
|
@staticmethod
|
|
195
|
-
def launch_optimization_run(
|
|
499
|
+
def launch_optimization_run(
|
|
500
|
+
dataset_id: int,
|
|
501
|
+
organization_id: typing.Optional[int] = None,
|
|
502
|
+
run_args: typing.Optional[RunArgs] = None,
|
|
503
|
+
) -> str:
|
|
196
504
|
"""
|
|
197
505
|
Run the dataset optimization process on the server using the dataset with the given ID
|
|
198
506
|
i.e. `dataset_id`.
|
|
@@ -203,26 +511,62 @@ class OptimizationDataset(BaseModel):
|
|
|
203
511
|
Returns:
|
|
204
512
|
ID of the run (`run_id`).
|
|
205
513
|
"""
|
|
514
|
+
run_info = {}
|
|
515
|
+
if organization_id:
|
|
516
|
+
run_info["organization_id"] = organization_id
|
|
517
|
+
if run_args:
|
|
518
|
+
run_info["run_args"] = run_args.model_dump(mode="json")
|
|
206
519
|
run_response = requests.post(
|
|
207
520
|
f"{API_HOST}/dataset-optimization/run/{dataset_id}",
|
|
521
|
+
json=run_info if len(run_info) > 0 else None,
|
|
208
522
|
headers=get_auth_headers(),
|
|
209
523
|
timeout=MODIFY_TIMEOUT,
|
|
210
524
|
)
|
|
211
|
-
run_response
|
|
525
|
+
raise_for_status_with_reason(run_response)
|
|
212
526
|
return run_response.json()["run_id"]
|
|
213
527
|
|
|
214
|
-
def
|
|
528
|
+
def _validate_run_args(self, run_args: RunArgs) -> None:
|
|
529
|
+
if self.labeling_type == LabelingType.SPEECH_TO_TEXT:
|
|
530
|
+
raise Exception("Speech to text cannot have `run_args` set")
|
|
531
|
+
if self.labeling_type != LabelingType.OBJECT_DETECTION and any(
|
|
532
|
+
(
|
|
533
|
+
run_args.min_abs_bbox_size != 0,
|
|
534
|
+
run_args.min_abs_bbox_area != 0,
|
|
535
|
+
run_args.min_rel_bbox_size != 0,
|
|
536
|
+
run_args.min_rel_bbox_area != 0,
|
|
537
|
+
)
|
|
538
|
+
):
|
|
539
|
+
raise Exception(
|
|
540
|
+
"Cannot set `min_abs_bbox_size`, `min_abs_bbox_area`, "
|
|
541
|
+
+ "`min_rel_bbox_size`, or `min_rel_bbox_area` for "
|
|
542
|
+
+ f"labeling type {self.labeling_type}"
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
def run_optimization(
|
|
546
|
+
self,
|
|
547
|
+
organization_id: typing.Optional[int] = None,
|
|
548
|
+
replace_dataset_if_exists: bool = False,
|
|
549
|
+
run_args: typing.Optional[RunArgs] = None,
|
|
550
|
+
) -> str:
|
|
215
551
|
"""
|
|
216
552
|
If the dataset was not created on the server yet, it is created.
|
|
217
553
|
Run the dataset optimization process on the server using the active `OptimizationDataset` instance
|
|
218
554
|
|
|
555
|
+
Args:
|
|
556
|
+
organization_id: The ID of the organization to run the optimization for.
|
|
557
|
+
replace_dataset_if_exists: If True, the dataset will be replaced if it already exists
|
|
558
|
+
(this is determined by a dataset of the same name in the same organization).
|
|
559
|
+
run_args: The run arguments to use for the optimization run
|
|
560
|
+
|
|
219
561
|
Returns:
|
|
220
562
|
An ID of the run (`run_id`) and stores that `run_id` on the instance
|
|
221
563
|
"""
|
|
222
564
|
try:
|
|
223
|
-
if not self.
|
|
224
|
-
self.
|
|
225
|
-
|
|
565
|
+
if not self.id:
|
|
566
|
+
self.id = self.create(replace_if_exists=replace_dataset_if_exists)
|
|
567
|
+
if run_args is not None:
|
|
568
|
+
self._validate_run_args(run_args)
|
|
569
|
+
run_id = self.launch_optimization_run(self.id, organization_id, run_args)
|
|
226
570
|
self.run_id = run_id
|
|
227
571
|
logger.info("Started the run with ID: %s", run_id)
|
|
228
572
|
return run_id
|
|
@@ -238,17 +582,17 @@ class OptimizationDataset(BaseModel):
|
|
|
238
582
|
except Exception:
|
|
239
583
|
content = error.response.text
|
|
240
584
|
raise HirundoError(
|
|
241
|
-
f"
|
|
585
|
+
f"Unable to start the run. Status code: {error.response.status_code} Content: {content}"
|
|
242
586
|
) from error
|
|
243
587
|
except Exception as error:
|
|
244
|
-
raise HirundoError(f"
|
|
588
|
+
raise HirundoError(f"Unable to start the run: {error}") from error
|
|
245
589
|
|
|
246
590
|
def clean_ids(self):
|
|
247
591
|
"""
|
|
248
|
-
Reset `dataset_id`, `
|
|
592
|
+
Reset `dataset_id`, `storage_config_id`, and `run_id` values on the instance to default value of `None`
|
|
249
593
|
"""
|
|
250
|
-
self.
|
|
251
|
-
self.
|
|
594
|
+
self.storage_config_id = None
|
|
595
|
+
self.id = None
|
|
252
596
|
self.run_id = None
|
|
253
597
|
|
|
254
598
|
@staticmethod
|
|
@@ -274,10 +618,19 @@ class OptimizationDataset(BaseModel):
|
|
|
274
618
|
return df
|
|
275
619
|
|
|
276
620
|
@staticmethod
|
|
277
|
-
def
|
|
621
|
+
def _read_csvs_to_df(data: dict):
|
|
278
622
|
if data["state"] == RunStatus.SUCCESS.value:
|
|
279
|
-
data["result"] = OptimizationDataset._clean_df_index(
|
|
280
|
-
pd.read_csv(
|
|
623
|
+
data["result"]["suspects"] = OptimizationDataset._clean_df_index(
|
|
624
|
+
pd.read_csv(
|
|
625
|
+
StringIO(data["result"]["suspects"]),
|
|
626
|
+
dtype=CUSTOMER_INTERCHANGE_DTYPES,
|
|
627
|
+
)
|
|
628
|
+
)
|
|
629
|
+
data["result"]["warnings_and_errors"] = OptimizationDataset._clean_df_index(
|
|
630
|
+
pd.read_csv(
|
|
631
|
+
StringIO(data["result"]["warnings_and_errors"]),
|
|
632
|
+
dtype=CUSTOMER_INTERCHANGE_DTYPES,
|
|
633
|
+
)
|
|
281
634
|
)
|
|
282
635
|
else:
|
|
283
636
|
pass
|
|
@@ -306,8 +659,16 @@ class OptimizationDataset(BaseModel):
|
|
|
306
659
|
last_event = json.loads(sse.data)
|
|
307
660
|
if not last_event:
|
|
308
661
|
continue
|
|
309
|
-
data
|
|
310
|
-
|
|
662
|
+
if "data" in last_event:
|
|
663
|
+
data = last_event["data"]
|
|
664
|
+
else:
|
|
665
|
+
if "detail" in last_event:
|
|
666
|
+
raise HirundoError(last_event["detail"])
|
|
667
|
+
elif "reason" in last_event:
|
|
668
|
+
raise HirundoError(last_event["reason"])
|
|
669
|
+
else:
|
|
670
|
+
raise HirundoError("Unknown error")
|
|
671
|
+
OptimizationDataset._read_csvs_to_df(data)
|
|
311
672
|
yield data
|
|
312
673
|
if not last_event or last_event["data"]["state"] == RunStatus.PENDING.value:
|
|
313
674
|
OptimizationDataset._check_run_by_id(run_id, retry + 1)
|
|
@@ -316,27 +677,24 @@ class OptimizationDataset(BaseModel):
|
|
|
316
677
|
@overload
|
|
317
678
|
def check_run_by_id(
|
|
318
679
|
run_id: str, stop_on_manual_approval: typing.Literal[True]
|
|
319
|
-
) -> typing.Optional[
|
|
320
|
-
...
|
|
680
|
+
) -> typing.Optional[DatasetOptimizationResults]: ...
|
|
321
681
|
|
|
322
682
|
@staticmethod
|
|
323
683
|
@overload
|
|
324
684
|
def check_run_by_id(
|
|
325
685
|
run_id: str, stop_on_manual_approval: typing.Literal[False] = False
|
|
326
|
-
) ->
|
|
327
|
-
...
|
|
686
|
+
) -> DatasetOptimizationResults: ...
|
|
328
687
|
|
|
329
688
|
@staticmethod
|
|
330
689
|
@overload
|
|
331
690
|
def check_run_by_id(
|
|
332
691
|
run_id: str, stop_on_manual_approval: bool
|
|
333
|
-
) -> typing.Optional[
|
|
334
|
-
...
|
|
692
|
+
) -> typing.Optional[DatasetOptimizationResults]: ...
|
|
335
693
|
|
|
336
694
|
@staticmethod
|
|
337
695
|
def check_run_by_id(
|
|
338
696
|
run_id: str, stop_on_manual_approval: bool = False
|
|
339
|
-
) -> typing.Optional[
|
|
697
|
+
) -> typing.Optional[DatasetOptimizationResults]:
|
|
340
698
|
"""
|
|
341
699
|
Check the status of a run given its ID
|
|
342
700
|
|
|
@@ -345,7 +703,7 @@ class OptimizationDataset(BaseModel):
|
|
|
345
703
|
stop_on_manual_approval: If True, the function will return `None` if the run is awaiting manual approval
|
|
346
704
|
|
|
347
705
|
Returns:
|
|
348
|
-
A
|
|
706
|
+
A DatasetOptimizationResults object with the results of the optimization run
|
|
349
707
|
|
|
350
708
|
Raises:
|
|
351
709
|
HirundoError: If the maximum number of retries is reached or if the run fails
|
|
@@ -354,22 +712,33 @@ class OptimizationDataset(BaseModel):
|
|
|
354
712
|
with logging_redirect_tqdm():
|
|
355
713
|
t = tqdm(total=100.0)
|
|
356
714
|
for iteration in OptimizationDataset._check_run_by_id(run_id):
|
|
357
|
-
if iteration["state"]
|
|
358
|
-
t.set_description("
|
|
359
|
-
t.n =
|
|
360
|
-
t.
|
|
361
|
-
t.close()
|
|
362
|
-
return iteration["result"]
|
|
363
|
-
elif iteration["state"] == RunStatus.PENDING.value:
|
|
364
|
-
t.set_description("Optimization run queued and not yet started")
|
|
365
|
-
t.n = 0.0
|
|
366
|
-
t.refresh()
|
|
367
|
-
elif iteration["state"] == RunStatus.STARTED.value:
|
|
368
|
-
t.set_description(
|
|
369
|
-
"Optimization run in progress. Downloading dataset"
|
|
370
|
-
)
|
|
371
|
-
t.n = 0.0
|
|
715
|
+
if iteration["state"] in STATUS_TO_PROGRESS_MAP:
|
|
716
|
+
t.set_description(STATUS_TO_TEXT_MAP[iteration["state"]])
|
|
717
|
+
t.n = STATUS_TO_PROGRESS_MAP[iteration["state"]]
|
|
718
|
+
logger.debug("Setting progress to %s", t.n)
|
|
372
719
|
t.refresh()
|
|
720
|
+
if iteration["state"] in [
|
|
721
|
+
RunStatus.FAILURE.value,
|
|
722
|
+
RunStatus.REJECTED.value,
|
|
723
|
+
RunStatus.REVOKED.value,
|
|
724
|
+
]:
|
|
725
|
+
raise HirundoError(
|
|
726
|
+
f"Optimization run failed with error: {iteration['result']}"
|
|
727
|
+
)
|
|
728
|
+
elif iteration["state"] == RunStatus.SUCCESS.value:
|
|
729
|
+
t.close()
|
|
730
|
+
return DatasetOptimizationResults(
|
|
731
|
+
suspects=iteration["result"]["suspects"],
|
|
732
|
+
warnings_and_errors=iteration["result"][
|
|
733
|
+
"warnings_and_errors"
|
|
734
|
+
],
|
|
735
|
+
)
|
|
736
|
+
elif (
|
|
737
|
+
iteration["state"] == RunStatus.AWAITING_MANUAL_APPROVAL.value
|
|
738
|
+
and stop_on_manual_approval
|
|
739
|
+
):
|
|
740
|
+
t.close()
|
|
741
|
+
return None
|
|
373
742
|
elif iteration["state"] is None:
|
|
374
743
|
if (
|
|
375
744
|
iteration["result"]
|
|
@@ -377,47 +746,42 @@ class OptimizationDataset(BaseModel):
|
|
|
377
746
|
and iteration["result"]["result"]
|
|
378
747
|
and isinstance(iteration["result"]["result"], str)
|
|
379
748
|
):
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
749
|
+
result_info = iteration["result"]["result"].split(":")
|
|
750
|
+
if len(result_info) > 1:
|
|
751
|
+
stage = result_info[0]
|
|
752
|
+
current_progress_percentage = float(
|
|
753
|
+
result_info[1].removeprefix(" ").removesuffix("% done")
|
|
754
|
+
)
|
|
755
|
+
elif len(result_info) == 1:
|
|
756
|
+
stage = result_info[0]
|
|
757
|
+
current_progress_percentage = t.n # Keep the same progress
|
|
758
|
+
else:
|
|
759
|
+
stage = "Unknown progress state"
|
|
760
|
+
current_progress_percentage = t.n # Keep the same progress
|
|
383
761
|
desc = (
|
|
384
762
|
"Optimization run completed. Uploading results"
|
|
385
763
|
if current_progress_percentage == 100.0
|
|
386
|
-
else
|
|
764
|
+
else stage
|
|
387
765
|
)
|
|
388
766
|
t.set_description(desc)
|
|
389
767
|
t.n = current_progress_percentage
|
|
768
|
+
logger.debug("Setting progress to %s", t.n)
|
|
390
769
|
t.refresh()
|
|
391
|
-
elif iteration["state"] == RunStatus.AWAITING_MANUAL_APPROVAL.value:
|
|
392
|
-
t.set_description("Awaiting manual approval")
|
|
393
|
-
t.n = 100.0
|
|
394
|
-
t.refresh()
|
|
395
|
-
if stop_on_manual_approval:
|
|
396
|
-
t.close()
|
|
397
|
-
return None
|
|
398
|
-
elif iteration["state"] == RunStatus.FAILURE.value:
|
|
399
|
-
t.set_description("Optimization run failed")
|
|
400
|
-
t.close()
|
|
401
|
-
raise HirundoError(
|
|
402
|
-
f"Optimization run failed with error: {iteration['result']}"
|
|
403
|
-
)
|
|
404
770
|
raise HirundoError("Optimization run failed with an unknown error")
|
|
405
771
|
|
|
406
772
|
@overload
|
|
407
773
|
def check_run(
|
|
408
774
|
self, stop_on_manual_approval: typing.Literal[True]
|
|
409
|
-
) -> typing.
|
|
410
|
-
...
|
|
775
|
+
) -> typing.Optional[DatasetOptimizationResults]: ...
|
|
411
776
|
|
|
412
777
|
@overload
|
|
413
778
|
def check_run(
|
|
414
779
|
self, stop_on_manual_approval: typing.Literal[False] = False
|
|
415
|
-
) ->
|
|
416
|
-
...
|
|
780
|
+
) -> DatasetOptimizationResults: ...
|
|
417
781
|
|
|
418
782
|
def check_run(
|
|
419
783
|
self, stop_on_manual_approval: bool = False
|
|
420
|
-
) -> typing.
|
|
784
|
+
) -> typing.Optional[DatasetOptimizationResults]:
|
|
421
785
|
"""
|
|
422
786
|
Check the status of the current active instance's run.
|
|
423
787
|
|
|
@@ -511,7 +875,7 @@ class OptimizationDataset(BaseModel):
|
|
|
511
875
|
headers=get_auth_headers(),
|
|
512
876
|
timeout=MODIFY_TIMEOUT,
|
|
513
877
|
)
|
|
514
|
-
response
|
|
878
|
+
raise_for_status_with_reason(response)
|
|
515
879
|
|
|
516
880
|
def cancel(self) -> None:
|
|
517
881
|
"""
|
|
@@ -520,3 +884,31 @@ class OptimizationDataset(BaseModel):
|
|
|
520
884
|
if not self.run_id:
|
|
521
885
|
raise ValueError("No run has been started")
|
|
522
886
|
self.cancel_by_id(self.run_id)
|
|
887
|
+
|
|
888
|
+
|
|
889
|
+
class DataOptimizationDatasetOut(BaseModel):
|
|
890
|
+
id: int
|
|
891
|
+
|
|
892
|
+
name: str
|
|
893
|
+
labeling_type: LabelingType
|
|
894
|
+
|
|
895
|
+
storage_config: ResponseStorageConfig
|
|
896
|
+
|
|
897
|
+
data_root_url: HirundoUrl
|
|
898
|
+
|
|
899
|
+
classes: typing.Optional[list[str]] = None
|
|
900
|
+
labeling_info: LabelingInfo
|
|
901
|
+
|
|
902
|
+
organization_id: typing.Optional[int]
|
|
903
|
+
creator_id: typing.Optional[int]
|
|
904
|
+
created_at: datetime.datetime
|
|
905
|
+
updated_at: datetime.datetime
|
|
906
|
+
|
|
907
|
+
|
|
908
|
+
class DataOptimizationRunOut(BaseModel):
|
|
909
|
+
id: int
|
|
910
|
+
name: str
|
|
911
|
+
run_id: str
|
|
912
|
+
status: RunStatus
|
|
913
|
+
approved: bool
|
|
914
|
+
created_at: datetime.datetime
|