hirundo 0.1.8__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hirundo/__init__.py +28 -13
- hirundo/_constraints.py +34 -2
- hirundo/_dataframe.py +43 -0
- hirundo/_env.py +2 -2
- hirundo/_headers.py +18 -2
- hirundo/_http.py +7 -2
- hirundo/_iter_sse_retrying.py +61 -17
- hirundo/_timeouts.py +1 -0
- hirundo/cli.py +52 -0
- hirundo/dataset_enum.py +23 -0
- hirundo/dataset_optimization.py +427 -164
- hirundo/dataset_optimization_results.py +42 -0
- hirundo/git.py +93 -35
- hirundo/storage.py +236 -68
- hirundo/unzip.py +247 -0
- {hirundo-0.1.8.dist-info → hirundo-0.1.16.dist-info}/METADATA +84 -44
- hirundo-0.1.16.dist-info/RECORD +23 -0
- {hirundo-0.1.8.dist-info → hirundo-0.1.16.dist-info}/WHEEL +1 -1
- hirundo/enum.py +0 -20
- hirundo-0.1.8.dist-info/RECORD +0 -20
- {hirundo-0.1.8.dist-info → hirundo-0.1.16.dist-info}/entry_points.txt +0 -0
- {hirundo-0.1.8.dist-info → hirundo-0.1.16.dist-info/licenses}/LICENSE +0 -0
- {hirundo-0.1.8.dist-info → hirundo-0.1.16.dist-info}/top_level.txt +0 -0
hirundo/dataset_optimization.py
CHANGED
|
@@ -1,27 +1,28 @@
|
|
|
1
|
+
import datetime
|
|
1
2
|
import json
|
|
2
3
|
import typing
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
3
5
|
from collections.abc import AsyncGenerator, Generator
|
|
4
6
|
from enum import Enum
|
|
5
|
-
from io import StringIO
|
|
6
7
|
from typing import overload
|
|
7
8
|
|
|
8
9
|
import httpx
|
|
9
|
-
import numpy as np
|
|
10
|
-
import pandas as pd
|
|
11
10
|
import requests
|
|
12
|
-
from pandas._typing import DtypeArg
|
|
13
11
|
from pydantic import BaseModel, Field, model_validator
|
|
14
12
|
from tqdm import tqdm
|
|
15
13
|
from tqdm.contrib.logging import logging_redirect_tqdm
|
|
16
14
|
|
|
15
|
+
from hirundo._constraints import HirundoUrl
|
|
17
16
|
from hirundo._env import API_HOST
|
|
18
|
-
from hirundo._headers import
|
|
17
|
+
from hirundo._headers import get_headers
|
|
19
18
|
from hirundo._http import raise_for_status_with_reason
|
|
20
19
|
from hirundo._iter_sse_retrying import aiter_sse_retrying, iter_sse_retrying
|
|
21
20
|
from hirundo._timeouts import MODIFY_TIMEOUT, READ_TIMEOUT
|
|
22
|
-
from hirundo.
|
|
21
|
+
from hirundo.dataset_enum import DatasetMetadataType, LabelingType
|
|
22
|
+
from hirundo.dataset_optimization_results import DatasetOptimizationResults
|
|
23
23
|
from hirundo.logger import get_logger
|
|
24
|
-
from hirundo.storage import
|
|
24
|
+
from hirundo.storage import ResponseStorageConfig, StorageConfig
|
|
25
|
+
from hirundo.unzip import download_and_extract_zip
|
|
25
26
|
|
|
26
27
|
logger = get_logger(__name__)
|
|
27
28
|
|
|
@@ -38,12 +39,14 @@ MAX_RETRIES = 200 # Max 200 retries for HTTP SSE connection
|
|
|
38
39
|
|
|
39
40
|
|
|
40
41
|
class RunStatus(Enum):
|
|
41
|
-
STARTED = "STARTED"
|
|
42
42
|
PENDING = "PENDING"
|
|
43
|
+
STARTED = "STARTED"
|
|
43
44
|
SUCCESS = "SUCCESS"
|
|
44
45
|
FAILURE = "FAILURE"
|
|
45
46
|
AWAITING_MANUAL_APPROVAL = "AWAITING MANUAL APPROVAL"
|
|
46
|
-
|
|
47
|
+
REVOKED = "REVOKED"
|
|
48
|
+
REJECTED = "REJECTED"
|
|
49
|
+
RETRY = "RETRY"
|
|
47
50
|
|
|
48
51
|
|
|
49
52
|
STATUS_TO_TEXT_MAP = {
|
|
@@ -52,7 +55,9 @@ STATUS_TO_TEXT_MAP = {
|
|
|
52
55
|
RunStatus.SUCCESS.value: "Optimization run completed successfully",
|
|
53
56
|
RunStatus.FAILURE.value: "Optimization run failed",
|
|
54
57
|
RunStatus.AWAITING_MANUAL_APPROVAL.value: "Awaiting manual approval",
|
|
55
|
-
RunStatus.
|
|
58
|
+
RunStatus.RETRY.value: "Optimization run failed. Retrying",
|
|
59
|
+
RunStatus.REVOKED.value: "Optimization run was cancelled",
|
|
60
|
+
RunStatus.REJECTED.value: "Optimization run was rejected",
|
|
56
61
|
}
|
|
57
62
|
STATUS_TO_PROGRESS_MAP = {
|
|
58
63
|
RunStatus.STARTED.value: 0.0,
|
|
@@ -60,100 +65,284 @@ STATUS_TO_PROGRESS_MAP = {
|
|
|
60
65
|
RunStatus.SUCCESS.value: 100.0,
|
|
61
66
|
RunStatus.FAILURE.value: 100.0,
|
|
62
67
|
RunStatus.AWAITING_MANUAL_APPROVAL.value: 100.0,
|
|
63
|
-
RunStatus.
|
|
68
|
+
RunStatus.RETRY.value: 0.0,
|
|
69
|
+
RunStatus.REVOKED.value: 100.0,
|
|
70
|
+
RunStatus.REJECTED.value: 0.0,
|
|
64
71
|
}
|
|
65
72
|
|
|
66
73
|
|
|
67
|
-
class
|
|
68
|
-
|
|
74
|
+
class Metadata(BaseModel, ABC):
|
|
75
|
+
type: DatasetMetadataType
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def metadata_url(self) -> HirundoUrl:
|
|
80
|
+
raise NotImplementedError()
|
|
69
81
|
|
|
70
|
-
|
|
82
|
+
|
|
83
|
+
class HirundoCSV(Metadata):
|
|
71
84
|
"""
|
|
72
|
-
A
|
|
85
|
+
A dataset metadata file in the Hirundo CSV format
|
|
73
86
|
"""
|
|
74
|
-
|
|
87
|
+
|
|
88
|
+
type: DatasetMetadataType = DatasetMetadataType.HIRUNDO_CSV
|
|
89
|
+
csv_url: HirundoUrl
|
|
75
90
|
"""
|
|
76
|
-
|
|
91
|
+
The URL to access the dataset metadata CSV file.
|
|
92
|
+
e.g. `s3://my-bucket-name/my-folder/my-metadata.csv`, `gs://my-bucket-name/my-folder/my-metadata.csv`,
|
|
93
|
+
or `ssh://my-username@my-repo-name/my-folder/my-metadata.csv`
|
|
94
|
+
(or `file:///datasets/my-folder/my-metadata.csv` if using LOCAL storage type with on-premises installation)
|
|
77
95
|
"""
|
|
78
96
|
|
|
97
|
+
@property
|
|
98
|
+
def metadata_url(self) -> HirundoUrl:
|
|
99
|
+
return self.csv_url
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class COCO(Metadata):
|
|
103
|
+
"""
|
|
104
|
+
A dataset metadata file in the COCO format
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
type: DatasetMetadataType = DatasetMetadataType.COCO
|
|
108
|
+
json_url: HirundoUrl
|
|
109
|
+
"""
|
|
110
|
+
The URL to access the dataset metadata JSON file.
|
|
111
|
+
e.g. `s3://my-bucket-name/my-folder/my-metadata.json`, `gs://my-bucket-name/my-folder/my-metadata.json`,
|
|
112
|
+
or `ssh://my-username@my-repo-name/my-folder/my-metadata.json`
|
|
113
|
+
(or `file:///datasets/my-folder/my-metadata.json` if using LOCAL storage type with on-premises installation)
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def metadata_url(self) -> HirundoUrl:
|
|
118
|
+
return self.json_url
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class YOLO(Metadata):
|
|
122
|
+
type: DatasetMetadataType = DatasetMetadataType.YOLO
|
|
123
|
+
data_yaml_url: typing.Optional[HirundoUrl] = None
|
|
124
|
+
labels_dir_url: HirundoUrl
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def metadata_url(self) -> HirundoUrl:
|
|
128
|
+
return self.labels_dir_url
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
LabelingInfo = typing.Union[HirundoCSV, COCO, YOLO]
|
|
132
|
+
"""
|
|
133
|
+
The dataset labeling info. The dataset labeling info can be one of the following:
|
|
134
|
+
- `DatasetMetadataType.HirundoCSV`: Indicates that the dataset metadata file is a CSV file with the Hirundo format
|
|
135
|
+
|
|
136
|
+
Currently no other formats are supported. Future versions of `hirundo` may support additional formats.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class VisionRunArgs(BaseModel):
|
|
141
|
+
upsample: bool = False
|
|
142
|
+
"""
|
|
143
|
+
Whether to upsample the dataset to attempt to balance the classes.
|
|
144
|
+
"""
|
|
145
|
+
min_abs_bbox_size: int = 0
|
|
146
|
+
"""
|
|
147
|
+
Minimum valid size (in pixels) of a bounding box to keep it in the dataset for optimization.
|
|
148
|
+
"""
|
|
149
|
+
min_abs_bbox_area: int = 0
|
|
150
|
+
"""
|
|
151
|
+
Minimum valid absolute area (in pixels²) of a bounding box to keep it in the dataset for optimization.
|
|
152
|
+
"""
|
|
153
|
+
min_rel_bbox_size: float = 0.0
|
|
154
|
+
"""
|
|
155
|
+
Minimum valid size (as a fraction of both image height and width) for a bounding box
|
|
156
|
+
to keep it in the dataset for optimization, relative to the corresponding dimension size,
|
|
157
|
+
i.e. if the bounding box is 10% of the image width and 5% of the image height, it will be kept if this value is 0.05, but not if the
|
|
158
|
+
value is 0.06 (since both width and height are checked).
|
|
159
|
+
"""
|
|
160
|
+
min_rel_bbox_area: float = 0.0
|
|
161
|
+
"""
|
|
162
|
+
Minimum valid relative area (as a fraction of the image area) of a bounding box to keep it in the dataset for optimization.
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
RunArgs = typing.Union[VisionRunArgs]
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class AugmentationName(str, Enum):
|
|
170
|
+
RANDOM_HORIZONTAL_FLIP = "RandomHorizontalFlip"
|
|
171
|
+
RANDOM_VERTICAL_FLIP = "RandomVerticalFlip"
|
|
172
|
+
RANDOM_ROTATION = "RandomRotation"
|
|
173
|
+
RANDOM_PERSPECTIVE = "RandomPerspective"
|
|
174
|
+
GAUSSIAN_NOISE = "GaussianNoise"
|
|
175
|
+
RANDOM_GRAYSCALE = "RandomGrayscale"
|
|
176
|
+
GAUSSIAN_BLUR = "GaussianBlur"
|
|
79
177
|
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
"
|
|
83
|
-
"
|
|
84
|
-
"
|
|
85
|
-
"label": str,
|
|
86
|
-
"bbox_id": str,
|
|
87
|
-
"xmin": np.int32,
|
|
88
|
-
"ymin": np.int32,
|
|
89
|
-
"xmax": np.int32,
|
|
90
|
-
"ymax": np.int32,
|
|
91
|
-
"suspect_level": np.float32, # If exists, must be one of the values in the enum below
|
|
92
|
-
"suggested_label": str,
|
|
93
|
-
"suggested_label_conf": np.float32,
|
|
94
|
-
"status": str,
|
|
95
|
-
# ⬆️ If exists, must be one of the following:
|
|
96
|
-
# NO_LABELS/MISSING_IMAGE/INVALID_IMAGE/INVALID_BBOX/INVALID_BBOX_SIZE/INVALID_SEG/INVALID_SEG_SIZE
|
|
97
|
-
}
|
|
178
|
+
|
|
179
|
+
class Modality(str, Enum):
|
|
180
|
+
IMAGE = "Image"
|
|
181
|
+
RADAR = "Radar"
|
|
182
|
+
EKG = "EKG"
|
|
98
183
|
|
|
99
184
|
|
|
100
185
|
class OptimizationDataset(BaseModel):
|
|
186
|
+
id: typing.Optional[int] = Field(default=None)
|
|
187
|
+
"""
|
|
188
|
+
The ID of the dataset created on the server.
|
|
189
|
+
"""
|
|
101
190
|
name: str
|
|
102
191
|
"""
|
|
103
192
|
The name of the dataset. Used to identify it amongst the list of datasets
|
|
104
193
|
belonging to your organization in `hirundo`.
|
|
105
194
|
"""
|
|
106
|
-
|
|
195
|
+
labeling_type: LabelingType
|
|
107
196
|
"""
|
|
108
|
-
Indicates the
|
|
109
|
-
- `
|
|
110
|
-
- `
|
|
197
|
+
Indicates the labeling type of the dataset. The labeling type can be one of the following:
|
|
198
|
+
- `LabelingType.SINGLE_LABEL_CLASSIFICATION`: Indicates that the dataset is for classification tasks
|
|
199
|
+
- `LabelingType.OBJECT_DETECTION`: Indicates that the dataset is for object detection tasks
|
|
200
|
+
- `LabelingType.SPEECH_TO_TEXT`: Indicates that the dataset is for speech-to-text tasks
|
|
111
201
|
"""
|
|
112
|
-
|
|
202
|
+
language: typing.Optional[str] = None
|
|
113
203
|
"""
|
|
114
|
-
|
|
115
|
-
If `None`, the `dataset_id` field must be set.
|
|
204
|
+
Language of the Speech-to-Text audio dataset. This is required for Speech-to-Text datasets.
|
|
116
205
|
"""
|
|
117
|
-
|
|
118
|
-
classes: typing.Optional[list[str]] = None
|
|
206
|
+
storage_config_id: typing.Optional[int] = None
|
|
119
207
|
"""
|
|
120
|
-
|
|
121
|
-
|
|
208
|
+
The ID of the storage config used to store the dataset and metadata.
|
|
209
|
+
"""
|
|
210
|
+
storage_config: typing.Optional[
|
|
211
|
+
typing.Union[StorageConfig, ResponseStorageConfig]
|
|
212
|
+
] = None
|
|
122
213
|
"""
|
|
123
|
-
|
|
214
|
+
The `StorageConfig` instance to link to.
|
|
124
215
|
"""
|
|
125
|
-
|
|
126
|
-
Note: This path will be prefixed with the `StorageLink`'s `path`.
|
|
216
|
+
data_root_url: HirundoUrl
|
|
127
217
|
"""
|
|
128
|
-
|
|
218
|
+
URL for data (e.g. images) within the `StorageConfig` instance,
|
|
219
|
+
e.g. `s3://my-bucket-name/my-images-folder`, `gs://my-bucket-name/my-images-folder`,
|
|
220
|
+
or `ssh://my-username@my-repo-name/my-images-folder`
|
|
221
|
+
(or `file:///datasets/my-images-folder` if using LOCAL storage type with on-premises installation)
|
|
222
|
+
|
|
223
|
+
Note: All CSV `image_path` entries in the metadata file should be relative to this folder.
|
|
129
224
|
"""
|
|
130
|
-
The type of dataset metadata file. The dataset metadata file can be one of the following:
|
|
131
|
-
- `DatasetMetadataType.HirundoCSV`: Indicates that the dataset metadata file is a CSV file with the Hirundo format
|
|
132
225
|
|
|
133
|
-
|
|
226
|
+
classes: typing.Optional[list[str]] = None
|
|
227
|
+
"""
|
|
228
|
+
A full list of possible classes used in classification / object detection.
|
|
229
|
+
It is currently required for clarity and performance.
|
|
134
230
|
"""
|
|
231
|
+
labeling_info: LabelingInfo
|
|
135
232
|
|
|
136
|
-
|
|
233
|
+
augmentations: typing.Optional[list[AugmentationName]] = None
|
|
137
234
|
"""
|
|
138
|
-
|
|
235
|
+
Used to define which augmentations are apply to a vision dataset.
|
|
236
|
+
For audio datasets, this field is ignored.
|
|
237
|
+
If no value is provided, all augmentations are applied to vision datasets.
|
|
139
238
|
"""
|
|
140
|
-
|
|
239
|
+
modality: Modality = Modality.IMAGE
|
|
141
240
|
"""
|
|
142
|
-
|
|
241
|
+
Used to define the modality of the dataset.
|
|
242
|
+
Defaults to Image.
|
|
143
243
|
"""
|
|
244
|
+
|
|
144
245
|
run_id: typing.Optional[str] = Field(default=None, init=False)
|
|
145
246
|
"""
|
|
146
247
|
The ID of the Dataset Optimization run created on the server.
|
|
147
248
|
"""
|
|
148
249
|
|
|
250
|
+
status: typing.Optional[RunStatus] = None
|
|
251
|
+
|
|
149
252
|
@model_validator(mode="after")
|
|
150
253
|
def validate_dataset(self):
|
|
151
|
-
if self.
|
|
152
|
-
raise ValueError(
|
|
254
|
+
if self.storage_config is None and self.storage_config_id is None:
|
|
255
|
+
raise ValueError(
|
|
256
|
+
"No dataset storage has been provided. Provide one via `storage_config` or `storage_config_id`"
|
|
257
|
+
)
|
|
258
|
+
elif self.storage_config is not None and self.storage_config_id is not None:
|
|
259
|
+
raise ValueError(
|
|
260
|
+
"Both `storage_config` and `storage_config_id` have been provided. Pick one."
|
|
261
|
+
)
|
|
262
|
+
if self.labeling_type == LabelingType.SPEECH_TO_TEXT and self.language is None:
|
|
263
|
+
raise ValueError("Language is required for Speech-to-Text datasets.")
|
|
264
|
+
elif (
|
|
265
|
+
self.labeling_type != LabelingType.SPEECH_TO_TEXT
|
|
266
|
+
and self.language is not None
|
|
267
|
+
):
|
|
268
|
+
raise ValueError("Language is only allowed for Speech-to-Text datasets.")
|
|
269
|
+
if (
|
|
270
|
+
self.labeling_info.type == DatasetMetadataType.YOLO
|
|
271
|
+
and isinstance(self.labeling_info, YOLO)
|
|
272
|
+
and (
|
|
273
|
+
self.labeling_info.data_yaml_url is not None
|
|
274
|
+
and self.classes is not None
|
|
275
|
+
)
|
|
276
|
+
):
|
|
277
|
+
raise ValueError(
|
|
278
|
+
"Only one of `classes` or `labeling_info.data_yaml_url` should be provided for YOLO datasets"
|
|
279
|
+
)
|
|
153
280
|
return self
|
|
154
281
|
|
|
155
282
|
@staticmethod
|
|
156
|
-
def
|
|
283
|
+
def get_by_id(dataset_id: int) -> "OptimizationDataset":
|
|
284
|
+
"""
|
|
285
|
+
Get a `OptimizationDataset` instance from the server by its ID
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
dataset_id: The ID of the `OptimizationDataset` instance to get
|
|
289
|
+
"""
|
|
290
|
+
response = requests.get(
|
|
291
|
+
f"{API_HOST}/dataset-optimization/dataset/{dataset_id}",
|
|
292
|
+
headers=get_headers(),
|
|
293
|
+
timeout=READ_TIMEOUT,
|
|
294
|
+
)
|
|
295
|
+
raise_for_status_with_reason(response)
|
|
296
|
+
dataset = response.json()
|
|
297
|
+
return OptimizationDataset(**dataset)
|
|
298
|
+
|
|
299
|
+
@staticmethod
|
|
300
|
+
def get_by_name(name: str) -> "OptimizationDataset":
|
|
301
|
+
"""
|
|
302
|
+
Get a `OptimizationDataset` instance from the server by its name
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
name: The name of the `OptimizationDataset` instance to get
|
|
306
|
+
"""
|
|
307
|
+
response = requests.get(
|
|
308
|
+
f"{API_HOST}/dataset-optimization/dataset/by-name/{name}",
|
|
309
|
+
headers=get_headers(),
|
|
310
|
+
timeout=READ_TIMEOUT,
|
|
311
|
+
)
|
|
312
|
+
raise_for_status_with_reason(response)
|
|
313
|
+
dataset = response.json()
|
|
314
|
+
return OptimizationDataset(**dataset)
|
|
315
|
+
|
|
316
|
+
@staticmethod
|
|
317
|
+
def list_datasets(
|
|
318
|
+
organization_id: typing.Optional[int] = None,
|
|
319
|
+
) -> list["DataOptimizationDatasetOut"]:
|
|
320
|
+
"""
|
|
321
|
+
Lists all the optimization datasets created by user's default organization
|
|
322
|
+
or the `organization_id` passed
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
organization_id: The ID of the organization to list the datasets for.
|
|
326
|
+
"""
|
|
327
|
+
response = requests.get(
|
|
328
|
+
f"{API_HOST}/dataset-optimization/dataset/",
|
|
329
|
+
params={"dataset_organization_id": organization_id},
|
|
330
|
+
headers=get_headers(),
|
|
331
|
+
timeout=READ_TIMEOUT,
|
|
332
|
+
)
|
|
333
|
+
raise_for_status_with_reason(response)
|
|
334
|
+
datasets = response.json()
|
|
335
|
+
return [
|
|
336
|
+
DataOptimizationDatasetOut(
|
|
337
|
+
**ds,
|
|
338
|
+
)
|
|
339
|
+
for ds in datasets
|
|
340
|
+
]
|
|
341
|
+
|
|
342
|
+
@staticmethod
|
|
343
|
+
def list_runs(
|
|
344
|
+
organization_id: typing.Optional[int] = None,
|
|
345
|
+
) -> list["DataOptimizationRunOut"]:
|
|
157
346
|
"""
|
|
158
347
|
Lists all the `OptimizationDataset` instances created by user's default organization
|
|
159
348
|
or the `organization_id` passed
|
|
@@ -163,13 +352,19 @@ class OptimizationDataset(BaseModel):
|
|
|
163
352
|
organization_id: The ID of the organization to list the datasets for.
|
|
164
353
|
"""
|
|
165
354
|
response = requests.get(
|
|
166
|
-
f"{API_HOST}/dataset-optimization/
|
|
355
|
+
f"{API_HOST}/dataset-optimization/run/list",
|
|
167
356
|
params={"dataset_organization_id": organization_id},
|
|
168
|
-
headers=
|
|
357
|
+
headers=get_headers(),
|
|
169
358
|
timeout=READ_TIMEOUT,
|
|
170
359
|
)
|
|
171
360
|
raise_for_status_with_reason(response)
|
|
172
|
-
|
|
361
|
+
runs = response.json()
|
|
362
|
+
return [
|
|
363
|
+
DataOptimizationRunOut(
|
|
364
|
+
**run,
|
|
365
|
+
)
|
|
366
|
+
for run in runs
|
|
367
|
+
]
|
|
173
368
|
|
|
174
369
|
@staticmethod
|
|
175
370
|
def delete_by_id(dataset_id: int) -> None:
|
|
@@ -181,73 +376,94 @@ class OptimizationDataset(BaseModel):
|
|
|
181
376
|
"""
|
|
182
377
|
response = requests.delete(
|
|
183
378
|
f"{API_HOST}/dataset-optimization/dataset/{dataset_id}",
|
|
184
|
-
headers=
|
|
379
|
+
headers=get_headers(),
|
|
185
380
|
timeout=MODIFY_TIMEOUT,
|
|
186
381
|
)
|
|
187
382
|
raise_for_status_with_reason(response)
|
|
188
383
|
logger.info("Deleted dataset with ID: %s", dataset_id)
|
|
189
384
|
|
|
190
|
-
def delete(self,
|
|
385
|
+
def delete(self, storage_config=True) -> None:
|
|
191
386
|
"""
|
|
192
387
|
Deletes the active `OptimizationDataset` instance from the server.
|
|
193
388
|
It can only be used on a `OptimizationDataset` instance that has been created.
|
|
194
389
|
|
|
195
390
|
Args:
|
|
196
|
-
|
|
391
|
+
storage_config: If True, the `OptimizationDataset`'s `StorageConfig` will also be deleted
|
|
197
392
|
|
|
198
|
-
Note: If `
|
|
199
|
-
This can either be set manually or by creating the `
|
|
393
|
+
Note: If `storage_config` is not set to `False` then the `storage_config_id` must be set
|
|
394
|
+
This can either be set manually or by creating the `StorageConfig` instance via the `OptimizationDataset`'s
|
|
200
395
|
`create` method
|
|
201
396
|
"""
|
|
202
|
-
if
|
|
203
|
-
if not self.
|
|
204
|
-
raise ValueError("No storage
|
|
205
|
-
|
|
206
|
-
if not self.
|
|
397
|
+
if storage_config:
|
|
398
|
+
if not self.storage_config_id:
|
|
399
|
+
raise ValueError("No storage config has been created")
|
|
400
|
+
StorageConfig.delete_by_id(self.storage_config_id)
|
|
401
|
+
if not self.id:
|
|
207
402
|
raise ValueError("No dataset has been created")
|
|
208
|
-
self.delete_by_id(self.
|
|
403
|
+
self.delete_by_id(self.id)
|
|
209
404
|
|
|
210
|
-
def create(
|
|
405
|
+
def create(
|
|
406
|
+
self,
|
|
407
|
+
organization_id: typing.Optional[int] = None,
|
|
408
|
+
replace_if_exists: bool = False,
|
|
409
|
+
) -> int:
|
|
211
410
|
"""
|
|
212
411
|
Create a `OptimizationDataset` instance on the server.
|
|
213
|
-
If `
|
|
412
|
+
If the `storage_config_id` field is not set, the storage config will also be created and the field will be set.
|
|
413
|
+
|
|
414
|
+
Args:
|
|
415
|
+
organization_id: The ID of the organization to create the dataset for.
|
|
416
|
+
replace_if_exists: If True, the dataset will be replaced if it already exists
|
|
417
|
+
(this is determined by a dataset of the same name in the same organization).
|
|
418
|
+
|
|
419
|
+
Returns:
|
|
420
|
+
The ID of the created `OptimizationDataset` instance
|
|
214
421
|
"""
|
|
215
|
-
if
|
|
422
|
+
if self.storage_config is None and self.storage_config_id is None:
|
|
216
423
|
raise ValueError("No dataset storage has been provided")
|
|
217
|
-
|
|
218
|
-
self.
|
|
219
|
-
|
|
220
|
-
|
|
424
|
+
elif self.storage_config and self.storage_config_id is None:
|
|
425
|
+
if isinstance(self.storage_config, ResponseStorageConfig):
|
|
426
|
+
self.storage_config_id = self.storage_config.id
|
|
427
|
+
elif isinstance(self.storage_config, StorageConfig):
|
|
428
|
+
self.storage_config_id = self.storage_config.create(
|
|
429
|
+
replace_if_exists=replace_if_exists,
|
|
430
|
+
)
|
|
431
|
+
elif (
|
|
432
|
+
self.storage_config is not None
|
|
433
|
+
and self.storage_config_id is not None
|
|
434
|
+
and (
|
|
435
|
+
not isinstance(self.storage_config, ResponseStorageConfig)
|
|
436
|
+
or self.storage_config.id != self.storage_config_id
|
|
437
|
+
)
|
|
221
438
|
):
|
|
222
|
-
|
|
223
|
-
|
|
439
|
+
raise ValueError(
|
|
440
|
+
"Both `storage_config` and `storage_config_id` have been provided. Storage config IDs do not match."
|
|
224
441
|
)
|
|
225
|
-
model_dict = self.model_dump()
|
|
442
|
+
model_dict = self.model_dump(mode="json")
|
|
226
443
|
# ⬆️ Get dict of model fields from Pydantic model instance
|
|
227
444
|
dataset_response = requests.post(
|
|
228
445
|
f"{API_HOST}/dataset-optimization/dataset/",
|
|
229
446
|
json={
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
},
|
|
234
|
-
**{k: model_dict[k] for k in model_dict.keys() - {"dataset_storage"}},
|
|
235
|
-
},
|
|
236
|
-
headers={
|
|
237
|
-
**json_headers,
|
|
238
|
-
**get_auth_headers(),
|
|
447
|
+
**{k: model_dict[k] for k in model_dict.keys() - {"storage_config"}},
|
|
448
|
+
"organization_id": organization_id,
|
|
449
|
+
"replace_if_exists": replace_if_exists,
|
|
239
450
|
},
|
|
451
|
+
headers=get_headers(),
|
|
240
452
|
timeout=MODIFY_TIMEOUT,
|
|
241
453
|
)
|
|
242
454
|
raise_for_status_with_reason(dataset_response)
|
|
243
|
-
self.
|
|
244
|
-
if not self.
|
|
245
|
-
raise HirundoError("
|
|
246
|
-
logger.info("Created dataset with ID: %s", self.
|
|
247
|
-
return self.
|
|
455
|
+
self.id = dataset_response.json()["id"]
|
|
456
|
+
if not self.id:
|
|
457
|
+
raise HirundoError("An error ocurred while trying to create the dataset")
|
|
458
|
+
logger.info("Created dataset with ID: %s", self.id)
|
|
459
|
+
return self.id
|
|
248
460
|
|
|
249
461
|
@staticmethod
|
|
250
|
-
def launch_optimization_run(
|
|
462
|
+
def launch_optimization_run(
|
|
463
|
+
dataset_id: int,
|
|
464
|
+
organization_id: typing.Optional[int] = None,
|
|
465
|
+
run_args: typing.Optional[RunArgs] = None,
|
|
466
|
+
) -> str:
|
|
251
467
|
"""
|
|
252
468
|
Run the dataset optimization process on the server using the dataset with the given ID
|
|
253
469
|
i.e. `dataset_id`.
|
|
@@ -258,26 +474,62 @@ class OptimizationDataset(BaseModel):
|
|
|
258
474
|
Returns:
|
|
259
475
|
ID of the run (`run_id`).
|
|
260
476
|
"""
|
|
477
|
+
run_info = {}
|
|
478
|
+
if organization_id:
|
|
479
|
+
run_info["organization_id"] = organization_id
|
|
480
|
+
if run_args:
|
|
481
|
+
run_info["run_args"] = run_args.model_dump(mode="json")
|
|
261
482
|
run_response = requests.post(
|
|
262
483
|
f"{API_HOST}/dataset-optimization/run/{dataset_id}",
|
|
263
|
-
|
|
484
|
+
json=run_info if len(run_info) > 0 else None,
|
|
485
|
+
headers=get_headers(),
|
|
264
486
|
timeout=MODIFY_TIMEOUT,
|
|
265
487
|
)
|
|
266
488
|
raise_for_status_with_reason(run_response)
|
|
267
489
|
return run_response.json()["run_id"]
|
|
268
490
|
|
|
269
|
-
def
|
|
491
|
+
def _validate_run_args(self, run_args: RunArgs) -> None:
|
|
492
|
+
if self.labeling_type == LabelingType.SPEECH_TO_TEXT:
|
|
493
|
+
raise Exception("Speech to text cannot have `run_args` set")
|
|
494
|
+
if self.labeling_type != LabelingType.OBJECT_DETECTION and any(
|
|
495
|
+
(
|
|
496
|
+
run_args.min_abs_bbox_size != 0,
|
|
497
|
+
run_args.min_abs_bbox_area != 0,
|
|
498
|
+
run_args.min_rel_bbox_size != 0,
|
|
499
|
+
run_args.min_rel_bbox_area != 0,
|
|
500
|
+
)
|
|
501
|
+
):
|
|
502
|
+
raise Exception(
|
|
503
|
+
"Cannot set `min_abs_bbox_size`, `min_abs_bbox_area`, "
|
|
504
|
+
+ "`min_rel_bbox_size`, or `min_rel_bbox_area` for "
|
|
505
|
+
+ f"labeling type {self.labeling_type}"
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
def run_optimization(
|
|
509
|
+
self,
|
|
510
|
+
organization_id: typing.Optional[int] = None,
|
|
511
|
+
replace_dataset_if_exists: bool = False,
|
|
512
|
+
run_args: typing.Optional[RunArgs] = None,
|
|
513
|
+
) -> str:
|
|
270
514
|
"""
|
|
271
515
|
If the dataset was not created on the server yet, it is created.
|
|
272
516
|
Run the dataset optimization process on the server using the active `OptimizationDataset` instance
|
|
273
517
|
|
|
518
|
+
Args:
|
|
519
|
+
organization_id: The ID of the organization to run the optimization for.
|
|
520
|
+
replace_dataset_if_exists: If True, the dataset will be replaced if it already exists
|
|
521
|
+
(this is determined by a dataset of the same name in the same organization).
|
|
522
|
+
run_args: The run arguments to use for the optimization run
|
|
523
|
+
|
|
274
524
|
Returns:
|
|
275
525
|
An ID of the run (`run_id`) and stores that `run_id` on the instance
|
|
276
526
|
"""
|
|
277
527
|
try:
|
|
278
|
-
if not self.
|
|
279
|
-
self.
|
|
280
|
-
|
|
528
|
+
if not self.id:
|
|
529
|
+
self.id = self.create(replace_if_exists=replace_dataset_if_exists)
|
|
530
|
+
if run_args is not None:
|
|
531
|
+
self._validate_run_args(run_args)
|
|
532
|
+
run_id = self.launch_optimization_run(self.id, organization_id, run_args)
|
|
281
533
|
self.run_id = run_id
|
|
282
534
|
logger.info("Started the run with ID: %s", run_id)
|
|
283
535
|
return run_id
|
|
@@ -293,59 +545,19 @@ class OptimizationDataset(BaseModel):
|
|
|
293
545
|
except Exception:
|
|
294
546
|
content = error.response.text
|
|
295
547
|
raise HirundoError(
|
|
296
|
-
f"
|
|
548
|
+
f"Unable to start the run. Status code: {error.response.status_code} Content: {content}"
|
|
297
549
|
) from error
|
|
298
550
|
except Exception as error:
|
|
299
|
-
raise HirundoError(f"
|
|
551
|
+
raise HirundoError(f"Unable to start the run: {error}") from error
|
|
300
552
|
|
|
301
553
|
def clean_ids(self):
|
|
302
554
|
"""
|
|
303
|
-
Reset `dataset_id`, `
|
|
555
|
+
Reset `dataset_id`, `storage_config_id`, and `run_id` values on the instance to default value of `None`
|
|
304
556
|
"""
|
|
305
|
-
self.
|
|
306
|
-
self.
|
|
557
|
+
self.storage_config_id = None
|
|
558
|
+
self.id = None
|
|
307
559
|
self.run_id = None
|
|
308
560
|
|
|
309
|
-
@staticmethod
|
|
310
|
-
def _clean_df_index(df: "pd.DataFrame") -> "pd.DataFrame":
|
|
311
|
-
"""
|
|
312
|
-
Clean the index of a dataframe in case it has unnamed columns.
|
|
313
|
-
|
|
314
|
-
Args:
|
|
315
|
-
df (DataFrame): Dataframe to clean
|
|
316
|
-
|
|
317
|
-
Returns:
|
|
318
|
-
DataFrame: Cleaned dataframe
|
|
319
|
-
"""
|
|
320
|
-
index_cols = sorted(
|
|
321
|
-
[col for col in df.columns if col.startswith("Unnamed")], reverse=True
|
|
322
|
-
)
|
|
323
|
-
if len(index_cols) > 0:
|
|
324
|
-
df.set_index(index_cols.pop(), inplace=True)
|
|
325
|
-
df.rename_axis(index=None, columns=None, inplace=True)
|
|
326
|
-
if len(index_cols) > 0:
|
|
327
|
-
df.drop(columns=index_cols, inplace=True)
|
|
328
|
-
|
|
329
|
-
return df
|
|
330
|
-
|
|
331
|
-
@staticmethod
|
|
332
|
-
def _read_csvs_to_df(data: dict):
|
|
333
|
-
if data["state"] == RunStatus.SUCCESS.value:
|
|
334
|
-
data["result"]["suspects"] = OptimizationDataset._clean_df_index(
|
|
335
|
-
pd.read_csv(
|
|
336
|
-
StringIO(data["result"]["suspects"]),
|
|
337
|
-
dtype=CUSTOMER_INTERCHANGE_DTYPES,
|
|
338
|
-
)
|
|
339
|
-
)
|
|
340
|
-
data["result"]["warnings_and_errors"] = OptimizationDataset._clean_df_index(
|
|
341
|
-
pd.read_csv(
|
|
342
|
-
StringIO(data["result"]["warnings_and_errors"]),
|
|
343
|
-
dtype=CUSTOMER_INTERCHANGE_DTYPES,
|
|
344
|
-
)
|
|
345
|
-
)
|
|
346
|
-
else:
|
|
347
|
-
pass
|
|
348
|
-
|
|
349
561
|
@staticmethod
|
|
350
562
|
def _check_run_by_id(run_id: str, retry=0) -> Generator[dict, None, None]:
|
|
351
563
|
if retry > MAX_RETRIES:
|
|
@@ -356,7 +568,7 @@ class OptimizationDataset(BaseModel):
|
|
|
356
568
|
client,
|
|
357
569
|
"GET",
|
|
358
570
|
f"{API_HOST}/dataset-optimization/run/{run_id}",
|
|
359
|
-
headers=
|
|
571
|
+
headers=get_headers(),
|
|
360
572
|
):
|
|
361
573
|
if sse.event == "ping":
|
|
362
574
|
continue
|
|
@@ -370,8 +582,15 @@ class OptimizationDataset(BaseModel):
|
|
|
370
582
|
last_event = json.loads(sse.data)
|
|
371
583
|
if not last_event:
|
|
372
584
|
continue
|
|
373
|
-
data
|
|
374
|
-
|
|
585
|
+
if "data" in last_event:
|
|
586
|
+
data = last_event["data"]
|
|
587
|
+
else:
|
|
588
|
+
if "detail" in last_event:
|
|
589
|
+
raise HirundoError(last_event["detail"])
|
|
590
|
+
elif "reason" in last_event:
|
|
591
|
+
raise HirundoError(last_event["reason"])
|
|
592
|
+
else:
|
|
593
|
+
raise HirundoError("Unknown error")
|
|
375
594
|
yield data
|
|
376
595
|
if not last_event or last_event["data"]["state"] == RunStatus.PENDING.value:
|
|
377
596
|
OptimizationDataset._check_run_by_id(run_id, retry + 1)
|
|
@@ -420,17 +639,22 @@ class OptimizationDataset(BaseModel):
|
|
|
420
639
|
t.n = STATUS_TO_PROGRESS_MAP[iteration["state"]]
|
|
421
640
|
logger.debug("Setting progress to %s", t.n)
|
|
422
641
|
t.refresh()
|
|
423
|
-
if iteration["state"]
|
|
642
|
+
if iteration["state"] in [
|
|
643
|
+
RunStatus.FAILURE.value,
|
|
644
|
+
RunStatus.REJECTED.value,
|
|
645
|
+
RunStatus.REVOKED.value,
|
|
646
|
+
]:
|
|
424
647
|
raise HirundoError(
|
|
425
648
|
f"Optimization run failed with error: {iteration['result']}"
|
|
426
649
|
)
|
|
427
650
|
elif iteration["state"] == RunStatus.SUCCESS.value:
|
|
428
651
|
t.close()
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
652
|
+
zip_temporary_url = iteration["result"]
|
|
653
|
+
logger.debug("Optimization run completed. Downloading results")
|
|
654
|
+
|
|
655
|
+
return download_and_extract_zip(
|
|
656
|
+
run_id,
|
|
657
|
+
zip_temporary_url,
|
|
434
658
|
)
|
|
435
659
|
elif (
|
|
436
660
|
iteration["state"] == RunStatus.AWAITING_MANUAL_APPROVAL.value
|
|
@@ -445,13 +669,22 @@ class OptimizationDataset(BaseModel):
|
|
|
445
669
|
and iteration["result"]["result"]
|
|
446
670
|
and isinstance(iteration["result"]["result"], str)
|
|
447
671
|
):
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
672
|
+
result_info = iteration["result"]["result"].split(":")
|
|
673
|
+
if len(result_info) > 1:
|
|
674
|
+
stage = result_info[0]
|
|
675
|
+
current_progress_percentage = float(
|
|
676
|
+
result_info[1].removeprefix(" ").removesuffix("% done")
|
|
677
|
+
)
|
|
678
|
+
elif len(result_info) == 1:
|
|
679
|
+
stage = result_info[0]
|
|
680
|
+
current_progress_percentage = t.n # Keep the same progress
|
|
681
|
+
else:
|
|
682
|
+
stage = "Unknown progress state"
|
|
683
|
+
current_progress_percentage = t.n # Keep the same progress
|
|
451
684
|
desc = (
|
|
452
685
|
"Optimization run completed. Uploading results"
|
|
453
686
|
if current_progress_percentage == 100.0
|
|
454
|
-
else
|
|
687
|
+
else stage
|
|
455
688
|
)
|
|
456
689
|
t.set_description(desc)
|
|
457
690
|
t.n = current_progress_percentage
|
|
@@ -513,7 +746,7 @@ class OptimizationDataset(BaseModel):
|
|
|
513
746
|
client,
|
|
514
747
|
"GET",
|
|
515
748
|
f"{API_HOST}/dataset-optimization/run/{run_id}",
|
|
516
|
-
headers=
|
|
749
|
+
headers=get_headers(),
|
|
517
750
|
)
|
|
518
751
|
async for sse in async_iterator:
|
|
519
752
|
if sse.event == "ping":
|
|
@@ -562,7 +795,7 @@ class OptimizationDataset(BaseModel):
|
|
|
562
795
|
logger.info("Cancelling run with ID: %s", run_id)
|
|
563
796
|
response = requests.delete(
|
|
564
797
|
f"{API_HOST}/dataset-optimization/run/{run_id}",
|
|
565
|
-
headers=
|
|
798
|
+
headers=get_headers(),
|
|
566
799
|
timeout=MODIFY_TIMEOUT,
|
|
567
800
|
)
|
|
568
801
|
raise_for_status_with_reason(response)
|
|
@@ -574,3 +807,33 @@ class OptimizationDataset(BaseModel):
|
|
|
574
807
|
if not self.run_id:
|
|
575
808
|
raise ValueError("No run has been started")
|
|
576
809
|
self.cancel_by_id(self.run_id)
|
|
810
|
+
|
|
811
|
+
|
|
812
|
+
class DataOptimizationDatasetOut(BaseModel):
|
|
813
|
+
id: int
|
|
814
|
+
|
|
815
|
+
name: str
|
|
816
|
+
labeling_type: LabelingType
|
|
817
|
+
|
|
818
|
+
storage_config: ResponseStorageConfig
|
|
819
|
+
|
|
820
|
+
data_root_url: HirundoUrl
|
|
821
|
+
|
|
822
|
+
classes: typing.Optional[list[str]] = None
|
|
823
|
+
labeling_info: LabelingInfo
|
|
824
|
+
|
|
825
|
+
organization_id: typing.Optional[int]
|
|
826
|
+
creator_id: typing.Optional[int]
|
|
827
|
+
created_at: datetime.datetime
|
|
828
|
+
updated_at: datetime.datetime
|
|
829
|
+
|
|
830
|
+
|
|
831
|
+
class DataOptimizationRunOut(BaseModel):
|
|
832
|
+
id: int
|
|
833
|
+
name: str
|
|
834
|
+
dataset_id: int
|
|
835
|
+
run_id: str
|
|
836
|
+
status: RunStatus
|
|
837
|
+
approved: bool
|
|
838
|
+
created_at: datetime.datetime
|
|
839
|
+
run_args: typing.Optional[RunArgs]
|