hirundo 0.1.9__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hirundo/__init__.py +30 -11
- hirundo/_constraints.py +164 -53
- hirundo/_dataframe.py +43 -0
- hirundo/_env.py +2 -2
- hirundo/_headers.py +18 -2
- hirundo/_timeouts.py +1 -0
- hirundo/_urls.py +59 -0
- hirundo/cli.py +52 -0
- hirundo/dataset_enum.py +46 -0
- hirundo/dataset_optimization.py +93 -182
- hirundo/dataset_optimization_results.py +42 -0
- hirundo/git.py +12 -19
- hirundo/labeling.py +140 -0
- hirundo/storage.py +48 -67
- hirundo/unzip.py +247 -0
- {hirundo-0.1.9.dist-info → hirundo-0.1.18.dist-info}/METADATA +55 -44
- hirundo-0.1.18.dist-info/RECORD +25 -0
- {hirundo-0.1.9.dist-info → hirundo-0.1.18.dist-info}/WHEEL +1 -1
- hirundo/enum.py +0 -23
- hirundo-0.1.9.dist-info/RECORD +0 -20
- {hirundo-0.1.9.dist-info → hirundo-0.1.18.dist-info}/entry_points.txt +0 -0
- {hirundo-0.1.9.dist-info → hirundo-0.1.18.dist-info/licenses}/LICENSE +0 -0
- {hirundo-0.1.9.dist-info → hirundo-0.1.18.dist-info}/top_level.txt +0 -0
hirundo/dataset_optimization.py
CHANGED
|
@@ -1,30 +1,29 @@
|
|
|
1
1
|
import datetime
|
|
2
2
|
import json
|
|
3
3
|
import typing
|
|
4
|
-
from abc import ABC, abstractmethod
|
|
5
4
|
from collections.abc import AsyncGenerator, Generator
|
|
6
5
|
from enum import Enum
|
|
7
|
-
from io import StringIO
|
|
8
6
|
from typing import overload
|
|
9
7
|
|
|
10
8
|
import httpx
|
|
11
|
-
import numpy as np
|
|
12
|
-
import pandas as pd
|
|
13
9
|
import requests
|
|
14
|
-
from pandas._typing import DtypeArg
|
|
15
10
|
from pydantic import BaseModel, Field, model_validator
|
|
16
11
|
from tqdm import tqdm
|
|
17
12
|
from tqdm.contrib.logging import logging_redirect_tqdm
|
|
18
13
|
|
|
19
|
-
from hirundo._constraints import
|
|
14
|
+
from hirundo._constraints import validate_labeling_info, validate_url
|
|
20
15
|
from hirundo._env import API_HOST
|
|
21
|
-
from hirundo._headers import
|
|
16
|
+
from hirundo._headers import get_headers
|
|
22
17
|
from hirundo._http import raise_for_status_with_reason
|
|
23
18
|
from hirundo._iter_sse_retrying import aiter_sse_retrying, iter_sse_retrying
|
|
24
19
|
from hirundo._timeouts import MODIFY_TIMEOUT, READ_TIMEOUT
|
|
25
|
-
from hirundo.
|
|
20
|
+
from hirundo._urls import HirundoUrl
|
|
21
|
+
from hirundo.dataset_enum import DatasetMetadataType, LabelingType
|
|
22
|
+
from hirundo.dataset_optimization_results import DatasetOptimizationResults
|
|
23
|
+
from hirundo.labeling import YOLO, LabelingInfo
|
|
26
24
|
from hirundo.logger import get_logger
|
|
27
25
|
from hirundo.storage import ResponseStorageConfig, StorageConfig
|
|
26
|
+
from hirundo.unzip import download_and_extract_zip
|
|
28
27
|
|
|
29
28
|
logger = get_logger(__name__)
|
|
30
29
|
|
|
@@ -73,105 +72,6 @@ STATUS_TO_PROGRESS_MAP = {
|
|
|
73
72
|
}
|
|
74
73
|
|
|
75
74
|
|
|
76
|
-
class DatasetOptimizationResults(BaseModel):
|
|
77
|
-
model_config = {"arbitrary_types_allowed": True}
|
|
78
|
-
|
|
79
|
-
suspects: pd.DataFrame
|
|
80
|
-
"""
|
|
81
|
-
A pandas DataFrame containing the results of the optimization run
|
|
82
|
-
"""
|
|
83
|
-
warnings_and_errors: pd.DataFrame
|
|
84
|
-
"""
|
|
85
|
-
A pandas DataFrame containing the warnings and errors of the optimization run
|
|
86
|
-
"""
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
CUSTOMER_INTERCHANGE_DTYPES: DtypeArg = {
|
|
90
|
-
"image_path": str,
|
|
91
|
-
"label_path": str,
|
|
92
|
-
"segments_mask_path": str,
|
|
93
|
-
"segment_id": np.int32,
|
|
94
|
-
"label": str,
|
|
95
|
-
"bbox_id": str,
|
|
96
|
-
"xmin": np.float32,
|
|
97
|
-
"ymin": np.float32,
|
|
98
|
-
"xmax": np.float32,
|
|
99
|
-
"ymax": np.float32,
|
|
100
|
-
"suspect_level": np.float32, # If exists, must be one of the values in the enum below
|
|
101
|
-
"suggested_label": str,
|
|
102
|
-
"suggested_label_conf": np.float32,
|
|
103
|
-
"status": str,
|
|
104
|
-
# ⬆️ If exists, must be one of the following:
|
|
105
|
-
# NO_LABELS/MISSING_IMAGE/INVALID_IMAGE/INVALID_BBOX/INVALID_BBOX_SIZE/INVALID_SEG/INVALID_SEG_SIZE
|
|
106
|
-
}
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
class Metadata(BaseModel, ABC):
|
|
110
|
-
type: DatasetMetadataType
|
|
111
|
-
|
|
112
|
-
@property
|
|
113
|
-
@abstractmethod
|
|
114
|
-
def metadata_url(self) -> HirundoUrl:
|
|
115
|
-
raise NotImplementedError()
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
class HirundoCSV(Metadata):
|
|
119
|
-
"""
|
|
120
|
-
A dataset metadata file in the Hirundo CSV format
|
|
121
|
-
"""
|
|
122
|
-
|
|
123
|
-
type: DatasetMetadataType = DatasetMetadataType.HIRUNDO_CSV
|
|
124
|
-
csv_url: HirundoUrl
|
|
125
|
-
"""
|
|
126
|
-
The URL to access the dataset metadata CSV file.
|
|
127
|
-
e.g. `s3://my-bucket-name/my-folder/my-metadata.csv`, `gs://my-bucket-name/my-folder/my-metadata.csv`,
|
|
128
|
-
or `ssh://my-username@my-repo-name/my-folder/my-metadata.csv`
|
|
129
|
-
(or `file:///datasets/my-folder/my-metadata.csv` if using LOCAL storage type with on-premises installation)
|
|
130
|
-
"""
|
|
131
|
-
|
|
132
|
-
@property
|
|
133
|
-
def metadata_url(self) -> HirundoUrl:
|
|
134
|
-
return self.csv_url
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
class COCO(Metadata):
|
|
138
|
-
"""
|
|
139
|
-
A dataset metadata file in the COCO format
|
|
140
|
-
"""
|
|
141
|
-
|
|
142
|
-
type: DatasetMetadataType = DatasetMetadataType.COCO
|
|
143
|
-
json_url: HirundoUrl
|
|
144
|
-
"""
|
|
145
|
-
The URL to access the dataset metadata JSON file.
|
|
146
|
-
e.g. `s3://my-bucket-name/my-folder/my-metadata.json`, `gs://my-bucket-name/my-folder/my-metadata.json`,
|
|
147
|
-
or `ssh://my-username@my-repo-name/my-folder/my-metadata.json`
|
|
148
|
-
(or `file:///datasets/my-folder/my-metadata.json` if using LOCAL storage type with on-premises installation)
|
|
149
|
-
"""
|
|
150
|
-
|
|
151
|
-
@property
|
|
152
|
-
def metadata_url(self) -> HirundoUrl:
|
|
153
|
-
return self.json_url
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
class YOLO(Metadata):
|
|
157
|
-
type: DatasetMetadataType = DatasetMetadataType.YOLO
|
|
158
|
-
data_yaml_url: typing.Optional[HirundoUrl] = None
|
|
159
|
-
labels_dir_url: HirundoUrl
|
|
160
|
-
|
|
161
|
-
@property
|
|
162
|
-
def metadata_url(self) -> HirundoUrl:
|
|
163
|
-
return self.labels_dir_url
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
LabelingInfo = typing.Union[HirundoCSV, COCO, YOLO]
|
|
167
|
-
"""
|
|
168
|
-
The dataset labeling info. The dataset labeling info can be one of the following:
|
|
169
|
-
- `DatasetMetadataType.HirundoCSV`: Indicates that the dataset metadata file is a CSV file with the Hirundo format
|
|
170
|
-
|
|
171
|
-
Currently no other formats are supported. Future versions of `hirundo` may support additional formats.
|
|
172
|
-
"""
|
|
173
|
-
|
|
174
|
-
|
|
175
75
|
class VisionRunArgs(BaseModel):
|
|
176
76
|
upsample: bool = False
|
|
177
77
|
"""
|
|
@@ -201,13 +101,14 @@ class VisionRunArgs(BaseModel):
|
|
|
201
101
|
RunArgs = typing.Union[VisionRunArgs]
|
|
202
102
|
|
|
203
103
|
|
|
204
|
-
class
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
104
|
+
class AugmentationName(str, Enum):
|
|
105
|
+
RANDOM_HORIZONTAL_FLIP = "RandomHorizontalFlip"
|
|
106
|
+
RANDOM_VERTICAL_FLIP = "RandomVerticalFlip"
|
|
107
|
+
RANDOM_ROTATION = "RandomRotation"
|
|
108
|
+
RANDOM_PERSPECTIVE = "RandomPerspective"
|
|
109
|
+
GAUSSIAN_NOISE = "GaussianNoise"
|
|
110
|
+
RANDOM_GRAYSCALE = "RandomGrayscale"
|
|
111
|
+
GAUSSIAN_BLUR = "GaussianBlur"
|
|
211
112
|
|
|
212
113
|
|
|
213
114
|
class Modality(str, Enum):
|
|
@@ -262,9 +163,9 @@ class OptimizationDataset(BaseModel):
|
|
|
262
163
|
A full list of possible classes used in classification / object detection.
|
|
263
164
|
It is currently required for clarity and performance.
|
|
264
165
|
"""
|
|
265
|
-
labeling_info: LabelingInfo
|
|
166
|
+
labeling_info: typing.Union[LabelingInfo, list[LabelingInfo]]
|
|
266
167
|
|
|
267
|
-
augmentations: typing.Optional[list[
|
|
168
|
+
augmentations: typing.Optional[list[AugmentationName]] = None
|
|
268
169
|
"""
|
|
269
170
|
Used to define which augmentations are apply to a vision dataset.
|
|
270
171
|
For audio datasets, this field is ignored.
|
|
@@ -301,16 +202,30 @@ class OptimizationDataset(BaseModel):
|
|
|
301
202
|
):
|
|
302
203
|
raise ValueError("Language is only allowed for Speech-to-Text datasets.")
|
|
303
204
|
if (
|
|
304
|
-
self.labeling_info
|
|
205
|
+
not isinstance(self.labeling_info, list)
|
|
206
|
+
and self.labeling_info.type == DatasetMetadataType.YOLO
|
|
305
207
|
and isinstance(self.labeling_info, YOLO)
|
|
306
208
|
and (
|
|
307
209
|
self.labeling_info.data_yaml_url is not None
|
|
308
210
|
and self.classes is not None
|
|
309
211
|
)
|
|
212
|
+
) or (
|
|
213
|
+
isinstance(self.labeling_info, list)
|
|
214
|
+
and self.classes is not None
|
|
215
|
+
and any(
|
|
216
|
+
isinstance(info, YOLO) and info.data_yaml_url is not None
|
|
217
|
+
for info in self.labeling_info
|
|
218
|
+
)
|
|
310
219
|
):
|
|
311
220
|
raise ValueError(
|
|
312
221
|
"Only one of `classes` or `labeling_info.data_yaml_url` should be provided for YOLO datasets"
|
|
313
222
|
)
|
|
223
|
+
if self.storage_config:
|
|
224
|
+
validate_labeling_info(
|
|
225
|
+
self.labeling_type, self.labeling_info, self.storage_config
|
|
226
|
+
)
|
|
227
|
+
if self.data_root_url and self.storage_config:
|
|
228
|
+
validate_url(self.data_root_url, self.storage_config)
|
|
314
229
|
return self
|
|
315
230
|
|
|
316
231
|
@staticmethod
|
|
@@ -323,7 +238,7 @@ class OptimizationDataset(BaseModel):
|
|
|
323
238
|
"""
|
|
324
239
|
response = requests.get(
|
|
325
240
|
f"{API_HOST}/dataset-optimization/dataset/{dataset_id}",
|
|
326
|
-
headers=
|
|
241
|
+
headers=get_headers(),
|
|
327
242
|
timeout=READ_TIMEOUT,
|
|
328
243
|
)
|
|
329
244
|
raise_for_status_with_reason(response)
|
|
@@ -340,7 +255,7 @@ class OptimizationDataset(BaseModel):
|
|
|
340
255
|
"""
|
|
341
256
|
response = requests.get(
|
|
342
257
|
f"{API_HOST}/dataset-optimization/dataset/by-name/{name}",
|
|
343
|
-
headers=
|
|
258
|
+
headers=get_headers(),
|
|
344
259
|
timeout=READ_TIMEOUT,
|
|
345
260
|
)
|
|
346
261
|
raise_for_status_with_reason(response)
|
|
@@ -361,7 +276,7 @@ class OptimizationDataset(BaseModel):
|
|
|
361
276
|
response = requests.get(
|
|
362
277
|
f"{API_HOST}/dataset-optimization/dataset/",
|
|
363
278
|
params={"dataset_organization_id": organization_id},
|
|
364
|
-
headers=
|
|
279
|
+
headers=get_headers(),
|
|
365
280
|
timeout=READ_TIMEOUT,
|
|
366
281
|
)
|
|
367
282
|
raise_for_status_with_reason(response)
|
|
@@ -388,7 +303,7 @@ class OptimizationDataset(BaseModel):
|
|
|
388
303
|
response = requests.get(
|
|
389
304
|
f"{API_HOST}/dataset-optimization/run/list",
|
|
390
305
|
params={"dataset_organization_id": organization_id},
|
|
391
|
-
headers=
|
|
306
|
+
headers=get_headers(),
|
|
392
307
|
timeout=READ_TIMEOUT,
|
|
393
308
|
)
|
|
394
309
|
raise_for_status_with_reason(response)
|
|
@@ -410,7 +325,7 @@ class OptimizationDataset(BaseModel):
|
|
|
410
325
|
"""
|
|
411
326
|
response = requests.delete(
|
|
412
327
|
f"{API_HOST}/dataset-optimization/dataset/{dataset_id}",
|
|
413
|
-
headers=
|
|
328
|
+
headers=get_headers(),
|
|
414
329
|
timeout=MODIFY_TIMEOUT,
|
|
415
330
|
)
|
|
416
331
|
raise_for_status_with_reason(response)
|
|
@@ -482,10 +397,7 @@ class OptimizationDataset(BaseModel):
|
|
|
482
397
|
"organization_id": organization_id,
|
|
483
398
|
"replace_if_exists": replace_if_exists,
|
|
484
399
|
},
|
|
485
|
-
headers=
|
|
486
|
-
**json_headers,
|
|
487
|
-
**get_auth_headers(),
|
|
488
|
-
},
|
|
400
|
+
headers=get_headers(),
|
|
489
401
|
timeout=MODIFY_TIMEOUT,
|
|
490
402
|
)
|
|
491
403
|
raise_for_status_with_reason(dataset_response)
|
|
@@ -519,7 +431,7 @@ class OptimizationDataset(BaseModel):
|
|
|
519
431
|
run_response = requests.post(
|
|
520
432
|
f"{API_HOST}/dataset-optimization/run/{dataset_id}",
|
|
521
433
|
json=run_info if len(run_info) > 0 else None,
|
|
522
|
-
headers=
|
|
434
|
+
headers=get_headers(),
|
|
523
435
|
timeout=MODIFY_TIMEOUT,
|
|
524
436
|
)
|
|
525
437
|
raise_for_status_with_reason(run_response)
|
|
@@ -595,46 +507,6 @@ class OptimizationDataset(BaseModel):
|
|
|
595
507
|
self.id = None
|
|
596
508
|
self.run_id = None
|
|
597
509
|
|
|
598
|
-
@staticmethod
|
|
599
|
-
def _clean_df_index(df: "pd.DataFrame") -> "pd.DataFrame":
|
|
600
|
-
"""
|
|
601
|
-
Clean the index of a dataframe in case it has unnamed columns.
|
|
602
|
-
|
|
603
|
-
Args:
|
|
604
|
-
df (DataFrame): Dataframe to clean
|
|
605
|
-
|
|
606
|
-
Returns:
|
|
607
|
-
DataFrame: Cleaned dataframe
|
|
608
|
-
"""
|
|
609
|
-
index_cols = sorted(
|
|
610
|
-
[col for col in df.columns if col.startswith("Unnamed")], reverse=True
|
|
611
|
-
)
|
|
612
|
-
if len(index_cols) > 0:
|
|
613
|
-
df.set_index(index_cols.pop(), inplace=True)
|
|
614
|
-
df.rename_axis(index=None, columns=None, inplace=True)
|
|
615
|
-
if len(index_cols) > 0:
|
|
616
|
-
df.drop(columns=index_cols, inplace=True)
|
|
617
|
-
|
|
618
|
-
return df
|
|
619
|
-
|
|
620
|
-
@staticmethod
|
|
621
|
-
def _read_csvs_to_df(data: dict):
|
|
622
|
-
if data["state"] == RunStatus.SUCCESS.value:
|
|
623
|
-
data["result"]["suspects"] = OptimizationDataset._clean_df_index(
|
|
624
|
-
pd.read_csv(
|
|
625
|
-
StringIO(data["result"]["suspects"]),
|
|
626
|
-
dtype=CUSTOMER_INTERCHANGE_DTYPES,
|
|
627
|
-
)
|
|
628
|
-
)
|
|
629
|
-
data["result"]["warnings_and_errors"] = OptimizationDataset._clean_df_index(
|
|
630
|
-
pd.read_csv(
|
|
631
|
-
StringIO(data["result"]["warnings_and_errors"]),
|
|
632
|
-
dtype=CUSTOMER_INTERCHANGE_DTYPES,
|
|
633
|
-
)
|
|
634
|
-
)
|
|
635
|
-
else:
|
|
636
|
-
pass
|
|
637
|
-
|
|
638
510
|
@staticmethod
|
|
639
511
|
def _check_run_by_id(run_id: str, retry=0) -> Generator[dict, None, None]:
|
|
640
512
|
if retry > MAX_RETRIES:
|
|
@@ -645,7 +517,7 @@ class OptimizationDataset(BaseModel):
|
|
|
645
517
|
client,
|
|
646
518
|
"GET",
|
|
647
519
|
f"{API_HOST}/dataset-optimization/run/{run_id}",
|
|
648
|
-
headers=
|
|
520
|
+
headers=get_headers(),
|
|
649
521
|
):
|
|
650
522
|
if sse.event == "ping":
|
|
651
523
|
continue
|
|
@@ -668,11 +540,21 @@ class OptimizationDataset(BaseModel):
|
|
|
668
540
|
raise HirundoError(last_event["reason"])
|
|
669
541
|
else:
|
|
670
542
|
raise HirundoError("Unknown error")
|
|
671
|
-
OptimizationDataset._read_csvs_to_df(data)
|
|
672
543
|
yield data
|
|
673
544
|
if not last_event or last_event["data"]["state"] == RunStatus.PENDING.value:
|
|
674
545
|
OptimizationDataset._check_run_by_id(run_id, retry + 1)
|
|
675
546
|
|
|
547
|
+
@staticmethod
|
|
548
|
+
def _handle_failure(iteration: dict):
|
|
549
|
+
if iteration["result"]:
|
|
550
|
+
raise HirundoError(
|
|
551
|
+
f"Optimization run failed with error: {iteration['result']}"
|
|
552
|
+
)
|
|
553
|
+
else:
|
|
554
|
+
raise HirundoError(
|
|
555
|
+
"Optimization run failed with an unknown error in _handle_failure"
|
|
556
|
+
)
|
|
557
|
+
|
|
676
558
|
@staticmethod
|
|
677
559
|
@overload
|
|
678
560
|
def check_run_by_id(
|
|
@@ -722,16 +604,19 @@ class OptimizationDataset(BaseModel):
|
|
|
722
604
|
RunStatus.REJECTED.value,
|
|
723
605
|
RunStatus.REVOKED.value,
|
|
724
606
|
]:
|
|
725
|
-
|
|
726
|
-
|
|
607
|
+
logger.error(
|
|
608
|
+
"State is failure, rejected, or revoked: %s",
|
|
609
|
+
iteration["state"],
|
|
727
610
|
)
|
|
611
|
+
OptimizationDataset._handle_failure(iteration)
|
|
728
612
|
elif iteration["state"] == RunStatus.SUCCESS.value:
|
|
729
613
|
t.close()
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
614
|
+
zip_temporary_url = iteration["result"]
|
|
615
|
+
logger.debug("Optimization run completed. Downloading results")
|
|
616
|
+
|
|
617
|
+
return download_and_extract_zip(
|
|
618
|
+
run_id,
|
|
619
|
+
zip_temporary_url,
|
|
735
620
|
)
|
|
736
621
|
elif (
|
|
737
622
|
iteration["state"] == RunStatus.AWAITING_MANUAL_APPROVAL.value
|
|
@@ -767,7 +652,9 @@ class OptimizationDataset(BaseModel):
|
|
|
767
652
|
t.n = current_progress_percentage
|
|
768
653
|
logger.debug("Setting progress to %s", t.n)
|
|
769
654
|
t.refresh()
|
|
770
|
-
raise HirundoError(
|
|
655
|
+
raise HirundoError(
|
|
656
|
+
"Optimization run failed with an unknown error in check_run_by_id"
|
|
657
|
+
)
|
|
771
658
|
|
|
772
659
|
@overload
|
|
773
660
|
def check_run(
|
|
@@ -823,7 +710,7 @@ class OptimizationDataset(BaseModel):
|
|
|
823
710
|
client,
|
|
824
711
|
"GET",
|
|
825
712
|
f"{API_HOST}/dataset-optimization/run/{run_id}",
|
|
826
|
-
headers=
|
|
713
|
+
headers=get_headers(),
|
|
827
714
|
)
|
|
828
715
|
async for sse in async_iterator:
|
|
829
716
|
if sse.event == "ping":
|
|
@@ -867,12 +754,10 @@ class OptimizationDataset(BaseModel):
|
|
|
867
754
|
Args:
|
|
868
755
|
run_id: The ID of the run to cancel
|
|
869
756
|
"""
|
|
870
|
-
if not run_id:
|
|
871
|
-
raise ValueError("No run has been started")
|
|
872
757
|
logger.info("Cancelling run with ID: %s", run_id)
|
|
873
758
|
response = requests.delete(
|
|
874
759
|
f"{API_HOST}/dataset-optimization/run/{run_id}",
|
|
875
|
-
headers=
|
|
760
|
+
headers=get_headers(),
|
|
876
761
|
timeout=MODIFY_TIMEOUT,
|
|
877
762
|
)
|
|
878
763
|
raise_for_status_with_reason(response)
|
|
@@ -885,6 +770,30 @@ class OptimizationDataset(BaseModel):
|
|
|
885
770
|
raise ValueError("No run has been started")
|
|
886
771
|
self.cancel_by_id(self.run_id)
|
|
887
772
|
|
|
773
|
+
@staticmethod
|
|
774
|
+
def archive_run_by_id(run_id: str) -> None:
|
|
775
|
+
"""
|
|
776
|
+
Archive the dataset optimization run for the given `run_id`.
|
|
777
|
+
|
|
778
|
+
Args:
|
|
779
|
+
run_id: The ID of the run to archive
|
|
780
|
+
"""
|
|
781
|
+
logger.info("Archiving run with ID: %s", run_id)
|
|
782
|
+
response = requests.patch(
|
|
783
|
+
f"{API_HOST}/dataset-optimization/run/archive/{run_id}",
|
|
784
|
+
headers=get_headers(),
|
|
785
|
+
timeout=MODIFY_TIMEOUT,
|
|
786
|
+
)
|
|
787
|
+
raise_for_status_with_reason(response)
|
|
788
|
+
|
|
789
|
+
def archive(self) -> None:
|
|
790
|
+
"""
|
|
791
|
+
Archive the current active instance's run.
|
|
792
|
+
"""
|
|
793
|
+
if not self.run_id:
|
|
794
|
+
raise ValueError("No run has been started")
|
|
795
|
+
self.archive_run_by_id(self.run_id)
|
|
796
|
+
|
|
888
797
|
|
|
889
798
|
class DataOptimizationDatasetOut(BaseModel):
|
|
890
799
|
id: int
|
|
@@ -897,7 +806,7 @@ class DataOptimizationDatasetOut(BaseModel):
|
|
|
897
806
|
data_root_url: HirundoUrl
|
|
898
807
|
|
|
899
808
|
classes: typing.Optional[list[str]] = None
|
|
900
|
-
labeling_info: LabelingInfo
|
|
809
|
+
labeling_info: typing.Union[LabelingInfo, list[LabelingInfo]]
|
|
901
810
|
|
|
902
811
|
organization_id: typing.Optional[int]
|
|
903
812
|
creator_id: typing.Optional[int]
|
|
@@ -908,7 +817,9 @@ class DataOptimizationDatasetOut(BaseModel):
|
|
|
908
817
|
class DataOptimizationRunOut(BaseModel):
|
|
909
818
|
id: int
|
|
910
819
|
name: str
|
|
820
|
+
dataset_id: int
|
|
911
821
|
run_id: str
|
|
912
822
|
status: RunStatus
|
|
913
823
|
approved: bool
|
|
914
824
|
created_at: datetime.datetime
|
|
825
|
+
run_args: typing.Optional[RunArgs]
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
from typing_extensions import TypeAliasType
|
|
6
|
+
|
|
7
|
+
from hirundo._dataframe import has_pandas, has_polars
|
|
8
|
+
|
|
9
|
+
DataFrameType = TypeAliasType("DataFrameType", None)
|
|
10
|
+
|
|
11
|
+
if has_pandas:
|
|
12
|
+
from hirundo._dataframe import pd
|
|
13
|
+
|
|
14
|
+
DataFrameType = TypeAliasType("DataFrameType", typing.Union[pd.DataFrame, None])
|
|
15
|
+
if has_polars:
|
|
16
|
+
from hirundo._dataframe import pl
|
|
17
|
+
|
|
18
|
+
DataFrameType = TypeAliasType("DataFrameType", typing.Union[pl.DataFrame, None])
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
T = typing.TypeVar("T")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class DatasetOptimizationResults(BaseModel, typing.Generic[T]):
|
|
25
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
26
|
+
|
|
27
|
+
cached_zip_path: Path
|
|
28
|
+
"""
|
|
29
|
+
The path to the cached zip file of the results
|
|
30
|
+
"""
|
|
31
|
+
suspects: T
|
|
32
|
+
"""
|
|
33
|
+
A polars/pandas DataFrame containing the results of the optimization run
|
|
34
|
+
"""
|
|
35
|
+
object_suspects: typing.Optional[T]
|
|
36
|
+
"""
|
|
37
|
+
A polars/pandas DataFrame containing the object-level results of the optimization run
|
|
38
|
+
"""
|
|
39
|
+
warnings_and_errors: T
|
|
40
|
+
"""
|
|
41
|
+
A polars/pandas DataFrame containing the warnings and errors of the optimization run
|
|
42
|
+
"""
|
hirundo/git.py
CHANGED
|
@@ -7,17 +7,17 @@ import requests
|
|
|
7
7
|
from pydantic import BaseModel, field_validator
|
|
8
8
|
from pydantic_core import Url
|
|
9
9
|
|
|
10
|
-
from hirundo._constraints import RepoUrl
|
|
11
10
|
from hirundo._env import API_HOST
|
|
12
|
-
from hirundo._headers import
|
|
11
|
+
from hirundo._headers import get_headers
|
|
13
12
|
from hirundo._http import raise_for_status_with_reason
|
|
14
13
|
from hirundo._timeouts import MODIFY_TIMEOUT, READ_TIMEOUT
|
|
14
|
+
from hirundo._urls import RepoUrl
|
|
15
15
|
from hirundo.logger import get_logger
|
|
16
16
|
|
|
17
17
|
logger = get_logger(__name__)
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
class
|
|
20
|
+
class GitPlainAuth(BaseModel):
|
|
21
21
|
username: str
|
|
22
22
|
"""
|
|
23
23
|
The username for the Git repository
|
|
@@ -28,7 +28,7 @@ class GitPlainAuthBase(BaseModel):
|
|
|
28
28
|
"""
|
|
29
29
|
|
|
30
30
|
|
|
31
|
-
class
|
|
31
|
+
class GitSSHAuth(BaseModel):
|
|
32
32
|
ssh_key: str
|
|
33
33
|
"""
|
|
34
34
|
The SSH key for the Git repository
|
|
@@ -52,7 +52,7 @@ class GitRepo(BaseModel):
|
|
|
52
52
|
repository_url: typing.Union[str, RepoUrl]
|
|
53
53
|
"""
|
|
54
54
|
The URL of the Git repository, it should start with `ssh://` or `https://` or be in the form `user@host:path`.
|
|
55
|
-
If it is in the form `user@host:path`, it will be rewritten to `ssh://user@host
|
|
55
|
+
If it is in the form `user@host:path`, it will be rewritten to `ssh://user@host/path`.
|
|
56
56
|
"""
|
|
57
57
|
organization_id: typing.Optional[int] = None
|
|
58
58
|
"""
|
|
@@ -60,14 +60,14 @@ class GitRepo(BaseModel):
|
|
|
60
60
|
If not provided, it will be assigned to your default organization.
|
|
61
61
|
"""
|
|
62
62
|
|
|
63
|
-
plain_auth: typing.Optional[
|
|
63
|
+
plain_auth: typing.Optional[GitPlainAuth] = pydantic.Field(
|
|
64
64
|
default=None, examples=[None, {"username": "ben", "password": "password"}]
|
|
65
65
|
)
|
|
66
66
|
"""
|
|
67
67
|
The plain authentication details for the Git repository.
|
|
68
68
|
Use this if using a special user with a username and password for authentication.
|
|
69
69
|
"""
|
|
70
|
-
ssh_auth: typing.Optional[
|
|
70
|
+
ssh_auth: typing.Optional[GitSSHAuth] = pydantic.Field(
|
|
71
71
|
default=None,
|
|
72
72
|
examples=[
|
|
73
73
|
{
|
|
@@ -124,10 +124,7 @@ class GitRepo(BaseModel):
|
|
|
124
124
|
**self.model_dump(mode="json"),
|
|
125
125
|
"replace_if_exists": replace_if_exists,
|
|
126
126
|
},
|
|
127
|
-
headers=
|
|
128
|
-
**json_headers,
|
|
129
|
-
**get_auth_headers(),
|
|
130
|
-
},
|
|
127
|
+
headers=get_headers(),
|
|
131
128
|
timeout=MODIFY_TIMEOUT,
|
|
132
129
|
)
|
|
133
130
|
raise_for_status_with_reason(git_repo)
|
|
@@ -145,7 +142,7 @@ class GitRepo(BaseModel):
|
|
|
145
142
|
"""
|
|
146
143
|
git_repo = requests.get(
|
|
147
144
|
f"{API_HOST}/git-repo/{git_repo_id}",
|
|
148
|
-
headers=
|
|
145
|
+
headers=get_headers(),
|
|
149
146
|
timeout=READ_TIMEOUT,
|
|
150
147
|
)
|
|
151
148
|
raise_for_status_with_reason(git_repo)
|
|
@@ -163,7 +160,7 @@ class GitRepo(BaseModel):
|
|
|
163
160
|
"""
|
|
164
161
|
git_repo = requests.get(
|
|
165
162
|
f"{API_HOST}/git-repo/by-name/{name}",
|
|
166
|
-
headers=
|
|
163
|
+
headers=get_headers(),
|
|
167
164
|
timeout=READ_TIMEOUT,
|
|
168
165
|
)
|
|
169
166
|
raise_for_status_with_reason(git_repo)
|
|
@@ -176,9 +173,7 @@ class GitRepo(BaseModel):
|
|
|
176
173
|
"""
|
|
177
174
|
git_repos = requests.get(
|
|
178
175
|
f"{API_HOST}/git-repo/",
|
|
179
|
-
headers=
|
|
180
|
-
**get_auth_headers(),
|
|
181
|
-
},
|
|
176
|
+
headers=get_headers(),
|
|
182
177
|
timeout=READ_TIMEOUT,
|
|
183
178
|
)
|
|
184
179
|
raise_for_status_with_reason(git_repos)
|
|
@@ -200,9 +195,7 @@ class GitRepo(BaseModel):
|
|
|
200
195
|
"""
|
|
201
196
|
git_repo = requests.delete(
|
|
202
197
|
f"{API_HOST}/git-repo/{git_repo_id}",
|
|
203
|
-
headers=
|
|
204
|
-
**get_auth_headers(),
|
|
205
|
-
},
|
|
198
|
+
headers=get_headers(),
|
|
206
199
|
timeout=MODIFY_TIMEOUT,
|
|
207
200
|
)
|
|
208
201
|
raise_for_status_with_reason(git_repo)
|