hirundo 0.1.21__py3-none-any.whl → 0.2.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hirundo/__init__.py CHANGED
@@ -5,8 +5,8 @@ from .dataset_enum import (
5
5
  )
6
6
  from .dataset_qa import (
7
7
  ClassificationRunArgs,
8
- Domain,
9
8
  HirundoError,
9
+ ModalityType,
10
10
  ObjectDetectionRunArgs,
11
11
  QADataset,
12
12
  RunArgs,
@@ -30,6 +30,15 @@ from .storage import (
30
30
  StorageGit,
31
31
  StorageS3,
32
32
  )
33
+ from .unlearning_llm import (
34
+ BiasRunInfo,
35
+ BiasType,
36
+ HuggingFaceTransformersModel,
37
+ LlmModel,
38
+ LlmSources,
39
+ LlmUnlearningRun,
40
+ LocalTransformersModel,
41
+ )
33
42
  from .unzip import load_df, load_from_zip
34
43
 
35
44
  __all__ = [
@@ -43,7 +52,7 @@ __all__ = [
43
52
  "KeylabsObjSegImages",
44
53
  "KeylabsObjSegVideo",
45
54
  "QADataset",
46
- "Domain",
55
+ "ModalityType",
47
56
  "RunArgs",
48
57
  "ClassificationRunArgs",
49
58
  "ObjectDetectionRunArgs",
@@ -59,8 +68,15 @@ __all__ = [
59
68
  "StorageGit",
60
69
  "StorageConfig",
61
70
  "DatasetQAResults",
71
+ "BiasRunInfo",
72
+ "BiasType",
73
+ "HuggingFaceTransformersModel",
74
+ "LlmModel",
75
+ "LlmSources",
76
+ "LlmUnlearningRun",
77
+ "LocalTransformersModel",
62
78
  "load_df",
63
79
  "load_from_zip",
64
80
  ]
65
81
 
66
- __version__ = "0.1.21"
82
+ __version__ = "0.2.3.post1"
hirundo/_constraints.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import re
2
- import typing
3
2
  from typing import TYPE_CHECKING
4
3
 
5
4
  from hirundo._urls import (
@@ -135,8 +134,8 @@ def validate_labeling_type(
135
134
 
136
135
  def validate_labeling_info(
137
136
  labeling_type: "LabelingType",
138
- labeling_info: "typing.Union[LabelingInfo, list[LabelingInfo]]",
139
- storage_config: "typing.Union[StorageConfig, ResponseStorageConfig]",
137
+ labeling_info: "LabelingInfo | list[LabelingInfo]",
138
+ storage_config: "StorageConfig | ResponseStorageConfig",
140
139
  ) -> None:
141
140
  """
142
141
  Validate the labeling info for a dataset
@@ -1,6 +1,5 @@
1
1
  import asyncio
2
2
  import time
3
- import typing
4
3
  import uuid
5
4
  from collections.abc import AsyncGenerator, Generator
6
5
 
@@ -15,13 +14,15 @@ from hirundo.logger import get_logger
15
14
 
16
15
  logger = get_logger(__name__)
17
16
 
17
+ MAX_RETRIES = 50
18
+
18
19
 
19
20
  # Credit: https://github.com/florimondmanca/httpx-sse/blob/master/README.md#handling-reconnections
20
21
  def iter_sse_retrying(
21
22
  client: httpx.Client,
22
23
  method: str,
23
24
  url: str,
24
- headers: typing.Optional[dict[str, str]] = None,
25
+ headers: dict[str, str] | None = None,
25
26
  ) -> Generator[ServerSentEvent, None, None]:
26
27
  if headers is None:
27
28
  headers = {}
@@ -41,7 +42,8 @@ def iter_sse_retrying(
41
42
  httpx.ReadError,
42
43
  httpx.RemoteProtocolError,
43
44
  urllib3.exceptions.ReadTimeoutError,
44
- )
45
+ ),
46
+ attempts=MAX_RETRIES,
45
47
  )
46
48
  def _iter_sse():
47
49
  nonlocal last_event_id, reconnection_delay
@@ -105,7 +107,8 @@ async def aiter_sse_retrying(
105
107
  httpx.ReadError,
106
108
  httpx.RemoteProtocolError,
107
109
  urllib3.exceptions.ReadTimeoutError,
108
- )
110
+ ),
111
+ attempts=MAX_RETRIES,
109
112
  )
110
113
  async def _iter_sse() -> AsyncGenerator[ServerSentEvent, None]:
111
114
  nonlocal last_event_id, reconnection_delay
@@ -0,0 +1,153 @@
1
+ import importlib.util
2
+ import tempfile
3
+ import zipfile
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING, cast
6
+
7
+ from hirundo import HirundoError
8
+ from hirundo._http import requests
9
+ from hirundo._timeouts import DOWNLOAD_READ_TIMEOUT
10
+ from hirundo.logger import get_logger
11
+
12
+ if TYPE_CHECKING:
13
+ from torch import device as torch_device
14
+ from transformers.configuration_utils import PretrainedConfig
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.pipelines.base import Pipeline
17
+
18
+ from hirundo.unlearning_llm import LlmModel, LlmModelOut
19
+
20
+ logger = get_logger(__name__)
21
+
22
+
23
+ ZIP_FILE_CHUNK_SIZE = 50 * 1024 * 1024 # 50 MB
24
+ REQUIRED_PACKAGES_FOR_PIPELINE = ["peft", "transformers", "accelerate"]
25
+
26
+
27
+ def get_hf_pipeline_for_run_given_model(
28
+ llm: "LlmModel | LlmModelOut",
29
+ run_id: str,
30
+ config: "PretrainedConfig | None" = None,
31
+ device: "str | int | torch_device | None" = None,
32
+ device_map: str | dict[str, int | str] | None = None,
33
+ trust_remote_code: bool = False,
34
+ token: str | None = None,
35
+ ) -> "Pipeline":
36
+ for package in REQUIRED_PACKAGES_FOR_PIPELINE:
37
+ if importlib.util.find_spec(package) is None:
38
+ raise HirundoError(
39
+ f'{package} is not installed. Please install transformers extra with pip install "hirundo[transformers]"'
40
+ )
41
+ from peft import PeftModel
42
+ from transformers.models.auto.configuration_auto import AutoConfig
43
+ from transformers.models.auto.modeling_auto import (
44
+ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
45
+ AutoModelForCausalLM,
46
+ AutoModelForImageTextToText,
47
+ )
48
+ from transformers.models.auto.tokenization_auto import AutoTokenizer
49
+ from transformers.pipelines import pipeline
50
+
51
+ from hirundo.unlearning_llm import (
52
+ HuggingFaceTransformersModel,
53
+ HuggingFaceTransformersModelOutput,
54
+ LlmUnlearningRun,
55
+ )
56
+
57
+ run_results = LlmUnlearningRun.check_run_by_id(run_id)
58
+ if run_results is None:
59
+ raise HirundoError("No run results found")
60
+ result_payload = (
61
+ run_results.get("result", run_results)
62
+ if isinstance(run_results, dict)
63
+ else run_results
64
+ )
65
+ if isinstance(result_payload, dict):
66
+ result_url = result_payload.get("result")
67
+ else:
68
+ result_url = result_payload
69
+ if not isinstance(result_url, str):
70
+ raise HirundoError("Run results did not include a download URL")
71
+ # Stream the zip file download
72
+
73
+ zip_file_path = tempfile.NamedTemporaryFile(delete=False).name
74
+ with requests.get(
75
+ result_url,
76
+ timeout=DOWNLOAD_READ_TIMEOUT,
77
+ stream=True,
78
+ ) as r:
79
+ r.raise_for_status()
80
+ with open(zip_file_path, "wb") as zip_file:
81
+ for chunk in r.iter_content(chunk_size=ZIP_FILE_CHUNK_SIZE):
82
+ zip_file.write(chunk)
83
+ logger.info(
84
+ "Successfully downloaded the result zip file for run ID %s to %s",
85
+ run_id,
86
+ zip_file_path,
87
+ )
88
+
89
+ with tempfile.TemporaryDirectory() as temp_dir:
90
+ temp_dir_path = Path(temp_dir)
91
+ with zipfile.ZipFile(zip_file_path, "r") as zip_file:
92
+ zip_file.extractall(temp_dir_path)
93
+ # Attempt to load the tokenizer normally
94
+ base_model_name = (
95
+ llm.model_source.model_name
96
+ if isinstance(
97
+ llm.model_source,
98
+ HuggingFaceTransformersModel | HuggingFaceTransformersModelOutput,
99
+ )
100
+ else llm.model_source.local_path
101
+ )
102
+ token = (
103
+ llm.model_source.token
104
+ if isinstance(
105
+ llm.model_source,
106
+ HuggingFaceTransformersModel,
107
+ )
108
+ else token
109
+ )
110
+ tokenizer = AutoTokenizer.from_pretrained(
111
+ base_model_name,
112
+ token=token,
113
+ trust_remote_code=trust_remote_code,
114
+ )
115
+ if tokenizer.pad_token is None:
116
+ tokenizer.pad_token = tokenizer.eos_token
117
+ config = AutoConfig.from_pretrained(
118
+ base_model_name,
119
+ token=token,
120
+ trust_remote_code=trust_remote_code,
121
+ )
122
+ config_dict = config.to_dict() if hasattr(config, "to_dict") else config
123
+ is_multimodal = (
124
+ config_dict.get("model_type")
125
+ in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.keys()
126
+ )
127
+ if is_multimodal:
128
+ base_model = AutoModelForImageTextToText.from_pretrained(
129
+ base_model_name,
130
+ token=token,
131
+ trust_remote_code=trust_remote_code,
132
+ )
133
+ else:
134
+ base_model = AutoModelForCausalLM.from_pretrained(
135
+ base_model_name,
136
+ token=token,
137
+ trust_remote_code=trust_remote_code,
138
+ )
139
+ model = cast(
140
+ "PreTrainedModel",
141
+ PeftModel.from_pretrained(
142
+ base_model, str(temp_dir_path / "unlearned_model_folder")
143
+ ),
144
+ )
145
+
146
+ return pipeline(
147
+ task="text-generation",
148
+ model=model,
149
+ tokenizer=tokenizer,
150
+ config=config,
151
+ device=device,
152
+ device_map=device_map,
153
+ )
@@ -0,0 +1,283 @@
1
+ import json
2
+ from collections.abc import AsyncGenerator, Generator
3
+ from enum import Enum
4
+
5
+ import httpx
6
+ from tqdm import tqdm
7
+
8
+ from hirundo._iter_sse_retrying import aiter_sse_retrying, iter_sse_retrying
9
+ from hirundo.logger import get_logger
10
+
11
+ _logger = get_logger(__name__)
12
+
13
+ DEFAULT_MAX_RETRIES = 200
14
+
15
+
16
+ class RunStatus(Enum):
17
+ PENDING = "PENDING"
18
+ STARTED = "STARTED"
19
+ SUCCESS = "SUCCESS"
20
+ FAILURE = "FAILURE"
21
+ AWAITING_MANUAL_APPROVAL = "AWAITING MANUAL APPROVAL"
22
+ REVOKED = "REVOKED"
23
+ REJECTED = "REJECTED"
24
+ RETRY = "RETRY"
25
+
26
+
27
+ STATUS_TO_PROGRESS_MAP = {
28
+ RunStatus.STARTED.value: 0.0,
29
+ RunStatus.PENDING.value: 0.0,
30
+ RunStatus.SUCCESS.value: 100.0,
31
+ RunStatus.FAILURE.value: 100.0,
32
+ RunStatus.AWAITING_MANUAL_APPROVAL.value: 100.0,
33
+ RunStatus.RETRY.value: 0.0,
34
+ RunStatus.REVOKED.value: 100.0,
35
+ RunStatus.REJECTED.value: 0.0,
36
+ }
37
+
38
+
39
+ def build_status_text_map(
40
+ run_label: str, *, started_detail: str | None = None
41
+ ) -> dict[str, str]:
42
+ """
43
+ Build a status->text mapping for a given run label.
44
+
45
+ Args:
46
+ run_label: Human-readable label used in status text.
47
+ started_detail: Optional override for the STARTED status text.
48
+
49
+ Returns:
50
+ Mapping of run state values to user-facing status text.
51
+ """
52
+ started_text = started_detail or f"{run_label} run in progress"
53
+ return {
54
+ RunStatus.STARTED.value: started_text,
55
+ RunStatus.PENDING.value: f"{run_label} run queued and not yet started",
56
+ RunStatus.SUCCESS.value: f"{run_label} run completed successfully",
57
+ RunStatus.FAILURE.value: f"{run_label} run failed",
58
+ RunStatus.AWAITING_MANUAL_APPROVAL.value: "Awaiting manual approval",
59
+ RunStatus.RETRY.value: f"{run_label} run failed. Retrying",
60
+ RunStatus.REVOKED.value: f"{run_label} run was cancelled",
61
+ RunStatus.REJECTED.value: f"{run_label} run was rejected",
62
+ }
63
+
64
+
65
+ def get_state(payload: dict, status_keys: tuple[str, ...]) -> str | None:
66
+ """
67
+ Return the first non-null state value from a payload using a list of keys.
68
+
69
+ Args:
70
+ payload: Run payload containing state/status information.
71
+ status_keys: Ordered keys to search for state values.
72
+
73
+ Returns:
74
+ The first non-null state value, or None if none are present.
75
+ """
76
+ for key in status_keys:
77
+ value = payload.get(key)
78
+ if value is not None:
79
+ return value
80
+ return None
81
+
82
+
83
+ def _extract_event_data(event: dict, error_cls: type[Exception]) -> dict:
84
+ if "data" in event:
85
+ return event["data"]
86
+ if "detail" in event:
87
+ raise error_cls(event["detail"])
88
+ if "reason" in event:
89
+ raise error_cls(event["reason"])
90
+ raise error_cls("Unknown error")
91
+
92
+
93
+ def _should_retry_after_stream(
94
+ last_event: dict | None,
95
+ status_keys: tuple[str, ...],
96
+ pending_state_value: str,
97
+ ) -> bool:
98
+ if not last_event:
99
+ return True
100
+ data = last_event.get("data")
101
+ if data is None:
102
+ return False
103
+ last_state = get_state(data, status_keys)
104
+ return last_state == pending_state_value
105
+
106
+
107
+ def iter_run_events(
108
+ url: str,
109
+ *,
110
+ headers: dict[str, str] | None = None,
111
+ retry: int = 0,
112
+ max_retries: int = DEFAULT_MAX_RETRIES,
113
+ pending_state_value: str = RunStatus.PENDING.value,
114
+ status_keys: tuple[str, ...] = ("state",),
115
+ error_cls: type[Exception] = RuntimeError,
116
+ log=_logger,
117
+ ) -> Generator[dict, None, None]:
118
+ """
119
+ Stream run events from an SSE endpoint with retries.
120
+
121
+ Args:
122
+ url: SSE endpoint URL.
123
+ headers: Optional HTTP headers.
124
+ retry: Internal retry counter (do not set manually).
125
+ max_retries: Maximum number of retry attempts.
126
+ pending_state_value: State value that triggers a re-check loop.
127
+ status_keys: Payload keys to search for the run state.
128
+ error_cls: Exception type to raise on errors.
129
+ log: Logger instance for debug output.
130
+
131
+ Yields:
132
+ Event payloads decoded from the SSE data field.
133
+ """
134
+ while True:
135
+ if retry > max_retries:
136
+ raise error_cls("Max retries reached")
137
+ last_event = None
138
+ with httpx.Client(timeout=httpx.Timeout(None, connect=5.0)) as client:
139
+ for sse in iter_sse_retrying(
140
+ client,
141
+ "GET",
142
+ url,
143
+ headers=headers,
144
+ ):
145
+ if sse.event == "ping":
146
+ continue
147
+ log.debug(
148
+ "[SYNC] received event: %s with data: %s and ID: %s and retry: %s",
149
+ sse.event,
150
+ sse.data,
151
+ sse.id,
152
+ sse.retry,
153
+ )
154
+ last_event = json.loads(sse.data)
155
+ if not last_event:
156
+ continue
157
+ data = _extract_event_data(last_event, error_cls)
158
+ yield data
159
+ if _should_retry_after_stream(last_event, status_keys, pending_state_value):
160
+ retry += 1
161
+ continue
162
+ return
163
+
164
+
165
+ async def aiter_run_events(
166
+ url: str,
167
+ *,
168
+ headers: dict[str, str] | None = None,
169
+ retry: int = 0,
170
+ max_retries: int = DEFAULT_MAX_RETRIES,
171
+ pending_state_value: str = RunStatus.PENDING.value,
172
+ status_keys: tuple[str, ...] = ("state",),
173
+ error_cls: type[Exception] = RuntimeError,
174
+ log=_logger,
175
+ ) -> AsyncGenerator[dict, None]:
176
+ """
177
+ Async stream run events from an SSE endpoint with retries.
178
+
179
+ Args:
180
+ url: SSE endpoint URL.
181
+ headers: Optional HTTP headers.
182
+ retry: Internal retry counter (do not set manually).
183
+ max_retries: Maximum number of retry attempts.
184
+ pending_state_value: State value that triggers a re-check loop.
185
+ status_keys: Payload keys to search for the run state.
186
+ error_cls: Exception type to raise on errors.
187
+ log: Logger instance for debug output.
188
+
189
+ Yields:
190
+ Event payloads decoded from the SSE data field.
191
+ """
192
+ while True:
193
+ if retry > max_retries:
194
+ raise error_cls("Max retries reached")
195
+ last_event = None
196
+ async with httpx.AsyncClient(
197
+ timeout=httpx.Timeout(None, connect=5.0)
198
+ ) as client:
199
+ async_iterator = await aiter_sse_retrying(
200
+ client,
201
+ "GET",
202
+ url,
203
+ headers=headers or {},
204
+ )
205
+ async for sse in async_iterator:
206
+ if sse.event == "ping":
207
+ continue
208
+ log.debug(
209
+ "[ASYNC] Received event: %s with data: %s and ID: %s and retry: %s",
210
+ sse.event,
211
+ sse.data,
212
+ sse.id,
213
+ sse.retry,
214
+ )
215
+ last_event = json.loads(sse.data)
216
+ data = _extract_event_data(last_event, error_cls)
217
+ yield data
218
+ if _should_retry_after_stream(last_event, status_keys, pending_state_value):
219
+ retry += 1
220
+ continue
221
+ return
222
+
223
+
224
+ def update_progress_from_result(
225
+ iteration: dict,
226
+ progress: tqdm,
227
+ *,
228
+ uploading_text: str,
229
+ log=_logger,
230
+ ) -> bool:
231
+ """
232
+ Update a tqdm progress bar based on a serialized progress result string.
233
+
234
+ Args:
235
+ iteration: Payload containing a nested result string.
236
+ progress: tqdm instance to update.
237
+ uploading_text: Description to show when progress reaches 100%.
238
+ log: Logger instance for debug output.
239
+
240
+ Returns:
241
+ True if a progress update occurred, False otherwise.
242
+ """
243
+ if (
244
+ iteration.get("result")
245
+ and isinstance(iteration["result"], dict)
246
+ and iteration["result"].get("result")
247
+ and isinstance(iteration["result"]["result"], str)
248
+ ):
249
+ result_info = iteration["result"]["result"].split(":")
250
+ if len(result_info) > 1:
251
+ stage = result_info[0]
252
+ current_progress_percentage = float(
253
+ result_info[1].removeprefix(" ").removesuffix("% done")
254
+ )
255
+ elif len(result_info) == 1:
256
+ stage = result_info[0]
257
+ current_progress_percentage = progress.n
258
+ else:
259
+ stage = "Unknown progress state"
260
+ current_progress_percentage = progress.n
261
+ desc = uploading_text if current_progress_percentage == 100.0 else stage
262
+ progress.set_description(desc)
263
+ progress.n = current_progress_percentage
264
+ log.debug("Setting progress to %s", progress.n)
265
+ progress.refresh()
266
+ return True
267
+ return False
268
+
269
+
270
+ def handle_run_failure(
271
+ iteration: dict, *, error_cls: type[Exception], run_label: str
272
+ ) -> None:
273
+ """
274
+ Raise a run-specific failure exception based on the iteration payload.
275
+
276
+ Args:
277
+ iteration: Payload containing error details.
278
+ error_cls: Exception type to raise.
279
+ run_label: Human-readable label for the run type.
280
+ """
281
+ if iteration.get("result"):
282
+ raise error_cls(f"{run_label} run failed with error: {iteration['result']}")
283
+ raise error_cls(f"{run_label} run failed with an unknown error")
hirundo/_urls.py CHANGED
@@ -54,6 +54,7 @@ HirundoUrl = Annotated[
54
54
  "s3",
55
55
  "gs",
56
56
  "ssh",
57
+ "hf",
57
58
  ]
58
59
  ),
59
60
  ]
hirundo/cli.py CHANGED
@@ -1,7 +1,6 @@
1
1
  import os
2
2
  import re
3
3
  import sys
4
- import typing
5
4
  from pathlib import Path
6
5
  from typing import Annotated
7
6
  from urllib.parse import urlparse
@@ -28,9 +27,7 @@ app = typer.Typer(
28
27
  )
29
28
 
30
29
 
31
- def _upsert_env(
32
- dotenv_filepath: typing.Union[str, Path], var_name: str, var_value: str
33
- ):
30
+ def _upsert_env(dotenv_filepath: str | Path, var_name: str, var_value: str):
34
31
  """
35
32
  Change an environment variable in the .env file.
36
33
  If the variable does not exist, it will be added.
hirundo/dataset_enum.py CHANGED
@@ -24,6 +24,7 @@ class DatasetMetadataType(str, Enum):
24
24
  HIRUNDO_CSV = "HirundoCSV"
25
25
  COCO = "COCO"
26
26
  YOLO = "YOLO"
27
+ HuggingFaceAudio = "HuggingFaceAudio"
27
28
  KeylabsObjDetImages = "KeylabsObjDetImages"
28
29
  KeylabsObjDetVideo = "KeylabsObjDetVideo"
29
30
  KeylabsObjSegImages = "KeylabsObjSegImages"
@@ -44,3 +45,4 @@ class StorageTypes(str, Enum):
44
45
  """
45
46
  Local storage config is only supported for on-premises installations.
46
47
  """
48
+ HUGGINGFACE = "HuggingFace"