hirundo 0.1.21__tar.gz → 0.2.3.post1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/PKG-INFO +42 -10
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/README.md +27 -3
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo/__init__.py +19 -3
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo/_constraints.py +2 -3
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo/_iter_sse_retrying.py +7 -4
- hirundo-0.2.3.post1/hirundo/_llm_pipeline.py +153 -0
- hirundo-0.2.3.post1/hirundo/_run_checking.py +283 -0
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo/_urls.py +1 -0
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo/cli.py +1 -4
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo/dataset_enum.py +2 -0
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo/dataset_qa.py +106 -190
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo/dataset_qa_results.py +3 -3
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo/git.py +7 -8
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo/labeling.py +22 -19
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo/storage.py +25 -24
- hirundo-0.2.3.post1/hirundo/unlearning_llm.py +599 -0
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo/unzip.py +3 -3
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo.egg-info/PKG-INFO +42 -10
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo.egg-info/SOURCES.txt +5 -1
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo.egg-info/requires.txt +14 -5
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/pyproject.toml +26 -13
- hirundo-0.2.3.post1/tests/testing_utils.py +7 -0
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/LICENSE +0 -0
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo/__main__.py +0 -0
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo/_dataframe.py +0 -0
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo/_env.py +0 -0
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo/_headers.py +0 -0
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo/_http.py +0 -0
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo/_timeouts.py +0 -0
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo/logger.py +0 -0
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo.egg-info/dependency_links.txt +0 -0
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo.egg-info/entry_points.txt +0 -0
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/hirundo.egg-info/top_level.txt +0 -0
- {hirundo-0.1.21 → hirundo-0.2.3.post1}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hirundo
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.3.post1
|
|
4
4
|
Summary: This package is used to interface with Hirundo's platform. It provides a simple API to optimize your ML datasets.
|
|
5
5
|
Author-email: Hirundo <dev@hirundo.io>
|
|
6
6
|
License: MIT License
|
|
@@ -18,7 +18,7 @@ Keywords: dataset,machine learning,data science,data engineering
|
|
|
18
18
|
Classifier: License :: OSI Approved :: MIT License
|
|
19
19
|
Classifier: Programming Language :: Python
|
|
20
20
|
Classifier: Programming Language :: Python :: 3
|
|
21
|
-
Requires-Python: >=3.
|
|
21
|
+
Requires-Python: >=3.10
|
|
22
22
|
Description-Content-Type: text/markdown
|
|
23
23
|
License-File: LICENSE
|
|
24
24
|
Requires-Dist: pyyaml>=6.0.1
|
|
@@ -34,8 +34,9 @@ Requires-Dist: httpx-sse>=0.4.0
|
|
|
34
34
|
Requires-Dist: tqdm>=4.66.5
|
|
35
35
|
Requires-Dist: h11>=0.16.0
|
|
36
36
|
Requires-Dist: requests>=2.32.4
|
|
37
|
-
Requires-Dist: urllib3>=2.
|
|
37
|
+
Requires-Dist: urllib3>=2.6.3
|
|
38
38
|
Requires-Dist: setuptools>=78.1.1
|
|
39
|
+
Requires-Dist: docutils<0.22.0
|
|
39
40
|
Provides-Extra: dev
|
|
40
41
|
Requires-Dist: pyyaml>=6.0.1; extra == "dev"
|
|
41
42
|
Requires-Dist: types-PyYAML>=6.0.12; extra == "dev"
|
|
@@ -50,15 +51,18 @@ Requires-Dist: stamina>=24.2.0; extra == "dev"
|
|
|
50
51
|
Requires-Dist: httpx-sse>=0.4.0; extra == "dev"
|
|
51
52
|
Requires-Dist: pytest>=8.2.0; extra == "dev"
|
|
52
53
|
Requires-Dist: pytest-asyncio>=0.23.6; extra == "dev"
|
|
53
|
-
Requires-Dist: uv>=0.
|
|
54
|
+
Requires-Dist: uv>=0.9.6; extra == "dev"
|
|
54
55
|
Requires-Dist: pre-commit>=3.7.1; extra == "dev"
|
|
56
|
+
Requires-Dist: basedpyright==1.37.1; extra == "dev"
|
|
55
57
|
Requires-Dist: virtualenv>=20.6.6; extra == "dev"
|
|
58
|
+
Requires-Dist: authlib>=1.6.6; extra == "dev"
|
|
56
59
|
Requires-Dist: ruff>=0.12.0; extra == "dev"
|
|
57
|
-
Requires-Dist: bumpver; extra == "dev"
|
|
60
|
+
Requires-Dist: bumpver>=2025.1131; extra == "dev"
|
|
58
61
|
Requires-Dist: platformdirs>=4.3.6; extra == "dev"
|
|
59
|
-
Requires-Dist: safety>=3.2.13; extra == "dev"
|
|
60
62
|
Requires-Dist: cryptography>=44.0.1; extra == "dev"
|
|
61
63
|
Requires-Dist: jinja2>=3.1.6; extra == "dev"
|
|
64
|
+
Requires-Dist: filelock>=3.20.1; extra == "dev"
|
|
65
|
+
Requires-Dist: marshmallow>=3.26.2; extra == "dev"
|
|
62
66
|
Provides-Extra: docs
|
|
63
67
|
Requires-Dist: sphinx>=7.4.7; extra == "docs"
|
|
64
68
|
Requires-Dist: sphinx-autobuild>=2024.9.3; extra == "docs"
|
|
@@ -67,13 +71,17 @@ Requires-Dist: autodoc_pydantic>=2.2.0; extra == "docs"
|
|
|
67
71
|
Requires-Dist: furo; extra == "docs"
|
|
68
72
|
Requires-Dist: sphinx-multiversion; extra == "docs"
|
|
69
73
|
Requires-Dist: esbonio; extra == "docs"
|
|
70
|
-
Requires-Dist: starlette>=0.
|
|
74
|
+
Requires-Dist: starlette>=0.49.1; extra == "docs"
|
|
71
75
|
Requires-Dist: markupsafe>=3.0.2; extra == "docs"
|
|
72
76
|
Requires-Dist: jinja2>=3.1.6; extra == "docs"
|
|
73
77
|
Provides-Extra: pandas
|
|
74
78
|
Requires-Dist: pandas>=2.2.3; extra == "pandas"
|
|
75
79
|
Provides-Extra: polars
|
|
76
80
|
Requires-Dist: polars>=1.0.0; extra == "polars"
|
|
81
|
+
Provides-Extra: transformers
|
|
82
|
+
Requires-Dist: transformers>=4.57.3; extra == "transformers"
|
|
83
|
+
Requires-Dist: peft>=0.18.1; extra == "transformers"
|
|
84
|
+
Requires-Dist: accelerate>=1.12.0; extra == "transformers"
|
|
77
85
|
Dynamic: license-file
|
|
78
86
|
|
|
79
87
|
# Hirundo
|
|
@@ -145,7 +153,31 @@ You can install the codebase with a simple `pip install hirundo` to install the
|
|
|
145
153
|
|
|
146
154
|
## Usage
|
|
147
155
|
|
|
148
|
-
|
|
156
|
+
### Unlearning LLM behavior
|
|
157
|
+
|
|
158
|
+
Make sure to install the `transformers` extra, i.e. `pip install hirundo[transformers]` or `uv pip install hirundo[transformers]` if you have `uv` installed which is much faster than `pip`.
|
|
159
|
+
|
|
160
|
+
```python
|
|
161
|
+
llm = LlmModel(
|
|
162
|
+
model_name="Nemotron-Flash-1B",
|
|
163
|
+
model_source=HuggingFaceTransformersModel(
|
|
164
|
+
model_name="nvidia/Nemotron-Flash-1B",
|
|
165
|
+
),
|
|
166
|
+
)
|
|
167
|
+
llm_id = llm.create()
|
|
168
|
+
run_info = BiasRunInfo(
|
|
169
|
+
bias_type=BiasType.ALL,
|
|
170
|
+
)
|
|
171
|
+
run_id = LlmUnlearningRun.launch(
|
|
172
|
+
llm_id,
|
|
173
|
+
run_info,
|
|
174
|
+
)
|
|
175
|
+
new_adapter = llm.get_hf_pipeline_for_run(run_id)
|
|
176
|
+
```
|
|
177
|
+
|
|
178
|
+
### Dataset QA
|
|
179
|
+
|
|
180
|
+
#### Classification example:
|
|
149
181
|
|
|
150
182
|
```python
|
|
151
183
|
from hirundo import (
|
|
@@ -182,7 +214,7 @@ results = test_dataset.check_run()
|
|
|
182
214
|
print(results)
|
|
183
215
|
```
|
|
184
216
|
|
|
185
|
-
Object detection example:
|
|
217
|
+
#### Object detection example:
|
|
186
218
|
|
|
187
219
|
```python
|
|
188
220
|
from hirundo import (
|
|
@@ -223,7 +255,7 @@ results = test_dataset.check_run()
|
|
|
223
255
|
print(results)
|
|
224
256
|
```
|
|
225
257
|
|
|
226
|
-
Note: Currently we only support the main CPython release 3.
|
|
258
|
+
Note: Currently we only support the main CPython release 3.10, 3.11, 3.12 & 3.13. PyPy support may be introduced in the future.
|
|
227
259
|
|
|
228
260
|
## Further documentation
|
|
229
261
|
|
|
@@ -67,7 +67,31 @@ You can install the codebase with a simple `pip install hirundo` to install the
|
|
|
67
67
|
|
|
68
68
|
## Usage
|
|
69
69
|
|
|
70
|
-
|
|
70
|
+
### Unlearning LLM behavior
|
|
71
|
+
|
|
72
|
+
Make sure to install the `transformers` extra, i.e. `pip install hirundo[transformers]` or `uv pip install hirundo[transformers]` if you have `uv` installed which is much faster than `pip`.
|
|
73
|
+
|
|
74
|
+
```python
|
|
75
|
+
llm = LlmModel(
|
|
76
|
+
model_name="Nemotron-Flash-1B",
|
|
77
|
+
model_source=HuggingFaceTransformersModel(
|
|
78
|
+
model_name="nvidia/Nemotron-Flash-1B",
|
|
79
|
+
),
|
|
80
|
+
)
|
|
81
|
+
llm_id = llm.create()
|
|
82
|
+
run_info = BiasRunInfo(
|
|
83
|
+
bias_type=BiasType.ALL,
|
|
84
|
+
)
|
|
85
|
+
run_id = LlmUnlearningRun.launch(
|
|
86
|
+
llm_id,
|
|
87
|
+
run_info,
|
|
88
|
+
)
|
|
89
|
+
new_adapter = llm.get_hf_pipeline_for_run(run_id)
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
### Dataset QA
|
|
93
|
+
|
|
94
|
+
#### Classification example:
|
|
71
95
|
|
|
72
96
|
```python
|
|
73
97
|
from hirundo import (
|
|
@@ -104,7 +128,7 @@ results = test_dataset.check_run()
|
|
|
104
128
|
print(results)
|
|
105
129
|
```
|
|
106
130
|
|
|
107
|
-
Object detection example:
|
|
131
|
+
#### Object detection example:
|
|
108
132
|
|
|
109
133
|
```python
|
|
110
134
|
from hirundo import (
|
|
@@ -145,7 +169,7 @@ results = test_dataset.check_run()
|
|
|
145
169
|
print(results)
|
|
146
170
|
```
|
|
147
171
|
|
|
148
|
-
Note: Currently we only support the main CPython release 3.
|
|
172
|
+
Note: Currently we only support the main CPython release 3.10, 3.11, 3.12 & 3.13. PyPy support may be introduced in the future.
|
|
149
173
|
|
|
150
174
|
## Further documentation
|
|
151
175
|
|
|
@@ -5,8 +5,8 @@ from .dataset_enum import (
|
|
|
5
5
|
)
|
|
6
6
|
from .dataset_qa import (
|
|
7
7
|
ClassificationRunArgs,
|
|
8
|
-
Domain,
|
|
9
8
|
HirundoError,
|
|
9
|
+
ModalityType,
|
|
10
10
|
ObjectDetectionRunArgs,
|
|
11
11
|
QADataset,
|
|
12
12
|
RunArgs,
|
|
@@ -30,6 +30,15 @@ from .storage import (
|
|
|
30
30
|
StorageGit,
|
|
31
31
|
StorageS3,
|
|
32
32
|
)
|
|
33
|
+
from .unlearning_llm import (
|
|
34
|
+
BiasRunInfo,
|
|
35
|
+
BiasType,
|
|
36
|
+
HuggingFaceTransformersModel,
|
|
37
|
+
LlmModel,
|
|
38
|
+
LlmSources,
|
|
39
|
+
LlmUnlearningRun,
|
|
40
|
+
LocalTransformersModel,
|
|
41
|
+
)
|
|
33
42
|
from .unzip import load_df, load_from_zip
|
|
34
43
|
|
|
35
44
|
__all__ = [
|
|
@@ -43,7 +52,7 @@ __all__ = [
|
|
|
43
52
|
"KeylabsObjSegImages",
|
|
44
53
|
"KeylabsObjSegVideo",
|
|
45
54
|
"QADataset",
|
|
46
|
-
"
|
|
55
|
+
"ModalityType",
|
|
47
56
|
"RunArgs",
|
|
48
57
|
"ClassificationRunArgs",
|
|
49
58
|
"ObjectDetectionRunArgs",
|
|
@@ -59,8 +68,15 @@ __all__ = [
|
|
|
59
68
|
"StorageGit",
|
|
60
69
|
"StorageConfig",
|
|
61
70
|
"DatasetQAResults",
|
|
71
|
+
"BiasRunInfo",
|
|
72
|
+
"BiasType",
|
|
73
|
+
"HuggingFaceTransformersModel",
|
|
74
|
+
"LlmModel",
|
|
75
|
+
"LlmSources",
|
|
76
|
+
"LlmUnlearningRun",
|
|
77
|
+
"LocalTransformersModel",
|
|
62
78
|
"load_df",
|
|
63
79
|
"load_from_zip",
|
|
64
80
|
]
|
|
65
81
|
|
|
66
|
-
__version__ = "0.
|
|
82
|
+
__version__ = "0.2.3.post1"
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import re
|
|
2
|
-
import typing
|
|
3
2
|
from typing import TYPE_CHECKING
|
|
4
3
|
|
|
5
4
|
from hirundo._urls import (
|
|
@@ -135,8 +134,8 @@ def validate_labeling_type(
|
|
|
135
134
|
|
|
136
135
|
def validate_labeling_info(
|
|
137
136
|
labeling_type: "LabelingType",
|
|
138
|
-
labeling_info: "
|
|
139
|
-
storage_config: "
|
|
137
|
+
labeling_info: "LabelingInfo | list[LabelingInfo]",
|
|
138
|
+
storage_config: "StorageConfig | ResponseStorageConfig",
|
|
140
139
|
) -> None:
|
|
141
140
|
"""
|
|
142
141
|
Validate the labeling info for a dataset
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import time
|
|
3
|
-
import typing
|
|
4
3
|
import uuid
|
|
5
4
|
from collections.abc import AsyncGenerator, Generator
|
|
6
5
|
|
|
@@ -15,13 +14,15 @@ from hirundo.logger import get_logger
|
|
|
15
14
|
|
|
16
15
|
logger = get_logger(__name__)
|
|
17
16
|
|
|
17
|
+
MAX_RETRIES = 50
|
|
18
|
+
|
|
18
19
|
|
|
19
20
|
# Credit: https://github.com/florimondmanca/httpx-sse/blob/master/README.md#handling-reconnections
|
|
20
21
|
def iter_sse_retrying(
|
|
21
22
|
client: httpx.Client,
|
|
22
23
|
method: str,
|
|
23
24
|
url: str,
|
|
24
|
-
headers:
|
|
25
|
+
headers: dict[str, str] | None = None,
|
|
25
26
|
) -> Generator[ServerSentEvent, None, None]:
|
|
26
27
|
if headers is None:
|
|
27
28
|
headers = {}
|
|
@@ -41,7 +42,8 @@ def iter_sse_retrying(
|
|
|
41
42
|
httpx.ReadError,
|
|
42
43
|
httpx.RemoteProtocolError,
|
|
43
44
|
urllib3.exceptions.ReadTimeoutError,
|
|
44
|
-
)
|
|
45
|
+
),
|
|
46
|
+
attempts=MAX_RETRIES,
|
|
45
47
|
)
|
|
46
48
|
def _iter_sse():
|
|
47
49
|
nonlocal last_event_id, reconnection_delay
|
|
@@ -105,7 +107,8 @@ async def aiter_sse_retrying(
|
|
|
105
107
|
httpx.ReadError,
|
|
106
108
|
httpx.RemoteProtocolError,
|
|
107
109
|
urllib3.exceptions.ReadTimeoutError,
|
|
108
|
-
)
|
|
110
|
+
),
|
|
111
|
+
attempts=MAX_RETRIES,
|
|
109
112
|
)
|
|
110
113
|
async def _iter_sse() -> AsyncGenerator[ServerSentEvent, None]:
|
|
111
114
|
nonlocal last_event_id, reconnection_delay
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
import importlib.util
|
|
2
|
+
import tempfile
|
|
3
|
+
import zipfile
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import TYPE_CHECKING, cast
|
|
6
|
+
|
|
7
|
+
from hirundo import HirundoError
|
|
8
|
+
from hirundo._http import requests
|
|
9
|
+
from hirundo._timeouts import DOWNLOAD_READ_TIMEOUT
|
|
10
|
+
from hirundo.logger import get_logger
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from torch import device as torch_device
|
|
14
|
+
from transformers.configuration_utils import PretrainedConfig
|
|
15
|
+
from transformers.modeling_utils import PreTrainedModel
|
|
16
|
+
from transformers.pipelines.base import Pipeline
|
|
17
|
+
|
|
18
|
+
from hirundo.unlearning_llm import LlmModel, LlmModelOut
|
|
19
|
+
|
|
20
|
+
logger = get_logger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
ZIP_FILE_CHUNK_SIZE = 50 * 1024 * 1024 # 50 MB
|
|
24
|
+
REQUIRED_PACKAGES_FOR_PIPELINE = ["peft", "transformers", "accelerate"]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_hf_pipeline_for_run_given_model(
|
|
28
|
+
llm: "LlmModel | LlmModelOut",
|
|
29
|
+
run_id: str,
|
|
30
|
+
config: "PretrainedConfig | None" = None,
|
|
31
|
+
device: "str | int | torch_device | None" = None,
|
|
32
|
+
device_map: str | dict[str, int | str] | None = None,
|
|
33
|
+
trust_remote_code: bool = False,
|
|
34
|
+
token: str | None = None,
|
|
35
|
+
) -> "Pipeline":
|
|
36
|
+
for package in REQUIRED_PACKAGES_FOR_PIPELINE:
|
|
37
|
+
if importlib.util.find_spec(package) is None:
|
|
38
|
+
raise HirundoError(
|
|
39
|
+
f'{package} is not installed. Please install transformers extra with pip install "hirundo[transformers]"'
|
|
40
|
+
)
|
|
41
|
+
from peft import PeftModel
|
|
42
|
+
from transformers.models.auto.configuration_auto import AutoConfig
|
|
43
|
+
from transformers.models.auto.modeling_auto import (
|
|
44
|
+
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
|
|
45
|
+
AutoModelForCausalLM,
|
|
46
|
+
AutoModelForImageTextToText,
|
|
47
|
+
)
|
|
48
|
+
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
|
49
|
+
from transformers.pipelines import pipeline
|
|
50
|
+
|
|
51
|
+
from hirundo.unlearning_llm import (
|
|
52
|
+
HuggingFaceTransformersModel,
|
|
53
|
+
HuggingFaceTransformersModelOutput,
|
|
54
|
+
LlmUnlearningRun,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
run_results = LlmUnlearningRun.check_run_by_id(run_id)
|
|
58
|
+
if run_results is None:
|
|
59
|
+
raise HirundoError("No run results found")
|
|
60
|
+
result_payload = (
|
|
61
|
+
run_results.get("result", run_results)
|
|
62
|
+
if isinstance(run_results, dict)
|
|
63
|
+
else run_results
|
|
64
|
+
)
|
|
65
|
+
if isinstance(result_payload, dict):
|
|
66
|
+
result_url = result_payload.get("result")
|
|
67
|
+
else:
|
|
68
|
+
result_url = result_payload
|
|
69
|
+
if not isinstance(result_url, str):
|
|
70
|
+
raise HirundoError("Run results did not include a download URL")
|
|
71
|
+
# Stream the zip file download
|
|
72
|
+
|
|
73
|
+
zip_file_path = tempfile.NamedTemporaryFile(delete=False).name
|
|
74
|
+
with requests.get(
|
|
75
|
+
result_url,
|
|
76
|
+
timeout=DOWNLOAD_READ_TIMEOUT,
|
|
77
|
+
stream=True,
|
|
78
|
+
) as r:
|
|
79
|
+
r.raise_for_status()
|
|
80
|
+
with open(zip_file_path, "wb") as zip_file:
|
|
81
|
+
for chunk in r.iter_content(chunk_size=ZIP_FILE_CHUNK_SIZE):
|
|
82
|
+
zip_file.write(chunk)
|
|
83
|
+
logger.info(
|
|
84
|
+
"Successfully downloaded the result zip file for run ID %s to %s",
|
|
85
|
+
run_id,
|
|
86
|
+
zip_file_path,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
90
|
+
temp_dir_path = Path(temp_dir)
|
|
91
|
+
with zipfile.ZipFile(zip_file_path, "r") as zip_file:
|
|
92
|
+
zip_file.extractall(temp_dir_path)
|
|
93
|
+
# Attempt to load the tokenizer normally
|
|
94
|
+
base_model_name = (
|
|
95
|
+
llm.model_source.model_name
|
|
96
|
+
if isinstance(
|
|
97
|
+
llm.model_source,
|
|
98
|
+
HuggingFaceTransformersModel | HuggingFaceTransformersModelOutput,
|
|
99
|
+
)
|
|
100
|
+
else llm.model_source.local_path
|
|
101
|
+
)
|
|
102
|
+
token = (
|
|
103
|
+
llm.model_source.token
|
|
104
|
+
if isinstance(
|
|
105
|
+
llm.model_source,
|
|
106
|
+
HuggingFaceTransformersModel,
|
|
107
|
+
)
|
|
108
|
+
else token
|
|
109
|
+
)
|
|
110
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
111
|
+
base_model_name,
|
|
112
|
+
token=token,
|
|
113
|
+
trust_remote_code=trust_remote_code,
|
|
114
|
+
)
|
|
115
|
+
if tokenizer.pad_token is None:
|
|
116
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
117
|
+
config = AutoConfig.from_pretrained(
|
|
118
|
+
base_model_name,
|
|
119
|
+
token=token,
|
|
120
|
+
trust_remote_code=trust_remote_code,
|
|
121
|
+
)
|
|
122
|
+
config_dict = config.to_dict() if hasattr(config, "to_dict") else config
|
|
123
|
+
is_multimodal = (
|
|
124
|
+
config_dict.get("model_type")
|
|
125
|
+
in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.keys()
|
|
126
|
+
)
|
|
127
|
+
if is_multimodal:
|
|
128
|
+
base_model = AutoModelForImageTextToText.from_pretrained(
|
|
129
|
+
base_model_name,
|
|
130
|
+
token=token,
|
|
131
|
+
trust_remote_code=trust_remote_code,
|
|
132
|
+
)
|
|
133
|
+
else:
|
|
134
|
+
base_model = AutoModelForCausalLM.from_pretrained(
|
|
135
|
+
base_model_name,
|
|
136
|
+
token=token,
|
|
137
|
+
trust_remote_code=trust_remote_code,
|
|
138
|
+
)
|
|
139
|
+
model = cast(
|
|
140
|
+
"PreTrainedModel",
|
|
141
|
+
PeftModel.from_pretrained(
|
|
142
|
+
base_model, str(temp_dir_path / "unlearned_model_folder")
|
|
143
|
+
),
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
return pipeline(
|
|
147
|
+
task="text-generation",
|
|
148
|
+
model=model,
|
|
149
|
+
tokenizer=tokenizer,
|
|
150
|
+
config=config,
|
|
151
|
+
device=device,
|
|
152
|
+
device_map=device_map,
|
|
153
|
+
)
|