hirundo 0.1.21__py3-none-any.whl → 0.2.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hirundo/__init__.py +19 -3
- hirundo/_constraints.py +2 -3
- hirundo/_iter_sse_retrying.py +7 -4
- hirundo/_llm_pipeline.py +153 -0
- hirundo/_run_checking.py +283 -0
- hirundo/_urls.py +1 -0
- hirundo/cli.py +1 -4
- hirundo/dataset_enum.py +2 -0
- hirundo/dataset_qa.py +106 -190
- hirundo/dataset_qa_results.py +3 -3
- hirundo/git.py +7 -8
- hirundo/labeling.py +22 -19
- hirundo/storage.py +25 -24
- hirundo/unlearning_llm.py +599 -0
- hirundo/unzip.py +3 -3
- {hirundo-0.1.21.dist-info → hirundo-0.2.3.post1.dist-info}/METADATA +42 -10
- hirundo-0.2.3.post1.dist-info/RECORD +28 -0
- {hirundo-0.1.21.dist-info → hirundo-0.2.3.post1.dist-info}/WHEEL +1 -1
- hirundo-0.1.21.dist-info/RECORD +0 -25
- {hirundo-0.1.21.dist-info → hirundo-0.2.3.post1.dist-info}/entry_points.txt +0 -0
- {hirundo-0.1.21.dist-info → hirundo-0.2.3.post1.dist-info}/licenses/LICENSE +0 -0
- {hirundo-0.1.21.dist-info → hirundo-0.2.3.post1.dist-info}/top_level.txt +0 -0
hirundo/__init__.py
CHANGED
|
@@ -5,8 +5,8 @@ from .dataset_enum import (
|
|
|
5
5
|
)
|
|
6
6
|
from .dataset_qa import (
|
|
7
7
|
ClassificationRunArgs,
|
|
8
|
-
Domain,
|
|
9
8
|
HirundoError,
|
|
9
|
+
ModalityType,
|
|
10
10
|
ObjectDetectionRunArgs,
|
|
11
11
|
QADataset,
|
|
12
12
|
RunArgs,
|
|
@@ -30,6 +30,15 @@ from .storage import (
|
|
|
30
30
|
StorageGit,
|
|
31
31
|
StorageS3,
|
|
32
32
|
)
|
|
33
|
+
from .unlearning_llm import (
|
|
34
|
+
BiasRunInfo,
|
|
35
|
+
BiasType,
|
|
36
|
+
HuggingFaceTransformersModel,
|
|
37
|
+
LlmModel,
|
|
38
|
+
LlmSources,
|
|
39
|
+
LlmUnlearningRun,
|
|
40
|
+
LocalTransformersModel,
|
|
41
|
+
)
|
|
33
42
|
from .unzip import load_df, load_from_zip
|
|
34
43
|
|
|
35
44
|
__all__ = [
|
|
@@ -43,7 +52,7 @@ __all__ = [
|
|
|
43
52
|
"KeylabsObjSegImages",
|
|
44
53
|
"KeylabsObjSegVideo",
|
|
45
54
|
"QADataset",
|
|
46
|
-
"
|
|
55
|
+
"ModalityType",
|
|
47
56
|
"RunArgs",
|
|
48
57
|
"ClassificationRunArgs",
|
|
49
58
|
"ObjectDetectionRunArgs",
|
|
@@ -59,8 +68,15 @@ __all__ = [
|
|
|
59
68
|
"StorageGit",
|
|
60
69
|
"StorageConfig",
|
|
61
70
|
"DatasetQAResults",
|
|
71
|
+
"BiasRunInfo",
|
|
72
|
+
"BiasType",
|
|
73
|
+
"HuggingFaceTransformersModel",
|
|
74
|
+
"LlmModel",
|
|
75
|
+
"LlmSources",
|
|
76
|
+
"LlmUnlearningRun",
|
|
77
|
+
"LocalTransformersModel",
|
|
62
78
|
"load_df",
|
|
63
79
|
"load_from_zip",
|
|
64
80
|
]
|
|
65
81
|
|
|
66
|
-
__version__ = "0.
|
|
82
|
+
__version__ = "0.2.3.post1"
|
hirundo/_constraints.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import re
|
|
2
|
-
import typing
|
|
3
2
|
from typing import TYPE_CHECKING
|
|
4
3
|
|
|
5
4
|
from hirundo._urls import (
|
|
@@ -135,8 +134,8 @@ def validate_labeling_type(
|
|
|
135
134
|
|
|
136
135
|
def validate_labeling_info(
|
|
137
136
|
labeling_type: "LabelingType",
|
|
138
|
-
labeling_info: "
|
|
139
|
-
storage_config: "
|
|
137
|
+
labeling_info: "LabelingInfo | list[LabelingInfo]",
|
|
138
|
+
storage_config: "StorageConfig | ResponseStorageConfig",
|
|
140
139
|
) -> None:
|
|
141
140
|
"""
|
|
142
141
|
Validate the labeling info for a dataset
|
hirundo/_iter_sse_retrying.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import time
|
|
3
|
-
import typing
|
|
4
3
|
import uuid
|
|
5
4
|
from collections.abc import AsyncGenerator, Generator
|
|
6
5
|
|
|
@@ -15,13 +14,15 @@ from hirundo.logger import get_logger
|
|
|
15
14
|
|
|
16
15
|
logger = get_logger(__name__)
|
|
17
16
|
|
|
17
|
+
MAX_RETRIES = 50
|
|
18
|
+
|
|
18
19
|
|
|
19
20
|
# Credit: https://github.com/florimondmanca/httpx-sse/blob/master/README.md#handling-reconnections
|
|
20
21
|
def iter_sse_retrying(
|
|
21
22
|
client: httpx.Client,
|
|
22
23
|
method: str,
|
|
23
24
|
url: str,
|
|
24
|
-
headers:
|
|
25
|
+
headers: dict[str, str] | None = None,
|
|
25
26
|
) -> Generator[ServerSentEvent, None, None]:
|
|
26
27
|
if headers is None:
|
|
27
28
|
headers = {}
|
|
@@ -41,7 +42,8 @@ def iter_sse_retrying(
|
|
|
41
42
|
httpx.ReadError,
|
|
42
43
|
httpx.RemoteProtocolError,
|
|
43
44
|
urllib3.exceptions.ReadTimeoutError,
|
|
44
|
-
)
|
|
45
|
+
),
|
|
46
|
+
attempts=MAX_RETRIES,
|
|
45
47
|
)
|
|
46
48
|
def _iter_sse():
|
|
47
49
|
nonlocal last_event_id, reconnection_delay
|
|
@@ -105,7 +107,8 @@ async def aiter_sse_retrying(
|
|
|
105
107
|
httpx.ReadError,
|
|
106
108
|
httpx.RemoteProtocolError,
|
|
107
109
|
urllib3.exceptions.ReadTimeoutError,
|
|
108
|
-
)
|
|
110
|
+
),
|
|
111
|
+
attempts=MAX_RETRIES,
|
|
109
112
|
)
|
|
110
113
|
async def _iter_sse() -> AsyncGenerator[ServerSentEvent, None]:
|
|
111
114
|
nonlocal last_event_id, reconnection_delay
|
hirundo/_llm_pipeline.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
import importlib.util
|
|
2
|
+
import tempfile
|
|
3
|
+
import zipfile
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import TYPE_CHECKING, cast
|
|
6
|
+
|
|
7
|
+
from hirundo import HirundoError
|
|
8
|
+
from hirundo._http import requests
|
|
9
|
+
from hirundo._timeouts import DOWNLOAD_READ_TIMEOUT
|
|
10
|
+
from hirundo.logger import get_logger
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from torch import device as torch_device
|
|
14
|
+
from transformers.configuration_utils import PretrainedConfig
|
|
15
|
+
from transformers.modeling_utils import PreTrainedModel
|
|
16
|
+
from transformers.pipelines.base import Pipeline
|
|
17
|
+
|
|
18
|
+
from hirundo.unlearning_llm import LlmModel, LlmModelOut
|
|
19
|
+
|
|
20
|
+
logger = get_logger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
ZIP_FILE_CHUNK_SIZE = 50 * 1024 * 1024 # 50 MB
|
|
24
|
+
REQUIRED_PACKAGES_FOR_PIPELINE = ["peft", "transformers", "accelerate"]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_hf_pipeline_for_run_given_model(
|
|
28
|
+
llm: "LlmModel | LlmModelOut",
|
|
29
|
+
run_id: str,
|
|
30
|
+
config: "PretrainedConfig | None" = None,
|
|
31
|
+
device: "str | int | torch_device | None" = None,
|
|
32
|
+
device_map: str | dict[str, int | str] | None = None,
|
|
33
|
+
trust_remote_code: bool = False,
|
|
34
|
+
token: str | None = None,
|
|
35
|
+
) -> "Pipeline":
|
|
36
|
+
for package in REQUIRED_PACKAGES_FOR_PIPELINE:
|
|
37
|
+
if importlib.util.find_spec(package) is None:
|
|
38
|
+
raise HirundoError(
|
|
39
|
+
f'{package} is not installed. Please install transformers extra with pip install "hirundo[transformers]"'
|
|
40
|
+
)
|
|
41
|
+
from peft import PeftModel
|
|
42
|
+
from transformers.models.auto.configuration_auto import AutoConfig
|
|
43
|
+
from transformers.models.auto.modeling_auto import (
|
|
44
|
+
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
|
|
45
|
+
AutoModelForCausalLM,
|
|
46
|
+
AutoModelForImageTextToText,
|
|
47
|
+
)
|
|
48
|
+
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
|
49
|
+
from transformers.pipelines import pipeline
|
|
50
|
+
|
|
51
|
+
from hirundo.unlearning_llm import (
|
|
52
|
+
HuggingFaceTransformersModel,
|
|
53
|
+
HuggingFaceTransformersModelOutput,
|
|
54
|
+
LlmUnlearningRun,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
run_results = LlmUnlearningRun.check_run_by_id(run_id)
|
|
58
|
+
if run_results is None:
|
|
59
|
+
raise HirundoError("No run results found")
|
|
60
|
+
result_payload = (
|
|
61
|
+
run_results.get("result", run_results)
|
|
62
|
+
if isinstance(run_results, dict)
|
|
63
|
+
else run_results
|
|
64
|
+
)
|
|
65
|
+
if isinstance(result_payload, dict):
|
|
66
|
+
result_url = result_payload.get("result")
|
|
67
|
+
else:
|
|
68
|
+
result_url = result_payload
|
|
69
|
+
if not isinstance(result_url, str):
|
|
70
|
+
raise HirundoError("Run results did not include a download URL")
|
|
71
|
+
# Stream the zip file download
|
|
72
|
+
|
|
73
|
+
zip_file_path = tempfile.NamedTemporaryFile(delete=False).name
|
|
74
|
+
with requests.get(
|
|
75
|
+
result_url,
|
|
76
|
+
timeout=DOWNLOAD_READ_TIMEOUT,
|
|
77
|
+
stream=True,
|
|
78
|
+
) as r:
|
|
79
|
+
r.raise_for_status()
|
|
80
|
+
with open(zip_file_path, "wb") as zip_file:
|
|
81
|
+
for chunk in r.iter_content(chunk_size=ZIP_FILE_CHUNK_SIZE):
|
|
82
|
+
zip_file.write(chunk)
|
|
83
|
+
logger.info(
|
|
84
|
+
"Successfully downloaded the result zip file for run ID %s to %s",
|
|
85
|
+
run_id,
|
|
86
|
+
zip_file_path,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
90
|
+
temp_dir_path = Path(temp_dir)
|
|
91
|
+
with zipfile.ZipFile(zip_file_path, "r") as zip_file:
|
|
92
|
+
zip_file.extractall(temp_dir_path)
|
|
93
|
+
# Attempt to load the tokenizer normally
|
|
94
|
+
base_model_name = (
|
|
95
|
+
llm.model_source.model_name
|
|
96
|
+
if isinstance(
|
|
97
|
+
llm.model_source,
|
|
98
|
+
HuggingFaceTransformersModel | HuggingFaceTransformersModelOutput,
|
|
99
|
+
)
|
|
100
|
+
else llm.model_source.local_path
|
|
101
|
+
)
|
|
102
|
+
token = (
|
|
103
|
+
llm.model_source.token
|
|
104
|
+
if isinstance(
|
|
105
|
+
llm.model_source,
|
|
106
|
+
HuggingFaceTransformersModel,
|
|
107
|
+
)
|
|
108
|
+
else token
|
|
109
|
+
)
|
|
110
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
111
|
+
base_model_name,
|
|
112
|
+
token=token,
|
|
113
|
+
trust_remote_code=trust_remote_code,
|
|
114
|
+
)
|
|
115
|
+
if tokenizer.pad_token is None:
|
|
116
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
117
|
+
config = AutoConfig.from_pretrained(
|
|
118
|
+
base_model_name,
|
|
119
|
+
token=token,
|
|
120
|
+
trust_remote_code=trust_remote_code,
|
|
121
|
+
)
|
|
122
|
+
config_dict = config.to_dict() if hasattr(config, "to_dict") else config
|
|
123
|
+
is_multimodal = (
|
|
124
|
+
config_dict.get("model_type")
|
|
125
|
+
in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.keys()
|
|
126
|
+
)
|
|
127
|
+
if is_multimodal:
|
|
128
|
+
base_model = AutoModelForImageTextToText.from_pretrained(
|
|
129
|
+
base_model_name,
|
|
130
|
+
token=token,
|
|
131
|
+
trust_remote_code=trust_remote_code,
|
|
132
|
+
)
|
|
133
|
+
else:
|
|
134
|
+
base_model = AutoModelForCausalLM.from_pretrained(
|
|
135
|
+
base_model_name,
|
|
136
|
+
token=token,
|
|
137
|
+
trust_remote_code=trust_remote_code,
|
|
138
|
+
)
|
|
139
|
+
model = cast(
|
|
140
|
+
"PreTrainedModel",
|
|
141
|
+
PeftModel.from_pretrained(
|
|
142
|
+
base_model, str(temp_dir_path / "unlearned_model_folder")
|
|
143
|
+
),
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
return pipeline(
|
|
147
|
+
task="text-generation",
|
|
148
|
+
model=model,
|
|
149
|
+
tokenizer=tokenizer,
|
|
150
|
+
config=config,
|
|
151
|
+
device=device,
|
|
152
|
+
device_map=device_map,
|
|
153
|
+
)
|
hirundo/_run_checking.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections.abc import AsyncGenerator, Generator
|
|
3
|
+
from enum import Enum
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
from hirundo._iter_sse_retrying import aiter_sse_retrying, iter_sse_retrying
|
|
9
|
+
from hirundo.logger import get_logger
|
|
10
|
+
|
|
11
|
+
_logger = get_logger(__name__)
|
|
12
|
+
|
|
13
|
+
DEFAULT_MAX_RETRIES = 200
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class RunStatus(Enum):
|
|
17
|
+
PENDING = "PENDING"
|
|
18
|
+
STARTED = "STARTED"
|
|
19
|
+
SUCCESS = "SUCCESS"
|
|
20
|
+
FAILURE = "FAILURE"
|
|
21
|
+
AWAITING_MANUAL_APPROVAL = "AWAITING MANUAL APPROVAL"
|
|
22
|
+
REVOKED = "REVOKED"
|
|
23
|
+
REJECTED = "REJECTED"
|
|
24
|
+
RETRY = "RETRY"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
STATUS_TO_PROGRESS_MAP = {
|
|
28
|
+
RunStatus.STARTED.value: 0.0,
|
|
29
|
+
RunStatus.PENDING.value: 0.0,
|
|
30
|
+
RunStatus.SUCCESS.value: 100.0,
|
|
31
|
+
RunStatus.FAILURE.value: 100.0,
|
|
32
|
+
RunStatus.AWAITING_MANUAL_APPROVAL.value: 100.0,
|
|
33
|
+
RunStatus.RETRY.value: 0.0,
|
|
34
|
+
RunStatus.REVOKED.value: 100.0,
|
|
35
|
+
RunStatus.REJECTED.value: 0.0,
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def build_status_text_map(
|
|
40
|
+
run_label: str, *, started_detail: str | None = None
|
|
41
|
+
) -> dict[str, str]:
|
|
42
|
+
"""
|
|
43
|
+
Build a status->text mapping for a given run label.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
run_label: Human-readable label used in status text.
|
|
47
|
+
started_detail: Optional override for the STARTED status text.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
Mapping of run state values to user-facing status text.
|
|
51
|
+
"""
|
|
52
|
+
started_text = started_detail or f"{run_label} run in progress"
|
|
53
|
+
return {
|
|
54
|
+
RunStatus.STARTED.value: started_text,
|
|
55
|
+
RunStatus.PENDING.value: f"{run_label} run queued and not yet started",
|
|
56
|
+
RunStatus.SUCCESS.value: f"{run_label} run completed successfully",
|
|
57
|
+
RunStatus.FAILURE.value: f"{run_label} run failed",
|
|
58
|
+
RunStatus.AWAITING_MANUAL_APPROVAL.value: "Awaiting manual approval",
|
|
59
|
+
RunStatus.RETRY.value: f"{run_label} run failed. Retrying",
|
|
60
|
+
RunStatus.REVOKED.value: f"{run_label} run was cancelled",
|
|
61
|
+
RunStatus.REJECTED.value: f"{run_label} run was rejected",
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def get_state(payload: dict, status_keys: tuple[str, ...]) -> str | None:
|
|
66
|
+
"""
|
|
67
|
+
Return the first non-null state value from a payload using a list of keys.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
payload: Run payload containing state/status information.
|
|
71
|
+
status_keys: Ordered keys to search for state values.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
The first non-null state value, or None if none are present.
|
|
75
|
+
"""
|
|
76
|
+
for key in status_keys:
|
|
77
|
+
value = payload.get(key)
|
|
78
|
+
if value is not None:
|
|
79
|
+
return value
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _extract_event_data(event: dict, error_cls: type[Exception]) -> dict:
|
|
84
|
+
if "data" in event:
|
|
85
|
+
return event["data"]
|
|
86
|
+
if "detail" in event:
|
|
87
|
+
raise error_cls(event["detail"])
|
|
88
|
+
if "reason" in event:
|
|
89
|
+
raise error_cls(event["reason"])
|
|
90
|
+
raise error_cls("Unknown error")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _should_retry_after_stream(
|
|
94
|
+
last_event: dict | None,
|
|
95
|
+
status_keys: tuple[str, ...],
|
|
96
|
+
pending_state_value: str,
|
|
97
|
+
) -> bool:
|
|
98
|
+
if not last_event:
|
|
99
|
+
return True
|
|
100
|
+
data = last_event.get("data")
|
|
101
|
+
if data is None:
|
|
102
|
+
return False
|
|
103
|
+
last_state = get_state(data, status_keys)
|
|
104
|
+
return last_state == pending_state_value
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def iter_run_events(
|
|
108
|
+
url: str,
|
|
109
|
+
*,
|
|
110
|
+
headers: dict[str, str] | None = None,
|
|
111
|
+
retry: int = 0,
|
|
112
|
+
max_retries: int = DEFAULT_MAX_RETRIES,
|
|
113
|
+
pending_state_value: str = RunStatus.PENDING.value,
|
|
114
|
+
status_keys: tuple[str, ...] = ("state",),
|
|
115
|
+
error_cls: type[Exception] = RuntimeError,
|
|
116
|
+
log=_logger,
|
|
117
|
+
) -> Generator[dict, None, None]:
|
|
118
|
+
"""
|
|
119
|
+
Stream run events from an SSE endpoint with retries.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
url: SSE endpoint URL.
|
|
123
|
+
headers: Optional HTTP headers.
|
|
124
|
+
retry: Internal retry counter (do not set manually).
|
|
125
|
+
max_retries: Maximum number of retry attempts.
|
|
126
|
+
pending_state_value: State value that triggers a re-check loop.
|
|
127
|
+
status_keys: Payload keys to search for the run state.
|
|
128
|
+
error_cls: Exception type to raise on errors.
|
|
129
|
+
log: Logger instance for debug output.
|
|
130
|
+
|
|
131
|
+
Yields:
|
|
132
|
+
Event payloads decoded from the SSE data field.
|
|
133
|
+
"""
|
|
134
|
+
while True:
|
|
135
|
+
if retry > max_retries:
|
|
136
|
+
raise error_cls("Max retries reached")
|
|
137
|
+
last_event = None
|
|
138
|
+
with httpx.Client(timeout=httpx.Timeout(None, connect=5.0)) as client:
|
|
139
|
+
for sse in iter_sse_retrying(
|
|
140
|
+
client,
|
|
141
|
+
"GET",
|
|
142
|
+
url,
|
|
143
|
+
headers=headers,
|
|
144
|
+
):
|
|
145
|
+
if sse.event == "ping":
|
|
146
|
+
continue
|
|
147
|
+
log.debug(
|
|
148
|
+
"[SYNC] received event: %s with data: %s and ID: %s and retry: %s",
|
|
149
|
+
sse.event,
|
|
150
|
+
sse.data,
|
|
151
|
+
sse.id,
|
|
152
|
+
sse.retry,
|
|
153
|
+
)
|
|
154
|
+
last_event = json.loads(sse.data)
|
|
155
|
+
if not last_event:
|
|
156
|
+
continue
|
|
157
|
+
data = _extract_event_data(last_event, error_cls)
|
|
158
|
+
yield data
|
|
159
|
+
if _should_retry_after_stream(last_event, status_keys, pending_state_value):
|
|
160
|
+
retry += 1
|
|
161
|
+
continue
|
|
162
|
+
return
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
async def aiter_run_events(
|
|
166
|
+
url: str,
|
|
167
|
+
*,
|
|
168
|
+
headers: dict[str, str] | None = None,
|
|
169
|
+
retry: int = 0,
|
|
170
|
+
max_retries: int = DEFAULT_MAX_RETRIES,
|
|
171
|
+
pending_state_value: str = RunStatus.PENDING.value,
|
|
172
|
+
status_keys: tuple[str, ...] = ("state",),
|
|
173
|
+
error_cls: type[Exception] = RuntimeError,
|
|
174
|
+
log=_logger,
|
|
175
|
+
) -> AsyncGenerator[dict, None]:
|
|
176
|
+
"""
|
|
177
|
+
Async stream run events from an SSE endpoint with retries.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
url: SSE endpoint URL.
|
|
181
|
+
headers: Optional HTTP headers.
|
|
182
|
+
retry: Internal retry counter (do not set manually).
|
|
183
|
+
max_retries: Maximum number of retry attempts.
|
|
184
|
+
pending_state_value: State value that triggers a re-check loop.
|
|
185
|
+
status_keys: Payload keys to search for the run state.
|
|
186
|
+
error_cls: Exception type to raise on errors.
|
|
187
|
+
log: Logger instance for debug output.
|
|
188
|
+
|
|
189
|
+
Yields:
|
|
190
|
+
Event payloads decoded from the SSE data field.
|
|
191
|
+
"""
|
|
192
|
+
while True:
|
|
193
|
+
if retry > max_retries:
|
|
194
|
+
raise error_cls("Max retries reached")
|
|
195
|
+
last_event = None
|
|
196
|
+
async with httpx.AsyncClient(
|
|
197
|
+
timeout=httpx.Timeout(None, connect=5.0)
|
|
198
|
+
) as client:
|
|
199
|
+
async_iterator = await aiter_sse_retrying(
|
|
200
|
+
client,
|
|
201
|
+
"GET",
|
|
202
|
+
url,
|
|
203
|
+
headers=headers or {},
|
|
204
|
+
)
|
|
205
|
+
async for sse in async_iterator:
|
|
206
|
+
if sse.event == "ping":
|
|
207
|
+
continue
|
|
208
|
+
log.debug(
|
|
209
|
+
"[ASYNC] Received event: %s with data: %s and ID: %s and retry: %s",
|
|
210
|
+
sse.event,
|
|
211
|
+
sse.data,
|
|
212
|
+
sse.id,
|
|
213
|
+
sse.retry,
|
|
214
|
+
)
|
|
215
|
+
last_event = json.loads(sse.data)
|
|
216
|
+
data = _extract_event_data(last_event, error_cls)
|
|
217
|
+
yield data
|
|
218
|
+
if _should_retry_after_stream(last_event, status_keys, pending_state_value):
|
|
219
|
+
retry += 1
|
|
220
|
+
continue
|
|
221
|
+
return
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def update_progress_from_result(
|
|
225
|
+
iteration: dict,
|
|
226
|
+
progress: tqdm,
|
|
227
|
+
*,
|
|
228
|
+
uploading_text: str,
|
|
229
|
+
log=_logger,
|
|
230
|
+
) -> bool:
|
|
231
|
+
"""
|
|
232
|
+
Update a tqdm progress bar based on a serialized progress result string.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
iteration: Payload containing a nested result string.
|
|
236
|
+
progress: tqdm instance to update.
|
|
237
|
+
uploading_text: Description to show when progress reaches 100%.
|
|
238
|
+
log: Logger instance for debug output.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
True if a progress update occurred, False otherwise.
|
|
242
|
+
"""
|
|
243
|
+
if (
|
|
244
|
+
iteration.get("result")
|
|
245
|
+
and isinstance(iteration["result"], dict)
|
|
246
|
+
and iteration["result"].get("result")
|
|
247
|
+
and isinstance(iteration["result"]["result"], str)
|
|
248
|
+
):
|
|
249
|
+
result_info = iteration["result"]["result"].split(":")
|
|
250
|
+
if len(result_info) > 1:
|
|
251
|
+
stage = result_info[0]
|
|
252
|
+
current_progress_percentage = float(
|
|
253
|
+
result_info[1].removeprefix(" ").removesuffix("% done")
|
|
254
|
+
)
|
|
255
|
+
elif len(result_info) == 1:
|
|
256
|
+
stage = result_info[0]
|
|
257
|
+
current_progress_percentage = progress.n
|
|
258
|
+
else:
|
|
259
|
+
stage = "Unknown progress state"
|
|
260
|
+
current_progress_percentage = progress.n
|
|
261
|
+
desc = uploading_text if current_progress_percentage == 100.0 else stage
|
|
262
|
+
progress.set_description(desc)
|
|
263
|
+
progress.n = current_progress_percentage
|
|
264
|
+
log.debug("Setting progress to %s", progress.n)
|
|
265
|
+
progress.refresh()
|
|
266
|
+
return True
|
|
267
|
+
return False
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def handle_run_failure(
|
|
271
|
+
iteration: dict, *, error_cls: type[Exception], run_label: str
|
|
272
|
+
) -> None:
|
|
273
|
+
"""
|
|
274
|
+
Raise a run-specific failure exception based on the iteration payload.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
iteration: Payload containing error details.
|
|
278
|
+
error_cls: Exception type to raise.
|
|
279
|
+
run_label: Human-readable label for the run type.
|
|
280
|
+
"""
|
|
281
|
+
if iteration.get("result"):
|
|
282
|
+
raise error_cls(f"{run_label} run failed with error: {iteration['result']}")
|
|
283
|
+
raise error_cls(f"{run_label} run failed with an unknown error")
|
hirundo/_urls.py
CHANGED
hirundo/cli.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import re
|
|
3
3
|
import sys
|
|
4
|
-
import typing
|
|
5
4
|
from pathlib import Path
|
|
6
5
|
from typing import Annotated
|
|
7
6
|
from urllib.parse import urlparse
|
|
@@ -28,9 +27,7 @@ app = typer.Typer(
|
|
|
28
27
|
)
|
|
29
28
|
|
|
30
29
|
|
|
31
|
-
def _upsert_env(
|
|
32
|
-
dotenv_filepath: typing.Union[str, Path], var_name: str, var_value: str
|
|
33
|
-
):
|
|
30
|
+
def _upsert_env(dotenv_filepath: str | Path, var_name: str, var_value: str):
|
|
34
31
|
"""
|
|
35
32
|
Change an environment variable in the .env file.
|
|
36
33
|
If the variable does not exist, it will be added.
|
hirundo/dataset_enum.py
CHANGED
|
@@ -24,6 +24,7 @@ class DatasetMetadataType(str, Enum):
|
|
|
24
24
|
HIRUNDO_CSV = "HirundoCSV"
|
|
25
25
|
COCO = "COCO"
|
|
26
26
|
YOLO = "YOLO"
|
|
27
|
+
HuggingFaceAudio = "HuggingFaceAudio"
|
|
27
28
|
KeylabsObjDetImages = "KeylabsObjDetImages"
|
|
28
29
|
KeylabsObjDetVideo = "KeylabsObjDetVideo"
|
|
29
30
|
KeylabsObjSegImages = "KeylabsObjSegImages"
|
|
@@ -44,3 +45,4 @@ class StorageTypes(str, Enum):
|
|
|
44
45
|
"""
|
|
45
46
|
Local storage config is only supported for on-premises installations.
|
|
46
47
|
"""
|
|
48
|
+
HUGGINGFACE = "HuggingFace"
|