hirundo 0.1.18__py3-none-any.whl → 0.2.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hirundo/__init__.py +28 -8
- hirundo/_constraints.py +3 -4
- hirundo/_headers.py +1 -1
- hirundo/_http.py +53 -0
- hirundo/_iter_sse_retrying.py +8 -5
- hirundo/_llm_pipeline.py +153 -0
- hirundo/_run_checking.py +283 -0
- hirundo/_urls.py +1 -0
- hirundo/cli.py +8 -11
- hirundo/dataset_enum.py +2 -0
- hirundo/{dataset_optimization.py → dataset_qa.py} +213 -256
- hirundo/{dataset_optimization_results.py → dataset_qa_results.py} +7 -7
- hirundo/git.py +8 -10
- hirundo/labeling.py +22 -19
- hirundo/storage.py +26 -26
- hirundo/unlearning_llm.py +599 -0
- hirundo/unzip.py +12 -13
- {hirundo-0.1.18.dist-info → hirundo-0.2.3.post1.dist-info}/METADATA +59 -20
- hirundo-0.2.3.post1.dist-info/RECORD +28 -0
- {hirundo-0.1.18.dist-info → hirundo-0.2.3.post1.dist-info}/WHEEL +1 -1
- hirundo-0.1.18.dist-info/RECORD +0 -25
- {hirundo-0.1.18.dist-info → hirundo-0.2.3.post1.dist-info}/entry_points.txt +0 -0
- {hirundo-0.1.18.dist-info → hirundo-0.2.3.post1.dist-info}/licenses/LICENSE +0 -0
- {hirundo-0.1.18.dist-info → hirundo-0.2.3.post1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,599 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import typing
|
|
3
|
+
from collections.abc import AsyncGenerator, Generator
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import TYPE_CHECKING, Literal, overload
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, ConfigDict
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
from tqdm.contrib.logging import logging_redirect_tqdm
|
|
10
|
+
|
|
11
|
+
from hirundo._env import API_HOST
|
|
12
|
+
from hirundo._headers import get_headers
|
|
13
|
+
from hirundo._http import raise_for_status_with_reason, requests
|
|
14
|
+
from hirundo._llm_pipeline import get_hf_pipeline_for_run_given_model
|
|
15
|
+
from hirundo._run_checking import (
|
|
16
|
+
STATUS_TO_PROGRESS_MAP,
|
|
17
|
+
RunStatus,
|
|
18
|
+
aiter_run_events,
|
|
19
|
+
build_status_text_map,
|
|
20
|
+
get_state,
|
|
21
|
+
handle_run_failure,
|
|
22
|
+
iter_run_events,
|
|
23
|
+
update_progress_from_result,
|
|
24
|
+
)
|
|
25
|
+
from hirundo._timeouts import MODIFY_TIMEOUT, READ_TIMEOUT
|
|
26
|
+
from hirundo.dataset_qa import HirundoError
|
|
27
|
+
from hirundo.logger import get_logger
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from torch import device as torch_device
|
|
31
|
+
from transformers.configuration_utils import PretrainedConfig
|
|
32
|
+
from transformers.pipelines.base import Pipeline
|
|
33
|
+
|
|
34
|
+
logger = get_logger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ModelSourceType(str, Enum):
|
|
38
|
+
HUGGINGFACE_TRANSFORMERS = "huggingface_transformers"
|
|
39
|
+
LOCAL_TRANSFORMERS = "local_transformers"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class HuggingFaceTransformersModel(BaseModel):
|
|
43
|
+
model_config = ConfigDict(protected_namespaces=("model_validate", "model_dump"))
|
|
44
|
+
|
|
45
|
+
type: Literal[ModelSourceType.HUGGINGFACE_TRANSFORMERS] = (
|
|
46
|
+
ModelSourceType.HUGGINGFACE_TRANSFORMERS
|
|
47
|
+
)
|
|
48
|
+
revision: str | None = None
|
|
49
|
+
code_revision: str | None = None
|
|
50
|
+
model_name: str
|
|
51
|
+
token: str | None = None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class HuggingFaceTransformersModelOutput(BaseModel):
|
|
55
|
+
model_config = ConfigDict(protected_namespaces=("model_validate", "model_dump"))
|
|
56
|
+
|
|
57
|
+
type: Literal[ModelSourceType.HUGGINGFACE_TRANSFORMERS] = (
|
|
58
|
+
ModelSourceType.HUGGINGFACE_TRANSFORMERS
|
|
59
|
+
)
|
|
60
|
+
model_name: str
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class LocalTransformersModel(BaseModel):
|
|
64
|
+
type: Literal[ModelSourceType.LOCAL_TRANSFORMERS] = (
|
|
65
|
+
ModelSourceType.LOCAL_TRANSFORMERS
|
|
66
|
+
)
|
|
67
|
+
revision: None = None
|
|
68
|
+
code_revision: None = None
|
|
69
|
+
local_path: str
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
LlmSources = HuggingFaceTransformersModel | LocalTransformersModel
|
|
73
|
+
LlmSourcesOutput = HuggingFaceTransformersModelOutput | LocalTransformersModel
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class LlmModel(BaseModel):
|
|
77
|
+
model_config = ConfigDict(protected_namespaces=("model_validate", "model_dump"))
|
|
78
|
+
|
|
79
|
+
id: int | None = None
|
|
80
|
+
organization_id: int | None = None
|
|
81
|
+
model_name: str
|
|
82
|
+
model_source: LlmSources
|
|
83
|
+
archive_existing_runs: bool = True
|
|
84
|
+
|
|
85
|
+
def create(
|
|
86
|
+
self,
|
|
87
|
+
replace_if_exists: bool = False,
|
|
88
|
+
) -> int:
|
|
89
|
+
llm_model_response = requests.post(
|
|
90
|
+
f"{API_HOST}/unlearning-llm/llm/",
|
|
91
|
+
json={
|
|
92
|
+
**self.model_dump(mode="json"),
|
|
93
|
+
"replace_if_exists": replace_if_exists,
|
|
94
|
+
},
|
|
95
|
+
headers=get_headers(),
|
|
96
|
+
timeout=MODIFY_TIMEOUT,
|
|
97
|
+
)
|
|
98
|
+
raise_for_status_with_reason(llm_model_response)
|
|
99
|
+
llm_model_id = llm_model_response.json()["id"]
|
|
100
|
+
self.id = llm_model_id
|
|
101
|
+
return llm_model_id
|
|
102
|
+
|
|
103
|
+
@staticmethod
|
|
104
|
+
def get_by_id(llm_model_id: int) -> "LlmModelOut":
|
|
105
|
+
llm_model_response = requests.get(
|
|
106
|
+
f"{API_HOST}/unlearning-llm/llm/{llm_model_id}",
|
|
107
|
+
headers=get_headers(),
|
|
108
|
+
timeout=READ_TIMEOUT,
|
|
109
|
+
)
|
|
110
|
+
raise_for_status_with_reason(llm_model_response)
|
|
111
|
+
return LlmModelOut.model_validate(llm_model_response.json())
|
|
112
|
+
|
|
113
|
+
@staticmethod
|
|
114
|
+
def get_by_name(llm_model_name: str) -> "LlmModelOut":
|
|
115
|
+
llm_model_response = requests.get(
|
|
116
|
+
f"{API_HOST}/unlearning-llm/llm/by-name/{llm_model_name}",
|
|
117
|
+
headers=get_headers(),
|
|
118
|
+
timeout=READ_TIMEOUT,
|
|
119
|
+
)
|
|
120
|
+
raise_for_status_with_reason(llm_model_response)
|
|
121
|
+
return LlmModelOut.model_validate(llm_model_response.json())
|
|
122
|
+
|
|
123
|
+
@staticmethod
|
|
124
|
+
def list(organization_id: int | None = None) -> list["LlmModelOut"]:
|
|
125
|
+
params = {}
|
|
126
|
+
if organization_id is not None:
|
|
127
|
+
params["model_organization_id"] = organization_id
|
|
128
|
+
llm_model_response = requests.get(
|
|
129
|
+
f"{API_HOST}/unlearning-llm/llm/",
|
|
130
|
+
params=params,
|
|
131
|
+
headers=get_headers(),
|
|
132
|
+
timeout=READ_TIMEOUT,
|
|
133
|
+
)
|
|
134
|
+
raise_for_status_with_reason(llm_model_response)
|
|
135
|
+
llm_model_json = llm_model_response.json()
|
|
136
|
+
return [LlmModelOut.model_validate(llm_model) for llm_model in llm_model_json]
|
|
137
|
+
|
|
138
|
+
@staticmethod
|
|
139
|
+
def delete_by_id(llm_model_id: int) -> None:
|
|
140
|
+
llm_model_response = requests.delete(
|
|
141
|
+
f"{API_HOST}/unlearning-llm/llm/{llm_model_id}",
|
|
142
|
+
headers=get_headers(),
|
|
143
|
+
timeout=MODIFY_TIMEOUT,
|
|
144
|
+
)
|
|
145
|
+
raise_for_status_with_reason(llm_model_response)
|
|
146
|
+
logger.info("Deleted LLM model with ID: %s", llm_model_id)
|
|
147
|
+
|
|
148
|
+
def delete(self) -> None:
|
|
149
|
+
if not self.id:
|
|
150
|
+
raise ValueError("No LLM model has been created")
|
|
151
|
+
self.delete_by_id(self.id)
|
|
152
|
+
|
|
153
|
+
def update(
|
|
154
|
+
self,
|
|
155
|
+
model_name: str | None = None,
|
|
156
|
+
model_source: LlmSources | None = None,
|
|
157
|
+
archive_existing_runs: bool | None = None,
|
|
158
|
+
) -> None:
|
|
159
|
+
if not self.id:
|
|
160
|
+
raise ValueError("No LLM model has been created")
|
|
161
|
+
payload: dict[str, typing.Any] = {
|
|
162
|
+
"model_name": model_name,
|
|
163
|
+
"model_source": model_source.model_dump(mode="json")
|
|
164
|
+
if model_source
|
|
165
|
+
else None,
|
|
166
|
+
"archive_existing_runs": archive_existing_runs,
|
|
167
|
+
"organization_id": self.organization_id,
|
|
168
|
+
}
|
|
169
|
+
llm_model_response = requests.put(
|
|
170
|
+
f"{API_HOST}/unlearning-llm/llm/{self.id}",
|
|
171
|
+
json=payload,
|
|
172
|
+
headers=get_headers(),
|
|
173
|
+
timeout=MODIFY_TIMEOUT,
|
|
174
|
+
)
|
|
175
|
+
raise_for_status_with_reason(llm_model_response)
|
|
176
|
+
if model_name is not None:
|
|
177
|
+
self.model_name = model_name
|
|
178
|
+
if model_source is not None:
|
|
179
|
+
self.model_source = model_source
|
|
180
|
+
if archive_existing_runs is not None:
|
|
181
|
+
self.archive_existing_runs = archive_existing_runs
|
|
182
|
+
|
|
183
|
+
def get_hf_pipeline_for_run(
|
|
184
|
+
self,
|
|
185
|
+
run_id: str,
|
|
186
|
+
config: "PretrainedConfig | None" = None,
|
|
187
|
+
device: "str | int | torch_device | None" = None,
|
|
188
|
+
device_map: str | dict[str, int | str] | None = None,
|
|
189
|
+
trust_remote_code: bool = False,
|
|
190
|
+
) -> "Pipeline":
|
|
191
|
+
return get_hf_pipeline_for_run_given_model(
|
|
192
|
+
self, run_id, config, device, device_map, trust_remote_code
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class LlmModelOut(BaseModel):
|
|
197
|
+
model_config = ConfigDict(protected_namespaces=("model_validate", "model_dump"))
|
|
198
|
+
|
|
199
|
+
id: int
|
|
200
|
+
organization_id: int
|
|
201
|
+
creator_id: int
|
|
202
|
+
creator_name: str
|
|
203
|
+
created_at: datetime.datetime
|
|
204
|
+
updated_at: datetime.datetime
|
|
205
|
+
model_name: str
|
|
206
|
+
model_source: LlmSourcesOutput
|
|
207
|
+
|
|
208
|
+
def get_hf_pipeline_for_run(
|
|
209
|
+
self,
|
|
210
|
+
run_id: str,
|
|
211
|
+
config: "PretrainedConfig | None" = None,
|
|
212
|
+
device: "str | int | torch_device | None" = None,
|
|
213
|
+
device_map: str | dict[str, int | str] | None = None,
|
|
214
|
+
trust_remote_code: bool = False,
|
|
215
|
+
token: str | None = None,
|
|
216
|
+
) -> "Pipeline":
|
|
217
|
+
return get_hf_pipeline_for_run_given_model(
|
|
218
|
+
self,
|
|
219
|
+
run_id,
|
|
220
|
+
config,
|
|
221
|
+
device,
|
|
222
|
+
device_map,
|
|
223
|
+
trust_remote_code,
|
|
224
|
+
token=token,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
class DatasetType(str, Enum):
|
|
229
|
+
NORMAL = "normal"
|
|
230
|
+
BIAS = "bias"
|
|
231
|
+
UNBIAS = "unbias"
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class UnlearningLlmAdvancedOptions(BaseModel):
|
|
235
|
+
max_tokens_for_model: dict[DatasetType, int] | int | None = None
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class BiasType(str, Enum):
|
|
239
|
+
ALL = "ALL"
|
|
240
|
+
RACE = "RACE"
|
|
241
|
+
NATIONALITY = "NATIONALITY"
|
|
242
|
+
GENDER = "GENDER"
|
|
243
|
+
PHYSICAL_APPEARANCE = "PHYSICAL_APPEARANCE"
|
|
244
|
+
RELIGION = "RELIGION"
|
|
245
|
+
AGE = "AGE"
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class UtilityType(str, Enum):
|
|
249
|
+
DEFAULT = "DEFAULT"
|
|
250
|
+
CUSTOM = "CUSTOM"
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class DefaultUtility(BaseModel):
|
|
254
|
+
utility_type: Literal[UtilityType.DEFAULT] = UtilityType.DEFAULT
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class HirundoCSVDataset(BaseModel):
|
|
258
|
+
type: Literal["HirundoCSV"] = "HirundoCSV"
|
|
259
|
+
csv_url: str
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
class HuggingFaceDataset(BaseModel):
|
|
263
|
+
type: Literal["HuggingFaceDataset"] = "HuggingFaceDataset"
|
|
264
|
+
hugging_face_dataset_name: str
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
CustomDataset = HirundoCSVDataset | HuggingFaceDataset
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class CustomUtility(BaseModel):
|
|
271
|
+
utility_type: Literal[UtilityType.CUSTOM] = UtilityType.CUSTOM
|
|
272
|
+
dataset: CustomDataset
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
class BiasBehavior(BaseModel):
|
|
276
|
+
type: Literal["BIAS"] = "BIAS"
|
|
277
|
+
bias_type: BiasType
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
class HallucinationType(str, Enum):
|
|
281
|
+
GENERAL = "GENERAL"
|
|
282
|
+
MEDICAL = "MEDICAL"
|
|
283
|
+
LEGAL = "LEGAL"
|
|
284
|
+
DEFENSE = "DEFENSE"
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
class HallucinationBehavior(BaseModel):
|
|
288
|
+
type: Literal["HALLUCINATION"] = "HALLUCINATION"
|
|
289
|
+
hallucination_type: HallucinationType
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
class SecurityBehavior(BaseModel):
|
|
293
|
+
type: Literal["SECURITY"] = "SECURITY"
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
class CustomBehavior(BaseModel):
|
|
297
|
+
type: Literal["CUSTOM"] = "CUSTOM"
|
|
298
|
+
biased_dataset: CustomDataset
|
|
299
|
+
unbiased_dataset: CustomDataset
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
TargetBehavior = (
|
|
303
|
+
BiasBehavior | HallucinationBehavior | SecurityBehavior | CustomBehavior
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
TargetUtility = DefaultUtility | CustomUtility
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
class LlmRunInfo(BaseModel):
|
|
310
|
+
model_config = ConfigDict(protected_namespaces=("model_validate", "model_dump"))
|
|
311
|
+
|
|
312
|
+
organization_id: int | None = None
|
|
313
|
+
name: str | None = None
|
|
314
|
+
target_behaviors: list[TargetBehavior]
|
|
315
|
+
target_utilities: list[TargetUtility]
|
|
316
|
+
advanced_options: UnlearningLlmAdvancedOptions | None = None
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
class BiasRunInfo(BaseModel):
|
|
320
|
+
bias_type: BiasType
|
|
321
|
+
organization_id: int | None = None
|
|
322
|
+
name: str | None = None
|
|
323
|
+
target_utilities: list[TargetUtility] | None = None
|
|
324
|
+
advanced_options: UnlearningLlmAdvancedOptions | None = None
|
|
325
|
+
|
|
326
|
+
def to_run_info(self) -> LlmRunInfo:
|
|
327
|
+
default_utilities: list[TargetUtility] = (
|
|
328
|
+
[DefaultUtility()]
|
|
329
|
+
if self.target_utilities is None
|
|
330
|
+
else list(self.target_utilities)
|
|
331
|
+
)
|
|
332
|
+
return LlmRunInfo(
|
|
333
|
+
organization_id=self.organization_id,
|
|
334
|
+
name=self.name,
|
|
335
|
+
target_behaviors=[BiasBehavior(bias_type=self.bias_type)],
|
|
336
|
+
target_utilities=default_utilities,
|
|
337
|
+
advanced_options=self.advanced_options,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
OutputLlm = dict[str, object]
|
|
342
|
+
BehaviorOptions = TargetBehavior
|
|
343
|
+
UtilityOptions = TargetUtility
|
|
344
|
+
CeleryTaskState = str
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
class OutputUnlearningLlmRun(BaseModel):
|
|
348
|
+
model_config = ConfigDict(protected_namespaces=("model_validate", "model_dump"))
|
|
349
|
+
|
|
350
|
+
id: int
|
|
351
|
+
name: str
|
|
352
|
+
model_id: int
|
|
353
|
+
model: OutputLlm
|
|
354
|
+
target_behaviors: list[BehaviorOptions]
|
|
355
|
+
target_utilities: list[UtilityOptions]
|
|
356
|
+
advanced_options: UnlearningLlmAdvancedOptions | None
|
|
357
|
+
run_id: str
|
|
358
|
+
mlflow_run_id: str | None
|
|
359
|
+
status: CeleryTaskState
|
|
360
|
+
approved: bool
|
|
361
|
+
created_at: datetime.datetime
|
|
362
|
+
completed_at: datetime.datetime | None
|
|
363
|
+
pre_process_progress: float
|
|
364
|
+
optimization_progress: float
|
|
365
|
+
post_process_progress: float
|
|
366
|
+
|
|
367
|
+
deleted_at: datetime.datetime | None = None
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
STATUS_TO_TEXT_MAP = build_status_text_map("LLM unlearning")
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
class LlmUnlearningRun:
|
|
374
|
+
@staticmethod
|
|
375
|
+
def launch(model_id: int, run_info: LlmRunInfo | BiasRunInfo) -> str:
|
|
376
|
+
resolved_run_info = (
|
|
377
|
+
run_info.to_run_info() if isinstance(run_info, BiasRunInfo) else run_info
|
|
378
|
+
)
|
|
379
|
+
run_response = requests.post(
|
|
380
|
+
f"{API_HOST}/unlearning-llm/run/{model_id}",
|
|
381
|
+
json=resolved_run_info.model_dump(mode="json"),
|
|
382
|
+
headers=get_headers(),
|
|
383
|
+
timeout=MODIFY_TIMEOUT,
|
|
384
|
+
)
|
|
385
|
+
raise_for_status_with_reason(run_response)
|
|
386
|
+
run_response_json = run_response.json() if run_response.content else {}
|
|
387
|
+
if isinstance(run_response_json, str):
|
|
388
|
+
return run_response_json
|
|
389
|
+
run_id = run_response_json.get("run_id")
|
|
390
|
+
if not run_id:
|
|
391
|
+
raise ValueError("No run ID returned from launch request")
|
|
392
|
+
return run_id
|
|
393
|
+
|
|
394
|
+
@staticmethod
|
|
395
|
+
def cancel(run_id: str) -> None:
|
|
396
|
+
run_response = requests.patch(
|
|
397
|
+
f"{API_HOST}/unlearning-llm/run/cancel/{run_id}",
|
|
398
|
+
headers=get_headers(),
|
|
399
|
+
timeout=MODIFY_TIMEOUT,
|
|
400
|
+
)
|
|
401
|
+
raise_for_status_with_reason(run_response)
|
|
402
|
+
|
|
403
|
+
@staticmethod
|
|
404
|
+
def rename(run_id: str, new_name: str) -> None:
|
|
405
|
+
run_response = requests.patch(
|
|
406
|
+
f"{API_HOST}/unlearning-llm/run/rename/{run_id}",
|
|
407
|
+
json={"new_name": new_name},
|
|
408
|
+
headers=get_headers(),
|
|
409
|
+
timeout=MODIFY_TIMEOUT,
|
|
410
|
+
)
|
|
411
|
+
raise_for_status_with_reason(run_response)
|
|
412
|
+
|
|
413
|
+
@staticmethod
|
|
414
|
+
def archive(run_id: str) -> None:
|
|
415
|
+
run_response = requests.patch(
|
|
416
|
+
f"{API_HOST}/unlearning-llm/run/archive/{run_id}",
|
|
417
|
+
headers=get_headers(),
|
|
418
|
+
timeout=MODIFY_TIMEOUT,
|
|
419
|
+
)
|
|
420
|
+
raise_for_status_with_reason(run_response)
|
|
421
|
+
|
|
422
|
+
@staticmethod
|
|
423
|
+
def restore(run_id: str) -> None:
|
|
424
|
+
run_response = requests.patch(
|
|
425
|
+
f"{API_HOST}/unlearning-llm/run/restore/{run_id}",
|
|
426
|
+
headers=get_headers(),
|
|
427
|
+
timeout=MODIFY_TIMEOUT,
|
|
428
|
+
)
|
|
429
|
+
raise_for_status_with_reason(run_response)
|
|
430
|
+
|
|
431
|
+
@staticmethod
|
|
432
|
+
def list(
|
|
433
|
+
organization_id: int | None = None,
|
|
434
|
+
archived: bool = False,
|
|
435
|
+
) -> list[OutputUnlearningLlmRun]:
|
|
436
|
+
params: dict[str, bool | int] = {"archived": archived}
|
|
437
|
+
if organization_id is not None:
|
|
438
|
+
params["unlearning_organization_id"] = organization_id
|
|
439
|
+
run_response = requests.get(
|
|
440
|
+
f"{API_HOST}/unlearning-llm/run/list",
|
|
441
|
+
params=params,
|
|
442
|
+
headers=get_headers(),
|
|
443
|
+
timeout=READ_TIMEOUT,
|
|
444
|
+
)
|
|
445
|
+
raise_for_status_with_reason(run_response)
|
|
446
|
+
response_json = run_response.json()
|
|
447
|
+
if isinstance(response_json, list):
|
|
448
|
+
return [
|
|
449
|
+
OutputUnlearningLlmRun.model_validate(run_payload)
|
|
450
|
+
for run_payload in response_json
|
|
451
|
+
]
|
|
452
|
+
return [OutputUnlearningLlmRun.model_validate(response_json)]
|
|
453
|
+
|
|
454
|
+
@staticmethod
|
|
455
|
+
def _check_run_by_id(run_id: str, retry=0) -> Generator[dict, None, None]:
|
|
456
|
+
yield from iter_run_events(
|
|
457
|
+
f"{API_HOST}/unlearning-llm/run/{run_id}",
|
|
458
|
+
headers=get_headers(),
|
|
459
|
+
retry=retry,
|
|
460
|
+
status_keys=("state", "status"),
|
|
461
|
+
error_cls=HirundoError,
|
|
462
|
+
log=logger,
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
@staticmethod
|
|
466
|
+
@overload
|
|
467
|
+
def check_run_by_id(
|
|
468
|
+
run_id: str, stop_on_manual_approval: Literal[True]
|
|
469
|
+
) -> typing.Any | None: ...
|
|
470
|
+
|
|
471
|
+
@staticmethod
|
|
472
|
+
@overload
|
|
473
|
+
def check_run_by_id(
|
|
474
|
+
run_id: str, stop_on_manual_approval: Literal[False] = False
|
|
475
|
+
) -> typing.Any: ...
|
|
476
|
+
|
|
477
|
+
@staticmethod
|
|
478
|
+
@overload
|
|
479
|
+
def check_run_by_id(
|
|
480
|
+
run_id: str, stop_on_manual_approval: bool
|
|
481
|
+
) -> typing.Any | None: ...
|
|
482
|
+
|
|
483
|
+
@staticmethod
|
|
484
|
+
def check_run_by_id(run_id: str, stop_on_manual_approval: bool = False):
|
|
485
|
+
"""
|
|
486
|
+
Check the status of a run given its ID
|
|
487
|
+
|
|
488
|
+
Args:
|
|
489
|
+
run_id: The `run_id` produced by a `launch` call
|
|
490
|
+
stop_on_manual_approval: If True, the function will return `None` if the run is awaiting manual approval
|
|
491
|
+
|
|
492
|
+
Returns:
|
|
493
|
+
The result payload for the run, if available
|
|
494
|
+
|
|
495
|
+
Raises:
|
|
496
|
+
HirundoError: If the maximum number of retries is reached or if the run fails
|
|
497
|
+
"""
|
|
498
|
+
logger.debug("Checking run with ID: %s", run_id)
|
|
499
|
+
with logging_redirect_tqdm():
|
|
500
|
+
t = tqdm(total=100.0)
|
|
501
|
+
for iteration in LlmUnlearningRun._check_run_by_id(run_id):
|
|
502
|
+
state = get_state(iteration, ("state", "status"))
|
|
503
|
+
if state in STATUS_TO_PROGRESS_MAP:
|
|
504
|
+
t.set_description(STATUS_TO_TEXT_MAP[state])
|
|
505
|
+
t.n = STATUS_TO_PROGRESS_MAP[state]
|
|
506
|
+
logger.debug("Setting progress to %s", t.n)
|
|
507
|
+
t.refresh()
|
|
508
|
+
if state in [
|
|
509
|
+
RunStatus.FAILURE.value,
|
|
510
|
+
RunStatus.REJECTED.value,
|
|
511
|
+
RunStatus.REVOKED.value,
|
|
512
|
+
]:
|
|
513
|
+
logger.error(
|
|
514
|
+
"State is failure, rejected, or revoked: %s",
|
|
515
|
+
state,
|
|
516
|
+
)
|
|
517
|
+
t.close()
|
|
518
|
+
handle_run_failure(
|
|
519
|
+
iteration,
|
|
520
|
+
error_cls=HirundoError,
|
|
521
|
+
run_label="LLM unlearning",
|
|
522
|
+
)
|
|
523
|
+
elif state == RunStatus.SUCCESS.value:
|
|
524
|
+
t.close()
|
|
525
|
+
return iteration.get("result") or iteration
|
|
526
|
+
elif (
|
|
527
|
+
state == RunStatus.AWAITING_MANUAL_APPROVAL.value
|
|
528
|
+
and stop_on_manual_approval
|
|
529
|
+
):
|
|
530
|
+
t.close()
|
|
531
|
+
return None
|
|
532
|
+
elif state is None:
|
|
533
|
+
update_progress_from_result(
|
|
534
|
+
iteration,
|
|
535
|
+
t,
|
|
536
|
+
uploading_text="LLM unlearning run completed. Uploading results",
|
|
537
|
+
log=logger,
|
|
538
|
+
)
|
|
539
|
+
raise HirundoError("LLM unlearning run failed with an unknown error")
|
|
540
|
+
|
|
541
|
+
@staticmethod
|
|
542
|
+
def check_run(run_id: str, stop_on_manual_approval: bool = False):
|
|
543
|
+
"""
|
|
544
|
+
Check the status of the given run.
|
|
545
|
+
|
|
546
|
+
Returns:
|
|
547
|
+
The result payload for the run, if available
|
|
548
|
+
"""
|
|
549
|
+
return LlmUnlearningRun.check_run_by_id(run_id, stop_on_manual_approval)
|
|
550
|
+
|
|
551
|
+
@staticmethod
|
|
552
|
+
async def acheck_run_by_id(run_id: str, retry=0) -> AsyncGenerator[dict, None]:
|
|
553
|
+
"""
|
|
554
|
+
Async version of :func:`check_run_by_id`
|
|
555
|
+
|
|
556
|
+
Check the status of a run given its ID.
|
|
557
|
+
|
|
558
|
+
This generator will produce values to show progress of the run.
|
|
559
|
+
|
|
560
|
+
Note: This function does not handle errors nor show progress. It is expected that you do that.
|
|
561
|
+
|
|
562
|
+
Args:
|
|
563
|
+
run_id: The `run_id` produced by a `launch` call
|
|
564
|
+
retry: A number used to track the number of retries to limit re-checks. *Do not* provide this value manually.
|
|
565
|
+
|
|
566
|
+
Yields:
|
|
567
|
+
Each event will be a dict, where:
|
|
568
|
+
- `"state"` is PENDING, STARTED, RETRY, FAILURE or SUCCESS
|
|
569
|
+
- `"result"` is a string describing the progress as a percentage for a PENDING state, or the error for a FAILURE state or the results for a SUCCESS state
|
|
570
|
+
|
|
571
|
+
"""
|
|
572
|
+
logger.debug("Checking run with ID: %s", run_id)
|
|
573
|
+
async for iteration in aiter_run_events(
|
|
574
|
+
f"{API_HOST}/unlearning-llm/run/{run_id}",
|
|
575
|
+
headers=get_headers(),
|
|
576
|
+
retry=retry,
|
|
577
|
+
status_keys=("state", "status"),
|
|
578
|
+
error_cls=HirundoError,
|
|
579
|
+
log=logger,
|
|
580
|
+
):
|
|
581
|
+
yield iteration
|
|
582
|
+
|
|
583
|
+
@staticmethod
|
|
584
|
+
async def acheck_run(run_id: str) -> AsyncGenerator[dict, None]:
|
|
585
|
+
"""
|
|
586
|
+
Async version of :func:`check_run`
|
|
587
|
+
|
|
588
|
+
Check the status of the given run.
|
|
589
|
+
|
|
590
|
+
This generator will produce values to show progress of the run.
|
|
591
|
+
|
|
592
|
+
Yields:
|
|
593
|
+
Each event will be a dict, where:
|
|
594
|
+
- `"state"` is PENDING, STARTED, RETRY, FAILURE or SUCCESS
|
|
595
|
+
- `"result"` is a string describing the progress as a percentage for a PENDING state, or the error for a FAILURE state or the results for a SUCCESS state
|
|
596
|
+
|
|
597
|
+
"""
|
|
598
|
+
async for iteration in LlmUnlearningRun.acheck_run_by_id(run_id):
|
|
599
|
+
yield iteration
|
hirundo/unzip.py
CHANGED
|
@@ -4,7 +4,6 @@ from collections.abc import Mapping
|
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from typing import IO, cast
|
|
6
6
|
|
|
7
|
-
import requests
|
|
8
7
|
from pydantic_core import Url
|
|
9
8
|
|
|
10
9
|
from hirundo._dataframe import (
|
|
@@ -18,16 +17,17 @@ from hirundo._dataframe import (
|
|
|
18
17
|
)
|
|
19
18
|
from hirundo._env import API_HOST
|
|
20
19
|
from hirundo._headers import _get_auth_headers
|
|
20
|
+
from hirundo._http import requests
|
|
21
21
|
from hirundo._timeouts import DOWNLOAD_READ_TIMEOUT
|
|
22
|
-
from hirundo.
|
|
22
|
+
from hirundo.dataset_qa_results import (
|
|
23
23
|
DataFrameType,
|
|
24
|
-
|
|
24
|
+
DatasetQAResults,
|
|
25
25
|
)
|
|
26
26
|
from hirundo.logger import get_logger
|
|
27
27
|
|
|
28
28
|
ZIP_FILE_CHUNK_SIZE = 50 * 1024 * 1024 # 50 MB
|
|
29
29
|
|
|
30
|
-
Dtype =
|
|
30
|
+
Dtype = type[int32] | type[float32] | type[string]
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
CUSTOMER_INTERCHANGE_DTYPES: Mapping[str, Dtype] = {
|
|
@@ -75,7 +75,7 @@ def _clean_df_index(df: "pd.DataFrame") -> "pd.DataFrame":
|
|
|
75
75
|
|
|
76
76
|
|
|
77
77
|
def load_df(
|
|
78
|
-
file: "
|
|
78
|
+
file: "str | IO[bytes]",
|
|
79
79
|
) -> "DataFrameType":
|
|
80
80
|
"""
|
|
81
81
|
Load a DataFrame from a CSV file.
|
|
@@ -117,7 +117,7 @@ def get_mislabel_suspect_filename(filenames: list[str]):
|
|
|
117
117
|
|
|
118
118
|
def download_and_extract_zip(
|
|
119
119
|
run_id: str, zip_url: str
|
|
120
|
-
) ->
|
|
120
|
+
) -> DatasetQAResults[DataFrameType]:
|
|
121
121
|
"""
|
|
122
122
|
Download and extract the zip file from the given URL.
|
|
123
123
|
|
|
@@ -127,11 +127,11 @@ def download_and_extract_zip(
|
|
|
127
127
|
and `warnings_and_errors.csv` files from the zip file.
|
|
128
128
|
|
|
129
129
|
Args:
|
|
130
|
-
run_id: The ID of the
|
|
130
|
+
run_id: The ID of the dataset QA run.
|
|
131
131
|
zip_url: The URL of the zip file to download.
|
|
132
132
|
|
|
133
133
|
Returns:
|
|
134
|
-
The dataset
|
|
134
|
+
The dataset QA results object.
|
|
135
135
|
"""
|
|
136
136
|
# Define the local file path
|
|
137
137
|
cache_dir = Path.home() / ".hirundo" / "cache"
|
|
@@ -140,9 +140,8 @@ def download_and_extract_zip(
|
|
|
140
140
|
|
|
141
141
|
headers = None
|
|
142
142
|
if Url(zip_url).scheme == "file":
|
|
143
|
-
zip_url = (
|
|
144
|
-
|
|
145
|
-
+ zip_url.replace("file://", "")
|
|
143
|
+
zip_url = f"{API_HOST}/dataset-qa/run/local-download" + zip_url.replace(
|
|
144
|
+
"file://", ""
|
|
146
145
|
)
|
|
147
146
|
headers = _get_auth_headers()
|
|
148
147
|
# Stream the zip file download
|
|
@@ -217,7 +216,7 @@ def download_and_extract_zip(
|
|
|
217
216
|
"Failed to load warnings and errors into DataFrame", exc_info=e
|
|
218
217
|
)
|
|
219
218
|
|
|
220
|
-
return
|
|
219
|
+
return DatasetQAResults[DataFrameType](
|
|
221
220
|
cached_zip_path=zip_file_path,
|
|
222
221
|
suspects=suspects_df,
|
|
223
222
|
object_suspects=object_suspects_df,
|
|
@@ -227,7 +226,7 @@ def download_and_extract_zip(
|
|
|
227
226
|
|
|
228
227
|
def load_from_zip(
|
|
229
228
|
zip_path: Path, file_name: str
|
|
230
|
-
) -> "
|
|
229
|
+
) -> "pd.DataFrame | pl.DataFrame | None":
|
|
231
230
|
"""
|
|
232
231
|
Load a given file from a given zip file.
|
|
233
232
|
|