hirundo 0.1.21__py3-none-any.whl → 0.2.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,599 @@
1
+ import datetime
2
+ import typing
3
+ from collections.abc import AsyncGenerator, Generator
4
+ from enum import Enum
5
+ from typing import TYPE_CHECKING, Literal, overload
6
+
7
+ from pydantic import BaseModel, ConfigDict
8
+ from tqdm import tqdm
9
+ from tqdm.contrib.logging import logging_redirect_tqdm
10
+
11
+ from hirundo._env import API_HOST
12
+ from hirundo._headers import get_headers
13
+ from hirundo._http import raise_for_status_with_reason, requests
14
+ from hirundo._llm_pipeline import get_hf_pipeline_for_run_given_model
15
+ from hirundo._run_checking import (
16
+ STATUS_TO_PROGRESS_MAP,
17
+ RunStatus,
18
+ aiter_run_events,
19
+ build_status_text_map,
20
+ get_state,
21
+ handle_run_failure,
22
+ iter_run_events,
23
+ update_progress_from_result,
24
+ )
25
+ from hirundo._timeouts import MODIFY_TIMEOUT, READ_TIMEOUT
26
+ from hirundo.dataset_qa import HirundoError
27
+ from hirundo.logger import get_logger
28
+
29
+ if TYPE_CHECKING:
30
+ from torch import device as torch_device
31
+ from transformers.configuration_utils import PretrainedConfig
32
+ from transformers.pipelines.base import Pipeline
33
+
34
+ logger = get_logger(__name__)
35
+
36
+
37
+ class ModelSourceType(str, Enum):
38
+ HUGGINGFACE_TRANSFORMERS = "huggingface_transformers"
39
+ LOCAL_TRANSFORMERS = "local_transformers"
40
+
41
+
42
+ class HuggingFaceTransformersModel(BaseModel):
43
+ model_config = ConfigDict(protected_namespaces=("model_validate", "model_dump"))
44
+
45
+ type: Literal[ModelSourceType.HUGGINGFACE_TRANSFORMERS] = (
46
+ ModelSourceType.HUGGINGFACE_TRANSFORMERS
47
+ )
48
+ revision: str | None = None
49
+ code_revision: str | None = None
50
+ model_name: str
51
+ token: str | None = None
52
+
53
+
54
+ class HuggingFaceTransformersModelOutput(BaseModel):
55
+ model_config = ConfigDict(protected_namespaces=("model_validate", "model_dump"))
56
+
57
+ type: Literal[ModelSourceType.HUGGINGFACE_TRANSFORMERS] = (
58
+ ModelSourceType.HUGGINGFACE_TRANSFORMERS
59
+ )
60
+ model_name: str
61
+
62
+
63
+ class LocalTransformersModel(BaseModel):
64
+ type: Literal[ModelSourceType.LOCAL_TRANSFORMERS] = (
65
+ ModelSourceType.LOCAL_TRANSFORMERS
66
+ )
67
+ revision: None = None
68
+ code_revision: None = None
69
+ local_path: str
70
+
71
+
72
+ LlmSources = HuggingFaceTransformersModel | LocalTransformersModel
73
+ LlmSourcesOutput = HuggingFaceTransformersModelOutput | LocalTransformersModel
74
+
75
+
76
+ class LlmModel(BaseModel):
77
+ model_config = ConfigDict(protected_namespaces=("model_validate", "model_dump"))
78
+
79
+ id: int | None = None
80
+ organization_id: int | None = None
81
+ model_name: str
82
+ model_source: LlmSources
83
+ archive_existing_runs: bool = True
84
+
85
+ def create(
86
+ self,
87
+ replace_if_exists: bool = False,
88
+ ) -> int:
89
+ llm_model_response = requests.post(
90
+ f"{API_HOST}/unlearning-llm/llm/",
91
+ json={
92
+ **self.model_dump(mode="json"),
93
+ "replace_if_exists": replace_if_exists,
94
+ },
95
+ headers=get_headers(),
96
+ timeout=MODIFY_TIMEOUT,
97
+ )
98
+ raise_for_status_with_reason(llm_model_response)
99
+ llm_model_id = llm_model_response.json()["id"]
100
+ self.id = llm_model_id
101
+ return llm_model_id
102
+
103
+ @staticmethod
104
+ def get_by_id(llm_model_id: int) -> "LlmModelOut":
105
+ llm_model_response = requests.get(
106
+ f"{API_HOST}/unlearning-llm/llm/{llm_model_id}",
107
+ headers=get_headers(),
108
+ timeout=READ_TIMEOUT,
109
+ )
110
+ raise_for_status_with_reason(llm_model_response)
111
+ return LlmModelOut.model_validate(llm_model_response.json())
112
+
113
+ @staticmethod
114
+ def get_by_name(llm_model_name: str) -> "LlmModelOut":
115
+ llm_model_response = requests.get(
116
+ f"{API_HOST}/unlearning-llm/llm/by-name/{llm_model_name}",
117
+ headers=get_headers(),
118
+ timeout=READ_TIMEOUT,
119
+ )
120
+ raise_for_status_with_reason(llm_model_response)
121
+ return LlmModelOut.model_validate(llm_model_response.json())
122
+
123
+ @staticmethod
124
+ def list(organization_id: int | None = None) -> list["LlmModelOut"]:
125
+ params = {}
126
+ if organization_id is not None:
127
+ params["model_organization_id"] = organization_id
128
+ llm_model_response = requests.get(
129
+ f"{API_HOST}/unlearning-llm/llm/",
130
+ params=params,
131
+ headers=get_headers(),
132
+ timeout=READ_TIMEOUT,
133
+ )
134
+ raise_for_status_with_reason(llm_model_response)
135
+ llm_model_json = llm_model_response.json()
136
+ return [LlmModelOut.model_validate(llm_model) for llm_model in llm_model_json]
137
+
138
+ @staticmethod
139
+ def delete_by_id(llm_model_id: int) -> None:
140
+ llm_model_response = requests.delete(
141
+ f"{API_HOST}/unlearning-llm/llm/{llm_model_id}",
142
+ headers=get_headers(),
143
+ timeout=MODIFY_TIMEOUT,
144
+ )
145
+ raise_for_status_with_reason(llm_model_response)
146
+ logger.info("Deleted LLM model with ID: %s", llm_model_id)
147
+
148
+ def delete(self) -> None:
149
+ if not self.id:
150
+ raise ValueError("No LLM model has been created")
151
+ self.delete_by_id(self.id)
152
+
153
+ def update(
154
+ self,
155
+ model_name: str | None = None,
156
+ model_source: LlmSources | None = None,
157
+ archive_existing_runs: bool | None = None,
158
+ ) -> None:
159
+ if not self.id:
160
+ raise ValueError("No LLM model has been created")
161
+ payload: dict[str, typing.Any] = {
162
+ "model_name": model_name,
163
+ "model_source": model_source.model_dump(mode="json")
164
+ if model_source
165
+ else None,
166
+ "archive_existing_runs": archive_existing_runs,
167
+ "organization_id": self.organization_id,
168
+ }
169
+ llm_model_response = requests.put(
170
+ f"{API_HOST}/unlearning-llm/llm/{self.id}",
171
+ json=payload,
172
+ headers=get_headers(),
173
+ timeout=MODIFY_TIMEOUT,
174
+ )
175
+ raise_for_status_with_reason(llm_model_response)
176
+ if model_name is not None:
177
+ self.model_name = model_name
178
+ if model_source is not None:
179
+ self.model_source = model_source
180
+ if archive_existing_runs is not None:
181
+ self.archive_existing_runs = archive_existing_runs
182
+
183
+ def get_hf_pipeline_for_run(
184
+ self,
185
+ run_id: str,
186
+ config: "PretrainedConfig | None" = None,
187
+ device: "str | int | torch_device | None" = None,
188
+ device_map: str | dict[str, int | str] | None = None,
189
+ trust_remote_code: bool = False,
190
+ ) -> "Pipeline":
191
+ return get_hf_pipeline_for_run_given_model(
192
+ self, run_id, config, device, device_map, trust_remote_code
193
+ )
194
+
195
+
196
+ class LlmModelOut(BaseModel):
197
+ model_config = ConfigDict(protected_namespaces=("model_validate", "model_dump"))
198
+
199
+ id: int
200
+ organization_id: int
201
+ creator_id: int
202
+ creator_name: str
203
+ created_at: datetime.datetime
204
+ updated_at: datetime.datetime
205
+ model_name: str
206
+ model_source: LlmSourcesOutput
207
+
208
+ def get_hf_pipeline_for_run(
209
+ self,
210
+ run_id: str,
211
+ config: "PretrainedConfig | None" = None,
212
+ device: "str | int | torch_device | None" = None,
213
+ device_map: str | dict[str, int | str] | None = None,
214
+ trust_remote_code: bool = False,
215
+ token: str | None = None,
216
+ ) -> "Pipeline":
217
+ return get_hf_pipeline_for_run_given_model(
218
+ self,
219
+ run_id,
220
+ config,
221
+ device,
222
+ device_map,
223
+ trust_remote_code,
224
+ token=token,
225
+ )
226
+
227
+
228
+ class DatasetType(str, Enum):
229
+ NORMAL = "normal"
230
+ BIAS = "bias"
231
+ UNBIAS = "unbias"
232
+
233
+
234
+ class UnlearningLlmAdvancedOptions(BaseModel):
235
+ max_tokens_for_model: dict[DatasetType, int] | int | None = None
236
+
237
+
238
+ class BiasType(str, Enum):
239
+ ALL = "ALL"
240
+ RACE = "RACE"
241
+ NATIONALITY = "NATIONALITY"
242
+ GENDER = "GENDER"
243
+ PHYSICAL_APPEARANCE = "PHYSICAL_APPEARANCE"
244
+ RELIGION = "RELIGION"
245
+ AGE = "AGE"
246
+
247
+
248
+ class UtilityType(str, Enum):
249
+ DEFAULT = "DEFAULT"
250
+ CUSTOM = "CUSTOM"
251
+
252
+
253
+ class DefaultUtility(BaseModel):
254
+ utility_type: Literal[UtilityType.DEFAULT] = UtilityType.DEFAULT
255
+
256
+
257
+ class HirundoCSVDataset(BaseModel):
258
+ type: Literal["HirundoCSV"] = "HirundoCSV"
259
+ csv_url: str
260
+
261
+
262
+ class HuggingFaceDataset(BaseModel):
263
+ type: Literal["HuggingFaceDataset"] = "HuggingFaceDataset"
264
+ hugging_face_dataset_name: str
265
+
266
+
267
+ CustomDataset = HirundoCSVDataset | HuggingFaceDataset
268
+
269
+
270
+ class CustomUtility(BaseModel):
271
+ utility_type: Literal[UtilityType.CUSTOM] = UtilityType.CUSTOM
272
+ dataset: CustomDataset
273
+
274
+
275
+ class BiasBehavior(BaseModel):
276
+ type: Literal["BIAS"] = "BIAS"
277
+ bias_type: BiasType
278
+
279
+
280
+ class HallucinationType(str, Enum):
281
+ GENERAL = "GENERAL"
282
+ MEDICAL = "MEDICAL"
283
+ LEGAL = "LEGAL"
284
+ DEFENSE = "DEFENSE"
285
+
286
+
287
+ class HallucinationBehavior(BaseModel):
288
+ type: Literal["HALLUCINATION"] = "HALLUCINATION"
289
+ hallucination_type: HallucinationType
290
+
291
+
292
+ class SecurityBehavior(BaseModel):
293
+ type: Literal["SECURITY"] = "SECURITY"
294
+
295
+
296
+ class CustomBehavior(BaseModel):
297
+ type: Literal["CUSTOM"] = "CUSTOM"
298
+ biased_dataset: CustomDataset
299
+ unbiased_dataset: CustomDataset
300
+
301
+
302
+ TargetBehavior = (
303
+ BiasBehavior | HallucinationBehavior | SecurityBehavior | CustomBehavior
304
+ )
305
+
306
+ TargetUtility = DefaultUtility | CustomUtility
307
+
308
+
309
+ class LlmRunInfo(BaseModel):
310
+ model_config = ConfigDict(protected_namespaces=("model_validate", "model_dump"))
311
+
312
+ organization_id: int | None = None
313
+ name: str | None = None
314
+ target_behaviors: list[TargetBehavior]
315
+ target_utilities: list[TargetUtility]
316
+ advanced_options: UnlearningLlmAdvancedOptions | None = None
317
+
318
+
319
+ class BiasRunInfo(BaseModel):
320
+ bias_type: BiasType
321
+ organization_id: int | None = None
322
+ name: str | None = None
323
+ target_utilities: list[TargetUtility] | None = None
324
+ advanced_options: UnlearningLlmAdvancedOptions | None = None
325
+
326
+ def to_run_info(self) -> LlmRunInfo:
327
+ default_utilities: list[TargetUtility] = (
328
+ [DefaultUtility()]
329
+ if self.target_utilities is None
330
+ else list(self.target_utilities)
331
+ )
332
+ return LlmRunInfo(
333
+ organization_id=self.organization_id,
334
+ name=self.name,
335
+ target_behaviors=[BiasBehavior(bias_type=self.bias_type)],
336
+ target_utilities=default_utilities,
337
+ advanced_options=self.advanced_options,
338
+ )
339
+
340
+
341
+ OutputLlm = dict[str, object]
342
+ BehaviorOptions = TargetBehavior
343
+ UtilityOptions = TargetUtility
344
+ CeleryTaskState = str
345
+
346
+
347
+ class OutputUnlearningLlmRun(BaseModel):
348
+ model_config = ConfigDict(protected_namespaces=("model_validate", "model_dump"))
349
+
350
+ id: int
351
+ name: str
352
+ model_id: int
353
+ model: OutputLlm
354
+ target_behaviors: list[BehaviorOptions]
355
+ target_utilities: list[UtilityOptions]
356
+ advanced_options: UnlearningLlmAdvancedOptions | None
357
+ run_id: str
358
+ mlflow_run_id: str | None
359
+ status: CeleryTaskState
360
+ approved: bool
361
+ created_at: datetime.datetime
362
+ completed_at: datetime.datetime | None
363
+ pre_process_progress: float
364
+ optimization_progress: float
365
+ post_process_progress: float
366
+
367
+ deleted_at: datetime.datetime | None = None
368
+
369
+
370
+ STATUS_TO_TEXT_MAP = build_status_text_map("LLM unlearning")
371
+
372
+
373
+ class LlmUnlearningRun:
374
+ @staticmethod
375
+ def launch(model_id: int, run_info: LlmRunInfo | BiasRunInfo) -> str:
376
+ resolved_run_info = (
377
+ run_info.to_run_info() if isinstance(run_info, BiasRunInfo) else run_info
378
+ )
379
+ run_response = requests.post(
380
+ f"{API_HOST}/unlearning-llm/run/{model_id}",
381
+ json=resolved_run_info.model_dump(mode="json"),
382
+ headers=get_headers(),
383
+ timeout=MODIFY_TIMEOUT,
384
+ )
385
+ raise_for_status_with_reason(run_response)
386
+ run_response_json = run_response.json() if run_response.content else {}
387
+ if isinstance(run_response_json, str):
388
+ return run_response_json
389
+ run_id = run_response_json.get("run_id")
390
+ if not run_id:
391
+ raise ValueError("No run ID returned from launch request")
392
+ return run_id
393
+
394
+ @staticmethod
395
+ def cancel(run_id: str) -> None:
396
+ run_response = requests.patch(
397
+ f"{API_HOST}/unlearning-llm/run/cancel/{run_id}",
398
+ headers=get_headers(),
399
+ timeout=MODIFY_TIMEOUT,
400
+ )
401
+ raise_for_status_with_reason(run_response)
402
+
403
+ @staticmethod
404
+ def rename(run_id: str, new_name: str) -> None:
405
+ run_response = requests.patch(
406
+ f"{API_HOST}/unlearning-llm/run/rename/{run_id}",
407
+ json={"new_name": new_name},
408
+ headers=get_headers(),
409
+ timeout=MODIFY_TIMEOUT,
410
+ )
411
+ raise_for_status_with_reason(run_response)
412
+
413
+ @staticmethod
414
+ def archive(run_id: str) -> None:
415
+ run_response = requests.patch(
416
+ f"{API_HOST}/unlearning-llm/run/archive/{run_id}",
417
+ headers=get_headers(),
418
+ timeout=MODIFY_TIMEOUT,
419
+ )
420
+ raise_for_status_with_reason(run_response)
421
+
422
+ @staticmethod
423
+ def restore(run_id: str) -> None:
424
+ run_response = requests.patch(
425
+ f"{API_HOST}/unlearning-llm/run/restore/{run_id}",
426
+ headers=get_headers(),
427
+ timeout=MODIFY_TIMEOUT,
428
+ )
429
+ raise_for_status_with_reason(run_response)
430
+
431
+ @staticmethod
432
+ def list(
433
+ organization_id: int | None = None,
434
+ archived: bool = False,
435
+ ) -> list[OutputUnlearningLlmRun]:
436
+ params: dict[str, bool | int] = {"archived": archived}
437
+ if organization_id is not None:
438
+ params["unlearning_organization_id"] = organization_id
439
+ run_response = requests.get(
440
+ f"{API_HOST}/unlearning-llm/run/list",
441
+ params=params,
442
+ headers=get_headers(),
443
+ timeout=READ_TIMEOUT,
444
+ )
445
+ raise_for_status_with_reason(run_response)
446
+ response_json = run_response.json()
447
+ if isinstance(response_json, list):
448
+ return [
449
+ OutputUnlearningLlmRun.model_validate(run_payload)
450
+ for run_payload in response_json
451
+ ]
452
+ return [OutputUnlearningLlmRun.model_validate(response_json)]
453
+
454
+ @staticmethod
455
+ def _check_run_by_id(run_id: str, retry=0) -> Generator[dict, None, None]:
456
+ yield from iter_run_events(
457
+ f"{API_HOST}/unlearning-llm/run/{run_id}",
458
+ headers=get_headers(),
459
+ retry=retry,
460
+ status_keys=("state", "status"),
461
+ error_cls=HirundoError,
462
+ log=logger,
463
+ )
464
+
465
+ @staticmethod
466
+ @overload
467
+ def check_run_by_id(
468
+ run_id: str, stop_on_manual_approval: Literal[True]
469
+ ) -> typing.Any | None: ...
470
+
471
+ @staticmethod
472
+ @overload
473
+ def check_run_by_id(
474
+ run_id: str, stop_on_manual_approval: Literal[False] = False
475
+ ) -> typing.Any: ...
476
+
477
+ @staticmethod
478
+ @overload
479
+ def check_run_by_id(
480
+ run_id: str, stop_on_manual_approval: bool
481
+ ) -> typing.Any | None: ...
482
+
483
+ @staticmethod
484
+ def check_run_by_id(run_id: str, stop_on_manual_approval: bool = False):
485
+ """
486
+ Check the status of a run given its ID
487
+
488
+ Args:
489
+ run_id: The `run_id` produced by a `launch` call
490
+ stop_on_manual_approval: If True, the function will return `None` if the run is awaiting manual approval
491
+
492
+ Returns:
493
+ The result payload for the run, if available
494
+
495
+ Raises:
496
+ HirundoError: If the maximum number of retries is reached or if the run fails
497
+ """
498
+ logger.debug("Checking run with ID: %s", run_id)
499
+ with logging_redirect_tqdm():
500
+ t = tqdm(total=100.0)
501
+ for iteration in LlmUnlearningRun._check_run_by_id(run_id):
502
+ state = get_state(iteration, ("state", "status"))
503
+ if state in STATUS_TO_PROGRESS_MAP:
504
+ t.set_description(STATUS_TO_TEXT_MAP[state])
505
+ t.n = STATUS_TO_PROGRESS_MAP[state]
506
+ logger.debug("Setting progress to %s", t.n)
507
+ t.refresh()
508
+ if state in [
509
+ RunStatus.FAILURE.value,
510
+ RunStatus.REJECTED.value,
511
+ RunStatus.REVOKED.value,
512
+ ]:
513
+ logger.error(
514
+ "State is failure, rejected, or revoked: %s",
515
+ state,
516
+ )
517
+ t.close()
518
+ handle_run_failure(
519
+ iteration,
520
+ error_cls=HirundoError,
521
+ run_label="LLM unlearning",
522
+ )
523
+ elif state == RunStatus.SUCCESS.value:
524
+ t.close()
525
+ return iteration.get("result") or iteration
526
+ elif (
527
+ state == RunStatus.AWAITING_MANUAL_APPROVAL.value
528
+ and stop_on_manual_approval
529
+ ):
530
+ t.close()
531
+ return None
532
+ elif state is None:
533
+ update_progress_from_result(
534
+ iteration,
535
+ t,
536
+ uploading_text="LLM unlearning run completed. Uploading results",
537
+ log=logger,
538
+ )
539
+ raise HirundoError("LLM unlearning run failed with an unknown error")
540
+
541
+ @staticmethod
542
+ def check_run(run_id: str, stop_on_manual_approval: bool = False):
543
+ """
544
+ Check the status of the given run.
545
+
546
+ Returns:
547
+ The result payload for the run, if available
548
+ """
549
+ return LlmUnlearningRun.check_run_by_id(run_id, stop_on_manual_approval)
550
+
551
+ @staticmethod
552
+ async def acheck_run_by_id(run_id: str, retry=0) -> AsyncGenerator[dict, None]:
553
+ """
554
+ Async version of :func:`check_run_by_id`
555
+
556
+ Check the status of a run given its ID.
557
+
558
+ This generator will produce values to show progress of the run.
559
+
560
+ Note: This function does not handle errors nor show progress. It is expected that you do that.
561
+
562
+ Args:
563
+ run_id: The `run_id` produced by a `launch` call
564
+ retry: A number used to track the number of retries to limit re-checks. *Do not* provide this value manually.
565
+
566
+ Yields:
567
+ Each event will be a dict, where:
568
+ - `"state"` is PENDING, STARTED, RETRY, FAILURE or SUCCESS
569
+ - `"result"` is a string describing the progress as a percentage for a PENDING state, or the error for a FAILURE state or the results for a SUCCESS state
570
+
571
+ """
572
+ logger.debug("Checking run with ID: %s", run_id)
573
+ async for iteration in aiter_run_events(
574
+ f"{API_HOST}/unlearning-llm/run/{run_id}",
575
+ headers=get_headers(),
576
+ retry=retry,
577
+ status_keys=("state", "status"),
578
+ error_cls=HirundoError,
579
+ log=logger,
580
+ ):
581
+ yield iteration
582
+
583
+ @staticmethod
584
+ async def acheck_run(run_id: str) -> AsyncGenerator[dict, None]:
585
+ """
586
+ Async version of :func:`check_run`
587
+
588
+ Check the status of the given run.
589
+
590
+ This generator will produce values to show progress of the run.
591
+
592
+ Yields:
593
+ Each event will be a dict, where:
594
+ - `"state"` is PENDING, STARTED, RETRY, FAILURE or SUCCESS
595
+ - `"result"` is a string describing the progress as a percentage for a PENDING state, or the error for a FAILURE state or the results for a SUCCESS state
596
+
597
+ """
598
+ async for iteration in LlmUnlearningRun.acheck_run_by_id(run_id):
599
+ yield iteration
hirundo/unzip.py CHANGED
@@ -27,7 +27,7 @@ from hirundo.logger import get_logger
27
27
 
28
28
  ZIP_FILE_CHUNK_SIZE = 50 * 1024 * 1024 # 50 MB
29
29
 
30
- Dtype = typing.Union[type[int32], type[float32], type[string]]
30
+ Dtype = type[int32] | type[float32] | type[string]
31
31
 
32
32
 
33
33
  CUSTOMER_INTERCHANGE_DTYPES: Mapping[str, Dtype] = {
@@ -75,7 +75,7 @@ def _clean_df_index(df: "pd.DataFrame") -> "pd.DataFrame":
75
75
 
76
76
 
77
77
  def load_df(
78
- file: "typing.Union[str, IO[bytes]]",
78
+ file: "str | IO[bytes]",
79
79
  ) -> "DataFrameType":
80
80
  """
81
81
  Load a DataFrame from a CSV file.
@@ -226,7 +226,7 @@ def download_and_extract_zip(
226
226
 
227
227
  def load_from_zip(
228
228
  zip_path: Path, file_name: str
229
- ) -> "typing.Union[pd.DataFrame, pl.DataFrame, None]":
229
+ ) -> "pd.DataFrame | pl.DataFrame | None":
230
230
  """
231
231
  Load a given file from a given zip file.
232
232