hirundo 0.1.21__py3-none-any.whl → 0.2.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hirundo/__init__.py +19 -3
- hirundo/_constraints.py +2 -3
- hirundo/_iter_sse_retrying.py +7 -4
- hirundo/_llm_pipeline.py +153 -0
- hirundo/_run_checking.py +283 -0
- hirundo/_urls.py +1 -0
- hirundo/cli.py +1 -4
- hirundo/dataset_enum.py +2 -0
- hirundo/dataset_qa.py +106 -190
- hirundo/dataset_qa_results.py +3 -3
- hirundo/git.py +7 -8
- hirundo/labeling.py +22 -19
- hirundo/storage.py +25 -24
- hirundo/unlearning_llm.py +599 -0
- hirundo/unzip.py +3 -3
- {hirundo-0.1.21.dist-info → hirundo-0.2.3.post1.dist-info}/METADATA +42 -10
- hirundo-0.2.3.post1.dist-info/RECORD +28 -0
- {hirundo-0.1.21.dist-info → hirundo-0.2.3.post1.dist-info}/WHEEL +1 -1
- hirundo-0.1.21.dist-info/RECORD +0 -25
- {hirundo-0.1.21.dist-info → hirundo-0.2.3.post1.dist-info}/entry_points.txt +0 -0
- {hirundo-0.1.21.dist-info → hirundo-0.2.3.post1.dist-info}/licenses/LICENSE +0 -0
- {hirundo-0.1.21.dist-info → hirundo-0.2.3.post1.dist-info}/top_level.txt +0 -0
hirundo/dataset_qa.py
CHANGED
|
@@ -1,11 +1,9 @@
|
|
|
1
1
|
import datetime
|
|
2
|
-
import json
|
|
3
2
|
import typing
|
|
4
3
|
from collections.abc import AsyncGenerator, Generator
|
|
5
4
|
from enum import Enum
|
|
6
5
|
from typing import overload
|
|
7
6
|
|
|
8
|
-
import httpx
|
|
9
7
|
from pydantic import BaseModel, Field, model_validator
|
|
10
8
|
from tqdm import tqdm
|
|
11
9
|
from tqdm.contrib.logging import logging_redirect_tqdm
|
|
@@ -14,7 +12,16 @@ from hirundo._constraints import validate_labeling_info, validate_url
|
|
|
14
12
|
from hirundo._env import API_HOST
|
|
15
13
|
from hirundo._headers import get_headers
|
|
16
14
|
from hirundo._http import raise_for_status_with_reason, requests
|
|
17
|
-
from hirundo.
|
|
15
|
+
from hirundo._run_checking import (
|
|
16
|
+
STATUS_TO_PROGRESS_MAP,
|
|
17
|
+
RunStatus,
|
|
18
|
+
aiter_run_events,
|
|
19
|
+
build_status_text_map,
|
|
20
|
+
get_state,
|
|
21
|
+
handle_run_failure,
|
|
22
|
+
iter_run_events,
|
|
23
|
+
update_progress_from_result,
|
|
24
|
+
)
|
|
18
25
|
from hirundo._timeouts import MODIFY_TIMEOUT, READ_TIMEOUT
|
|
19
26
|
from hirundo._urls import HirundoUrl
|
|
20
27
|
from hirundo.dataset_enum import DatasetMetadataType, LabelingType
|
|
@@ -35,87 +42,57 @@ class HirundoError(Exception):
|
|
|
35
42
|
pass
|
|
36
43
|
|
|
37
44
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
PENDING = "PENDING"
|
|
43
|
-
STARTED = "STARTED"
|
|
44
|
-
SUCCESS = "SUCCESS"
|
|
45
|
-
FAILURE = "FAILURE"
|
|
46
|
-
AWAITING_MANUAL_APPROVAL = "AWAITING MANUAL APPROVAL"
|
|
47
|
-
REVOKED = "REVOKED"
|
|
48
|
-
REJECTED = "REJECTED"
|
|
49
|
-
RETRY = "RETRY"
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
STATUS_TO_TEXT_MAP = {
|
|
53
|
-
RunStatus.STARTED.value: "Dataset QA run in progress. Downloading dataset",
|
|
54
|
-
RunStatus.PENDING.value: "Dataset QA run queued and not yet started",
|
|
55
|
-
RunStatus.SUCCESS.value: "Dataset QA run completed successfully",
|
|
56
|
-
RunStatus.FAILURE.value: "Dataset QA run failed",
|
|
57
|
-
RunStatus.AWAITING_MANUAL_APPROVAL.value: "Awaiting manual approval",
|
|
58
|
-
RunStatus.RETRY.value: "Dataset QA run failed. Retrying",
|
|
59
|
-
RunStatus.REVOKED.value: "Dataset QA run was cancelled",
|
|
60
|
-
RunStatus.REJECTED.value: "Dataset QA run was rejected",
|
|
61
|
-
}
|
|
62
|
-
STATUS_TO_PROGRESS_MAP = {
|
|
63
|
-
RunStatus.STARTED.value: 0.0,
|
|
64
|
-
RunStatus.PENDING.value: 0.0,
|
|
65
|
-
RunStatus.SUCCESS.value: 100.0,
|
|
66
|
-
RunStatus.FAILURE.value: 100.0,
|
|
67
|
-
RunStatus.AWAITING_MANUAL_APPROVAL.value: 100.0,
|
|
68
|
-
RunStatus.RETRY.value: 0.0,
|
|
69
|
-
RunStatus.REVOKED.value: 100.0,
|
|
70
|
-
RunStatus.REJECTED.value: 0.0,
|
|
71
|
-
}
|
|
45
|
+
STATUS_TO_TEXT_MAP = build_status_text_map(
|
|
46
|
+
"Dataset QA",
|
|
47
|
+
started_detail="Dataset QA run in progress. Downloading dataset",
|
|
48
|
+
)
|
|
72
49
|
|
|
73
50
|
|
|
74
51
|
class ClassificationRunArgs(BaseModel):
|
|
75
|
-
image_size:
|
|
52
|
+
image_size: tuple[int, int] | None = (224, 224)
|
|
76
53
|
"""
|
|
77
54
|
Size (width, height) to which to resize classification images.
|
|
78
55
|
It is recommended to keep this value at (224, 224) unless your classes are differentiated by very small differences.
|
|
79
56
|
"""
|
|
80
|
-
upsample:
|
|
57
|
+
upsample: bool | None = False
|
|
81
58
|
"""
|
|
82
59
|
Whether to upsample the dataset to attempt to balance the classes.
|
|
83
60
|
"""
|
|
84
61
|
|
|
85
62
|
|
|
86
63
|
class ObjectDetectionRunArgs(ClassificationRunArgs):
|
|
87
|
-
min_abs_bbox_size:
|
|
64
|
+
min_abs_bbox_size: int | None = None
|
|
88
65
|
"""
|
|
89
66
|
Minimum valid size (in pixels) of a bounding box to keep it in the dataset for QA.
|
|
90
67
|
"""
|
|
91
|
-
min_abs_bbox_area:
|
|
68
|
+
min_abs_bbox_area: int | None = None
|
|
92
69
|
"""
|
|
93
70
|
Minimum valid absolute area (in pixels²) of a bounding box to keep it in the dataset for QA.
|
|
94
71
|
"""
|
|
95
|
-
min_rel_bbox_size:
|
|
72
|
+
min_rel_bbox_size: float | None = None
|
|
96
73
|
"""
|
|
97
74
|
Minimum valid size (as a fraction of both image height and width) for a bounding box
|
|
98
75
|
to keep it in the dataset for QA, relative to the corresponding dimension size,
|
|
99
76
|
i.e. if the bounding box is 10% of the image width and 5% of the image height, it will be kept if this value is 0.05, but not if the
|
|
100
77
|
value is 0.06 (since both width and height are checked).
|
|
101
78
|
"""
|
|
102
|
-
min_rel_bbox_area:
|
|
79
|
+
min_rel_bbox_area: float | None = None
|
|
103
80
|
"""
|
|
104
81
|
Minimum valid relative area (as a fraction of the image area) of a bounding box to keep it in the dataset for QA.
|
|
105
82
|
"""
|
|
106
|
-
crop_ratio:
|
|
83
|
+
crop_ratio: float | None = None
|
|
107
84
|
"""
|
|
108
85
|
Ratio of the bounding box to crop.
|
|
109
86
|
Change this value at your own risk. It is recommended to keep it at 1.0 unless you know what you are doing.
|
|
110
87
|
"""
|
|
111
|
-
add_mask_channel:
|
|
88
|
+
add_mask_channel: bool | None = None
|
|
112
89
|
"""
|
|
113
90
|
Whether to add a mask channel to the image.
|
|
114
91
|
Change at your own risk. It is recommended to keep it at False unless you know what you are doing.
|
|
115
92
|
"""
|
|
116
93
|
|
|
117
94
|
|
|
118
|
-
RunArgs =
|
|
95
|
+
RunArgs = ClassificationRunArgs | ObjectDetectionRunArgs
|
|
119
96
|
|
|
120
97
|
|
|
121
98
|
class AugmentationName(str, Enum):
|
|
@@ -128,32 +105,32 @@ class AugmentationName(str, Enum):
|
|
|
128
105
|
GAUSSIAN_BLUR = "GaussianBlur"
|
|
129
106
|
|
|
130
107
|
|
|
131
|
-
class
|
|
108
|
+
class ModalityType(str, Enum):
|
|
132
109
|
RADAR = "RADAR"
|
|
133
110
|
VISION = "VISION"
|
|
134
111
|
SPEECH = "SPEECH"
|
|
135
112
|
TABULAR = "TABULAR"
|
|
136
113
|
|
|
137
114
|
|
|
138
|
-
|
|
139
|
-
|
|
115
|
+
MODALITY_TO_SUPPORTED_LABELING_TYPES = {
|
|
116
|
+
ModalityType.RADAR: [
|
|
140
117
|
LabelingType.SINGLE_LABEL_CLASSIFICATION,
|
|
141
118
|
LabelingType.OBJECT_DETECTION,
|
|
142
119
|
],
|
|
143
|
-
|
|
120
|
+
ModalityType.VISION: [
|
|
144
121
|
LabelingType.SINGLE_LABEL_CLASSIFICATION,
|
|
145
122
|
LabelingType.OBJECT_DETECTION,
|
|
146
123
|
LabelingType.OBJECT_SEGMENTATION,
|
|
147
124
|
LabelingType.SEMANTIC_SEGMENTATION,
|
|
148
125
|
LabelingType.PANOPTIC_SEGMENTATION,
|
|
149
126
|
],
|
|
150
|
-
|
|
151
|
-
|
|
127
|
+
ModalityType.SPEECH: [LabelingType.SPEECH_TO_TEXT],
|
|
128
|
+
ModalityType.TABULAR: [LabelingType.SINGLE_LABEL_CLASSIFICATION],
|
|
152
129
|
}
|
|
153
130
|
|
|
154
131
|
|
|
155
132
|
class QADataset(BaseModel):
|
|
156
|
-
id:
|
|
133
|
+
id: int | None = Field(default=None)
|
|
157
134
|
"""
|
|
158
135
|
The ID of the dataset created on the server.
|
|
159
136
|
"""
|
|
@@ -169,17 +146,15 @@ class QADataset(BaseModel):
|
|
|
169
146
|
- `LabelingType.OBJECT_DETECTION`: Indicates that the dataset is for object detection tasks
|
|
170
147
|
- `LabelingType.SPEECH_TO_TEXT`: Indicates that the dataset is for speech-to-text tasks
|
|
171
148
|
"""
|
|
172
|
-
language:
|
|
149
|
+
language: str | None = None
|
|
173
150
|
"""
|
|
174
151
|
Language of the Speech-to-Text audio dataset. This is required for Speech-to-Text datasets.
|
|
175
152
|
"""
|
|
176
|
-
storage_config_id:
|
|
153
|
+
storage_config_id: int | None = None
|
|
177
154
|
"""
|
|
178
155
|
The ID of the storage config used to store the dataset and metadata.
|
|
179
156
|
"""
|
|
180
|
-
storage_config:
|
|
181
|
-
typing.Union[StorageConfig, ResponseStorageConfig]
|
|
182
|
-
] = None
|
|
157
|
+
storage_config: StorageConfig | ResponseStorageConfig | None = None
|
|
183
158
|
"""
|
|
184
159
|
The `StorageConfig` instance to link to.
|
|
185
160
|
"""
|
|
@@ -193,41 +168,44 @@ class QADataset(BaseModel):
|
|
|
193
168
|
Note: All CSV `image_path` entries in the metadata file should be relative to this folder.
|
|
194
169
|
"""
|
|
195
170
|
|
|
196
|
-
classes:
|
|
171
|
+
classes: list[str] | None = None
|
|
197
172
|
"""
|
|
198
173
|
A full list of possible classes used in classification / object detection.
|
|
199
174
|
It is currently required for clarity and performance.
|
|
200
175
|
"""
|
|
201
|
-
labeling_info:
|
|
176
|
+
labeling_info: LabelingInfo | list[LabelingInfo]
|
|
202
177
|
|
|
203
|
-
augmentations:
|
|
178
|
+
augmentations: list[AugmentationName] | None = None
|
|
204
179
|
"""
|
|
205
180
|
Used to define which augmentations are apply to a vision dataset.
|
|
206
181
|
For audio datasets, this field is ignored.
|
|
207
182
|
If no value is provided, all augmentations are applied to vision datasets.
|
|
208
183
|
"""
|
|
209
|
-
|
|
184
|
+
modality: ModalityType = ModalityType.VISION
|
|
210
185
|
"""
|
|
211
|
-
Used to define the
|
|
186
|
+
Used to define the modality of the dataset.
|
|
212
187
|
Defaults to Image.
|
|
213
188
|
"""
|
|
214
189
|
|
|
215
|
-
run_id:
|
|
190
|
+
run_id: str | None = Field(default=None, init=False)
|
|
216
191
|
"""
|
|
217
192
|
The ID of the Dataset QA run created on the server.
|
|
218
193
|
"""
|
|
219
194
|
|
|
220
|
-
status:
|
|
195
|
+
status: RunStatus | None = None
|
|
221
196
|
|
|
222
197
|
@model_validator(mode="after")
|
|
223
198
|
def validate_dataset(self):
|
|
224
|
-
if self.
|
|
199
|
+
if self.modality not in MODALITY_TO_SUPPORTED_LABELING_TYPES:
|
|
225
200
|
raise ValueError(
|
|
226
|
-
f"
|
|
201
|
+
f"Modality {self.modality} is not supported. Supported modalities are: {list(MODALITY_TO_SUPPORTED_LABELING_TYPES.keys())}"
|
|
227
202
|
)
|
|
228
|
-
if
|
|
203
|
+
if (
|
|
204
|
+
self.labeling_type
|
|
205
|
+
not in MODALITY_TO_SUPPORTED_LABELING_TYPES[self.modality]
|
|
206
|
+
):
|
|
229
207
|
raise ValueError(
|
|
230
|
-
f"Labeling type {self.labeling_type} is not supported for
|
|
208
|
+
f"Labeling type {self.labeling_type} is not supported for modality {self.modality}. Supported labeling types are: {MODALITY_TO_SUPPORTED_LABELING_TYPES[self.modality]}"
|
|
231
209
|
)
|
|
232
210
|
if self.storage_config is None and self.storage_config_id is None:
|
|
233
211
|
raise ValueError(
|
|
@@ -307,7 +285,7 @@ class QADataset(BaseModel):
|
|
|
307
285
|
|
|
308
286
|
@staticmethod
|
|
309
287
|
def list_datasets(
|
|
310
|
-
organization_id:
|
|
288
|
+
organization_id: int | None = None,
|
|
311
289
|
) -> list["QADatasetOut"]:
|
|
312
290
|
"""
|
|
313
291
|
Lists all the datasets created by user's default organization
|
|
@@ -333,7 +311,8 @@ class QADataset(BaseModel):
|
|
|
333
311
|
|
|
334
312
|
@staticmethod
|
|
335
313
|
def list_runs(
|
|
336
|
-
organization_id:
|
|
314
|
+
organization_id: int | None = None,
|
|
315
|
+
archived: bool | None = False,
|
|
337
316
|
) -> list["DataQARunOut"]:
|
|
338
317
|
"""
|
|
339
318
|
Lists all the `QADataset` instances created by user's default organization
|
|
@@ -342,10 +321,11 @@ class QADataset(BaseModel):
|
|
|
342
321
|
|
|
343
322
|
Args:
|
|
344
323
|
organization_id: The ID of the organization to list the datasets for.
|
|
324
|
+
archived: Whether to list archived runs.
|
|
345
325
|
"""
|
|
346
326
|
response = requests.get(
|
|
347
327
|
f"{API_HOST}/dataset-qa/run/list",
|
|
348
|
-
params={"dataset_organization_id": organization_id},
|
|
328
|
+
params={"dataset_organization_id": organization_id, "archived": archived},
|
|
349
329
|
headers=get_headers(),
|
|
350
330
|
timeout=READ_TIMEOUT,
|
|
351
331
|
)
|
|
@@ -396,7 +376,7 @@ class QADataset(BaseModel):
|
|
|
396
376
|
|
|
397
377
|
def create(
|
|
398
378
|
self,
|
|
399
|
-
organization_id:
|
|
379
|
+
organization_id: int | None = None,
|
|
400
380
|
replace_if_exists: bool = False,
|
|
401
381
|
) -> int:
|
|
402
382
|
"""
|
|
@@ -453,8 +433,8 @@ class QADataset(BaseModel):
|
|
|
453
433
|
@staticmethod
|
|
454
434
|
def launch_qa_run(
|
|
455
435
|
dataset_id: int,
|
|
456
|
-
organization_id:
|
|
457
|
-
run_args:
|
|
436
|
+
organization_id: int | None = None,
|
|
437
|
+
run_args: RunArgs | None = None,
|
|
458
438
|
) -> str:
|
|
459
439
|
"""
|
|
460
440
|
Run the dataset QA process on the server using the dataset with the given ID
|
|
@@ -503,9 +483,9 @@ class QADataset(BaseModel):
|
|
|
503
483
|
|
|
504
484
|
def run_qa(
|
|
505
485
|
self,
|
|
506
|
-
organization_id:
|
|
486
|
+
organization_id: int | None = None,
|
|
507
487
|
replace_dataset_if_exists: bool = False,
|
|
508
|
-
run_args:
|
|
488
|
+
run_args: RunArgs | None = None,
|
|
509
489
|
) -> str:
|
|
510
490
|
"""
|
|
511
491
|
If the dataset was not created on the server yet, it is created.
|
|
@@ -556,53 +536,20 @@ class QADataset(BaseModel):
|
|
|
556
536
|
|
|
557
537
|
@staticmethod
|
|
558
538
|
def _check_run_by_id(run_id: str, retry=0) -> Generator[dict, None, None]:
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
headers=get_headers(),
|
|
568
|
-
):
|
|
569
|
-
if sse.event == "ping":
|
|
570
|
-
continue
|
|
571
|
-
logger.debug(
|
|
572
|
-
"[SYNC] received event: %s with data: %s and ID: %s and retry: %s",
|
|
573
|
-
sse.event,
|
|
574
|
-
sse.data,
|
|
575
|
-
sse.id,
|
|
576
|
-
sse.retry,
|
|
577
|
-
)
|
|
578
|
-
last_event = json.loads(sse.data)
|
|
579
|
-
if not last_event:
|
|
580
|
-
continue
|
|
581
|
-
if "data" in last_event:
|
|
582
|
-
data = last_event["data"]
|
|
583
|
-
else:
|
|
584
|
-
if "detail" in last_event:
|
|
585
|
-
raise HirundoError(last_event["detail"])
|
|
586
|
-
elif "reason" in last_event:
|
|
587
|
-
raise HirundoError(last_event["reason"])
|
|
588
|
-
else:
|
|
589
|
-
raise HirundoError("Unknown error")
|
|
590
|
-
yield data
|
|
591
|
-
if not last_event or last_event["data"]["state"] == RunStatus.PENDING.value:
|
|
592
|
-
QADataset._check_run_by_id(run_id, retry + 1)
|
|
593
|
-
|
|
594
|
-
@staticmethod
|
|
595
|
-
def _handle_failure(iteration: dict):
|
|
596
|
-
if iteration["result"]:
|
|
597
|
-
raise HirundoError(f"QA run failed with error: {iteration['result']}")
|
|
598
|
-
else:
|
|
599
|
-
raise HirundoError("QA run failed with an unknown error in _handle_failure")
|
|
539
|
+
yield from iter_run_events(
|
|
540
|
+
f"{API_HOST}/dataset-qa/run/{run_id}",
|
|
541
|
+
headers=get_headers(),
|
|
542
|
+
retry=retry,
|
|
543
|
+
status_keys=("state",),
|
|
544
|
+
error_cls=HirundoError,
|
|
545
|
+
log=logger,
|
|
546
|
+
)
|
|
600
547
|
|
|
601
548
|
@staticmethod
|
|
602
549
|
@overload
|
|
603
550
|
def check_run_by_id(
|
|
604
551
|
run_id: str, stop_on_manual_approval: typing.Literal[True]
|
|
605
|
-
) ->
|
|
552
|
+
) -> DatasetQAResults | None: ...
|
|
606
553
|
|
|
607
554
|
@staticmethod
|
|
608
555
|
@overload
|
|
@@ -614,12 +561,12 @@ class QADataset(BaseModel):
|
|
|
614
561
|
@overload
|
|
615
562
|
def check_run_by_id(
|
|
616
563
|
run_id: str, stop_on_manual_approval: bool
|
|
617
|
-
) ->
|
|
564
|
+
) -> DatasetQAResults | None: ...
|
|
618
565
|
|
|
619
566
|
@staticmethod
|
|
620
567
|
def check_run_by_id(
|
|
621
568
|
run_id: str, stop_on_manual_approval: bool = False
|
|
622
|
-
) ->
|
|
569
|
+
) -> DatasetQAResults | None:
|
|
623
570
|
"""
|
|
624
571
|
Check the status of a run given its ID
|
|
625
572
|
|
|
@@ -637,22 +584,25 @@ class QADataset(BaseModel):
|
|
|
637
584
|
with logging_redirect_tqdm():
|
|
638
585
|
t = tqdm(total=100.0)
|
|
639
586
|
for iteration in QADataset._check_run_by_id(run_id):
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
t.
|
|
587
|
+
state = get_state(iteration, ("state",))
|
|
588
|
+
if state in STATUS_TO_PROGRESS_MAP:
|
|
589
|
+
t.set_description(STATUS_TO_TEXT_MAP[state])
|
|
590
|
+
t.n = STATUS_TO_PROGRESS_MAP[state]
|
|
643
591
|
logger.debug("Setting progress to %s", t.n)
|
|
644
592
|
t.refresh()
|
|
645
|
-
if
|
|
593
|
+
if state in [
|
|
646
594
|
RunStatus.FAILURE.value,
|
|
647
595
|
RunStatus.REJECTED.value,
|
|
648
596
|
RunStatus.REVOKED.value,
|
|
649
597
|
]:
|
|
650
598
|
logger.error(
|
|
651
599
|
"State is failure, rejected, or revoked: %s",
|
|
652
|
-
|
|
600
|
+
state,
|
|
653
601
|
)
|
|
654
|
-
|
|
655
|
-
|
|
602
|
+
handle_run_failure(
|
|
603
|
+
iteration, error_cls=HirundoError, run_label="QA"
|
|
604
|
+
)
|
|
605
|
+
elif state == RunStatus.SUCCESS.value:
|
|
656
606
|
t.close()
|
|
657
607
|
zip_temporary_url = iteration["result"]
|
|
658
608
|
logger.debug("QA run completed. Downloading results")
|
|
@@ -662,45 +612,24 @@ class QADataset(BaseModel):
|
|
|
662
612
|
zip_temporary_url,
|
|
663
613
|
)
|
|
664
614
|
elif (
|
|
665
|
-
|
|
615
|
+
state == RunStatus.AWAITING_MANUAL_APPROVAL.value
|
|
666
616
|
and stop_on_manual_approval
|
|
667
617
|
):
|
|
668
618
|
t.close()
|
|
669
619
|
return None
|
|
670
|
-
elif
|
|
671
|
-
|
|
672
|
-
iteration
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
)
|
|
677
|
-
result_info = iteration["result"]["result"].split(":")
|
|
678
|
-
if len(result_info) > 1:
|
|
679
|
-
stage = result_info[0]
|
|
680
|
-
current_progress_percentage = float(
|
|
681
|
-
result_info[1].removeprefix(" ").removesuffix("% done")
|
|
682
|
-
)
|
|
683
|
-
elif len(result_info) == 1:
|
|
684
|
-
stage = result_info[0]
|
|
685
|
-
current_progress_percentage = t.n # Keep the same progress
|
|
686
|
-
else:
|
|
687
|
-
stage = "Unknown progress state"
|
|
688
|
-
current_progress_percentage = t.n # Keep the same progress
|
|
689
|
-
desc = (
|
|
690
|
-
"QA run completed. Uploading results"
|
|
691
|
-
if current_progress_percentage == 100.0
|
|
692
|
-
else stage
|
|
693
|
-
)
|
|
694
|
-
t.set_description(desc)
|
|
695
|
-
t.n = current_progress_percentage
|
|
696
|
-
logger.debug("Setting progress to %s", t.n)
|
|
697
|
-
t.refresh()
|
|
620
|
+
elif state is None:
|
|
621
|
+
update_progress_from_result(
|
|
622
|
+
iteration,
|
|
623
|
+
t,
|
|
624
|
+
uploading_text="QA run completed. Uploading results",
|
|
625
|
+
log=logger,
|
|
626
|
+
)
|
|
698
627
|
raise HirundoError("QA run failed with an unknown error in check_run_by_id")
|
|
699
628
|
|
|
700
629
|
@overload
|
|
701
630
|
def check_run(
|
|
702
631
|
self, stop_on_manual_approval: typing.Literal[True]
|
|
703
|
-
) ->
|
|
632
|
+
) -> DatasetQAResults | None: ...
|
|
704
633
|
|
|
705
634
|
@overload
|
|
706
635
|
def check_run(
|
|
@@ -709,7 +638,7 @@ class QADataset(BaseModel):
|
|
|
709
638
|
|
|
710
639
|
def check_run(
|
|
711
640
|
self, stop_on_manual_approval: bool = False
|
|
712
|
-
) ->
|
|
641
|
+
) -> DatasetQAResults | None:
|
|
713
642
|
"""
|
|
714
643
|
Check the status of the current active instance's run.
|
|
715
644
|
|
|
@@ -741,32 +670,15 @@ class QADataset(BaseModel):
|
|
|
741
670
|
|
|
742
671
|
"""
|
|
743
672
|
logger.debug("Checking run with ID: %s", run_id)
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
f"{API_HOST}/dataset-qa/run/{run_id}",
|
|
754
|
-
headers=get_headers(),
|
|
755
|
-
)
|
|
756
|
-
async for sse in async_iterator:
|
|
757
|
-
if sse.event == "ping":
|
|
758
|
-
continue
|
|
759
|
-
logger.debug(
|
|
760
|
-
"[ASYNC] Received event: %s with data: %s and ID: %s and retry: %s",
|
|
761
|
-
sse.event,
|
|
762
|
-
sse.data,
|
|
763
|
-
sse.id,
|
|
764
|
-
sse.retry,
|
|
765
|
-
)
|
|
766
|
-
last_event = json.loads(sse.data)
|
|
767
|
-
yield last_event["data"]
|
|
768
|
-
if not last_event or last_event["data"]["state"] == RunStatus.PENDING.value:
|
|
769
|
-
QADataset.acheck_run_by_id(run_id, retry + 1)
|
|
673
|
+
async for iteration in aiter_run_events(
|
|
674
|
+
f"{API_HOST}/dataset-qa/run/{run_id}",
|
|
675
|
+
headers=get_headers(),
|
|
676
|
+
retry=retry,
|
|
677
|
+
status_keys=("state",),
|
|
678
|
+
error_cls=HirundoError,
|
|
679
|
+
log=logger,
|
|
680
|
+
):
|
|
681
|
+
yield iteration
|
|
770
682
|
|
|
771
683
|
async def acheck_run(self) -> AsyncGenerator[dict, None]:
|
|
772
684
|
"""
|
|
@@ -776,6 +688,8 @@ class QADataset(BaseModel):
|
|
|
776
688
|
|
|
777
689
|
This generator will produce values to show progress of the run.
|
|
778
690
|
|
|
691
|
+
Note: This function does not handle errors nor show progress. It is expected that you do that.
|
|
692
|
+
|
|
779
693
|
Yields:
|
|
780
694
|
Each event will be a dict, where:
|
|
781
695
|
- `"state"` is PENDING, STARTED, RETRY, FAILURE or SUCCESS
|
|
@@ -846,11 +760,11 @@ class QADatasetOut(BaseModel):
|
|
|
846
760
|
|
|
847
761
|
data_root_url: HirundoUrl
|
|
848
762
|
|
|
849
|
-
classes:
|
|
850
|
-
labeling_info:
|
|
763
|
+
classes: list[str] | None = None
|
|
764
|
+
labeling_info: LabelingInfo | list[LabelingInfo]
|
|
851
765
|
|
|
852
|
-
organization_id:
|
|
853
|
-
creator_id:
|
|
766
|
+
organization_id: int | None
|
|
767
|
+
creator_id: int | None
|
|
854
768
|
created_at: datetime.datetime
|
|
855
769
|
updated_at: datetime.datetime
|
|
856
770
|
|
|
@@ -863,4 +777,6 @@ class DataQARunOut(BaseModel):
|
|
|
863
777
|
status: RunStatus
|
|
864
778
|
approved: bool
|
|
865
779
|
created_at: datetime.datetime
|
|
866
|
-
run_args:
|
|
780
|
+
run_args: RunArgs | None
|
|
781
|
+
|
|
782
|
+
deleted_at: datetime.datetime | None = None
|
hirundo/dataset_qa_results.py
CHANGED
|
@@ -11,11 +11,11 @@ DataFrameType = TypeAliasType("DataFrameType", None)
|
|
|
11
11
|
if has_pandas:
|
|
12
12
|
from hirundo._dataframe import pd
|
|
13
13
|
|
|
14
|
-
DataFrameType = TypeAliasType("DataFrameType",
|
|
14
|
+
DataFrameType = TypeAliasType("DataFrameType", pd.DataFrame | None)
|
|
15
15
|
if has_polars:
|
|
16
16
|
from hirundo._dataframe import pl
|
|
17
17
|
|
|
18
|
-
DataFrameType = TypeAliasType("DataFrameType",
|
|
18
|
+
DataFrameType = TypeAliasType("DataFrameType", pl.DataFrame | None)
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
T = typing.TypeVar("T")
|
|
@@ -32,7 +32,7 @@ class DatasetQAResults(BaseModel, typing.Generic[T]):
|
|
|
32
32
|
"""
|
|
33
33
|
A polars/pandas DataFrame containing the results of the data QA run
|
|
34
34
|
"""
|
|
35
|
-
object_suspects:
|
|
35
|
+
object_suspects: T | None
|
|
36
36
|
"""
|
|
37
37
|
A polars/pandas DataFrame containing the object-level results of the data QA run
|
|
38
38
|
"""
|
hirundo/git.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import datetime
|
|
2
2
|
import re
|
|
3
|
-
import typing
|
|
4
3
|
|
|
5
4
|
import pydantic
|
|
6
5
|
from pydantic import BaseModel, field_validator
|
|
@@ -32,14 +31,14 @@ class GitSSHAuth(BaseModel):
|
|
|
32
31
|
"""
|
|
33
32
|
The SSH key for the Git repository
|
|
34
33
|
"""
|
|
35
|
-
ssh_password:
|
|
34
|
+
ssh_password: str | None
|
|
36
35
|
"""
|
|
37
36
|
The password for the SSH key for the Git repository.
|
|
38
37
|
"""
|
|
39
38
|
|
|
40
39
|
|
|
41
40
|
class GitRepo(BaseModel):
|
|
42
|
-
id:
|
|
41
|
+
id: int | None = None
|
|
43
42
|
"""
|
|
44
43
|
The ID of the Git repository.
|
|
45
44
|
"""
|
|
@@ -48,25 +47,25 @@ class GitRepo(BaseModel):
|
|
|
48
47
|
"""
|
|
49
48
|
A name to identify the Git repository in the Hirundo system.
|
|
50
49
|
"""
|
|
51
|
-
repository_url:
|
|
50
|
+
repository_url: str | RepoUrl
|
|
52
51
|
"""
|
|
53
52
|
The URL of the Git repository, it should start with `ssh://` or `https://` or be in the form `user@host:path`.
|
|
54
53
|
If it is in the form `user@host:path`, it will be rewritten to `ssh://user@host/path`.
|
|
55
54
|
"""
|
|
56
|
-
organization_id:
|
|
55
|
+
organization_id: int | None = None
|
|
57
56
|
"""
|
|
58
57
|
The ID of the organization that the Git repository belongs to.
|
|
59
58
|
If not provided, it will be assigned to your default organization.
|
|
60
59
|
"""
|
|
61
60
|
|
|
62
|
-
plain_auth:
|
|
61
|
+
plain_auth: GitPlainAuth | None = pydantic.Field(
|
|
63
62
|
default=None, examples=[None, {"username": "ben", "password": "password"}]
|
|
64
63
|
)
|
|
65
64
|
"""
|
|
66
65
|
The plain authentication details for the Git repository.
|
|
67
66
|
Use this if using a special user with a username and password for authentication.
|
|
68
67
|
"""
|
|
69
|
-
ssh_auth:
|
|
68
|
+
ssh_auth: GitSSHAuth | None = pydantic.Field(
|
|
70
69
|
default=None,
|
|
71
70
|
examples=[
|
|
72
71
|
{
|
|
@@ -84,7 +83,7 @@ class GitRepo(BaseModel):
|
|
|
84
83
|
|
|
85
84
|
@field_validator("repository_url", mode="before", check_fields=True)
|
|
86
85
|
@classmethod
|
|
87
|
-
def check_valid_repository_url(cls, repository_url:
|
|
86
|
+
def check_valid_repository_url(cls, repository_url: str | RepoUrl):
|
|
88
87
|
# Check if the URL has the `@` and `:` pattern with a non-numeric section before the next slash
|
|
89
88
|
match = re.match("([^@]+@[^:]+):([^0-9/][^/]*)/(.+)", str(repository_url))
|
|
90
89
|
if match:
|