hirundo 0.1.21__py3-none-any.whl → 0.2.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hirundo/__init__.py +19 -3
- hirundo/_constraints.py +2 -3
- hirundo/_iter_sse_retrying.py +7 -4
- hirundo/_llm_pipeline.py +153 -0
- hirundo/_run_checking.py +283 -0
- hirundo/_urls.py +1 -0
- hirundo/cli.py +1 -4
- hirundo/dataset_enum.py +2 -0
- hirundo/dataset_qa.py +106 -190
- hirundo/dataset_qa_results.py +3 -3
- hirundo/git.py +7 -8
- hirundo/labeling.py +22 -19
- hirundo/storage.py +25 -24
- hirundo/unlearning_llm.py +599 -0
- hirundo/unzip.py +3 -3
- {hirundo-0.1.21.dist-info → hirundo-0.2.3.post1.dist-info}/METADATA +42 -10
- hirundo-0.2.3.post1.dist-info/RECORD +28 -0
- {hirundo-0.1.21.dist-info → hirundo-0.2.3.post1.dist-info}/WHEEL +1 -1
- hirundo-0.1.21.dist-info/RECORD +0 -25
- {hirundo-0.1.21.dist-info → hirundo-0.2.3.post1.dist-info}/entry_points.txt +0 -0
- {hirundo-0.1.21.dist-info → hirundo-0.2.3.post1.dist-info}/licenses/LICENSE +0 -0
- {hirundo-0.1.21.dist-info → hirundo-0.2.3.post1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,599 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import typing
|
|
3
|
+
from collections.abc import AsyncGenerator, Generator
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import TYPE_CHECKING, Literal, overload
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, ConfigDict
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
from tqdm.contrib.logging import logging_redirect_tqdm
|
|
10
|
+
|
|
11
|
+
from hirundo._env import API_HOST
|
|
12
|
+
from hirundo._headers import get_headers
|
|
13
|
+
from hirundo._http import raise_for_status_with_reason, requests
|
|
14
|
+
from hirundo._llm_pipeline import get_hf_pipeline_for_run_given_model
|
|
15
|
+
from hirundo._run_checking import (
|
|
16
|
+
STATUS_TO_PROGRESS_MAP,
|
|
17
|
+
RunStatus,
|
|
18
|
+
aiter_run_events,
|
|
19
|
+
build_status_text_map,
|
|
20
|
+
get_state,
|
|
21
|
+
handle_run_failure,
|
|
22
|
+
iter_run_events,
|
|
23
|
+
update_progress_from_result,
|
|
24
|
+
)
|
|
25
|
+
from hirundo._timeouts import MODIFY_TIMEOUT, READ_TIMEOUT
|
|
26
|
+
from hirundo.dataset_qa import HirundoError
|
|
27
|
+
from hirundo.logger import get_logger
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from torch import device as torch_device
|
|
31
|
+
from transformers.configuration_utils import PretrainedConfig
|
|
32
|
+
from transformers.pipelines.base import Pipeline
|
|
33
|
+
|
|
34
|
+
logger = get_logger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ModelSourceType(str, Enum):
|
|
38
|
+
HUGGINGFACE_TRANSFORMERS = "huggingface_transformers"
|
|
39
|
+
LOCAL_TRANSFORMERS = "local_transformers"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class HuggingFaceTransformersModel(BaseModel):
|
|
43
|
+
model_config = ConfigDict(protected_namespaces=("model_validate", "model_dump"))
|
|
44
|
+
|
|
45
|
+
type: Literal[ModelSourceType.HUGGINGFACE_TRANSFORMERS] = (
|
|
46
|
+
ModelSourceType.HUGGINGFACE_TRANSFORMERS
|
|
47
|
+
)
|
|
48
|
+
revision: str | None = None
|
|
49
|
+
code_revision: str | None = None
|
|
50
|
+
model_name: str
|
|
51
|
+
token: str | None = None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class HuggingFaceTransformersModelOutput(BaseModel):
|
|
55
|
+
model_config = ConfigDict(protected_namespaces=("model_validate", "model_dump"))
|
|
56
|
+
|
|
57
|
+
type: Literal[ModelSourceType.HUGGINGFACE_TRANSFORMERS] = (
|
|
58
|
+
ModelSourceType.HUGGINGFACE_TRANSFORMERS
|
|
59
|
+
)
|
|
60
|
+
model_name: str
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class LocalTransformersModel(BaseModel):
|
|
64
|
+
type: Literal[ModelSourceType.LOCAL_TRANSFORMERS] = (
|
|
65
|
+
ModelSourceType.LOCAL_TRANSFORMERS
|
|
66
|
+
)
|
|
67
|
+
revision: None = None
|
|
68
|
+
code_revision: None = None
|
|
69
|
+
local_path: str
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
LlmSources = HuggingFaceTransformersModel | LocalTransformersModel
|
|
73
|
+
LlmSourcesOutput = HuggingFaceTransformersModelOutput | LocalTransformersModel
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class LlmModel(BaseModel):
|
|
77
|
+
model_config = ConfigDict(protected_namespaces=("model_validate", "model_dump"))
|
|
78
|
+
|
|
79
|
+
id: int | None = None
|
|
80
|
+
organization_id: int | None = None
|
|
81
|
+
model_name: str
|
|
82
|
+
model_source: LlmSources
|
|
83
|
+
archive_existing_runs: bool = True
|
|
84
|
+
|
|
85
|
+
def create(
|
|
86
|
+
self,
|
|
87
|
+
replace_if_exists: bool = False,
|
|
88
|
+
) -> int:
|
|
89
|
+
llm_model_response = requests.post(
|
|
90
|
+
f"{API_HOST}/unlearning-llm/llm/",
|
|
91
|
+
json={
|
|
92
|
+
**self.model_dump(mode="json"),
|
|
93
|
+
"replace_if_exists": replace_if_exists,
|
|
94
|
+
},
|
|
95
|
+
headers=get_headers(),
|
|
96
|
+
timeout=MODIFY_TIMEOUT,
|
|
97
|
+
)
|
|
98
|
+
raise_for_status_with_reason(llm_model_response)
|
|
99
|
+
llm_model_id = llm_model_response.json()["id"]
|
|
100
|
+
self.id = llm_model_id
|
|
101
|
+
return llm_model_id
|
|
102
|
+
|
|
103
|
+
@staticmethod
|
|
104
|
+
def get_by_id(llm_model_id: int) -> "LlmModelOut":
|
|
105
|
+
llm_model_response = requests.get(
|
|
106
|
+
f"{API_HOST}/unlearning-llm/llm/{llm_model_id}",
|
|
107
|
+
headers=get_headers(),
|
|
108
|
+
timeout=READ_TIMEOUT,
|
|
109
|
+
)
|
|
110
|
+
raise_for_status_with_reason(llm_model_response)
|
|
111
|
+
return LlmModelOut.model_validate(llm_model_response.json())
|
|
112
|
+
|
|
113
|
+
@staticmethod
|
|
114
|
+
def get_by_name(llm_model_name: str) -> "LlmModelOut":
|
|
115
|
+
llm_model_response = requests.get(
|
|
116
|
+
f"{API_HOST}/unlearning-llm/llm/by-name/{llm_model_name}",
|
|
117
|
+
headers=get_headers(),
|
|
118
|
+
timeout=READ_TIMEOUT,
|
|
119
|
+
)
|
|
120
|
+
raise_for_status_with_reason(llm_model_response)
|
|
121
|
+
return LlmModelOut.model_validate(llm_model_response.json())
|
|
122
|
+
|
|
123
|
+
@staticmethod
|
|
124
|
+
def list(organization_id: int | None = None) -> list["LlmModelOut"]:
|
|
125
|
+
params = {}
|
|
126
|
+
if organization_id is not None:
|
|
127
|
+
params["model_organization_id"] = organization_id
|
|
128
|
+
llm_model_response = requests.get(
|
|
129
|
+
f"{API_HOST}/unlearning-llm/llm/",
|
|
130
|
+
params=params,
|
|
131
|
+
headers=get_headers(),
|
|
132
|
+
timeout=READ_TIMEOUT,
|
|
133
|
+
)
|
|
134
|
+
raise_for_status_with_reason(llm_model_response)
|
|
135
|
+
llm_model_json = llm_model_response.json()
|
|
136
|
+
return [LlmModelOut.model_validate(llm_model) for llm_model in llm_model_json]
|
|
137
|
+
|
|
138
|
+
@staticmethod
|
|
139
|
+
def delete_by_id(llm_model_id: int) -> None:
|
|
140
|
+
llm_model_response = requests.delete(
|
|
141
|
+
f"{API_HOST}/unlearning-llm/llm/{llm_model_id}",
|
|
142
|
+
headers=get_headers(),
|
|
143
|
+
timeout=MODIFY_TIMEOUT,
|
|
144
|
+
)
|
|
145
|
+
raise_for_status_with_reason(llm_model_response)
|
|
146
|
+
logger.info("Deleted LLM model with ID: %s", llm_model_id)
|
|
147
|
+
|
|
148
|
+
def delete(self) -> None:
|
|
149
|
+
if not self.id:
|
|
150
|
+
raise ValueError("No LLM model has been created")
|
|
151
|
+
self.delete_by_id(self.id)
|
|
152
|
+
|
|
153
|
+
def update(
|
|
154
|
+
self,
|
|
155
|
+
model_name: str | None = None,
|
|
156
|
+
model_source: LlmSources | None = None,
|
|
157
|
+
archive_existing_runs: bool | None = None,
|
|
158
|
+
) -> None:
|
|
159
|
+
if not self.id:
|
|
160
|
+
raise ValueError("No LLM model has been created")
|
|
161
|
+
payload: dict[str, typing.Any] = {
|
|
162
|
+
"model_name": model_name,
|
|
163
|
+
"model_source": model_source.model_dump(mode="json")
|
|
164
|
+
if model_source
|
|
165
|
+
else None,
|
|
166
|
+
"archive_existing_runs": archive_existing_runs,
|
|
167
|
+
"organization_id": self.organization_id,
|
|
168
|
+
}
|
|
169
|
+
llm_model_response = requests.put(
|
|
170
|
+
f"{API_HOST}/unlearning-llm/llm/{self.id}",
|
|
171
|
+
json=payload,
|
|
172
|
+
headers=get_headers(),
|
|
173
|
+
timeout=MODIFY_TIMEOUT,
|
|
174
|
+
)
|
|
175
|
+
raise_for_status_with_reason(llm_model_response)
|
|
176
|
+
if model_name is not None:
|
|
177
|
+
self.model_name = model_name
|
|
178
|
+
if model_source is not None:
|
|
179
|
+
self.model_source = model_source
|
|
180
|
+
if archive_existing_runs is not None:
|
|
181
|
+
self.archive_existing_runs = archive_existing_runs
|
|
182
|
+
|
|
183
|
+
def get_hf_pipeline_for_run(
|
|
184
|
+
self,
|
|
185
|
+
run_id: str,
|
|
186
|
+
config: "PretrainedConfig | None" = None,
|
|
187
|
+
device: "str | int | torch_device | None" = None,
|
|
188
|
+
device_map: str | dict[str, int | str] | None = None,
|
|
189
|
+
trust_remote_code: bool = False,
|
|
190
|
+
) -> "Pipeline":
|
|
191
|
+
return get_hf_pipeline_for_run_given_model(
|
|
192
|
+
self, run_id, config, device, device_map, trust_remote_code
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class LlmModelOut(BaseModel):
|
|
197
|
+
model_config = ConfigDict(protected_namespaces=("model_validate", "model_dump"))
|
|
198
|
+
|
|
199
|
+
id: int
|
|
200
|
+
organization_id: int
|
|
201
|
+
creator_id: int
|
|
202
|
+
creator_name: str
|
|
203
|
+
created_at: datetime.datetime
|
|
204
|
+
updated_at: datetime.datetime
|
|
205
|
+
model_name: str
|
|
206
|
+
model_source: LlmSourcesOutput
|
|
207
|
+
|
|
208
|
+
def get_hf_pipeline_for_run(
|
|
209
|
+
self,
|
|
210
|
+
run_id: str,
|
|
211
|
+
config: "PretrainedConfig | None" = None,
|
|
212
|
+
device: "str | int | torch_device | None" = None,
|
|
213
|
+
device_map: str | dict[str, int | str] | None = None,
|
|
214
|
+
trust_remote_code: bool = False,
|
|
215
|
+
token: str | None = None,
|
|
216
|
+
) -> "Pipeline":
|
|
217
|
+
return get_hf_pipeline_for_run_given_model(
|
|
218
|
+
self,
|
|
219
|
+
run_id,
|
|
220
|
+
config,
|
|
221
|
+
device,
|
|
222
|
+
device_map,
|
|
223
|
+
trust_remote_code,
|
|
224
|
+
token=token,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
class DatasetType(str, Enum):
|
|
229
|
+
NORMAL = "normal"
|
|
230
|
+
BIAS = "bias"
|
|
231
|
+
UNBIAS = "unbias"
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class UnlearningLlmAdvancedOptions(BaseModel):
|
|
235
|
+
max_tokens_for_model: dict[DatasetType, int] | int | None = None
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class BiasType(str, Enum):
|
|
239
|
+
ALL = "ALL"
|
|
240
|
+
RACE = "RACE"
|
|
241
|
+
NATIONALITY = "NATIONALITY"
|
|
242
|
+
GENDER = "GENDER"
|
|
243
|
+
PHYSICAL_APPEARANCE = "PHYSICAL_APPEARANCE"
|
|
244
|
+
RELIGION = "RELIGION"
|
|
245
|
+
AGE = "AGE"
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class UtilityType(str, Enum):
|
|
249
|
+
DEFAULT = "DEFAULT"
|
|
250
|
+
CUSTOM = "CUSTOM"
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class DefaultUtility(BaseModel):
|
|
254
|
+
utility_type: Literal[UtilityType.DEFAULT] = UtilityType.DEFAULT
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class HirundoCSVDataset(BaseModel):
|
|
258
|
+
type: Literal["HirundoCSV"] = "HirundoCSV"
|
|
259
|
+
csv_url: str
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
class HuggingFaceDataset(BaseModel):
|
|
263
|
+
type: Literal["HuggingFaceDataset"] = "HuggingFaceDataset"
|
|
264
|
+
hugging_face_dataset_name: str
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
CustomDataset = HirundoCSVDataset | HuggingFaceDataset
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class CustomUtility(BaseModel):
|
|
271
|
+
utility_type: Literal[UtilityType.CUSTOM] = UtilityType.CUSTOM
|
|
272
|
+
dataset: CustomDataset
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
class BiasBehavior(BaseModel):
|
|
276
|
+
type: Literal["BIAS"] = "BIAS"
|
|
277
|
+
bias_type: BiasType
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
class HallucinationType(str, Enum):
|
|
281
|
+
GENERAL = "GENERAL"
|
|
282
|
+
MEDICAL = "MEDICAL"
|
|
283
|
+
LEGAL = "LEGAL"
|
|
284
|
+
DEFENSE = "DEFENSE"
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
class HallucinationBehavior(BaseModel):
|
|
288
|
+
type: Literal["HALLUCINATION"] = "HALLUCINATION"
|
|
289
|
+
hallucination_type: HallucinationType
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
class SecurityBehavior(BaseModel):
|
|
293
|
+
type: Literal["SECURITY"] = "SECURITY"
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
class CustomBehavior(BaseModel):
|
|
297
|
+
type: Literal["CUSTOM"] = "CUSTOM"
|
|
298
|
+
biased_dataset: CustomDataset
|
|
299
|
+
unbiased_dataset: CustomDataset
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
TargetBehavior = (
|
|
303
|
+
BiasBehavior | HallucinationBehavior | SecurityBehavior | CustomBehavior
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
TargetUtility = DefaultUtility | CustomUtility
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
class LlmRunInfo(BaseModel):
|
|
310
|
+
model_config = ConfigDict(protected_namespaces=("model_validate", "model_dump"))
|
|
311
|
+
|
|
312
|
+
organization_id: int | None = None
|
|
313
|
+
name: str | None = None
|
|
314
|
+
target_behaviors: list[TargetBehavior]
|
|
315
|
+
target_utilities: list[TargetUtility]
|
|
316
|
+
advanced_options: UnlearningLlmAdvancedOptions | None = None
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
class BiasRunInfo(BaseModel):
|
|
320
|
+
bias_type: BiasType
|
|
321
|
+
organization_id: int | None = None
|
|
322
|
+
name: str | None = None
|
|
323
|
+
target_utilities: list[TargetUtility] | None = None
|
|
324
|
+
advanced_options: UnlearningLlmAdvancedOptions | None = None
|
|
325
|
+
|
|
326
|
+
def to_run_info(self) -> LlmRunInfo:
|
|
327
|
+
default_utilities: list[TargetUtility] = (
|
|
328
|
+
[DefaultUtility()]
|
|
329
|
+
if self.target_utilities is None
|
|
330
|
+
else list(self.target_utilities)
|
|
331
|
+
)
|
|
332
|
+
return LlmRunInfo(
|
|
333
|
+
organization_id=self.organization_id,
|
|
334
|
+
name=self.name,
|
|
335
|
+
target_behaviors=[BiasBehavior(bias_type=self.bias_type)],
|
|
336
|
+
target_utilities=default_utilities,
|
|
337
|
+
advanced_options=self.advanced_options,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
OutputLlm = dict[str, object]
|
|
342
|
+
BehaviorOptions = TargetBehavior
|
|
343
|
+
UtilityOptions = TargetUtility
|
|
344
|
+
CeleryTaskState = str
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
class OutputUnlearningLlmRun(BaseModel):
|
|
348
|
+
model_config = ConfigDict(protected_namespaces=("model_validate", "model_dump"))
|
|
349
|
+
|
|
350
|
+
id: int
|
|
351
|
+
name: str
|
|
352
|
+
model_id: int
|
|
353
|
+
model: OutputLlm
|
|
354
|
+
target_behaviors: list[BehaviorOptions]
|
|
355
|
+
target_utilities: list[UtilityOptions]
|
|
356
|
+
advanced_options: UnlearningLlmAdvancedOptions | None
|
|
357
|
+
run_id: str
|
|
358
|
+
mlflow_run_id: str | None
|
|
359
|
+
status: CeleryTaskState
|
|
360
|
+
approved: bool
|
|
361
|
+
created_at: datetime.datetime
|
|
362
|
+
completed_at: datetime.datetime | None
|
|
363
|
+
pre_process_progress: float
|
|
364
|
+
optimization_progress: float
|
|
365
|
+
post_process_progress: float
|
|
366
|
+
|
|
367
|
+
deleted_at: datetime.datetime | None = None
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
STATUS_TO_TEXT_MAP = build_status_text_map("LLM unlearning")
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
class LlmUnlearningRun:
|
|
374
|
+
@staticmethod
|
|
375
|
+
def launch(model_id: int, run_info: LlmRunInfo | BiasRunInfo) -> str:
|
|
376
|
+
resolved_run_info = (
|
|
377
|
+
run_info.to_run_info() if isinstance(run_info, BiasRunInfo) else run_info
|
|
378
|
+
)
|
|
379
|
+
run_response = requests.post(
|
|
380
|
+
f"{API_HOST}/unlearning-llm/run/{model_id}",
|
|
381
|
+
json=resolved_run_info.model_dump(mode="json"),
|
|
382
|
+
headers=get_headers(),
|
|
383
|
+
timeout=MODIFY_TIMEOUT,
|
|
384
|
+
)
|
|
385
|
+
raise_for_status_with_reason(run_response)
|
|
386
|
+
run_response_json = run_response.json() if run_response.content else {}
|
|
387
|
+
if isinstance(run_response_json, str):
|
|
388
|
+
return run_response_json
|
|
389
|
+
run_id = run_response_json.get("run_id")
|
|
390
|
+
if not run_id:
|
|
391
|
+
raise ValueError("No run ID returned from launch request")
|
|
392
|
+
return run_id
|
|
393
|
+
|
|
394
|
+
@staticmethod
|
|
395
|
+
def cancel(run_id: str) -> None:
|
|
396
|
+
run_response = requests.patch(
|
|
397
|
+
f"{API_HOST}/unlearning-llm/run/cancel/{run_id}",
|
|
398
|
+
headers=get_headers(),
|
|
399
|
+
timeout=MODIFY_TIMEOUT,
|
|
400
|
+
)
|
|
401
|
+
raise_for_status_with_reason(run_response)
|
|
402
|
+
|
|
403
|
+
@staticmethod
|
|
404
|
+
def rename(run_id: str, new_name: str) -> None:
|
|
405
|
+
run_response = requests.patch(
|
|
406
|
+
f"{API_HOST}/unlearning-llm/run/rename/{run_id}",
|
|
407
|
+
json={"new_name": new_name},
|
|
408
|
+
headers=get_headers(),
|
|
409
|
+
timeout=MODIFY_TIMEOUT,
|
|
410
|
+
)
|
|
411
|
+
raise_for_status_with_reason(run_response)
|
|
412
|
+
|
|
413
|
+
@staticmethod
|
|
414
|
+
def archive(run_id: str) -> None:
|
|
415
|
+
run_response = requests.patch(
|
|
416
|
+
f"{API_HOST}/unlearning-llm/run/archive/{run_id}",
|
|
417
|
+
headers=get_headers(),
|
|
418
|
+
timeout=MODIFY_TIMEOUT,
|
|
419
|
+
)
|
|
420
|
+
raise_for_status_with_reason(run_response)
|
|
421
|
+
|
|
422
|
+
@staticmethod
|
|
423
|
+
def restore(run_id: str) -> None:
|
|
424
|
+
run_response = requests.patch(
|
|
425
|
+
f"{API_HOST}/unlearning-llm/run/restore/{run_id}",
|
|
426
|
+
headers=get_headers(),
|
|
427
|
+
timeout=MODIFY_TIMEOUT,
|
|
428
|
+
)
|
|
429
|
+
raise_for_status_with_reason(run_response)
|
|
430
|
+
|
|
431
|
+
@staticmethod
|
|
432
|
+
def list(
|
|
433
|
+
organization_id: int | None = None,
|
|
434
|
+
archived: bool = False,
|
|
435
|
+
) -> list[OutputUnlearningLlmRun]:
|
|
436
|
+
params: dict[str, bool | int] = {"archived": archived}
|
|
437
|
+
if organization_id is not None:
|
|
438
|
+
params["unlearning_organization_id"] = organization_id
|
|
439
|
+
run_response = requests.get(
|
|
440
|
+
f"{API_HOST}/unlearning-llm/run/list",
|
|
441
|
+
params=params,
|
|
442
|
+
headers=get_headers(),
|
|
443
|
+
timeout=READ_TIMEOUT,
|
|
444
|
+
)
|
|
445
|
+
raise_for_status_with_reason(run_response)
|
|
446
|
+
response_json = run_response.json()
|
|
447
|
+
if isinstance(response_json, list):
|
|
448
|
+
return [
|
|
449
|
+
OutputUnlearningLlmRun.model_validate(run_payload)
|
|
450
|
+
for run_payload in response_json
|
|
451
|
+
]
|
|
452
|
+
return [OutputUnlearningLlmRun.model_validate(response_json)]
|
|
453
|
+
|
|
454
|
+
@staticmethod
|
|
455
|
+
def _check_run_by_id(run_id: str, retry=0) -> Generator[dict, None, None]:
|
|
456
|
+
yield from iter_run_events(
|
|
457
|
+
f"{API_HOST}/unlearning-llm/run/{run_id}",
|
|
458
|
+
headers=get_headers(),
|
|
459
|
+
retry=retry,
|
|
460
|
+
status_keys=("state", "status"),
|
|
461
|
+
error_cls=HirundoError,
|
|
462
|
+
log=logger,
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
@staticmethod
|
|
466
|
+
@overload
|
|
467
|
+
def check_run_by_id(
|
|
468
|
+
run_id: str, stop_on_manual_approval: Literal[True]
|
|
469
|
+
) -> typing.Any | None: ...
|
|
470
|
+
|
|
471
|
+
@staticmethod
|
|
472
|
+
@overload
|
|
473
|
+
def check_run_by_id(
|
|
474
|
+
run_id: str, stop_on_manual_approval: Literal[False] = False
|
|
475
|
+
) -> typing.Any: ...
|
|
476
|
+
|
|
477
|
+
@staticmethod
|
|
478
|
+
@overload
|
|
479
|
+
def check_run_by_id(
|
|
480
|
+
run_id: str, stop_on_manual_approval: bool
|
|
481
|
+
) -> typing.Any | None: ...
|
|
482
|
+
|
|
483
|
+
@staticmethod
|
|
484
|
+
def check_run_by_id(run_id: str, stop_on_manual_approval: bool = False):
|
|
485
|
+
"""
|
|
486
|
+
Check the status of a run given its ID
|
|
487
|
+
|
|
488
|
+
Args:
|
|
489
|
+
run_id: The `run_id` produced by a `launch` call
|
|
490
|
+
stop_on_manual_approval: If True, the function will return `None` if the run is awaiting manual approval
|
|
491
|
+
|
|
492
|
+
Returns:
|
|
493
|
+
The result payload for the run, if available
|
|
494
|
+
|
|
495
|
+
Raises:
|
|
496
|
+
HirundoError: If the maximum number of retries is reached or if the run fails
|
|
497
|
+
"""
|
|
498
|
+
logger.debug("Checking run with ID: %s", run_id)
|
|
499
|
+
with logging_redirect_tqdm():
|
|
500
|
+
t = tqdm(total=100.0)
|
|
501
|
+
for iteration in LlmUnlearningRun._check_run_by_id(run_id):
|
|
502
|
+
state = get_state(iteration, ("state", "status"))
|
|
503
|
+
if state in STATUS_TO_PROGRESS_MAP:
|
|
504
|
+
t.set_description(STATUS_TO_TEXT_MAP[state])
|
|
505
|
+
t.n = STATUS_TO_PROGRESS_MAP[state]
|
|
506
|
+
logger.debug("Setting progress to %s", t.n)
|
|
507
|
+
t.refresh()
|
|
508
|
+
if state in [
|
|
509
|
+
RunStatus.FAILURE.value,
|
|
510
|
+
RunStatus.REJECTED.value,
|
|
511
|
+
RunStatus.REVOKED.value,
|
|
512
|
+
]:
|
|
513
|
+
logger.error(
|
|
514
|
+
"State is failure, rejected, or revoked: %s",
|
|
515
|
+
state,
|
|
516
|
+
)
|
|
517
|
+
t.close()
|
|
518
|
+
handle_run_failure(
|
|
519
|
+
iteration,
|
|
520
|
+
error_cls=HirundoError,
|
|
521
|
+
run_label="LLM unlearning",
|
|
522
|
+
)
|
|
523
|
+
elif state == RunStatus.SUCCESS.value:
|
|
524
|
+
t.close()
|
|
525
|
+
return iteration.get("result") or iteration
|
|
526
|
+
elif (
|
|
527
|
+
state == RunStatus.AWAITING_MANUAL_APPROVAL.value
|
|
528
|
+
and stop_on_manual_approval
|
|
529
|
+
):
|
|
530
|
+
t.close()
|
|
531
|
+
return None
|
|
532
|
+
elif state is None:
|
|
533
|
+
update_progress_from_result(
|
|
534
|
+
iteration,
|
|
535
|
+
t,
|
|
536
|
+
uploading_text="LLM unlearning run completed. Uploading results",
|
|
537
|
+
log=logger,
|
|
538
|
+
)
|
|
539
|
+
raise HirundoError("LLM unlearning run failed with an unknown error")
|
|
540
|
+
|
|
541
|
+
@staticmethod
|
|
542
|
+
def check_run(run_id: str, stop_on_manual_approval: bool = False):
|
|
543
|
+
"""
|
|
544
|
+
Check the status of the given run.
|
|
545
|
+
|
|
546
|
+
Returns:
|
|
547
|
+
The result payload for the run, if available
|
|
548
|
+
"""
|
|
549
|
+
return LlmUnlearningRun.check_run_by_id(run_id, stop_on_manual_approval)
|
|
550
|
+
|
|
551
|
+
@staticmethod
|
|
552
|
+
async def acheck_run_by_id(run_id: str, retry=0) -> AsyncGenerator[dict, None]:
|
|
553
|
+
"""
|
|
554
|
+
Async version of :func:`check_run_by_id`
|
|
555
|
+
|
|
556
|
+
Check the status of a run given its ID.
|
|
557
|
+
|
|
558
|
+
This generator will produce values to show progress of the run.
|
|
559
|
+
|
|
560
|
+
Note: This function does not handle errors nor show progress. It is expected that you do that.
|
|
561
|
+
|
|
562
|
+
Args:
|
|
563
|
+
run_id: The `run_id` produced by a `launch` call
|
|
564
|
+
retry: A number used to track the number of retries to limit re-checks. *Do not* provide this value manually.
|
|
565
|
+
|
|
566
|
+
Yields:
|
|
567
|
+
Each event will be a dict, where:
|
|
568
|
+
- `"state"` is PENDING, STARTED, RETRY, FAILURE or SUCCESS
|
|
569
|
+
- `"result"` is a string describing the progress as a percentage for a PENDING state, or the error for a FAILURE state or the results for a SUCCESS state
|
|
570
|
+
|
|
571
|
+
"""
|
|
572
|
+
logger.debug("Checking run with ID: %s", run_id)
|
|
573
|
+
async for iteration in aiter_run_events(
|
|
574
|
+
f"{API_HOST}/unlearning-llm/run/{run_id}",
|
|
575
|
+
headers=get_headers(),
|
|
576
|
+
retry=retry,
|
|
577
|
+
status_keys=("state", "status"),
|
|
578
|
+
error_cls=HirundoError,
|
|
579
|
+
log=logger,
|
|
580
|
+
):
|
|
581
|
+
yield iteration
|
|
582
|
+
|
|
583
|
+
@staticmethod
|
|
584
|
+
async def acheck_run(run_id: str) -> AsyncGenerator[dict, None]:
|
|
585
|
+
"""
|
|
586
|
+
Async version of :func:`check_run`
|
|
587
|
+
|
|
588
|
+
Check the status of the given run.
|
|
589
|
+
|
|
590
|
+
This generator will produce values to show progress of the run.
|
|
591
|
+
|
|
592
|
+
Yields:
|
|
593
|
+
Each event will be a dict, where:
|
|
594
|
+
- `"state"` is PENDING, STARTED, RETRY, FAILURE or SUCCESS
|
|
595
|
+
- `"result"` is a string describing the progress as a percentage for a PENDING state, or the error for a FAILURE state or the results for a SUCCESS state
|
|
596
|
+
|
|
597
|
+
"""
|
|
598
|
+
async for iteration in LlmUnlearningRun.acheck_run_by_id(run_id):
|
|
599
|
+
yield iteration
|
hirundo/unzip.py
CHANGED
|
@@ -27,7 +27,7 @@ from hirundo.logger import get_logger
|
|
|
27
27
|
|
|
28
28
|
ZIP_FILE_CHUNK_SIZE = 50 * 1024 * 1024 # 50 MB
|
|
29
29
|
|
|
30
|
-
Dtype =
|
|
30
|
+
Dtype = type[int32] | type[float32] | type[string]
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
CUSTOMER_INTERCHANGE_DTYPES: Mapping[str, Dtype] = {
|
|
@@ -75,7 +75,7 @@ def _clean_df_index(df: "pd.DataFrame") -> "pd.DataFrame":
|
|
|
75
75
|
|
|
76
76
|
|
|
77
77
|
def load_df(
|
|
78
|
-
file: "
|
|
78
|
+
file: "str | IO[bytes]",
|
|
79
79
|
) -> "DataFrameType":
|
|
80
80
|
"""
|
|
81
81
|
Load a DataFrame from a CSV file.
|
|
@@ -226,7 +226,7 @@ def download_and_extract_zip(
|
|
|
226
226
|
|
|
227
227
|
def load_from_zip(
|
|
228
228
|
zip_path: Path, file_name: str
|
|
229
|
-
) -> "
|
|
229
|
+
) -> "pd.DataFrame | pl.DataFrame | None":
|
|
230
230
|
"""
|
|
231
231
|
Load a given file from a given zip file.
|
|
232
232
|
|