strands-diffusers 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- strands_diffusers/__init__.py +41 -0
- strands_diffusers/_version.py +24 -0
- strands_diffusers/core/__init__.py +4 -0
- strands_diffusers/core/engine.py +163 -0
- strands_diffusers/core/io.py +552 -0
- strands_diffusers/core/registry.py +349 -0
- strands_diffusers/core/viz.py +256 -0
- strands_diffusers/tools/__init__.py +4 -0
- strands_diffusers/tools/use_diffusers.py +420 -0
- strands_diffusers-0.1.0.dist-info/METADATA +199 -0
- strands_diffusers-0.1.0.dist-info/RECORD +13 -0
- strands_diffusers-0.1.0.dist-info/WHEEL +5 -0
- strands_diffusers-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Strands Diffusers — the universal entrypoint to HuggingFace diffusers.
|
|
2
|
+
|
|
3
|
+
100% diffusers coverage with zero hardcoding: every pipeline (text→image, image→
|
|
4
|
+
video, video→world), model, and scheduler, the same way `use_aws` wraps boto3,
|
|
5
|
+
`use_lerobot` wraps lerobot, and `use_transformers` wraps the transformers task
|
|
6
|
+
taxonomy.
|
|
7
|
+
|
|
8
|
+
Special focus: Physical-AI world-foundation models (NVIDIA Cosmos) that emit not
|
|
9
|
+
just video but ROBOT ACTIONS. A single Cosmos3 action-policy run returns a
|
|
10
|
+
playable world video AND a normalized action chunk — both surfaced as artifacts.
|
|
11
|
+
|
|
12
|
+
Quick start:
|
|
13
|
+
from strands import Agent
|
|
14
|
+
from strands_diffusers import use_diffusers
|
|
15
|
+
|
|
16
|
+
agent = Agent(tools=[use_diffusers])
|
|
17
|
+
agent("Generate an image of a robot arm in a kitchen")
|
|
18
|
+
agent("Run a Cosmos action-policy rollout on this robot video and give me the actions")
|
|
19
|
+
|
|
20
|
+
Discovery (the agent never guesses):
|
|
21
|
+
use_diffusers(action="pipelines") # all 300+ pipelines + modality
|
|
22
|
+
use_diffusers(action="wfm") # world-foundation / action models
|
|
23
|
+
use_diffusers(action="modalities") # pipelines grouped by modality
|
|
24
|
+
use_diffusers(action="pipeline_info", target="Cosmos3OmniPipeline")
|
|
25
|
+
use_diffusers(action="inspect", target="StableDiffusionPipeline")
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
from strands_diffusers._version import version as __version__
|
|
30
|
+
except ImportError: # not installed / no git metadata
|
|
31
|
+
__version__ = "0.1.0"
|
|
32
|
+
|
|
33
|
+
from strands_diffusers.core import engine, io, registry
|
|
34
|
+
from strands_diffusers.tools.use_diffusers import use_diffusers
|
|
35
|
+
|
|
36
|
+
__all__ = [
|
|
37
|
+
"use_diffusers",
|
|
38
|
+
"registry",
|
|
39
|
+
"engine",
|
|
40
|
+
"io",
|
|
41
|
+
]
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# file generated by vcs-versioning
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"__version__",
|
|
7
|
+
"__version_tuple__",
|
|
8
|
+
"version",
|
|
9
|
+
"version_tuple",
|
|
10
|
+
"__commit_id__",
|
|
11
|
+
"commit_id",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
version: str
|
|
15
|
+
__version__: str
|
|
16
|
+
__version_tuple__: tuple[int | str, ...]
|
|
17
|
+
version_tuple: tuple[int | str, ...]
|
|
18
|
+
commit_id: str | None
|
|
19
|
+
__commit_id__: str | None
|
|
20
|
+
|
|
21
|
+
__version__ = version = '0.1.0'
|
|
22
|
+
__version_tuple__ = version_tuple = (0, 1, 0)
|
|
23
|
+
|
|
24
|
+
__commit_id__ = commit_id = None
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""Pipeline engine — load once, cache, run. Auto device/dtype.
|
|
2
|
+
|
|
3
|
+
Wraps diffusers' `DiffusionPipeline.from_pretrained` (and any specific pipeline
|
|
4
|
+
class) as the universal loader. Pipelines are cached per (class, model) so repeat
|
|
5
|
+
runs are cheap, and a generic `load_object` lets you reach any diffusers class
|
|
6
|
+
(schedulers, VAEs, transformers) for low-level control — the equivalent of the
|
|
7
|
+
`call` layer in use_transformers.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import logging
|
|
13
|
+
from typing import Any, Dict, Optional
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
# session-scoped cache of loaded objects (pipelines, models, schedulers)
|
|
18
|
+
_CACHE: Dict[str, Any] = {}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def select_device(device: Optional[str] = None) -> str:
|
|
22
|
+
if device and device != "auto":
|
|
23
|
+
return device
|
|
24
|
+
try:
|
|
25
|
+
import torch
|
|
26
|
+
if torch.cuda.is_available():
|
|
27
|
+
return "cuda"
|
|
28
|
+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
29
|
+
return "mps"
|
|
30
|
+
except ImportError:
|
|
31
|
+
pass
|
|
32
|
+
return "cpu"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def select_dtype(device: str):
|
|
36
|
+
"""Pick a sensible default dtype for diffusion on the device."""
|
|
37
|
+
try:
|
|
38
|
+
import torch
|
|
39
|
+
if device == "cuda":
|
|
40
|
+
return torch.bfloat16
|
|
41
|
+
if device == "mps":
|
|
42
|
+
return torch.float16
|
|
43
|
+
except ImportError:
|
|
44
|
+
pass
|
|
45
|
+
return None # float32 on cpu
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_pipeline(pipeline_class: str, model: str,
|
|
49
|
+
device: Optional[str] = None, cache_key: Optional[str] = None,
|
|
50
|
+
dtype: Optional[str] = None, move_to_device: bool = True,
|
|
51
|
+
**from_pretrained_kwargs: Any):
|
|
52
|
+
"""Build (or fetch cached) a diffusers pipeline.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
pipeline_class: diffusers pipeline class name, e.g. "StableDiffusionPipeline",
|
|
56
|
+
"Cosmos3OmniPipeline", or "DiffusionPipeline" for auto-detection.
|
|
57
|
+
model: HF repo id or local path.
|
|
58
|
+
device: "cuda" / "mps" / "cpu" / "auto".
|
|
59
|
+
cache_key: name to cache under (default derived from class+model).
|
|
60
|
+
dtype: explicit torch dtype name ("bfloat16","float16","float32") or None.
|
|
61
|
+
move_to_device: call .to(device) unless device_map was passed.
|
|
62
|
+
"""
|
|
63
|
+
from . import registry
|
|
64
|
+
|
|
65
|
+
key = cache_key or f"pipe::{pipeline_class}::{model}"
|
|
66
|
+
if key in _CACHE:
|
|
67
|
+
return _CACHE[key], key
|
|
68
|
+
|
|
69
|
+
cls = registry.resolve_attr(pipeline_class)
|
|
70
|
+
dev = select_device(device)
|
|
71
|
+
kwargs = dict(from_pretrained_kwargs)
|
|
72
|
+
|
|
73
|
+
# dtype: explicit name > device default. diffusers accepts torch_dtype.
|
|
74
|
+
if "torch_dtype" not in kwargs and "dtype" not in kwargs:
|
|
75
|
+
td = _resolve_dtype(dtype) if dtype else select_dtype(dev)
|
|
76
|
+
if td is not None:
|
|
77
|
+
kwargs["torch_dtype"] = td
|
|
78
|
+
|
|
79
|
+
logger.info("Loading %s from %s on %s", pipeline_class, model, dev)
|
|
80
|
+
pipe = cls.from_pretrained(model, **kwargs)
|
|
81
|
+
|
|
82
|
+
# move to device unless from_pretrained already placed it (device_map)
|
|
83
|
+
if move_to_device and "device_map" not in kwargs and hasattr(pipe, "to"):
|
|
84
|
+
try:
|
|
85
|
+
pipe = pipe.to(dev)
|
|
86
|
+
except Exception as e:
|
|
87
|
+
logger.debug("Could not .to(%s): %s", dev, e)
|
|
88
|
+
|
|
89
|
+
_CACHE[key] = pipe
|
|
90
|
+
return pipe, key
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def load_object(class_name: str, model_path: Optional[str] = None,
|
|
94
|
+
device: Optional[str] = None, cache_key: Optional[str] = None,
|
|
95
|
+
from_config: bool = False, **kwargs: Any):
|
|
96
|
+
"""Load any diffusers class via from_pretrained / from_config.
|
|
97
|
+
|
|
98
|
+
For lower-level control than full pipelines — schedulers, VAEs, transformers,
|
|
99
|
+
e.g. swap a pipeline's scheduler:
|
|
100
|
+
load_object("UniPCMultistepScheduler", from_config=True, config=cached_cfg)
|
|
101
|
+
"""
|
|
102
|
+
from . import registry
|
|
103
|
+
|
|
104
|
+
key = cache_key or f"obj::{class_name}::{model_path or 'cfg'}"
|
|
105
|
+
if key in _CACHE:
|
|
106
|
+
return _CACHE[key], key
|
|
107
|
+
|
|
108
|
+
cls = registry.resolve_attr(class_name)
|
|
109
|
+
if from_config:
|
|
110
|
+
obj = cls.from_config(**kwargs)
|
|
111
|
+
else:
|
|
112
|
+
dev = select_device(device)
|
|
113
|
+
if class_name.endswith(("Model", "Transformer")) or class_name.startswith("Autoencoder"):
|
|
114
|
+
if "torch_dtype" not in kwargs and "dtype" not in kwargs:
|
|
115
|
+
td = select_dtype(dev)
|
|
116
|
+
if td is not None:
|
|
117
|
+
kwargs["torch_dtype"] = td
|
|
118
|
+
obj = cls.from_pretrained(model_path, **kwargs)
|
|
119
|
+
|
|
120
|
+
_CACHE[key] = obj
|
|
121
|
+
return obj, key
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _resolve_dtype(name: str):
|
|
125
|
+
import torch
|
|
126
|
+
return {
|
|
127
|
+
"bfloat16": torch.bfloat16, "bf16": torch.bfloat16,
|
|
128
|
+
"float16": torch.float16, "fp16": torch.float16, "half": torch.float16,
|
|
129
|
+
"float32": torch.float32, "fp32": torch.float32, "float": torch.float32,
|
|
130
|
+
}.get(str(name).lower())
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def cache_list() -> Dict[str, str]:
|
|
134
|
+
return {k: type(v).__name__ for k, v in _CACHE.items()}
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def cache_clear(key: Optional[str] = None) -> int:
|
|
138
|
+
global _CACHE
|
|
139
|
+
if key:
|
|
140
|
+
if key in _CACHE:
|
|
141
|
+
del _CACHE[key]
|
|
142
|
+
_free_memory()
|
|
143
|
+
return 1
|
|
144
|
+
return 0
|
|
145
|
+
n = len(_CACHE)
|
|
146
|
+
_CACHE.clear()
|
|
147
|
+
_free_memory()
|
|
148
|
+
return n
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def cache_get(key: str) -> Optional[Any]:
|
|
152
|
+
return _CACHE.get(key)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _free_memory():
|
|
156
|
+
try:
|
|
157
|
+
import gc
|
|
158
|
+
import torch
|
|
159
|
+
gc.collect()
|
|
160
|
+
if torch.cuda.is_available():
|
|
161
|
+
torch.cuda.empty_cache()
|
|
162
|
+
except Exception:
|
|
163
|
+
pass
|