strands-transformers 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,52 @@
1
+ """Strands Transformers — the universal entrypoint to HuggingFace transformers.
2
+
3
+ 100% transformers coverage with zero hardcoding: every task across every modality
4
+ (text, image, video, audio, robot-state) in and out, the same way `use_aws` wraps
5
+ boto3 and `use_lerobot` wraps lerobot.
6
+
7
+ Quick start:
8
+ from strands import Agent
9
+ from strands_transformers import use_transformers
10
+
11
+ agent = Agent(tools=[use_transformers])
12
+ agent("Transcribe recording.wav") # ASR
13
+ agent("Describe scene.jpg and plan a grasp") # image-text-to-text / VLA
14
+ agent("Say 'hello' as audio") # text-to-audio
15
+
16
+ # Or use a local HF model as the agent's brain:
17
+ from strands_transformers import TransformerModel
18
+ model = TransformerModel(model_path="Qwen/Qwen3-1.7B")
19
+ agent = Agent(model=model, tools=[use_transformers])
20
+ """
21
+
22
+ try:
23
+ from strands_transformers._version import version as __version__
24
+ except Exception: # not built / no tags yet
25
+ __version__ = "0.0.0"
26
+
27
+ from strands_transformers.core import engine, io, registry
28
+ from strands_transformers.tools.use_transformers import use_transformers
29
+
30
+
31
+ def __getattr__(name):
32
+ # Lazy import the model provider (pulls in torch) only when requested.
33
+ if name == "TransformerModel":
34
+ from strands_transformers.models.transformers import TransformerModel
35
+ return TransformerModel
36
+ # Audio content-block helpers (our extension to the Strands taxonomy).
37
+ if name in ("make_audio_block", "extract_audio_payload", "AudioContent"):
38
+ from strands_transformers.types import audio as _audio
39
+ return getattr(_audio, name)
40
+ raise AttributeError(f"module 'strands_transformers' has no attribute '{name}'")
41
+
42
+
43
+ __all__ = [
44
+ "use_transformers",
45
+ "TransformerModel",
46
+ "make_audio_block",
47
+ "extract_audio_payload",
48
+ "AudioContent",
49
+ "registry",
50
+ "engine",
51
+ "io",
52
+ ]
@@ -0,0 +1,24 @@
1
+ # file generated by vcs-versioning
2
+ # don't change, don't track in version control
3
+ from __future__ import annotations
4
+
5
+ __all__ = [
6
+ "__version__",
7
+ "__version_tuple__",
8
+ "version",
9
+ "version_tuple",
10
+ "__commit_id__",
11
+ "commit_id",
12
+ ]
13
+
14
+ version: str
15
+ __version__: str
16
+ __version_tuple__: tuple[int | str, ...]
17
+ version_tuple: tuple[int | str, ...]
18
+ commit_id: str | None
19
+ __commit_id__: str | None
20
+
21
+ __version__ = version = '0.2.0'
22
+ __version_tuple__ = version_tuple = (0, 2, 0)
23
+
24
+ __commit_id__ = commit_id = None
@@ -0,0 +1,5 @@
1
+ """Core primitives: registry (task taxonomy), engine (load/cache/run), io (multimodal)."""
2
+
3
+ from . import engine, io, registry
4
+
5
+ __all__ = ["registry", "engine", "io"]
@@ -0,0 +1,251 @@
1
+ """Backward-compat shims for older `trust_remote_code` models on new transformers.
2
+
3
+ Many published models (e.g. openvla/openvla-7b) ship custom code written against
4
+ transformers 4.x. On transformers 5.x some symbols moved and some Auto* classes
5
+ were renamed/removed. Rather than forcing users to pin an old transformers, we
6
+ patch the gaps at runtime so the model's own code loads unchanged.
7
+
8
+ Currently handled:
9
+ - `transformers.tokenization_utils.{PaddingStrategy,TruncationStrategy,
10
+ PreTokenizedInput,TextInput,...}` re-exported from `tokenization_utils_base`.
11
+ - `AutoModelForVision2Seq` recreated as an alias of `AutoModelForImageTextToText`
12
+ so old `auto_map` entries resolve (used by OpenVLA & friends).
13
+
14
+ `apply()` is idempotent and safe to call repeatedly.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import logging
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ _APPLIED = False
24
+ _VISION2SEQ_ALIAS = None
25
+
26
+
27
+ def apply(force: bool = False) -> None:
28
+ """Apply compat shims. Idempotent.
29
+
30
+ Some trust_remote_code models (OpenVLA) re-import transformers during load,
31
+ which can replace the cached aliases. Callers may pass force=True (or call
32
+ `ensure_alias()`) to re-assert the shims right before resolving a class.
33
+ """
34
+ global _APPLIED
35
+ if _APPLIED and not force:
36
+ return
37
+ _patch_tokenization_utils()
38
+ _patch_vision2seq()
39
+ _patch_tie_weights()
40
+ _patch_broken_torchcodec()
41
+ _APPLIED = True
42
+
43
+
44
+ def ensure_alias() -> None:
45
+ """Cheaply re-assert the Vision2Seq alias on the live transformers module."""
46
+ _patch_vision2seq()
47
+
48
+
49
+ _MOVED_SYMBOLS = (
50
+ "PaddingStrategy",
51
+ "TruncationStrategy",
52
+ "PreTokenizedInput",
53
+ "TextInput",
54
+ "TextInputPair",
55
+ "PreTokenizedInputPair",
56
+ "EncodedInput",
57
+ "EncodedInputPair",
58
+ )
59
+
60
+
61
+ def _patch_tokenization_utils() -> None:
62
+ """Re-export symbols that moved to tokenization_utils_base in transformers 5.x.
63
+
64
+ In transformers 5.x `transformers.tokenization_utils` is a virtual alias
65
+ module with `__spec__ = None` / `__file__ = None`. The HuggingFace dynamic
66
+ module loader (used by trust_remote_code models like OpenVLA) executes
67
+ `from transformers.tokenization_utils import PaddingStrategy` through the
68
+ import machinery, which rejects a module with no location ("unknown
69
+ location") even when the attribute is set.
70
+
71
+ Fix: rebind `transformers.tokenization_utils` in sys.modules to a *real*
72
+ module object that has a proper spec/loader (we reuse the concrete
73
+ `tokenization_utils_sentencepiece` module file) and inject the moved
74
+ symbols onto it. This makes `from ... import ...` succeed.
75
+ """
76
+ import sys
77
+
78
+ try:
79
+ import transformers.tokenization_utils_base as tub
80
+ # the concrete file transformers aliases tokenization_utils to
81
+ import transformers.tokenization_utils_sentencepiece as concrete
82
+ except Exception as e: # pragma: no cover
83
+ logger.debug("tokenization_utils patch skipped: %s", e)
84
+ return
85
+
86
+ missing = [n for n in _MOVED_SYMBOLS if not hasattr(concrete, n) and hasattr(tub, n)]
87
+ # Inject the moved symbols onto the concrete (real, file-backed) module.
88
+ for name in _MOVED_SYMBOLS:
89
+ if not hasattr(concrete, name) and hasattr(tub, name):
90
+ setattr(concrete, name, getattr(tub, name))
91
+
92
+ # Point the virtual alias at the real module so import machinery has a
93
+ # valid location and finds the symbols.
94
+ current = sys.modules.get("transformers.tokenization_utils")
95
+ if current is None or getattr(current, "__spec__", None) is None:
96
+ sys.modules["transformers.tokenization_utils"] = concrete
97
+ import transformers
98
+ transformers.tokenization_utils = concrete
99
+ logger.debug("tokenization_utils compat: injected %s", missing)
100
+
101
+
102
+ def _patch_vision2seq() -> None:
103
+ """Recreate AutoModelForVision2Seq (removed in 5.x) as an ImageTextToText alias.
104
+
105
+ Custom-code models look up their `auto_map` key by the Auto class *name*, so a
106
+ same-named subclass is enough for `AutoModelForVision2Seq.from_pretrained(...,
107
+ trust_remote_code=True)` to find the remote class.
108
+ """
109
+ import sys
110
+
111
+ import transformers
112
+
113
+ base = getattr(transformers, "AutoModelForImageTextToText", None)
114
+ if base is None:
115
+ return
116
+
117
+ global _VISION2SEQ_ALIAS
118
+ if _VISION2SEQ_ALIAS is None:
119
+ class AutoModelForVision2Seq(base): # type: ignore[misc, valid-type]
120
+ """Compat alias of AutoModelForImageTextToText for legacy auto_map entries."""
121
+
122
+ _VISION2SEQ_ALIAS = AutoModelForVision2Seq
123
+
124
+ alias = _VISION2SEQ_ALIAS
125
+ # Assert the alias on every live handle to the transformers module. OpenVLA's
126
+ # remote code can re-import transformers and swap sys.modules, so set it on
127
+ # both the imported object and the current sys.modules entry.
128
+ for target in {transformers, sys.modules.get("transformers")}:
129
+ if target is not None and getattr(target, "AutoModelForVision2Seq", None) is not alias:
130
+ try:
131
+ target.AutoModelForVision2Seq = alias
132
+ except Exception:
133
+ pass
134
+ try:
135
+ import transformers.models.auto.modeling_auto as ma
136
+ if getattr(ma, "AutoModelForVision2Seq", None) is not alias:
137
+ ma.AutoModelForVision2Seq = alias
138
+ except Exception:
139
+ pass
140
+ # register_for_auto_class() validates against transformers.models.auto, so
141
+ # the alias must be visible there too (used by remote code during load).
142
+ try:
143
+ import transformers.models.auto as auto_module
144
+ if getattr(auto_module, "AutoModelForVision2Seq", None) is not alias:
145
+ auto_module.AutoModelForVision2Seq = alias
146
+ except Exception:
147
+ pass
148
+
149
+
150
+ def _patch_tie_weights() -> None:
151
+ """Tolerate transformers 5.x calling tie_weights(missing_keys, recompute_mapping).
152
+
153
+ transformers 5.x invokes `model.tie_weights(missing_keys=..., recompute_mapping=...)`
154
+ during from_pretrained, but many 4.x-era custom models override `tie_weights(self)`
155
+ with no extra params, raising TypeError. We wrap PreTrainedModel.tie_weights so any
156
+ subclass override that rejects the new kwargs is retried without them.
157
+ """
158
+ try:
159
+ from transformers.modeling_utils import PreTrainedModel
160
+ except Exception: # pragma: no cover
161
+ return
162
+ if getattr(PreTrainedModel, "_st_tie_weights_wrapped", False):
163
+ return
164
+
165
+ # Wrap init_weights (defined on the base class, not overridden by legacy
166
+ # models) so its internal `self.tie_weights(recompute_mapping=...)` call
167
+ # tolerates subclasses whose tie_weights() override rejects the new kwargs.
168
+ original_init = PreTrainedModel.init_weights
169
+
170
+ def init_weights(self):
171
+ cls = type(self)
172
+ tw = cls.tie_weights
173
+ if not getattr(cls, "_st_tw_wrapped", False) and tw is not PreTrainedModel.tie_weights:
174
+ def safe_tie_weights(self, *args, _orig=tw, **kwargs):
175
+ try:
176
+ return _orig(self, *args, **kwargs)
177
+ except TypeError:
178
+ return _orig(self)
179
+ cls.tie_weights = safe_tie_weights
180
+ cls._st_tw_wrapped = True
181
+ return original_init(self)
182
+
183
+ PreTrainedModel.init_weights = init_weights
184
+ PreTrainedModel._st_tie_weights_wrapped = True
185
+
186
+ def _patch_broken_torchcodec() -> None:
187
+ """Disable torchcodec detection when the installed torchcodec is broken.
188
+
189
+ transformers' audio pipelines call `is_torchcodec_available()` and then do an
190
+ unconditional `import torchcodec`. If torchcodec is installed but its native
191
+ lib fails to load (common ffmpeg ABI mismatch), this crashes even for already
192
+ decoded array/dict inputs. We probe the actual import once; if it fails, we
193
+ override the availability checks to return False so pipelines fall back to the
194
+ ffmpeg/array path.
195
+ """
196
+ try:
197
+ import torchcodec # noqa: F401
198
+ return # torchcodec works — nothing to do
199
+ except Exception:
200
+ pass
201
+
202
+ def _false() -> bool:
203
+ return False
204
+
205
+ patched = 0
206
+ for modname in (
207
+ "transformers.utils",
208
+ "transformers.utils.import_utils",
209
+ "transformers.pipelines.automatic_speech_recognition",
210
+ "transformers.pipelines.audio_classification",
211
+ ):
212
+ try:
213
+ import importlib
214
+ mod = importlib.import_module(modname)
215
+ if hasattr(mod, "is_torchcodec_available"):
216
+ mod.is_torchcodec_available = _false
217
+ patched += 1
218
+ except Exception:
219
+ continue
220
+ if patched:
221
+ logger.debug("Disabled broken torchcodec in %d module(s)", patched)
222
+
223
+ def spoof_timm_version(version: str = "0.9.16"):
224
+ """Temporarily spoof `timm.__version__` for models with hard version pins.
225
+
226
+ Some legacy models (e.g. OpenVLA) hard-assert an exact old timm version in
227
+ their remote code. Newer timm is usually API-compatible for inference. This
228
+ returns a context manager that swaps `timm.__version__` and restores it.
229
+
230
+ Usage:
231
+ with compat.spoof_timm_version():
232
+ model = AutoModel.from_pretrained(..., trust_remote_code=True)
233
+ """
234
+ import contextlib
235
+
236
+ @contextlib.contextmanager
237
+ def _ctx():
238
+ try:
239
+ import timm
240
+ except ImportError:
241
+ yield
242
+ return
243
+ original = getattr(timm, "__version__", None)
244
+ try:
245
+ timm.__version__ = version
246
+ yield
247
+ finally:
248
+ if original is not None:
249
+ timm.__version__ = original
250
+
251
+ return _ctx()
@@ -0,0 +1,160 @@
1
+ """Pipeline & model engine — load once, cache, run. Auto device/dtype.
2
+
3
+ Wraps transformers.pipeline() as the universal native-I/O runner, plus a generic
4
+ loader for raw AutoModel/AutoProcessor access when you need lower-level control
5
+ (e.g. robot VLA models that output action tensors).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ import os
12
+ from functools import lru_cache
13
+ from typing import Any, Dict, Optional
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # session-scoped cache of loaded objects (pipelines, models, processors)
18
+ _CACHE: Dict[str, Any] = {}
19
+
20
+
21
+ def select_device(device: Optional[str] = None) -> str:
22
+ if device and device != "auto":
23
+ return device
24
+ try:
25
+ import torch
26
+ if torch.cuda.is_available():
27
+ return "cuda"
28
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
29
+ return "mps"
30
+ except ImportError:
31
+ pass
32
+ return "cpu"
33
+
34
+
35
+ def select_dtype(device: str):
36
+ """Pick a sensible default dtype for the device."""
37
+ try:
38
+ import torch
39
+ if device in ("cuda", "mps"):
40
+ return torch.bfloat16 if device == "cuda" else torch.float16
41
+ except ImportError:
42
+ pass
43
+ return None # let transformers decide (float32 on cpu)
44
+
45
+
46
+ def get_pipeline(task: str, model: Optional[str] = None,
47
+ device: Optional[str] = None, cache_key: Optional[str] = None,
48
+ **pipeline_kwargs: Any):
49
+ """Build (or fetch cached) a transformers pipeline for a task.
50
+
51
+ pipeline() natively accepts paths/URLs/PIL/arrays as input and handles
52
+ tokenization, image processing, feature extraction, etc. automatically.
53
+ """
54
+ from . import compat
55
+ compat.apply()
56
+ from transformers import pipeline
57
+
58
+ key = cache_key or f"pipe::{task}::{model or 'default'}"
59
+ if key in _CACHE:
60
+ return _CACHE[key], key
61
+
62
+ dev = select_device(device)
63
+ kwargs: Dict[str, Any] = {"task": task}
64
+ if model:
65
+ kwargs["model"] = model
66
+ # device handling: pipeline takes device int/str or device_map
67
+ if dev == "cuda":
68
+ kwargs["device"] = 0
69
+ elif dev == "mps":
70
+ kwargs["device"] = "mps"
71
+ else:
72
+ kwargs["device"] = -1
73
+ # Tasks whose post-processing produces images/dense maps need float32 — half
74
+ # precision (bf16/fp16) breaks PIL/numpy conversion ("unsupported ScalarType
75
+ # BFloat16"). Skip the half-precision default for those; callers can still
76
+ # override via pipeline_kwargs.
77
+ _FLOAT32_TASKS = {
78
+ "depth-estimation", "image-segmentation", "image-to-image",
79
+ "semantic-segmentation", "instance-segmentation", "mask-generation",
80
+ }
81
+ dtype = None if task in _FLOAT32_TASKS else select_dtype(dev)
82
+ if dtype is not None:
83
+ kwargs["dtype"] = dtype
84
+ kwargs.update(pipeline_kwargs)
85
+
86
+ logger.info("Loading pipeline task=%s model=%s device=%s", task, model, dev)
87
+ pipe = pipeline(**kwargs)
88
+ _CACHE[key] = pipe
89
+ return pipe, key
90
+
91
+
92
+ def load_object(auto_class: str, model_path: str,
93
+ device: Optional[str] = None, cache_key: Optional[str] = None,
94
+ **from_pretrained_kwargs: Any):
95
+ """Load any AutoModel*/AutoProcessor/AutoTokenizer via from_pretrained.
96
+
97
+ For lower-level control than pipelines — e.g. VLA / robot-action models where
98
+ you feed processor(images, text, state) and call model.generate / model(**).
99
+ """
100
+ from . import compat, registry
101
+
102
+ compat.apply(force=True)
103
+ key = cache_key or f"obj::{auto_class}::{model_path}"
104
+ if key in _CACHE:
105
+ return _CACHE[key], key
106
+
107
+ cls = registry.resolve_attr(auto_class)
108
+ dev = select_device(device)
109
+ kwargs = dict(from_pretrained_kwargs)
110
+ # only models (not processors/tokenizers) take dtype/device_map
111
+ if auto_class.startswith("AutoModel") or auto_class.startswith("AutoBackbone"):
112
+ dtype = select_dtype(dev)
113
+ if dtype is not None and "dtype" not in kwargs and "torch_dtype" not in kwargs:
114
+ kwargs["dtype"] = dtype
115
+ if "trust_remote_code" not in kwargs:
116
+ kwargs["trust_remote_code"] = True
117
+
118
+ logger.info("Loading %s from %s on %s", auto_class, model_path, dev)
119
+ obj = cls.from_pretrained(model_path, **kwargs)
120
+
121
+ if hasattr(obj, "to") and (auto_class.startswith("AutoModel") or auto_class.startswith("AutoBackbone")):
122
+ try:
123
+ obj = obj.to(dev)
124
+ except Exception as e: # some models pinned via device_map
125
+ logger.debug("Could not .to(%s): %s", dev, e)
126
+
127
+ _CACHE[key] = obj
128
+ return obj, key
129
+
130
+
131
+ def cache_list() -> Dict[str, str]:
132
+ return {k: type(v).__name__ for k, v in _CACHE.items()}
133
+
134
+
135
+ def cache_clear(key: Optional[str] = None) -> int:
136
+ global _CACHE
137
+ if key:
138
+ if key in _CACHE:
139
+ del _CACHE[key]
140
+ _free_memory()
141
+ return 1
142
+ return 0
143
+ n = len(_CACHE)
144
+ _CACHE.clear()
145
+ _free_memory()
146
+ return n
147
+
148
+
149
+ def cache_get(key: str) -> Optional[Any]:
150
+ return _CACHE.get(key)
151
+
152
+
153
+ def _free_memory():
154
+ try:
155
+ import gc, torch
156
+ gc.collect()
157
+ if torch.cuda.is_available():
158
+ torch.cuda.empty_cache()
159
+ except Exception:
160
+ pass