strands-diffusers 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,41 @@
1
+ """Strands Diffusers — the universal entrypoint to HuggingFace diffusers.
2
+
3
+ 100% diffusers coverage with zero hardcoding: every pipeline (text→image, image→
4
+ video, video→world), model, and scheduler, the same way `use_aws` wraps boto3,
5
+ `use_lerobot` wraps lerobot, and `use_transformers` wraps the transformers task
6
+ taxonomy.
7
+
8
+ Special focus: Physical-AI world-foundation models (NVIDIA Cosmos) that emit not
9
+ just video but ROBOT ACTIONS. A single Cosmos3 action-policy run returns a
10
+ playable world video AND a normalized action chunk — both surfaced as artifacts.
11
+
12
+ Quick start:
13
+ from strands import Agent
14
+ from strands_diffusers import use_diffusers
15
+
16
+ agent = Agent(tools=[use_diffusers])
17
+ agent("Generate an image of a robot arm in a kitchen")
18
+ agent("Run a Cosmos action-policy rollout on this robot video and give me the actions")
19
+
20
+ Discovery (the agent never guesses):
21
+ use_diffusers(action="pipelines") # all 300+ pipelines + modality
22
+ use_diffusers(action="wfm") # world-foundation / action models
23
+ use_diffusers(action="modalities") # pipelines grouped by modality
24
+ use_diffusers(action="pipeline_info", target="Cosmos3OmniPipeline")
25
+ use_diffusers(action="inspect", target="StableDiffusionPipeline")
26
+ """
27
+
28
+ try:
29
+ from strands_diffusers._version import version as __version__
30
+ except ImportError: # not installed / no git metadata
31
+ __version__ = "0.1.0"
32
+
33
+ from strands_diffusers.core import engine, io, registry
34
+ from strands_diffusers.tools.use_diffusers import use_diffusers
35
+
36
+ __all__ = [
37
+ "use_diffusers",
38
+ "registry",
39
+ "engine",
40
+ "io",
41
+ ]
@@ -0,0 +1,24 @@
1
+ # file generated by vcs-versioning
2
+ # don't change, don't track in version control
3
+ from __future__ import annotations
4
+
5
+ __all__ = [
6
+ "__version__",
7
+ "__version_tuple__",
8
+ "version",
9
+ "version_tuple",
10
+ "__commit_id__",
11
+ "commit_id",
12
+ ]
13
+
14
+ version: str
15
+ __version__: str
16
+ __version_tuple__: tuple[int | str, ...]
17
+ version_tuple: tuple[int | str, ...]
18
+ commit_id: str | None
19
+ __commit_id__: str | None
20
+
21
+ __version__ = version = '0.1.0'
22
+ __version_tuple__ = version_tuple = (0, 1, 0)
23
+
24
+ __commit_id__ = commit_id = None
@@ -0,0 +1,4 @@
1
+ """Core engine, io, and registry for strands-diffusers."""
2
+ from strands_diffusers.core import engine, io, registry
3
+
4
+ __all__ = ["engine", "io", "registry"]
@@ -0,0 +1,163 @@
1
+ """Pipeline engine — load once, cache, run. Auto device/dtype.
2
+
3
+ Wraps diffusers' `DiffusionPipeline.from_pretrained` (and any specific pipeline
4
+ class) as the universal loader. Pipelines are cached per (class, model) so repeat
5
+ runs are cheap, and a generic `load_object` lets you reach any diffusers class
6
+ (schedulers, VAEs, transformers) for low-level control — the equivalent of the
7
+ `call` layer in use_transformers.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import logging
13
+ from typing import Any, Dict, Optional
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # session-scoped cache of loaded objects (pipelines, models, schedulers)
18
+ _CACHE: Dict[str, Any] = {}
19
+
20
+
21
+ def select_device(device: Optional[str] = None) -> str:
22
+ if device and device != "auto":
23
+ return device
24
+ try:
25
+ import torch
26
+ if torch.cuda.is_available():
27
+ return "cuda"
28
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
29
+ return "mps"
30
+ except ImportError:
31
+ pass
32
+ return "cpu"
33
+
34
+
35
+ def select_dtype(device: str):
36
+ """Pick a sensible default dtype for diffusion on the device."""
37
+ try:
38
+ import torch
39
+ if device == "cuda":
40
+ return torch.bfloat16
41
+ if device == "mps":
42
+ return torch.float16
43
+ except ImportError:
44
+ pass
45
+ return None # float32 on cpu
46
+
47
+
48
+ def get_pipeline(pipeline_class: str, model: str,
49
+ device: Optional[str] = None, cache_key: Optional[str] = None,
50
+ dtype: Optional[str] = None, move_to_device: bool = True,
51
+ **from_pretrained_kwargs: Any):
52
+ """Build (or fetch cached) a diffusers pipeline.
53
+
54
+ Args:
55
+ pipeline_class: diffusers pipeline class name, e.g. "StableDiffusionPipeline",
56
+ "Cosmos3OmniPipeline", or "DiffusionPipeline" for auto-detection.
57
+ model: HF repo id or local path.
58
+ device: "cuda" / "mps" / "cpu" / "auto".
59
+ cache_key: name to cache under (default derived from class+model).
60
+ dtype: explicit torch dtype name ("bfloat16","float16","float32") or None.
61
+ move_to_device: call .to(device) unless device_map was passed.
62
+ """
63
+ from . import registry
64
+
65
+ key = cache_key or f"pipe::{pipeline_class}::{model}"
66
+ if key in _CACHE:
67
+ return _CACHE[key], key
68
+
69
+ cls = registry.resolve_attr(pipeline_class)
70
+ dev = select_device(device)
71
+ kwargs = dict(from_pretrained_kwargs)
72
+
73
+ # dtype: explicit name > device default. diffusers accepts torch_dtype.
74
+ if "torch_dtype" not in kwargs and "dtype" not in kwargs:
75
+ td = _resolve_dtype(dtype) if dtype else select_dtype(dev)
76
+ if td is not None:
77
+ kwargs["torch_dtype"] = td
78
+
79
+ logger.info("Loading %s from %s on %s", pipeline_class, model, dev)
80
+ pipe = cls.from_pretrained(model, **kwargs)
81
+
82
+ # move to device unless from_pretrained already placed it (device_map)
83
+ if move_to_device and "device_map" not in kwargs and hasattr(pipe, "to"):
84
+ try:
85
+ pipe = pipe.to(dev)
86
+ except Exception as e:
87
+ logger.debug("Could not .to(%s): %s", dev, e)
88
+
89
+ _CACHE[key] = pipe
90
+ return pipe, key
91
+
92
+
93
+ def load_object(class_name: str, model_path: Optional[str] = None,
94
+ device: Optional[str] = None, cache_key: Optional[str] = None,
95
+ from_config: bool = False, **kwargs: Any):
96
+ """Load any diffusers class via from_pretrained / from_config.
97
+
98
+ For lower-level control than full pipelines — schedulers, VAEs, transformers,
99
+ e.g. swap a pipeline's scheduler:
100
+ load_object("UniPCMultistepScheduler", from_config=True, config=cached_cfg)
101
+ """
102
+ from . import registry
103
+
104
+ key = cache_key or f"obj::{class_name}::{model_path or 'cfg'}"
105
+ if key in _CACHE:
106
+ return _CACHE[key], key
107
+
108
+ cls = registry.resolve_attr(class_name)
109
+ if from_config:
110
+ obj = cls.from_config(**kwargs)
111
+ else:
112
+ dev = select_device(device)
113
+ if class_name.endswith(("Model", "Transformer")) or class_name.startswith("Autoencoder"):
114
+ if "torch_dtype" not in kwargs and "dtype" not in kwargs:
115
+ td = select_dtype(dev)
116
+ if td is not None:
117
+ kwargs["torch_dtype"] = td
118
+ obj = cls.from_pretrained(model_path, **kwargs)
119
+
120
+ _CACHE[key] = obj
121
+ return obj, key
122
+
123
+
124
+ def _resolve_dtype(name: str):
125
+ import torch
126
+ return {
127
+ "bfloat16": torch.bfloat16, "bf16": torch.bfloat16,
128
+ "float16": torch.float16, "fp16": torch.float16, "half": torch.float16,
129
+ "float32": torch.float32, "fp32": torch.float32, "float": torch.float32,
130
+ }.get(str(name).lower())
131
+
132
+
133
+ def cache_list() -> Dict[str, str]:
134
+ return {k: type(v).__name__ for k, v in _CACHE.items()}
135
+
136
+
137
+ def cache_clear(key: Optional[str] = None) -> int:
138
+ global _CACHE
139
+ if key:
140
+ if key in _CACHE:
141
+ del _CACHE[key]
142
+ _free_memory()
143
+ return 1
144
+ return 0
145
+ n = len(_CACHE)
146
+ _CACHE.clear()
147
+ _free_memory()
148
+ return n
149
+
150
+
151
+ def cache_get(key: str) -> Optional[Any]:
152
+ return _CACHE.get(key)
153
+
154
+
155
+ def _free_memory():
156
+ try:
157
+ import gc
158
+ import torch
159
+ gc.collect()
160
+ if torch.cuda.is_available():
161
+ torch.cuda.empty_cache()
162
+ except Exception:
163
+ pass