strands-transformers 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- strands_transformers/__init__.py +52 -0
- strands_transformers/_version.py +24 -0
- strands_transformers/core/__init__.py +5 -0
- strands_transformers/core/compat.py +251 -0
- strands_transformers/core/engine.py +160 -0
- strands_transformers/core/io.py +273 -0
- strands_transformers/core/registry.py +195 -0
- strands_transformers/models/__init__.py +5 -0
- strands_transformers/models/transformers.py +1421 -0
- strands_transformers/tools/__init__.py +5 -0
- strands_transformers/tools/use_transformers.py +409 -0
- strands_transformers/types/__init__.py +24 -0
- strands_transformers/types/audio.py +91 -0
- strands_transformers-0.2.0.dist-info/METADATA +252 -0
- strands_transformers-0.2.0.dist-info/RECORD +17 -0
- strands_transformers-0.2.0.dist-info/WHEEL +5 -0
- strands_transformers-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""Strands Transformers — the universal entrypoint to HuggingFace transformers.
|
|
2
|
+
|
|
3
|
+
100% transformers coverage with zero hardcoding: every task across every modality
|
|
4
|
+
(text, image, video, audio, robot-state) in and out, the same way `use_aws` wraps
|
|
5
|
+
boto3 and `use_lerobot` wraps lerobot.
|
|
6
|
+
|
|
7
|
+
Quick start:
|
|
8
|
+
from strands import Agent
|
|
9
|
+
from strands_transformers import use_transformers
|
|
10
|
+
|
|
11
|
+
agent = Agent(tools=[use_transformers])
|
|
12
|
+
agent("Transcribe recording.wav") # ASR
|
|
13
|
+
agent("Describe scene.jpg and plan a grasp") # image-text-to-text / VLA
|
|
14
|
+
agent("Say 'hello' as audio") # text-to-audio
|
|
15
|
+
|
|
16
|
+
# Or use a local HF model as the agent's brain:
|
|
17
|
+
from strands_transformers import TransformerModel
|
|
18
|
+
model = TransformerModel(model_path="Qwen/Qwen3-1.7B")
|
|
19
|
+
agent = Agent(model=model, tools=[use_transformers])
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
from strands_transformers._version import version as __version__
|
|
24
|
+
except Exception: # not built / no tags yet
|
|
25
|
+
__version__ = "0.0.0"
|
|
26
|
+
|
|
27
|
+
from strands_transformers.core import engine, io, registry
|
|
28
|
+
from strands_transformers.tools.use_transformers import use_transformers
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def __getattr__(name):
|
|
32
|
+
# Lazy import the model provider (pulls in torch) only when requested.
|
|
33
|
+
if name == "TransformerModel":
|
|
34
|
+
from strands_transformers.models.transformers import TransformerModel
|
|
35
|
+
return TransformerModel
|
|
36
|
+
# Audio content-block helpers (our extension to the Strands taxonomy).
|
|
37
|
+
if name in ("make_audio_block", "extract_audio_payload", "AudioContent"):
|
|
38
|
+
from strands_transformers.types import audio as _audio
|
|
39
|
+
return getattr(_audio, name)
|
|
40
|
+
raise AttributeError(f"module 'strands_transformers' has no attribute '{name}'")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
__all__ = [
|
|
44
|
+
"use_transformers",
|
|
45
|
+
"TransformerModel",
|
|
46
|
+
"make_audio_block",
|
|
47
|
+
"extract_audio_payload",
|
|
48
|
+
"AudioContent",
|
|
49
|
+
"registry",
|
|
50
|
+
"engine",
|
|
51
|
+
"io",
|
|
52
|
+
]
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# file generated by vcs-versioning
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"__version__",
|
|
7
|
+
"__version_tuple__",
|
|
8
|
+
"version",
|
|
9
|
+
"version_tuple",
|
|
10
|
+
"__commit_id__",
|
|
11
|
+
"commit_id",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
version: str
|
|
15
|
+
__version__: str
|
|
16
|
+
__version_tuple__: tuple[int | str, ...]
|
|
17
|
+
version_tuple: tuple[int | str, ...]
|
|
18
|
+
commit_id: str | None
|
|
19
|
+
__commit_id__: str | None
|
|
20
|
+
|
|
21
|
+
__version__ = version = '0.2.0'
|
|
22
|
+
__version_tuple__ = version_tuple = (0, 2, 0)
|
|
23
|
+
|
|
24
|
+
__commit_id__ = commit_id = None
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
"""Backward-compat shims for older `trust_remote_code` models on new transformers.
|
|
2
|
+
|
|
3
|
+
Many published models (e.g. openvla/openvla-7b) ship custom code written against
|
|
4
|
+
transformers 4.x. On transformers 5.x some symbols moved and some Auto* classes
|
|
5
|
+
were renamed/removed. Rather than forcing users to pin an old transformers, we
|
|
6
|
+
patch the gaps at runtime so the model's own code loads unchanged.
|
|
7
|
+
|
|
8
|
+
Currently handled:
|
|
9
|
+
- `transformers.tokenization_utils.{PaddingStrategy,TruncationStrategy,
|
|
10
|
+
PreTokenizedInput,TextInput,...}` re-exported from `tokenization_utils_base`.
|
|
11
|
+
- `AutoModelForVision2Seq` recreated as an alias of `AutoModelForImageTextToText`
|
|
12
|
+
so old `auto_map` entries resolve (used by OpenVLA & friends).
|
|
13
|
+
|
|
14
|
+
`apply()` is idempotent and safe to call repeatedly.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import logging
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
_APPLIED = False
|
|
24
|
+
_VISION2SEQ_ALIAS = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def apply(force: bool = False) -> None:
|
|
28
|
+
"""Apply compat shims. Idempotent.
|
|
29
|
+
|
|
30
|
+
Some trust_remote_code models (OpenVLA) re-import transformers during load,
|
|
31
|
+
which can replace the cached aliases. Callers may pass force=True (or call
|
|
32
|
+
`ensure_alias()`) to re-assert the shims right before resolving a class.
|
|
33
|
+
"""
|
|
34
|
+
global _APPLIED
|
|
35
|
+
if _APPLIED and not force:
|
|
36
|
+
return
|
|
37
|
+
_patch_tokenization_utils()
|
|
38
|
+
_patch_vision2seq()
|
|
39
|
+
_patch_tie_weights()
|
|
40
|
+
_patch_broken_torchcodec()
|
|
41
|
+
_APPLIED = True
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def ensure_alias() -> None:
|
|
45
|
+
"""Cheaply re-assert the Vision2Seq alias on the live transformers module."""
|
|
46
|
+
_patch_vision2seq()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
_MOVED_SYMBOLS = (
|
|
50
|
+
"PaddingStrategy",
|
|
51
|
+
"TruncationStrategy",
|
|
52
|
+
"PreTokenizedInput",
|
|
53
|
+
"TextInput",
|
|
54
|
+
"TextInputPair",
|
|
55
|
+
"PreTokenizedInputPair",
|
|
56
|
+
"EncodedInput",
|
|
57
|
+
"EncodedInputPair",
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _patch_tokenization_utils() -> None:
|
|
62
|
+
"""Re-export symbols that moved to tokenization_utils_base in transformers 5.x.
|
|
63
|
+
|
|
64
|
+
In transformers 5.x `transformers.tokenization_utils` is a virtual alias
|
|
65
|
+
module with `__spec__ = None` / `__file__ = None`. The HuggingFace dynamic
|
|
66
|
+
module loader (used by trust_remote_code models like OpenVLA) executes
|
|
67
|
+
`from transformers.tokenization_utils import PaddingStrategy` through the
|
|
68
|
+
import machinery, which rejects a module with no location ("unknown
|
|
69
|
+
location") even when the attribute is set.
|
|
70
|
+
|
|
71
|
+
Fix: rebind `transformers.tokenization_utils` in sys.modules to a *real*
|
|
72
|
+
module object that has a proper spec/loader (we reuse the concrete
|
|
73
|
+
`tokenization_utils_sentencepiece` module file) and inject the moved
|
|
74
|
+
symbols onto it. This makes `from ... import ...` succeed.
|
|
75
|
+
"""
|
|
76
|
+
import sys
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
import transformers.tokenization_utils_base as tub
|
|
80
|
+
# the concrete file transformers aliases tokenization_utils to
|
|
81
|
+
import transformers.tokenization_utils_sentencepiece as concrete
|
|
82
|
+
except Exception as e: # pragma: no cover
|
|
83
|
+
logger.debug("tokenization_utils patch skipped: %s", e)
|
|
84
|
+
return
|
|
85
|
+
|
|
86
|
+
missing = [n for n in _MOVED_SYMBOLS if not hasattr(concrete, n) and hasattr(tub, n)]
|
|
87
|
+
# Inject the moved symbols onto the concrete (real, file-backed) module.
|
|
88
|
+
for name in _MOVED_SYMBOLS:
|
|
89
|
+
if not hasattr(concrete, name) and hasattr(tub, name):
|
|
90
|
+
setattr(concrete, name, getattr(tub, name))
|
|
91
|
+
|
|
92
|
+
# Point the virtual alias at the real module so import machinery has a
|
|
93
|
+
# valid location and finds the symbols.
|
|
94
|
+
current = sys.modules.get("transformers.tokenization_utils")
|
|
95
|
+
if current is None or getattr(current, "__spec__", None) is None:
|
|
96
|
+
sys.modules["transformers.tokenization_utils"] = concrete
|
|
97
|
+
import transformers
|
|
98
|
+
transformers.tokenization_utils = concrete
|
|
99
|
+
logger.debug("tokenization_utils compat: injected %s", missing)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _patch_vision2seq() -> None:
|
|
103
|
+
"""Recreate AutoModelForVision2Seq (removed in 5.x) as an ImageTextToText alias.
|
|
104
|
+
|
|
105
|
+
Custom-code models look up their `auto_map` key by the Auto class *name*, so a
|
|
106
|
+
same-named subclass is enough for `AutoModelForVision2Seq.from_pretrained(...,
|
|
107
|
+
trust_remote_code=True)` to find the remote class.
|
|
108
|
+
"""
|
|
109
|
+
import sys
|
|
110
|
+
|
|
111
|
+
import transformers
|
|
112
|
+
|
|
113
|
+
base = getattr(transformers, "AutoModelForImageTextToText", None)
|
|
114
|
+
if base is None:
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
global _VISION2SEQ_ALIAS
|
|
118
|
+
if _VISION2SEQ_ALIAS is None:
|
|
119
|
+
class AutoModelForVision2Seq(base): # type: ignore[misc, valid-type]
|
|
120
|
+
"""Compat alias of AutoModelForImageTextToText for legacy auto_map entries."""
|
|
121
|
+
|
|
122
|
+
_VISION2SEQ_ALIAS = AutoModelForVision2Seq
|
|
123
|
+
|
|
124
|
+
alias = _VISION2SEQ_ALIAS
|
|
125
|
+
# Assert the alias on every live handle to the transformers module. OpenVLA's
|
|
126
|
+
# remote code can re-import transformers and swap sys.modules, so set it on
|
|
127
|
+
# both the imported object and the current sys.modules entry.
|
|
128
|
+
for target in {transformers, sys.modules.get("transformers")}:
|
|
129
|
+
if target is not None and getattr(target, "AutoModelForVision2Seq", None) is not alias:
|
|
130
|
+
try:
|
|
131
|
+
target.AutoModelForVision2Seq = alias
|
|
132
|
+
except Exception:
|
|
133
|
+
pass
|
|
134
|
+
try:
|
|
135
|
+
import transformers.models.auto.modeling_auto as ma
|
|
136
|
+
if getattr(ma, "AutoModelForVision2Seq", None) is not alias:
|
|
137
|
+
ma.AutoModelForVision2Seq = alias
|
|
138
|
+
except Exception:
|
|
139
|
+
pass
|
|
140
|
+
# register_for_auto_class() validates against transformers.models.auto, so
|
|
141
|
+
# the alias must be visible there too (used by remote code during load).
|
|
142
|
+
try:
|
|
143
|
+
import transformers.models.auto as auto_module
|
|
144
|
+
if getattr(auto_module, "AutoModelForVision2Seq", None) is not alias:
|
|
145
|
+
auto_module.AutoModelForVision2Seq = alias
|
|
146
|
+
except Exception:
|
|
147
|
+
pass
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _patch_tie_weights() -> None:
|
|
151
|
+
"""Tolerate transformers 5.x calling tie_weights(missing_keys, recompute_mapping).
|
|
152
|
+
|
|
153
|
+
transformers 5.x invokes `model.tie_weights(missing_keys=..., recompute_mapping=...)`
|
|
154
|
+
during from_pretrained, but many 4.x-era custom models override `tie_weights(self)`
|
|
155
|
+
with no extra params, raising TypeError. We wrap PreTrainedModel.tie_weights so any
|
|
156
|
+
subclass override that rejects the new kwargs is retried without them.
|
|
157
|
+
"""
|
|
158
|
+
try:
|
|
159
|
+
from transformers.modeling_utils import PreTrainedModel
|
|
160
|
+
except Exception: # pragma: no cover
|
|
161
|
+
return
|
|
162
|
+
if getattr(PreTrainedModel, "_st_tie_weights_wrapped", False):
|
|
163
|
+
return
|
|
164
|
+
|
|
165
|
+
# Wrap init_weights (defined on the base class, not overridden by legacy
|
|
166
|
+
# models) so its internal `self.tie_weights(recompute_mapping=...)` call
|
|
167
|
+
# tolerates subclasses whose tie_weights() override rejects the new kwargs.
|
|
168
|
+
original_init = PreTrainedModel.init_weights
|
|
169
|
+
|
|
170
|
+
def init_weights(self):
|
|
171
|
+
cls = type(self)
|
|
172
|
+
tw = cls.tie_weights
|
|
173
|
+
if not getattr(cls, "_st_tw_wrapped", False) and tw is not PreTrainedModel.tie_weights:
|
|
174
|
+
def safe_tie_weights(self, *args, _orig=tw, **kwargs):
|
|
175
|
+
try:
|
|
176
|
+
return _orig(self, *args, **kwargs)
|
|
177
|
+
except TypeError:
|
|
178
|
+
return _orig(self)
|
|
179
|
+
cls.tie_weights = safe_tie_weights
|
|
180
|
+
cls._st_tw_wrapped = True
|
|
181
|
+
return original_init(self)
|
|
182
|
+
|
|
183
|
+
PreTrainedModel.init_weights = init_weights
|
|
184
|
+
PreTrainedModel._st_tie_weights_wrapped = True
|
|
185
|
+
|
|
186
|
+
def _patch_broken_torchcodec() -> None:
|
|
187
|
+
"""Disable torchcodec detection when the installed torchcodec is broken.
|
|
188
|
+
|
|
189
|
+
transformers' audio pipelines call `is_torchcodec_available()` and then do an
|
|
190
|
+
unconditional `import torchcodec`. If torchcodec is installed but its native
|
|
191
|
+
lib fails to load (common ffmpeg ABI mismatch), this crashes even for already
|
|
192
|
+
decoded array/dict inputs. We probe the actual import once; if it fails, we
|
|
193
|
+
override the availability checks to return False so pipelines fall back to the
|
|
194
|
+
ffmpeg/array path.
|
|
195
|
+
"""
|
|
196
|
+
try:
|
|
197
|
+
import torchcodec # noqa: F401
|
|
198
|
+
return # torchcodec works — nothing to do
|
|
199
|
+
except Exception:
|
|
200
|
+
pass
|
|
201
|
+
|
|
202
|
+
def _false() -> bool:
|
|
203
|
+
return False
|
|
204
|
+
|
|
205
|
+
patched = 0
|
|
206
|
+
for modname in (
|
|
207
|
+
"transformers.utils",
|
|
208
|
+
"transformers.utils.import_utils",
|
|
209
|
+
"transformers.pipelines.automatic_speech_recognition",
|
|
210
|
+
"transformers.pipelines.audio_classification",
|
|
211
|
+
):
|
|
212
|
+
try:
|
|
213
|
+
import importlib
|
|
214
|
+
mod = importlib.import_module(modname)
|
|
215
|
+
if hasattr(mod, "is_torchcodec_available"):
|
|
216
|
+
mod.is_torchcodec_available = _false
|
|
217
|
+
patched += 1
|
|
218
|
+
except Exception:
|
|
219
|
+
continue
|
|
220
|
+
if patched:
|
|
221
|
+
logger.debug("Disabled broken torchcodec in %d module(s)", patched)
|
|
222
|
+
|
|
223
|
+
def spoof_timm_version(version: str = "0.9.16"):
|
|
224
|
+
"""Temporarily spoof `timm.__version__` for models with hard version pins.
|
|
225
|
+
|
|
226
|
+
Some legacy models (e.g. OpenVLA) hard-assert an exact old timm version in
|
|
227
|
+
their remote code. Newer timm is usually API-compatible for inference. This
|
|
228
|
+
returns a context manager that swaps `timm.__version__` and restores it.
|
|
229
|
+
|
|
230
|
+
Usage:
|
|
231
|
+
with compat.spoof_timm_version():
|
|
232
|
+
model = AutoModel.from_pretrained(..., trust_remote_code=True)
|
|
233
|
+
"""
|
|
234
|
+
import contextlib
|
|
235
|
+
|
|
236
|
+
@contextlib.contextmanager
|
|
237
|
+
def _ctx():
|
|
238
|
+
try:
|
|
239
|
+
import timm
|
|
240
|
+
except ImportError:
|
|
241
|
+
yield
|
|
242
|
+
return
|
|
243
|
+
original = getattr(timm, "__version__", None)
|
|
244
|
+
try:
|
|
245
|
+
timm.__version__ = version
|
|
246
|
+
yield
|
|
247
|
+
finally:
|
|
248
|
+
if original is not None:
|
|
249
|
+
timm.__version__ = original
|
|
250
|
+
|
|
251
|
+
return _ctx()
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
"""Pipeline & model engine — load once, cache, run. Auto device/dtype.
|
|
2
|
+
|
|
3
|
+
Wraps transformers.pipeline() as the universal native-I/O runner, plus a generic
|
|
4
|
+
loader for raw AutoModel/AutoProcessor access when you need lower-level control
|
|
5
|
+
(e.g. robot VLA models that output action tensors).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
import os
|
|
12
|
+
from functools import lru_cache
|
|
13
|
+
from typing import Any, Dict, Optional
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
# session-scoped cache of loaded objects (pipelines, models, processors)
|
|
18
|
+
_CACHE: Dict[str, Any] = {}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def select_device(device: Optional[str] = None) -> str:
|
|
22
|
+
if device and device != "auto":
|
|
23
|
+
return device
|
|
24
|
+
try:
|
|
25
|
+
import torch
|
|
26
|
+
if torch.cuda.is_available():
|
|
27
|
+
return "cuda"
|
|
28
|
+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
29
|
+
return "mps"
|
|
30
|
+
except ImportError:
|
|
31
|
+
pass
|
|
32
|
+
return "cpu"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def select_dtype(device: str):
|
|
36
|
+
"""Pick a sensible default dtype for the device."""
|
|
37
|
+
try:
|
|
38
|
+
import torch
|
|
39
|
+
if device in ("cuda", "mps"):
|
|
40
|
+
return torch.bfloat16 if device == "cuda" else torch.float16
|
|
41
|
+
except ImportError:
|
|
42
|
+
pass
|
|
43
|
+
return None # let transformers decide (float32 on cpu)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_pipeline(task: str, model: Optional[str] = None,
|
|
47
|
+
device: Optional[str] = None, cache_key: Optional[str] = None,
|
|
48
|
+
**pipeline_kwargs: Any):
|
|
49
|
+
"""Build (or fetch cached) a transformers pipeline for a task.
|
|
50
|
+
|
|
51
|
+
pipeline() natively accepts paths/URLs/PIL/arrays as input and handles
|
|
52
|
+
tokenization, image processing, feature extraction, etc. automatically.
|
|
53
|
+
"""
|
|
54
|
+
from . import compat
|
|
55
|
+
compat.apply()
|
|
56
|
+
from transformers import pipeline
|
|
57
|
+
|
|
58
|
+
key = cache_key or f"pipe::{task}::{model or 'default'}"
|
|
59
|
+
if key in _CACHE:
|
|
60
|
+
return _CACHE[key], key
|
|
61
|
+
|
|
62
|
+
dev = select_device(device)
|
|
63
|
+
kwargs: Dict[str, Any] = {"task": task}
|
|
64
|
+
if model:
|
|
65
|
+
kwargs["model"] = model
|
|
66
|
+
# device handling: pipeline takes device int/str or device_map
|
|
67
|
+
if dev == "cuda":
|
|
68
|
+
kwargs["device"] = 0
|
|
69
|
+
elif dev == "mps":
|
|
70
|
+
kwargs["device"] = "mps"
|
|
71
|
+
else:
|
|
72
|
+
kwargs["device"] = -1
|
|
73
|
+
# Tasks whose post-processing produces images/dense maps need float32 — half
|
|
74
|
+
# precision (bf16/fp16) breaks PIL/numpy conversion ("unsupported ScalarType
|
|
75
|
+
# BFloat16"). Skip the half-precision default for those; callers can still
|
|
76
|
+
# override via pipeline_kwargs.
|
|
77
|
+
_FLOAT32_TASKS = {
|
|
78
|
+
"depth-estimation", "image-segmentation", "image-to-image",
|
|
79
|
+
"semantic-segmentation", "instance-segmentation", "mask-generation",
|
|
80
|
+
}
|
|
81
|
+
dtype = None if task in _FLOAT32_TASKS else select_dtype(dev)
|
|
82
|
+
if dtype is not None:
|
|
83
|
+
kwargs["dtype"] = dtype
|
|
84
|
+
kwargs.update(pipeline_kwargs)
|
|
85
|
+
|
|
86
|
+
logger.info("Loading pipeline task=%s model=%s device=%s", task, model, dev)
|
|
87
|
+
pipe = pipeline(**kwargs)
|
|
88
|
+
_CACHE[key] = pipe
|
|
89
|
+
return pipe, key
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def load_object(auto_class: str, model_path: str,
|
|
93
|
+
device: Optional[str] = None, cache_key: Optional[str] = None,
|
|
94
|
+
**from_pretrained_kwargs: Any):
|
|
95
|
+
"""Load any AutoModel*/AutoProcessor/AutoTokenizer via from_pretrained.
|
|
96
|
+
|
|
97
|
+
For lower-level control than pipelines — e.g. VLA / robot-action models where
|
|
98
|
+
you feed processor(images, text, state) and call model.generate / model(**).
|
|
99
|
+
"""
|
|
100
|
+
from . import compat, registry
|
|
101
|
+
|
|
102
|
+
compat.apply(force=True)
|
|
103
|
+
key = cache_key or f"obj::{auto_class}::{model_path}"
|
|
104
|
+
if key in _CACHE:
|
|
105
|
+
return _CACHE[key], key
|
|
106
|
+
|
|
107
|
+
cls = registry.resolve_attr(auto_class)
|
|
108
|
+
dev = select_device(device)
|
|
109
|
+
kwargs = dict(from_pretrained_kwargs)
|
|
110
|
+
# only models (not processors/tokenizers) take dtype/device_map
|
|
111
|
+
if auto_class.startswith("AutoModel") or auto_class.startswith("AutoBackbone"):
|
|
112
|
+
dtype = select_dtype(dev)
|
|
113
|
+
if dtype is not None and "dtype" not in kwargs and "torch_dtype" not in kwargs:
|
|
114
|
+
kwargs["dtype"] = dtype
|
|
115
|
+
if "trust_remote_code" not in kwargs:
|
|
116
|
+
kwargs["trust_remote_code"] = True
|
|
117
|
+
|
|
118
|
+
logger.info("Loading %s from %s on %s", auto_class, model_path, dev)
|
|
119
|
+
obj = cls.from_pretrained(model_path, **kwargs)
|
|
120
|
+
|
|
121
|
+
if hasattr(obj, "to") and (auto_class.startswith("AutoModel") or auto_class.startswith("AutoBackbone")):
|
|
122
|
+
try:
|
|
123
|
+
obj = obj.to(dev)
|
|
124
|
+
except Exception as e: # some models pinned via device_map
|
|
125
|
+
logger.debug("Could not .to(%s): %s", dev, e)
|
|
126
|
+
|
|
127
|
+
_CACHE[key] = obj
|
|
128
|
+
return obj, key
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def cache_list() -> Dict[str, str]:
|
|
132
|
+
return {k: type(v).__name__ for k, v in _CACHE.items()}
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def cache_clear(key: Optional[str] = None) -> int:
|
|
136
|
+
global _CACHE
|
|
137
|
+
if key:
|
|
138
|
+
if key in _CACHE:
|
|
139
|
+
del _CACHE[key]
|
|
140
|
+
_free_memory()
|
|
141
|
+
return 1
|
|
142
|
+
return 0
|
|
143
|
+
n = len(_CACHE)
|
|
144
|
+
_CACHE.clear()
|
|
145
|
+
_free_memory()
|
|
146
|
+
return n
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def cache_get(key: str) -> Optional[Any]:
|
|
150
|
+
return _CACHE.get(key)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _free_memory():
|
|
154
|
+
try:
|
|
155
|
+
import gc, torch
|
|
156
|
+
gc.collect()
|
|
157
|
+
if torch.cuda.is_available():
|
|
158
|
+
torch.cuda.empty_cache()
|
|
159
|
+
except Exception:
|
|
160
|
+
pass
|