abstractvision 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- abstractvision/__init__.py +18 -3
- abstractvision/__main__.py +8 -0
- abstractvision/artifacts.py +320 -0
- abstractvision/assets/vision_model_capabilities.json +406 -0
- abstractvision/backends/__init__.py +43 -0
- abstractvision/backends/base_backend.py +63 -0
- abstractvision/backends/huggingface_diffusers.py +1503 -0
- abstractvision/backends/openai_compatible.py +325 -0
- abstractvision/backends/stable_diffusion_cpp.py +751 -0
- abstractvision/cli.py +778 -0
- abstractvision/errors.py +19 -0
- abstractvision/integrations/__init__.py +5 -0
- abstractvision/integrations/abstractcore.py +263 -0
- abstractvision/integrations/abstractcore_plugin.py +193 -0
- abstractvision/model_capabilities.py +255 -0
- abstractvision/types.py +95 -0
- abstractvision/vision_manager.py +115 -0
- abstractvision-0.2.1.dist-info/METADATA +243 -0
- abstractvision-0.2.1.dist-info/RECORD +23 -0
- {abstractvision-0.1.0.dist-info → abstractvision-0.2.1.dist-info}/WHEEL +1 -1
- abstractvision-0.2.1.dist-info/entry_points.txt +5 -0
- abstractvision-0.1.0.dist-info/METADATA +0 -65
- abstractvision-0.1.0.dist-info/RECORD +0 -6
- {abstractvision-0.1.0.dist-info → abstractvision-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {abstractvision-0.1.0.dist-info → abstractvision-0.2.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1503 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import io
|
|
4
|
+
import inspect
|
|
5
|
+
import os
|
|
6
|
+
import hashlib
|
|
7
|
+
from contextlib import contextmanager
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
11
|
+
|
|
12
|
+
from ..errors import CapabilityNotSupportedError, OptionalDependencyMissingError
|
|
13
|
+
from ..types import (
|
|
14
|
+
GeneratedAsset,
|
|
15
|
+
ImageEditRequest,
|
|
16
|
+
ImageGenerationRequest,
|
|
17
|
+
ImageToVideoRequest,
|
|
18
|
+
MultiAngleRequest,
|
|
19
|
+
VideoGenerationRequest,
|
|
20
|
+
VisionBackendCapabilities,
|
|
21
|
+
)
|
|
22
|
+
from .base_backend import VisionBackend
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _require_optional_dep(name: str, install_hint: str) -> None:
|
|
26
|
+
import sys
|
|
27
|
+
|
|
28
|
+
raise OptionalDependencyMissingError(
|
|
29
|
+
f"Optional dependency missing: {name}. Install via: {install_hint} " f"(python={sys.executable})"
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _lazy_import_diffusers():
|
|
34
|
+
try:
|
|
35
|
+
import warnings
|
|
36
|
+
|
|
37
|
+
# Some Diffusers modules decorate functions with `torch.autocast(device_type="cuda", ...)`,
|
|
38
|
+
# which emits noisy warnings on non-CUDA machines (including Apple Silicon / MPS).
|
|
39
|
+
warnings.filterwarnings("ignore", message=r".*CUDA is not available.*Disabling autocast.*", category=UserWarning)
|
|
40
|
+
warnings.filterwarnings(
|
|
41
|
+
"ignore",
|
|
42
|
+
message=r".*device_type of 'cuda'.*CUDA is not available.*",
|
|
43
|
+
category=UserWarning,
|
|
44
|
+
)
|
|
45
|
+
import diffusers # type: ignore
|
|
46
|
+
from diffusers import DiffusionPipeline # type: ignore
|
|
47
|
+
except Exception as e: # pragma: no cover
|
|
48
|
+
raise OptionalDependencyMissingError(
|
|
49
|
+
"Optional dependency missing (or failed to import): diffusers. Install via: pip install 'diffusers'. "
|
|
50
|
+
f"(python={__import__('sys').executable})"
|
|
51
|
+
) from e
|
|
52
|
+
|
|
53
|
+
# AutoPipeline classes are optional here. Some environments may have diffusers installed but fail to import
|
|
54
|
+
# AutoPipeline due to version mismatches with transformers/torch or other optional deps. We can still load
|
|
55
|
+
# many text-to-image models via `DiffusionPipeline` and only require AutoPipeline for i2i/inpaint.
|
|
56
|
+
AutoPipelineForText2Image = None
|
|
57
|
+
AutoPipelineForImage2Image = None
|
|
58
|
+
AutoPipelineForInpainting = None
|
|
59
|
+
try:
|
|
60
|
+
from diffusers import AutoPipelineForText2Image as _AutoPipelineForText2Image # type: ignore
|
|
61
|
+
|
|
62
|
+
AutoPipelineForText2Image = _AutoPipelineForText2Image
|
|
63
|
+
except Exception:
|
|
64
|
+
pass
|
|
65
|
+
try:
|
|
66
|
+
from diffusers import AutoPipelineForImage2Image as _AutoPipelineForImage2Image # type: ignore
|
|
67
|
+
|
|
68
|
+
AutoPipelineForImage2Image = _AutoPipelineForImage2Image
|
|
69
|
+
except Exception:
|
|
70
|
+
pass
|
|
71
|
+
try:
|
|
72
|
+
from diffusers import AutoPipelineForInpainting as _AutoPipelineForInpainting # type: ignore
|
|
73
|
+
|
|
74
|
+
AutoPipelineForInpainting = _AutoPipelineForInpainting
|
|
75
|
+
except Exception:
|
|
76
|
+
pass
|
|
77
|
+
|
|
78
|
+
return (
|
|
79
|
+
DiffusionPipeline,
|
|
80
|
+
AutoPipelineForText2Image,
|
|
81
|
+
AutoPipelineForImage2Image,
|
|
82
|
+
AutoPipelineForInpainting,
|
|
83
|
+
getattr(diffusers, "__version__", "unknown"),
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _lazy_import_torch():
|
|
88
|
+
try:
|
|
89
|
+
import torch # type: ignore
|
|
90
|
+
except Exception: # pragma: no cover
|
|
91
|
+
_require_optional_dep("torch", "pip install 'torch'")
|
|
92
|
+
return torch
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _lazy_import_pil():
|
|
96
|
+
try:
|
|
97
|
+
from PIL import Image # type: ignore
|
|
98
|
+
except Exception: # pragma: no cover
|
|
99
|
+
_require_optional_dep("pillow", "pip install 'pillow'")
|
|
100
|
+
return Image
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
_TRANSFORMERS_CLIP_POSITION_IDS_PATCHED = False
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _maybe_patch_transformers_clip_position_ids() -> None:
|
|
107
|
+
"""Fix Transformers v5 noisy LOAD REPORTs for common CLIP checkpoints.
|
|
108
|
+
|
|
109
|
+
Transformers 5 logs a detailed load report when encountering unexpected keys like
|
|
110
|
+
`*.embeddings.position_ids` in older CLIP checkpoints (e.g. SD1.5 text encoder / safety checker).
|
|
111
|
+
|
|
112
|
+
The root cause is a small architecture/state-dict mismatch: those checkpoints include a persistent
|
|
113
|
+
`position_ids` buffer, while newer CLIP embedding classes may not. We re-add that buffer so the
|
|
114
|
+
checkpoint matches the instantiated model and no "UNEXPECTED" keys are reported.
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
global _TRANSFORMERS_CLIP_POSITION_IDS_PATCHED
|
|
118
|
+
if _TRANSFORMERS_CLIP_POSITION_IDS_PATCHED:
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
import transformers # type: ignore
|
|
123
|
+
import torch as _torch # type: ignore
|
|
124
|
+
except Exception:
|
|
125
|
+
return
|
|
126
|
+
|
|
127
|
+
ver = str(getattr(transformers, "__version__", "0"))
|
|
128
|
+
try:
|
|
129
|
+
major = int(ver.split(".", 1)[0])
|
|
130
|
+
except Exception:
|
|
131
|
+
major = 0
|
|
132
|
+
if major < 5:
|
|
133
|
+
_TRANSFORMERS_CLIP_POSITION_IDS_PATCHED = True
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
try:
|
|
137
|
+
from transformers.models.clip.modeling_clip import CLIPTextEmbeddings, CLIPVisionEmbeddings # type: ignore
|
|
138
|
+
except Exception:
|
|
139
|
+
_TRANSFORMERS_CLIP_POSITION_IDS_PATCHED = True
|
|
140
|
+
return
|
|
141
|
+
|
|
142
|
+
def _patch(cls: Any) -> None:
|
|
143
|
+
if bool(getattr(cls, "_abstractvision_position_ids_patched", False)):
|
|
144
|
+
return
|
|
145
|
+
orig_init = getattr(cls, "__init__", None)
|
|
146
|
+
if not callable(orig_init):
|
|
147
|
+
return
|
|
148
|
+
|
|
149
|
+
def __init__(self, *args, **kwargs): # type: ignore[no-redef]
|
|
150
|
+
orig_init(self, *args, **kwargs)
|
|
151
|
+
if hasattr(self, "position_ids"):
|
|
152
|
+
# In Transformers 5, `position_ids` is sometimes registered as a non-persistent buffer
|
|
153
|
+
# (`persistent=False`), so it isn't part of the state dict and is reported as UNEXPECTED
|
|
154
|
+
# when loading older checkpoints that include it. Make it persistent.
|
|
155
|
+
try:
|
|
156
|
+
buffers = getattr(self, "_buffers", None)
|
|
157
|
+
if isinstance(buffers, dict) and "position_ids" in buffers:
|
|
158
|
+
non_persistent = getattr(self, "_non_persistent_buffers_set", None)
|
|
159
|
+
if isinstance(non_persistent, set):
|
|
160
|
+
non_persistent.discard("position_ids")
|
|
161
|
+
return
|
|
162
|
+
except Exception:
|
|
163
|
+
return
|
|
164
|
+
pos_emb = getattr(self, "position_embedding", None)
|
|
165
|
+
num = getattr(pos_emb, "num_embeddings", None) if pos_emb is not None else None
|
|
166
|
+
if num is None:
|
|
167
|
+
return
|
|
168
|
+
try:
|
|
169
|
+
position_ids = _torch.arange(int(num)).unsqueeze(0)
|
|
170
|
+
self.register_buffer("position_ids", position_ids, persistent=True)
|
|
171
|
+
except Exception:
|
|
172
|
+
return
|
|
173
|
+
|
|
174
|
+
setattr(cls, "__init__", __init__)
|
|
175
|
+
setattr(cls, "_abstractvision_position_ids_patched", True)
|
|
176
|
+
|
|
177
|
+
_patch(CLIPTextEmbeddings)
|
|
178
|
+
_patch(CLIPVisionEmbeddings)
|
|
179
|
+
_TRANSFORMERS_CLIP_POSITION_IDS_PATCHED = True
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
@contextmanager
|
|
183
|
+
def _hf_offline_env(enabled: bool):
|
|
184
|
+
"""Control Hugging Face offline mode within a scope.
|
|
185
|
+
|
|
186
|
+
When `enabled=True`, we force offline mode (no network calls).
|
|
187
|
+
When `enabled=False`, we force online mode (overrides e.g. HF_HUB_OFFLINE=1 in the user's shell).
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
# These are respected by huggingface_hub / transformers / diffusers.
|
|
191
|
+
# We scope them to the load/call to avoid surprising other parts of the process.
|
|
192
|
+
vars_to_set = {
|
|
193
|
+
"HF_HUB_OFFLINE": "1" if enabled else "0",
|
|
194
|
+
"TRANSFORMERS_OFFLINE": "1" if enabled else "0",
|
|
195
|
+
"DIFFUSERS_OFFLINE": "1" if enabled else "0",
|
|
196
|
+
# Avoid any telemetry even in edge cases.
|
|
197
|
+
"HF_HUB_DISABLE_TELEMETRY": "1",
|
|
198
|
+
}
|
|
199
|
+
old = {k: os.environ.get(k) for k in vars_to_set.keys()}
|
|
200
|
+
try:
|
|
201
|
+
for k, v in vars_to_set.items():
|
|
202
|
+
os.environ[k] = v
|
|
203
|
+
yield
|
|
204
|
+
finally:
|
|
205
|
+
for k, prev in old.items():
|
|
206
|
+
if prev is None:
|
|
207
|
+
os.environ.pop(k, None)
|
|
208
|
+
else:
|
|
209
|
+
os.environ[k] = prev
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def _torch_dtype_from_str(torch: Any, value: Optional[str]) -> Any:
|
|
213
|
+
if value is None:
|
|
214
|
+
return None
|
|
215
|
+
v = str(value).strip().lower()
|
|
216
|
+
if not v or v == "auto":
|
|
217
|
+
return None
|
|
218
|
+
if v in {"float16", "fp16"}:
|
|
219
|
+
return torch.float16
|
|
220
|
+
if v in {"bfloat16", "bf16"}:
|
|
221
|
+
return torch.bfloat16
|
|
222
|
+
if v in {"float32", "fp32"}:
|
|
223
|
+
return torch.float32
|
|
224
|
+
raise ValueError(f"Unsupported torch_dtype: {value!r}")
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _default_torch_dtype_for_device(torch: Any, device: str) -> Any:
|
|
228
|
+
d = str(device or "").strip().lower()
|
|
229
|
+
if not d:
|
|
230
|
+
return None
|
|
231
|
+
if d.startswith("cuda"):
|
|
232
|
+
return torch.float16
|
|
233
|
+
if d == "mps" or d.startswith("mps:"):
|
|
234
|
+
# Default to fp16 on Apple Silicon for broad model compatibility.
|
|
235
|
+
# (Some pipelines mix dtypes when using bf16, which can crash with matmul dtype mismatches.)
|
|
236
|
+
#
|
|
237
|
+
# You can still force bf16 explicitly via `torch_dtype=bfloat16`.
|
|
238
|
+
return torch.float16
|
|
239
|
+
return None
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _require_device_available(torch: Any, device: str) -> None:
|
|
243
|
+
d = str(device or "").strip().lower()
|
|
244
|
+
if not d:
|
|
245
|
+
return
|
|
246
|
+
|
|
247
|
+
if d.startswith("cuda"):
|
|
248
|
+
cuda = getattr(torch, "cuda", None)
|
|
249
|
+
is_available = getattr(cuda, "is_available", None) if cuda is not None else None
|
|
250
|
+
ok = bool(is_available()) if callable(is_available) else False
|
|
251
|
+
if not ok:
|
|
252
|
+
raise ValueError(
|
|
253
|
+
"Device 'cuda' was requested, but torch.cuda.is_available() is False. "
|
|
254
|
+
"Install a CUDA-enabled PyTorch build or use device='cpu'."
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
if d == "mps" or d.startswith("mps:"):
|
|
258
|
+
backends = getattr(torch, "backends", None)
|
|
259
|
+
mps = getattr(backends, "mps", None) if backends is not None else None
|
|
260
|
+
is_available = getattr(mps, "is_available", None) if mps is not None else None
|
|
261
|
+
ok = bool(is_available()) if callable(is_available) else False
|
|
262
|
+
if not ok:
|
|
263
|
+
raise ValueError(
|
|
264
|
+
"Device 'mps' was requested, but torch.backends.mps.is_available() is False. "
|
|
265
|
+
"On macOS this typically means you are not using an Apple Silicon + MPS-enabled PyTorch build. "
|
|
266
|
+
"Use device='cpu', or use the stable-diffusion.cpp (sd-cli) backend for GGUF models."
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def _call_param_names(fn: Any) -> Optional[set[str]]:
|
|
271
|
+
try:
|
|
272
|
+
sig = inspect.signature(fn)
|
|
273
|
+
for p in sig.parameters.values():
|
|
274
|
+
if p.kind == p.VAR_KEYWORD:
|
|
275
|
+
return None
|
|
276
|
+
return {str(k) for k in sig.parameters.keys() if str(k) != "self"}
|
|
277
|
+
except Exception:
|
|
278
|
+
return None
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def _looks_like_dtype_mismatch_error(e: Exception) -> bool:
|
|
282
|
+
msg = str(e or "")
|
|
283
|
+
m = msg.lower()
|
|
284
|
+
return (
|
|
285
|
+
"must have the same dtype" in m
|
|
286
|
+
or ("input type" in m and "bias type" in m and "should be the same" in m)
|
|
287
|
+
or ("expected scalar type" in m and "but found" in m)
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def _maybe_upcast_vae_for_mps(torch: Any, pipe: Any, device: str) -> None:
|
|
292
|
+
d = str(device or "").strip().lower()
|
|
293
|
+
if d != "mps" and not d.startswith("mps:"):
|
|
294
|
+
return
|
|
295
|
+
|
|
296
|
+
# On Apple Silicon, some pipelines can produce NaNs/black images when decoding with a float16 VAE.
|
|
297
|
+
# A common fix is to keep the main model in fp16 but run VAE encode/decode in fp32.
|
|
298
|
+
#
|
|
299
|
+
# Diffusers pipelines do not consistently cast inputs to `vae.dtype` before calling `vae.encode/decode`.
|
|
300
|
+
# If we upcast only the VAE weights to fp32 while the pipeline still produces fp16 latents/images,
|
|
301
|
+
# PyTorch can raise dtype mismatch errors like:
|
|
302
|
+
# "Input type (c10::Half) and bias type (float) should be the same"
|
|
303
|
+
#
|
|
304
|
+
# To keep this backend robust across Diffusers versions, when we upcast the VAE we also wrap
|
|
305
|
+
# `vae.encode` and `vae.decode` to cast their tensor inputs to the VAE's dtype.
|
|
306
|
+
vae = getattr(pipe, "vae", None)
|
|
307
|
+
if vae is None:
|
|
308
|
+
return
|
|
309
|
+
to_fn = getattr(vae, "to", None)
|
|
310
|
+
if not callable(to_fn):
|
|
311
|
+
return
|
|
312
|
+
dtype = getattr(vae, "dtype", None)
|
|
313
|
+
if dtype == getattr(torch, "float16", None):
|
|
314
|
+
try:
|
|
315
|
+
vae.to(dtype=torch.float32)
|
|
316
|
+
_maybe_cast_vae_inputs_to_dtype(vae)
|
|
317
|
+
except Exception:
|
|
318
|
+
return
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def _maybe_cast_pipe_modules_to_dtype(pipe: Any, *, dtype: Any) -> None:
|
|
322
|
+
if dtype is None:
|
|
323
|
+
return
|
|
324
|
+
|
|
325
|
+
def _to(module: Any) -> None:
|
|
326
|
+
if module is None:
|
|
327
|
+
return
|
|
328
|
+
to_fn = getattr(module, "to", None)
|
|
329
|
+
if not callable(to_fn):
|
|
330
|
+
return
|
|
331
|
+
try:
|
|
332
|
+
module.to(dtype=dtype)
|
|
333
|
+
except Exception:
|
|
334
|
+
return
|
|
335
|
+
|
|
336
|
+
# Best-effort: different pipelines use different component names (unet vs transformer, etc).
|
|
337
|
+
for attr in (
|
|
338
|
+
"model",
|
|
339
|
+
"transformer",
|
|
340
|
+
"unet",
|
|
341
|
+
"text_encoder",
|
|
342
|
+
"text_encoder_2",
|
|
343
|
+
"image_encoder",
|
|
344
|
+
"prior",
|
|
345
|
+
"vae",
|
|
346
|
+
"safety_checker",
|
|
347
|
+
):
|
|
348
|
+
_to(getattr(pipe, attr, None))
|
|
349
|
+
|
|
350
|
+
vae = getattr(pipe, "vae", None)
|
|
351
|
+
if vae is not None:
|
|
352
|
+
_to(getattr(vae, "encoder", None))
|
|
353
|
+
_to(getattr(vae, "decoder", None))
|
|
354
|
+
|
|
355
|
+
# As a fallback, cast all registered components when available (covers pipelines that don't follow
|
|
356
|
+
# the common attribute naming patterns above).
|
|
357
|
+
comps = getattr(pipe, "components", None)
|
|
358
|
+
if isinstance(comps, dict):
|
|
359
|
+
for v in comps.values():
|
|
360
|
+
_to(v)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def _maybe_cast_vae_inputs_to_dtype(vae: Any) -> None:
|
|
364
|
+
if getattr(vae, "_abstractvision_casts_inputs_to_dtype", False):
|
|
365
|
+
return
|
|
366
|
+
|
|
367
|
+
try:
|
|
368
|
+
import types
|
|
369
|
+
|
|
370
|
+
def _wrap(name: str) -> None:
|
|
371
|
+
orig = getattr(vae, name, None)
|
|
372
|
+
if not callable(orig):
|
|
373
|
+
return
|
|
374
|
+
|
|
375
|
+
def wrapper(self: Any, x: Any, *args: Any, **kwargs: Any) -> Any:
|
|
376
|
+
try:
|
|
377
|
+
dtype = getattr(self, "dtype", None)
|
|
378
|
+
x_dtype = getattr(x, "dtype", None)
|
|
379
|
+
to_fn = getattr(x, "to", None)
|
|
380
|
+
if dtype is not None and x_dtype is not None and x_dtype != dtype and callable(to_fn):
|
|
381
|
+
x = x.to(dtype=dtype)
|
|
382
|
+
except Exception:
|
|
383
|
+
pass
|
|
384
|
+
return orig(x, *args, **kwargs)
|
|
385
|
+
|
|
386
|
+
setattr(vae, name, types.MethodType(wrapper, vae))
|
|
387
|
+
|
|
388
|
+
_wrap("encode")
|
|
389
|
+
_wrap("decode")
|
|
390
|
+
setattr(vae, "_abstractvision_casts_inputs_to_dtype", True)
|
|
391
|
+
except Exception:
|
|
392
|
+
return
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
@dataclass(frozen=True)
|
|
396
|
+
class HuggingFaceDiffusersBackendConfig:
|
|
397
|
+
"""Config for a local Diffusers backend.
|
|
398
|
+
|
|
399
|
+
Notes:
|
|
400
|
+
- Downloads are enabled by default so a fresh environment can work after a `pip install`.
|
|
401
|
+
- To force offline mode (no network calls / cache-only), set `allow_download=False`.
|
|
402
|
+
"""
|
|
403
|
+
|
|
404
|
+
model_id: str
|
|
405
|
+
device: str = "cpu" # "cpu" | "cuda" | "mps" | "auto" | ...
|
|
406
|
+
torch_dtype: Optional[str] = None # "float16" | "bfloat16" | "float32" | None
|
|
407
|
+
allow_download: bool = True
|
|
408
|
+
auto_retry_fp32: bool = True
|
|
409
|
+
cache_dir: Optional[str] = None
|
|
410
|
+
revision: Optional[str] = None
|
|
411
|
+
variant: Optional[str] = None
|
|
412
|
+
use_safetensors: bool = True
|
|
413
|
+
low_cpu_mem_usage: bool = True
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
class HuggingFaceDiffusersVisionBackend(VisionBackend):
|
|
417
|
+
"""Local generative vision backend using HuggingFace Diffusers (images only, phase 1)."""
|
|
418
|
+
|
|
419
|
+
def __init__(self, *, config: HuggingFaceDiffusersBackendConfig):
|
|
420
|
+
self._cfg = config
|
|
421
|
+
self._pipelines: Dict[str, Any] = {}
|
|
422
|
+
self._call_params: Dict[str, Optional[set[str]]] = {}
|
|
423
|
+
self._fused_lora_signature: Dict[str, Optional[str]] = {}
|
|
424
|
+
self._rapid_transformer_key: Optional[str] = None
|
|
425
|
+
self._rapid_transformer: Any = None
|
|
426
|
+
self._resolved_device: Optional[str] = None
|
|
427
|
+
|
|
428
|
+
def _effective_device(self, torch: Any) -> str:
|
|
429
|
+
if self._resolved_device is not None:
|
|
430
|
+
return self._resolved_device
|
|
431
|
+
|
|
432
|
+
raw = str(getattr(self._cfg, "device", "") or "").strip()
|
|
433
|
+
d = raw.lower()
|
|
434
|
+
if not d or d in {"auto", "default"}:
|
|
435
|
+
cuda = getattr(torch, "cuda", None)
|
|
436
|
+
if cuda is not None and callable(getattr(cuda, "is_available", None)) and cuda.is_available():
|
|
437
|
+
self._resolved_device = "cuda"
|
|
438
|
+
return self._resolved_device
|
|
439
|
+
|
|
440
|
+
backends = getattr(torch, "backends", None)
|
|
441
|
+
mps = getattr(backends, "mps", None) if backends is not None else None
|
|
442
|
+
if mps is not None and callable(getattr(mps, "is_available", None)) and mps.is_available():
|
|
443
|
+
self._resolved_device = "mps"
|
|
444
|
+
return self._resolved_device
|
|
445
|
+
|
|
446
|
+
self._resolved_device = "cpu"
|
|
447
|
+
return self._resolved_device
|
|
448
|
+
|
|
449
|
+
# Normalize common spellings but preserve explicit device indexes (e.g. "cuda:0").
|
|
450
|
+
if d == "gpu":
|
|
451
|
+
self._resolved_device = "cuda"
|
|
452
|
+
else:
|
|
453
|
+
self._resolved_device = raw
|
|
454
|
+
return self._resolved_device
|
|
455
|
+
|
|
456
|
+
def preload(self) -> None:
|
|
457
|
+
# Best-effort: preload the most common pipeline.
|
|
458
|
+
self._get_or_load_pipeline("t2i")
|
|
459
|
+
|
|
460
|
+
def unload(self) -> None:
|
|
461
|
+
# Best-effort: release pipelines and GPU cache.
|
|
462
|
+
pipes = list(self._pipelines.values())
|
|
463
|
+
self._pipelines.clear()
|
|
464
|
+
self._call_params.clear()
|
|
465
|
+
self._fused_lora_signature.clear()
|
|
466
|
+
self._rapid_transformer_key = None
|
|
467
|
+
self._rapid_transformer = None
|
|
468
|
+
|
|
469
|
+
# Drop references and aggressively collect.
|
|
470
|
+
try:
|
|
471
|
+
for p in pipes:
|
|
472
|
+
try:
|
|
473
|
+
# Try to free adapter weights.
|
|
474
|
+
unfuse = getattr(p, "unfuse_lora", None)
|
|
475
|
+
if callable(unfuse):
|
|
476
|
+
unfuse()
|
|
477
|
+
except Exception:
|
|
478
|
+
pass
|
|
479
|
+
try:
|
|
480
|
+
unload = getattr(p, "unload_lora_weights", None)
|
|
481
|
+
if callable(unload):
|
|
482
|
+
unload()
|
|
483
|
+
except Exception:
|
|
484
|
+
pass
|
|
485
|
+
finally:
|
|
486
|
+
pipes = []
|
|
487
|
+
|
|
488
|
+
try:
|
|
489
|
+
import gc
|
|
490
|
+
|
|
491
|
+
gc.collect()
|
|
492
|
+
except Exception:
|
|
493
|
+
pass
|
|
494
|
+
|
|
495
|
+
try:
|
|
496
|
+
torch = _lazy_import_torch()
|
|
497
|
+
cuda = getattr(torch, "cuda", None)
|
|
498
|
+
if cuda is not None and callable(getattr(cuda, "is_available", None)) and cuda.is_available():
|
|
499
|
+
empty = getattr(cuda, "empty_cache", None)
|
|
500
|
+
if callable(empty):
|
|
501
|
+
empty()
|
|
502
|
+
ipc_collect = getattr(cuda, "ipc_collect", None)
|
|
503
|
+
if callable(ipc_collect):
|
|
504
|
+
ipc_collect()
|
|
505
|
+
|
|
506
|
+
mps = getattr(torch, "mps", None)
|
|
507
|
+
empty_mps = getattr(mps, "empty_cache", None) if mps is not None else None
|
|
508
|
+
if callable(empty_mps):
|
|
509
|
+
empty_mps()
|
|
510
|
+
except Exception:
|
|
511
|
+
pass
|
|
512
|
+
|
|
513
|
+
def _lora_signature(self, loras: List[Dict[str, Any]]) -> Optional[str]:
|
|
514
|
+
if not loras:
|
|
515
|
+
return None
|
|
516
|
+
parts: List[str] = []
|
|
517
|
+
for spec in sorted(loras, key=lambda x: str(x.get("source") or "")):
|
|
518
|
+
parts.append(
|
|
519
|
+
"|".join(
|
|
520
|
+
[
|
|
521
|
+
str(spec.get("source") or ""),
|
|
522
|
+
str(spec.get("subfolder") or ""),
|
|
523
|
+
str(spec.get("weight_name") or ""),
|
|
524
|
+
str(spec.get("scale") or 1.0),
|
|
525
|
+
]
|
|
526
|
+
)
|
|
527
|
+
)
|
|
528
|
+
combined = "::".join(parts)
|
|
529
|
+
return hashlib.md5(combined.encode("utf-8")).hexdigest()[:12]
|
|
530
|
+
|
|
531
|
+
def _parse_loras(self, extra: Any) -> List[Dict[str, Any]]:
|
|
532
|
+
if not isinstance(extra, dict) or not extra:
|
|
533
|
+
return []
|
|
534
|
+
|
|
535
|
+
raw: Any = None
|
|
536
|
+
for k in ("loras", "loras_json", "lora", "lora_json"):
|
|
537
|
+
if k in extra and extra.get(k) is not None:
|
|
538
|
+
raw = extra.get(k)
|
|
539
|
+
break
|
|
540
|
+
if raw is None:
|
|
541
|
+
return []
|
|
542
|
+
|
|
543
|
+
import json
|
|
544
|
+
|
|
545
|
+
items: Any = raw
|
|
546
|
+
if isinstance(raw, str):
|
|
547
|
+
s = raw.strip()
|
|
548
|
+
if not s:
|
|
549
|
+
return []
|
|
550
|
+
# Prefer JSON, but allow a simple comma-separated list of sources.
|
|
551
|
+
if s.startswith("[") or s.startswith("{"):
|
|
552
|
+
try:
|
|
553
|
+
items = json.loads(s)
|
|
554
|
+
except Exception:
|
|
555
|
+
items = raw
|
|
556
|
+
if isinstance(items, str):
|
|
557
|
+
parts = [p.strip() for p in items.split(",") if p.strip()]
|
|
558
|
+
items = [{"source": p} for p in parts]
|
|
559
|
+
|
|
560
|
+
if isinstance(items, dict):
|
|
561
|
+
items = [items]
|
|
562
|
+
if isinstance(items, str):
|
|
563
|
+
return [{"source": items.strip()}] if items.strip() else []
|
|
564
|
+
if not isinstance(items, list):
|
|
565
|
+
return []
|
|
566
|
+
|
|
567
|
+
out: List[Dict[str, Any]] = []
|
|
568
|
+
for el in items:
|
|
569
|
+
if isinstance(el, str):
|
|
570
|
+
src = el.strip()
|
|
571
|
+
if src:
|
|
572
|
+
out.append({"source": src, "scale": 1.0})
|
|
573
|
+
continue
|
|
574
|
+
if not isinstance(el, dict):
|
|
575
|
+
continue
|
|
576
|
+
src = str(el.get("source") or "").strip()
|
|
577
|
+
if not src:
|
|
578
|
+
continue
|
|
579
|
+
spec: Dict[str, Any] = {"source": src}
|
|
580
|
+
if el.get("subfolder") is not None:
|
|
581
|
+
spec["subfolder"] = str(el.get("subfolder") or "").strip() or None
|
|
582
|
+
if el.get("weight_name") is not None:
|
|
583
|
+
spec["weight_name"] = str(el.get("weight_name") or "").strip() or None
|
|
584
|
+
if el.get("adapter_name") is not None:
|
|
585
|
+
spec["adapter_name"] = str(el.get("adapter_name") or "").strip() or None
|
|
586
|
+
try:
|
|
587
|
+
spec["scale"] = float(el.get("scale") if el.get("scale") is not None else 1.0)
|
|
588
|
+
except Exception:
|
|
589
|
+
spec["scale"] = 1.0
|
|
590
|
+
out.append(spec)
|
|
591
|
+
return out
|
|
592
|
+
|
|
593
|
+
def _resolved_adapter_name(self, spec: Dict[str, Any]) -> str:
|
|
594
|
+
name = str(spec.get("adapter_name") or "").strip()
|
|
595
|
+
if name:
|
|
596
|
+
return name
|
|
597
|
+
key = "|".join([str(spec.get("source") or ""), str(spec.get("subfolder") or ""), str(spec.get("weight_name") or "")])
|
|
598
|
+
return "lora_" + hashlib.md5(key.encode("utf-8")).hexdigest()[:12]
|
|
599
|
+
|
|
600
|
+
def _apply_loras(self, *, kind: str, pipe: Any, extra: Any) -> Optional[str]:
|
|
601
|
+
loras = self._parse_loras(extra)
|
|
602
|
+
new_sig = self._lora_signature(loras)
|
|
603
|
+
cur_sig = self._fused_lora_signature.get(kind)
|
|
604
|
+
if new_sig == cur_sig:
|
|
605
|
+
return cur_sig
|
|
606
|
+
|
|
607
|
+
# Always clear previous adapters before applying a new set.
|
|
608
|
+
if hasattr(pipe, "unfuse_lora"):
|
|
609
|
+
try:
|
|
610
|
+
pipe.unfuse_lora()
|
|
611
|
+
except Exception:
|
|
612
|
+
pass
|
|
613
|
+
if hasattr(pipe, "unload_lora_weights"):
|
|
614
|
+
try:
|
|
615
|
+
pipe.unload_lora_weights()
|
|
616
|
+
except Exception:
|
|
617
|
+
pass
|
|
618
|
+
|
|
619
|
+
if not loras:
|
|
620
|
+
self._fused_lora_signature[kind] = None
|
|
621
|
+
return None
|
|
622
|
+
|
|
623
|
+
adapter_names: List[str] = []
|
|
624
|
+
adapter_scales: List[float] = []
|
|
625
|
+
|
|
626
|
+
with _hf_offline_env(not bool(self._cfg.allow_download)):
|
|
627
|
+
for spec in loras:
|
|
628
|
+
adapter_name = self._resolved_adapter_name(spec)
|
|
629
|
+
adapter_names.append(adapter_name)
|
|
630
|
+
adapter_scales.append(float(spec.get("scale") or 1.0))
|
|
631
|
+
|
|
632
|
+
kwargs: Dict[str, Any] = {}
|
|
633
|
+
if spec.get("weight_name"):
|
|
634
|
+
kwargs["weight_name"] = spec["weight_name"]
|
|
635
|
+
if spec.get("subfolder"):
|
|
636
|
+
kwargs["subfolder"] = spec["subfolder"]
|
|
637
|
+
kwargs["local_files_only"] = not bool(self._cfg.allow_download)
|
|
638
|
+
if self._cfg.cache_dir:
|
|
639
|
+
kwargs["cache_dir"] = str(self._cfg.cache_dir)
|
|
640
|
+
|
|
641
|
+
load_fn = getattr(pipe, "load_lora_weights", None)
|
|
642
|
+
if not callable(load_fn):
|
|
643
|
+
raise ValueError("This diffusers pipeline does not support LoRA adapters (missing load_lora_weights).")
|
|
644
|
+
load_fn(spec["source"], adapter_name=adapter_name, **kwargs)
|
|
645
|
+
|
|
646
|
+
if hasattr(pipe, "set_adapters"):
|
|
647
|
+
try:
|
|
648
|
+
pipe.set_adapters(adapter_names, adapter_weights=adapter_scales)
|
|
649
|
+
except Exception:
|
|
650
|
+
pass
|
|
651
|
+
|
|
652
|
+
if hasattr(pipe, "fuse_lora"):
|
|
653
|
+
try:
|
|
654
|
+
pipe.fuse_lora()
|
|
655
|
+
except Exception:
|
|
656
|
+
pass
|
|
657
|
+
|
|
658
|
+
if hasattr(pipe, "unload_lora_weights"):
|
|
659
|
+
try:
|
|
660
|
+
pipe.unload_lora_weights()
|
|
661
|
+
except Exception:
|
|
662
|
+
pass
|
|
663
|
+
|
|
664
|
+
self._fused_lora_signature[kind] = new_sig
|
|
665
|
+
return new_sig
|
|
666
|
+
|
|
667
|
+
def _maybe_apply_rapid_aio_transformer(self, *, pipe: Any, extra: Any, torch_dtype: Any) -> Optional[str]:
|
|
668
|
+
"""Optionally swap the pipeline's transformer with a Rapid-AIO distilled transformer.
|
|
669
|
+
|
|
670
|
+
This is primarily useful for Qwen Image Edit pipelines (very fast 4-step inference), but we keep it
|
|
671
|
+
generic: if a pipeline has a `.transformer` module and diffusers provides a compatible transformer
|
|
672
|
+
class, we can hot-swap it.
|
|
673
|
+
|
|
674
|
+
Downloads are enabled by default; set allow_download=False for cache-only/offline mode.
|
|
675
|
+
"""
|
|
676
|
+
|
|
677
|
+
if not isinstance(extra, dict) or not extra:
|
|
678
|
+
return None
|
|
679
|
+
|
|
680
|
+
repo = None
|
|
681
|
+
if extra.get("rapid_aio_repo"):
|
|
682
|
+
repo = str(extra.get("rapid_aio_repo") or "").strip()
|
|
683
|
+
elif extra.get("rapid_aio") is True:
|
|
684
|
+
repo = "linoyts/Qwen-Image-Edit-Rapid-AIO"
|
|
685
|
+
elif isinstance(extra.get("rapid_aio"), str) and str(extra.get("rapid_aio")).strip():
|
|
686
|
+
repo = str(extra.get("rapid_aio")).strip()
|
|
687
|
+
if not repo:
|
|
688
|
+
return None
|
|
689
|
+
|
|
690
|
+
subfolder = str(extra.get("rapid_aio_subfolder") or "transformer").strip() or "transformer"
|
|
691
|
+
key = f"{repo}|{subfolder}|{torch_dtype}"
|
|
692
|
+
if key == self._rapid_transformer_key and self._rapid_transformer is not None:
|
|
693
|
+
tr = self._rapid_transformer
|
|
694
|
+
else:
|
|
695
|
+
try:
|
|
696
|
+
from diffusers.models import QwenImageTransformer2DModel # type: ignore
|
|
697
|
+
except Exception:
|
|
698
|
+
raise ValueError(
|
|
699
|
+
"Rapid-AIO transformer override requires diffusers.models.QwenImageTransformer2DModel, "
|
|
700
|
+
"which is not available in this diffusers build."
|
|
701
|
+
)
|
|
702
|
+
kwargs: Dict[str, Any] = {"subfolder": subfolder, "local_files_only": not bool(self._cfg.allow_download)}
|
|
703
|
+
if self._cfg.cache_dir:
|
|
704
|
+
kwargs["cache_dir"] = str(self._cfg.cache_dir)
|
|
705
|
+
with _hf_offline_env(not bool(self._cfg.allow_download)):
|
|
706
|
+
tr = QwenImageTransformer2DModel.from_pretrained(repo, torch_dtype=torch_dtype, **kwargs)
|
|
707
|
+
torch = _lazy_import_torch()
|
|
708
|
+
device = self._effective_device(torch)
|
|
709
|
+
try:
|
|
710
|
+
tr = tr.to(device=str(device), dtype=torch_dtype)
|
|
711
|
+
except Exception:
|
|
712
|
+
try:
|
|
713
|
+
tr = tr.to(dtype=torch_dtype)
|
|
714
|
+
tr = tr.to(str(device))
|
|
715
|
+
except Exception:
|
|
716
|
+
pass
|
|
717
|
+
|
|
718
|
+
self._rapid_transformer_key = key
|
|
719
|
+
self._rapid_transformer = tr
|
|
720
|
+
|
|
721
|
+
if hasattr(pipe, "register_modules"):
|
|
722
|
+
try:
|
|
723
|
+
pipe.register_modules(transformer=tr)
|
|
724
|
+
except Exception:
|
|
725
|
+
setattr(pipe, "transformer", tr)
|
|
726
|
+
else:
|
|
727
|
+
setattr(pipe, "transformer", tr)
|
|
728
|
+
|
|
729
|
+
_maybe_cast_pipe_modules_to_dtype(pipe, dtype=torch_dtype)
|
|
730
|
+
return repo
|
|
731
|
+
|
|
732
|
+
def get_capabilities(self) -> VisionBackendCapabilities:
|
|
733
|
+
return VisionBackendCapabilities(
|
|
734
|
+
supported_tasks=["text_to_image", "image_to_image"],
|
|
735
|
+
supports_mask=None, # depends on whether inpaint pipeline loads for the model
|
|
736
|
+
)
|
|
737
|
+
|
|
738
|
+
def _pipeline_common_kwargs(self) -> Dict[str, Any]:
|
|
739
|
+
kwargs: Dict[str, Any] = {
|
|
740
|
+
"local_files_only": not bool(self._cfg.allow_download),
|
|
741
|
+
"use_safetensors": bool(self._cfg.use_safetensors),
|
|
742
|
+
}
|
|
743
|
+
if self._cfg.cache_dir:
|
|
744
|
+
kwargs["cache_dir"] = str(self._cfg.cache_dir)
|
|
745
|
+
if self._cfg.revision:
|
|
746
|
+
kwargs["revision"] = str(self._cfg.revision)
|
|
747
|
+
if self._cfg.variant:
|
|
748
|
+
kwargs["variant"] = str(self._cfg.variant)
|
|
749
|
+
return kwargs
|
|
750
|
+
|
|
751
|
+
def _hf_cache_root(self) -> Path:
|
|
752
|
+
if self._cfg.cache_dir:
|
|
753
|
+
return Path(self._cfg.cache_dir).expanduser()
|
|
754
|
+
hub_cache = os.environ.get("HF_HUB_CACHE")
|
|
755
|
+
if hub_cache:
|
|
756
|
+
return Path(hub_cache).expanduser()
|
|
757
|
+
hf_home = os.environ.get("HF_HOME")
|
|
758
|
+
if hf_home:
|
|
759
|
+
return Path(hf_home).expanduser() / "hub"
|
|
760
|
+
return Path.home() / ".cache" / "huggingface" / "hub"
|
|
761
|
+
|
|
762
|
+
def _resolve_snapshot_dir(self) -> Optional[Path]:
|
|
763
|
+
model_id = str(self._cfg.model_id).strip()
|
|
764
|
+
if not model_id:
|
|
765
|
+
return None
|
|
766
|
+
|
|
767
|
+
p = Path(model_id).expanduser()
|
|
768
|
+
if p.exists():
|
|
769
|
+
return p
|
|
770
|
+
|
|
771
|
+
if "/" not in model_id:
|
|
772
|
+
return None
|
|
773
|
+
|
|
774
|
+
cache_root = self._hf_cache_root()
|
|
775
|
+
repo_dir = cache_root / ("models--" + model_id.replace("/", "--"))
|
|
776
|
+
snaps = repo_dir / "snapshots"
|
|
777
|
+
if not snaps.is_dir():
|
|
778
|
+
return None
|
|
779
|
+
|
|
780
|
+
rev = str(self._cfg.revision or "main").strip() or "main"
|
|
781
|
+
ref_file = repo_dir / "refs" / rev
|
|
782
|
+
if ref_file.is_file():
|
|
783
|
+
commit = ref_file.read_text(encoding="utf-8").strip()
|
|
784
|
+
snap_dir = snaps / commit
|
|
785
|
+
if snap_dir.is_dir():
|
|
786
|
+
return snap_dir
|
|
787
|
+
|
|
788
|
+
# Fallback: pick the most recently modified snapshot.
|
|
789
|
+
candidates = [d for d in snaps.iterdir() if d.is_dir()]
|
|
790
|
+
if not candidates:
|
|
791
|
+
return None
|
|
792
|
+
return max(candidates, key=lambda d: d.stat().st_mtime)
|
|
793
|
+
|
|
794
|
+
def _preflight_check_model_index(self) -> None:
|
|
795
|
+
snap = self._resolve_snapshot_dir()
|
|
796
|
+
if snap is None:
|
|
797
|
+
return
|
|
798
|
+
idx_path = snap / "model_index.json"
|
|
799
|
+
if not idx_path.is_file():
|
|
800
|
+
return
|
|
801
|
+
|
|
802
|
+
try:
|
|
803
|
+
import json
|
|
804
|
+
|
|
805
|
+
model_index = json.loads(idx_path.read_text(encoding="utf-8"))
|
|
806
|
+
except Exception:
|
|
807
|
+
return
|
|
808
|
+
|
|
809
|
+
class_name = str(model_index.get("_class_name") or "").strip()
|
|
810
|
+
if not class_name:
|
|
811
|
+
return
|
|
812
|
+
|
|
813
|
+
(
|
|
814
|
+
_DiffusionPipeline,
|
|
815
|
+
_AutoPipelineForText2Image,
|
|
816
|
+
_AutoPipelineForImage2Image,
|
|
817
|
+
_AutoPipelineForInpainting,
|
|
818
|
+
diffusers_version,
|
|
819
|
+
) = _lazy_import_diffusers()
|
|
820
|
+
|
|
821
|
+
import diffusers as _diffusers # type: ignore
|
|
822
|
+
|
|
823
|
+
if not hasattr(_diffusers, class_name):
|
|
824
|
+
required = str(model_index.get("_diffusers_version") or "unknown")
|
|
825
|
+
install_hint = "pip install -U 'git+https://github.com/huggingface/diffusers@main'"
|
|
826
|
+
install_hint_alt = "pip install -e '.[huggingface-dev]'"
|
|
827
|
+
extra = ""
|
|
828
|
+
if class_name == "Flux2KleinPipeline":
|
|
829
|
+
extra = (
|
|
830
|
+
" Note: this model uses a different text encoder than the released Flux2Pipeline in diffusers 0.36 "
|
|
831
|
+
"(Klein uses Qwen3; Flux2Pipeline is built around Mistral3), so a newer diffusers is required."
|
|
832
|
+
)
|
|
833
|
+
raise ValueError(
|
|
834
|
+
f"Diffusers pipeline class {class_name!r} is required by this model, but is not available in your "
|
|
835
|
+
f"installed diffusers ({diffusers_version}). "
|
|
836
|
+
f"The model's model_index.json was authored for diffusers {required}. "
|
|
837
|
+
"This class is not available in the latest PyPI release at the time of writing. "
|
|
838
|
+
f"Install a newer diffusers (offline runtime is still supported): {install_hint}. "
|
|
839
|
+
f"If you're installing AbstractVision from a repo checkout, you can also use: {install_hint_alt}.{extra}"
|
|
840
|
+
)
|
|
841
|
+
|
|
842
|
+
# Optional: sanity-check that referenced Transformers classes exist to avoid late failures.
|
|
843
|
+
try:
|
|
844
|
+
import transformers # type: ignore
|
|
845
|
+
|
|
846
|
+
missing_tf: list[str] = []
|
|
847
|
+
for v in model_index.values():
|
|
848
|
+
if (
|
|
849
|
+
isinstance(v, list)
|
|
850
|
+
and len(v) == 2
|
|
851
|
+
and isinstance(v[0], str)
|
|
852
|
+
and isinstance(v[1], str)
|
|
853
|
+
and v[0].strip().lower() == "transformers"
|
|
854
|
+
):
|
|
855
|
+
tf_cls = v[1].strip()
|
|
856
|
+
if tf_cls and not hasattr(transformers, tf_cls):
|
|
857
|
+
missing_tf.append(tf_cls)
|
|
858
|
+
if missing_tf:
|
|
859
|
+
tf_ver = getattr(transformers, "__version__", "unknown")
|
|
860
|
+
raise ValueError(
|
|
861
|
+
"This model references Transformers classes that are not available in your environment "
|
|
862
|
+
f"(transformers={tf_ver}): {', '.join(sorted(set(missing_tf)))}. "
|
|
863
|
+
"Upgrade transformers to a compatible version."
|
|
864
|
+
)
|
|
865
|
+
except ValueError:
|
|
866
|
+
raise
|
|
867
|
+
except Exception:
|
|
868
|
+
pass
|
|
869
|
+
|
|
870
|
+
def _get_or_load_pipeline(self, kind: str) -> Any:
|
|
871
|
+
existing = self._pipelines.get(kind)
|
|
872
|
+
if existing is not None:
|
|
873
|
+
return existing
|
|
874
|
+
|
|
875
|
+
(
|
|
876
|
+
DiffusionPipeline,
|
|
877
|
+
AutoPipelineForText2Image,
|
|
878
|
+
AutoPipelineForImage2Image,
|
|
879
|
+
AutoPipelineForInpainting,
|
|
880
|
+
diffusers_version,
|
|
881
|
+
) = _lazy_import_diffusers()
|
|
882
|
+
torch = _lazy_import_torch()
|
|
883
|
+
device = self._effective_device(torch)
|
|
884
|
+
_require_device_available(torch, device)
|
|
885
|
+
|
|
886
|
+
self._preflight_check_model_index()
|
|
887
|
+
_maybe_patch_transformers_clip_position_ids()
|
|
888
|
+
|
|
889
|
+
torch_dtype = _torch_dtype_from_str(torch, self._cfg.torch_dtype)
|
|
890
|
+
if torch_dtype is None:
|
|
891
|
+
torch_dtype = _default_torch_dtype_for_device(torch, device)
|
|
892
|
+
common = self._pipeline_common_kwargs()
|
|
893
|
+
if bool(self._cfg.low_cpu_mem_usage):
|
|
894
|
+
common["low_cpu_mem_usage"] = True
|
|
895
|
+
|
|
896
|
+
# Auto-select checkpoint variants when appropriate (best-effort).
|
|
897
|
+
# Prefer fp16 on GPU backends (CUDA/MPS) to cut memory/disk use, but never on CPU.
|
|
898
|
+
auto_variant: Optional[str] = None
|
|
899
|
+
if not str(getattr(self._cfg, "variant", "") or "").strip() and str(device).strip().lower() != "cpu":
|
|
900
|
+
if torch_dtype == getattr(torch, "float16", object()):
|
|
901
|
+
auto_variant = "fp16"
|
|
902
|
+
|
|
903
|
+
def _looks_like_missing_variant_error(e: Exception, variant: str) -> bool:
|
|
904
|
+
msg = str(e or "")
|
|
905
|
+
m = msg.lower()
|
|
906
|
+
v = str(variant or "").strip().lower()
|
|
907
|
+
if not v:
|
|
908
|
+
return False
|
|
909
|
+
return (
|
|
910
|
+
(f".{v}." in m or f" {v} " in m or f"'{v}'" in m)
|
|
911
|
+
and (
|
|
912
|
+
"no such file" in m
|
|
913
|
+
or "does not exist" in m
|
|
914
|
+
or "not found" in m
|
|
915
|
+
or "is not present" in m
|
|
916
|
+
or "couldn't find" in m
|
|
917
|
+
or "cannot find" in m
|
|
918
|
+
)
|
|
919
|
+
)
|
|
920
|
+
|
|
921
|
+
def _from_pretrained(cls: Any) -> Any:
|
|
922
|
+
if auto_variant:
|
|
923
|
+
common2 = dict(common)
|
|
924
|
+
common2["variant"] = auto_variant
|
|
925
|
+
try:
|
|
926
|
+
return cls.from_pretrained(self._cfg.model_id, torch_dtype=torch_dtype, **common2)
|
|
927
|
+
except Exception as e:
|
|
928
|
+
# If the repo doesn't provide the fp16 variant, fall back to regular weights.
|
|
929
|
+
if _looks_like_missing_variant_error(e, auto_variant):
|
|
930
|
+
return cls.from_pretrained(self._cfg.model_id, torch_dtype=torch_dtype, **common)
|
|
931
|
+
raise
|
|
932
|
+
return cls.from_pretrained(self._cfg.model_id, torch_dtype=torch_dtype, **common)
|
|
933
|
+
|
|
934
|
+
def _maybe_raise_offline_missing_model(e: Exception) -> None:
|
|
935
|
+
model_id = str(self._cfg.model_id or "").strip()
|
|
936
|
+
if not model_id or "/" not in model_id:
|
|
937
|
+
return
|
|
938
|
+
# If it's not in cache, provide a clearer message than the upstream
|
|
939
|
+
# "does not appear to have a file named model_index.json" wording.
|
|
940
|
+
if self._resolve_snapshot_dir() is not None:
|
|
941
|
+
return
|
|
942
|
+
msg = str(e)
|
|
943
|
+
if "model_index.json" not in msg:
|
|
944
|
+
return
|
|
945
|
+
raise ValueError(
|
|
946
|
+
f"Model {model_id!r} is not available locally and downloads are disabled. "
|
|
947
|
+
"Either pre-download it (e.g. via `huggingface-cli download ...`) or enable downloads "
|
|
948
|
+
"(set allow_download=True; for AbstractCore Server: set ABSTRACTCORE_VISION_ALLOW_DOWNLOAD=1). "
|
|
949
|
+
"If the model is gated, accept its terms on Hugging Face and set `HF_TOKEN` before downloading."
|
|
950
|
+
) from e
|
|
951
|
+
|
|
952
|
+
pipe = None
|
|
953
|
+
with _hf_offline_env(not bool(self._cfg.allow_download)):
|
|
954
|
+
if kind == "t2i":
|
|
955
|
+
# Prefer AutoPipeline when available, but fall back to DiffusionPipeline for robustness.
|
|
956
|
+
if AutoPipelineForText2Image is not None:
|
|
957
|
+
try:
|
|
958
|
+
pipe = _from_pretrained(AutoPipelineForText2Image)
|
|
959
|
+
except ValueError as e:
|
|
960
|
+
_maybe_raise_offline_missing_model(e)
|
|
961
|
+
pipe = None
|
|
962
|
+
if pipe is None:
|
|
963
|
+
try:
|
|
964
|
+
pipe = _from_pretrained(DiffusionPipeline)
|
|
965
|
+
except Exception as e:
|
|
966
|
+
_maybe_raise_offline_missing_model(e)
|
|
967
|
+
raise
|
|
968
|
+
elif kind == "i2i":
|
|
969
|
+
if AutoPipelineForImage2Image is not None:
|
|
970
|
+
try:
|
|
971
|
+
pipe = _from_pretrained(AutoPipelineForImage2Image)
|
|
972
|
+
except ValueError as e:
|
|
973
|
+
_maybe_raise_offline_missing_model(e)
|
|
974
|
+
pipe = None
|
|
975
|
+
if pipe is None:
|
|
976
|
+
try:
|
|
977
|
+
pipe = _from_pretrained(DiffusionPipeline)
|
|
978
|
+
except Exception as e:
|
|
979
|
+
_maybe_raise_offline_missing_model(e)
|
|
980
|
+
raise ValueError(
|
|
981
|
+
"Diffusers could not load an image-to-image pipeline for this model id. "
|
|
982
|
+
"Install/upgrade diffusers (and compatible transformers/torch), or use a model repo that "
|
|
983
|
+
"ships an image-to-image pipeline. "
|
|
984
|
+
f"(diffusers={diffusers_version})"
|
|
985
|
+
) from e
|
|
986
|
+
elif kind == "inpaint":
|
|
987
|
+
if AutoPipelineForInpainting is None:
|
|
988
|
+
raise ValueError(
|
|
989
|
+
"Diffusers inpainting pipeline is not available in this environment. "
|
|
990
|
+
"Install/upgrade diffusers (and compatible transformers/torch). "
|
|
991
|
+
f"(diffusers={diffusers_version})"
|
|
992
|
+
)
|
|
993
|
+
pipe = _from_pretrained(AutoPipelineForInpainting)
|
|
994
|
+
else:
|
|
995
|
+
raise ValueError(f"Unknown pipeline kind: {kind!r}")
|
|
996
|
+
|
|
997
|
+
# Diffusers pipelines support `.to(<device>)` with a string.
|
|
998
|
+
pipe = pipe.to(str(device))
|
|
999
|
+
_maybe_cast_pipe_modules_to_dtype(pipe, dtype=torch_dtype)
|
|
1000
|
+
_maybe_upcast_vae_for_mps(torch, pipe, device)
|
|
1001
|
+
self._pipelines[kind] = pipe
|
|
1002
|
+
self._call_params[kind] = _call_param_names(getattr(pipe, "__call__", None))
|
|
1003
|
+
return pipe
|
|
1004
|
+
|
|
1005
|
+
def _pil_from_bytes(self, data: bytes):
|
|
1006
|
+
Image = _lazy_import_pil()
|
|
1007
|
+
img = Image.open(io.BytesIO(bytes(data)))
|
|
1008
|
+
# Many pipelines expect RGB.
|
|
1009
|
+
return img.convert("RGB")
|
|
1010
|
+
|
|
1011
|
+
def _png_bytes(self, img) -> bytes:
|
|
1012
|
+
buf = io.BytesIO()
|
|
1013
|
+
img.save(buf, format="PNG")
|
|
1014
|
+
return buf.getvalue()
|
|
1015
|
+
|
|
1016
|
+
def _seed_generator(self, seed: Optional[int]):
|
|
1017
|
+
if seed is None:
|
|
1018
|
+
return None
|
|
1019
|
+
torch = _lazy_import_torch()
|
|
1020
|
+
d = str(self._effective_device(torch) or "").strip().lower()
|
|
1021
|
+
gen_device = "cpu" if d == "mps" or d.startswith("mps:") else str(self._effective_device(torch))
|
|
1022
|
+
try:
|
|
1023
|
+
gen = torch.Generator(device=gen_device)
|
|
1024
|
+
except Exception:
|
|
1025
|
+
gen = torch.Generator()
|
|
1026
|
+
gen.manual_seed(int(seed))
|
|
1027
|
+
return gen
|
|
1028
|
+
|
|
1029
|
+
def _is_probably_all_black_image(self, img: Any) -> bool:
|
|
1030
|
+
try:
|
|
1031
|
+
rgb = img.convert("RGB")
|
|
1032
|
+
extrema = rgb.getextrema()
|
|
1033
|
+
if isinstance(extrema, tuple) and len(extrema) == 2 and all(isinstance(x, int) for x in extrema):
|
|
1034
|
+
_, mx = extrema
|
|
1035
|
+
return mx <= 1
|
|
1036
|
+
if isinstance(extrema, tuple):
|
|
1037
|
+
return all(isinstance(x, tuple) and len(x) == 2 and int(x[1]) <= 1 for x in extrema)
|
|
1038
|
+
except Exception:
|
|
1039
|
+
return False
|
|
1040
|
+
return False
|
|
1041
|
+
|
|
1042
|
+
def _pipe_call(self, pipe: Any, kwargs: Dict[str, Any]):
|
|
1043
|
+
import warnings
|
|
1044
|
+
|
|
1045
|
+
call_kwargs = dict(kwargs)
|
|
1046
|
+
if callable(kwargs.get("__abstractvision_progress_callback")):
|
|
1047
|
+
progress_cb = kwargs.get("__abstractvision_progress_callback")
|
|
1048
|
+
total_steps = kwargs.get("__abstractvision_progress_total_steps")
|
|
1049
|
+
try:
|
|
1050
|
+
call_kwargs.pop("__abstractvision_progress_callback", None)
|
|
1051
|
+
call_kwargs.pop("__abstractvision_progress_total_steps", None)
|
|
1052
|
+
except Exception:
|
|
1053
|
+
pass
|
|
1054
|
+
try:
|
|
1055
|
+
call_kwargs = self._inject_progress_kwargs(
|
|
1056
|
+
pipe=pipe,
|
|
1057
|
+
kwargs=call_kwargs,
|
|
1058
|
+
progress_callback=progress_cb,
|
|
1059
|
+
total_steps=int(total_steps) if total_steps is not None else None,
|
|
1060
|
+
)
|
|
1061
|
+
except Exception:
|
|
1062
|
+
# Best-effort: never break inference for progress reporting.
|
|
1063
|
+
pass
|
|
1064
|
+
|
|
1065
|
+
with warnings.catch_warnings(record=True) as w:
|
|
1066
|
+
warnings.simplefilter("always", RuntimeWarning)
|
|
1067
|
+
with _hf_offline_env(not bool(self._cfg.allow_download)):
|
|
1068
|
+
out = pipe(**call_kwargs)
|
|
1069
|
+
had_invalid_cast = any(
|
|
1070
|
+
issubclass(getattr(x, "category", Warning), RuntimeWarning)
|
|
1071
|
+
and "invalid value encountered in cast" in str(getattr(x, "message", ""))
|
|
1072
|
+
for x in w
|
|
1073
|
+
)
|
|
1074
|
+
return out, had_invalid_cast
|
|
1075
|
+
|
|
1076
|
+
def _pipe_progress_param_names(self, pipe: Any) -> set[str]:
|
|
1077
|
+
fn = getattr(pipe, "__call__", None)
|
|
1078
|
+
if not callable(fn):
|
|
1079
|
+
return set()
|
|
1080
|
+
try:
|
|
1081
|
+
sig = inspect.signature(fn)
|
|
1082
|
+
except Exception:
|
|
1083
|
+
return set()
|
|
1084
|
+
return {str(k) for k in sig.parameters.keys() if str(k) != "self"}
|
|
1085
|
+
|
|
1086
|
+
def _inject_progress_kwargs(
|
|
1087
|
+
self,
|
|
1088
|
+
*,
|
|
1089
|
+
pipe: Any,
|
|
1090
|
+
kwargs: Dict[str, Any],
|
|
1091
|
+
progress_callback: Callable[[int, Optional[int]], None],
|
|
1092
|
+
total_steps: Optional[int],
|
|
1093
|
+
) -> Dict[str, Any]:
|
|
1094
|
+
names = self._pipe_progress_param_names(pipe)
|
|
1095
|
+
if not names:
|
|
1096
|
+
return kwargs
|
|
1097
|
+
|
|
1098
|
+
if "callback_on_step_end" in names:
|
|
1099
|
+
|
|
1100
|
+
def _on_step_end(*args: Any, **kw: Any) -> Any:
|
|
1101
|
+
# Expected signature: (pipe, step, timestep, callback_kwargs)
|
|
1102
|
+
step = None
|
|
1103
|
+
cb_kwargs = None
|
|
1104
|
+
try:
|
|
1105
|
+
if len(args) >= 2:
|
|
1106
|
+
step = args[1]
|
|
1107
|
+
if len(args) >= 4:
|
|
1108
|
+
cb_kwargs = args[3]
|
|
1109
|
+
if cb_kwargs is None:
|
|
1110
|
+
cb_kwargs = kw.get("callback_kwargs")
|
|
1111
|
+
except Exception:
|
|
1112
|
+
pass
|
|
1113
|
+
try:
|
|
1114
|
+
if step is not None:
|
|
1115
|
+
progress_callback(int(step) + 1, total_steps)
|
|
1116
|
+
except Exception:
|
|
1117
|
+
pass
|
|
1118
|
+
return cb_kwargs if cb_kwargs is not None else {}
|
|
1119
|
+
|
|
1120
|
+
kwargs["callback_on_step_end"] = _on_step_end
|
|
1121
|
+
# Avoid passing large tensors through callback_kwargs unless explicitly requested.
|
|
1122
|
+
if "callback_on_step_end_tensor_inputs" in names:
|
|
1123
|
+
kwargs.setdefault("callback_on_step_end_tensor_inputs", [])
|
|
1124
|
+
return kwargs
|
|
1125
|
+
|
|
1126
|
+
if "callback" in names:
|
|
1127
|
+
|
|
1128
|
+
def _callback(*args: Any, **_kw: Any) -> None:
|
|
1129
|
+
# Expected signature: (step, timestep, latents)
|
|
1130
|
+
try:
|
|
1131
|
+
if args:
|
|
1132
|
+
progress_callback(int(args[0]) + 1, total_steps)
|
|
1133
|
+
except Exception:
|
|
1134
|
+
pass
|
|
1135
|
+
|
|
1136
|
+
kwargs["callback"] = _callback
|
|
1137
|
+
if "callback_steps" in names:
|
|
1138
|
+
kwargs["callback_steps"] = 1
|
|
1139
|
+
return kwargs
|
|
1140
|
+
|
|
1141
|
+
return kwargs
|
|
1142
|
+
|
|
1143
|
+
def _maybe_retry_on_dtype_mismatch(
|
|
1144
|
+
self,
|
|
1145
|
+
*,
|
|
1146
|
+
kind: str,
|
|
1147
|
+
pipe: Any,
|
|
1148
|
+
kwargs: Dict[str, Any],
|
|
1149
|
+
error: Exception,
|
|
1150
|
+
progress_callback: Optional[Callable[[int, Optional[int]], None]] = None,
|
|
1151
|
+
total_steps: Optional[int] = None,
|
|
1152
|
+
) -> Optional[Any]:
|
|
1153
|
+
if not bool(getattr(self._cfg, "auto_retry_fp32", False)):
|
|
1154
|
+
return None
|
|
1155
|
+
if not _looks_like_dtype_mismatch_error(error):
|
|
1156
|
+
return None
|
|
1157
|
+
|
|
1158
|
+
torch = _lazy_import_torch()
|
|
1159
|
+
device = self._effective_device(torch)
|
|
1160
|
+
d = str(device or "").strip().lower()
|
|
1161
|
+
if not (d == "mps" or d.startswith("mps:")):
|
|
1162
|
+
return None
|
|
1163
|
+
|
|
1164
|
+
current_dtype = getattr(pipe, "dtype", None)
|
|
1165
|
+
if current_dtype is None:
|
|
1166
|
+
current_dtype = _torch_dtype_from_str(torch, self._cfg.torch_dtype) or _default_torch_dtype_for_device(
|
|
1167
|
+
torch, device
|
|
1168
|
+
)
|
|
1169
|
+
|
|
1170
|
+
candidates: list[Any] = []
|
|
1171
|
+
if current_dtype == getattr(torch, "bfloat16", object()):
|
|
1172
|
+
candidates.append(torch.float16)
|
|
1173
|
+
if current_dtype != getattr(torch, "float32", object()):
|
|
1174
|
+
candidates.append(torch.float32)
|
|
1175
|
+
|
|
1176
|
+
for target in candidates:
|
|
1177
|
+
try:
|
|
1178
|
+
pipe2 = pipe.to(device=str(device), dtype=target)
|
|
1179
|
+
except Exception:
|
|
1180
|
+
try:
|
|
1181
|
+
pipe2 = pipe.to(dtype=target)
|
|
1182
|
+
pipe2 = pipe2.to(str(device))
|
|
1183
|
+
except Exception:
|
|
1184
|
+
continue
|
|
1185
|
+
|
|
1186
|
+
_maybe_upcast_vae_for_mps(torch, pipe2, device)
|
|
1187
|
+
self._pipelines[kind] = pipe2
|
|
1188
|
+
self._call_params[kind] = _call_param_names(getattr(pipe2, "__call__", None))
|
|
1189
|
+
|
|
1190
|
+
try:
|
|
1191
|
+
call_kwargs = dict(kwargs)
|
|
1192
|
+
if progress_callback is not None:
|
|
1193
|
+
call_kwargs["__abstractvision_progress_callback"] = progress_callback
|
|
1194
|
+
call_kwargs["__abstractvision_progress_total_steps"] = total_steps
|
|
1195
|
+
out2, _had_invalid_cast2 = self._pipe_call(pipe2, call_kwargs)
|
|
1196
|
+
return out2
|
|
1197
|
+
except Exception:
|
|
1198
|
+
continue
|
|
1199
|
+
return None
|
|
1200
|
+
|
|
1201
|
+
def _maybe_retry_fp32_on_invalid_output(
|
|
1202
|
+
self,
|
|
1203
|
+
*,
|
|
1204
|
+
kind: str,
|
|
1205
|
+
pipe: Any,
|
|
1206
|
+
kwargs: Dict[str, Any],
|
|
1207
|
+
progress_callback: Optional[Callable[[int, Optional[int]], None]] = None,
|
|
1208
|
+
total_steps: Optional[int] = None,
|
|
1209
|
+
) -> Optional[Any]:
|
|
1210
|
+
if not bool(getattr(self._cfg, "auto_retry_fp32", False)):
|
|
1211
|
+
return None
|
|
1212
|
+
torch = _lazy_import_torch()
|
|
1213
|
+
device = self._effective_device(torch)
|
|
1214
|
+
d = str(device or "").strip().lower()
|
|
1215
|
+
cfg_dtype = _torch_dtype_from_str(torch, self._cfg.torch_dtype)
|
|
1216
|
+
if cfg_dtype is None:
|
|
1217
|
+
cfg_dtype = _default_torch_dtype_for_device(torch, device)
|
|
1218
|
+
|
|
1219
|
+
# Currently, we only auto-retry on Apple Silicon / MPS when running fp16,
|
|
1220
|
+
# because NaNs/black images are common for some models (e.g. Qwen Image).
|
|
1221
|
+
if not (d == "mps" or d.startswith("mps:")):
|
|
1222
|
+
return None
|
|
1223
|
+
if cfg_dtype != torch.float16:
|
|
1224
|
+
return None
|
|
1225
|
+
|
|
1226
|
+
try:
|
|
1227
|
+
pipe_fp32 = pipe.to(device=str(device), dtype=torch.float32)
|
|
1228
|
+
except Exception:
|
|
1229
|
+
try:
|
|
1230
|
+
pipe_fp32 = pipe.to(dtype=torch.float32)
|
|
1231
|
+
pipe_fp32 = pipe_fp32.to(str(device))
|
|
1232
|
+
except Exception:
|
|
1233
|
+
return None
|
|
1234
|
+
|
|
1235
|
+
_maybe_upcast_vae_for_mps(torch, pipe_fp32, device)
|
|
1236
|
+
self._pipelines[kind] = pipe_fp32
|
|
1237
|
+
self._call_params[kind] = _call_param_names(getattr(pipe_fp32, "__call__", None))
|
|
1238
|
+
|
|
1239
|
+
call_kwargs = dict(kwargs)
|
|
1240
|
+
if progress_callback is not None:
|
|
1241
|
+
call_kwargs["__abstractvision_progress_callback"] = progress_callback
|
|
1242
|
+
call_kwargs["__abstractvision_progress_total_steps"] = total_steps
|
|
1243
|
+
out2, had_invalid_cast2 = self._pipe_call(pipe_fp32, call_kwargs)
|
|
1244
|
+
if had_invalid_cast2:
|
|
1245
|
+
raise ValueError(
|
|
1246
|
+
"Diffusers produced invalid pixel values (NaNs) while decoding the image "
|
|
1247
|
+
"(resulting in an all-black output). "
|
|
1248
|
+
"Tried an automatic fp32 retry on MPS and it still failed. "
|
|
1249
|
+
"Try setting torch_dtype=float32 explicitly, increasing steps, or use the stable-diffusion.cpp backend."
|
|
1250
|
+
)
|
|
1251
|
+
return out2
|
|
1252
|
+
|
|
1253
|
+
def generate_image(self, request: ImageGenerationRequest) -> GeneratedAsset:
|
|
1254
|
+
return self.generate_image_with_progress(request, progress_callback=None)
|
|
1255
|
+
|
|
1256
|
+
def generate_image_with_progress(
|
|
1257
|
+
self,
|
|
1258
|
+
request: ImageGenerationRequest,
|
|
1259
|
+
progress_callback: Optional[Callable[[int, Optional[int]], None]] = None,
|
|
1260
|
+
) -> GeneratedAsset:
|
|
1261
|
+
pipe = self._get_or_load_pipeline("t2i")
|
|
1262
|
+
call_params = self._call_params.get("t2i")
|
|
1263
|
+
total_steps = int(request.steps) if request.steps is not None else None
|
|
1264
|
+
|
|
1265
|
+
torch_dtype = getattr(pipe, "dtype", None)
|
|
1266
|
+
if torch_dtype is None:
|
|
1267
|
+
torch = _lazy_import_torch()
|
|
1268
|
+
device = self._effective_device(torch)
|
|
1269
|
+
torch_dtype = _torch_dtype_from_str(torch, self._cfg.torch_dtype) or _default_torch_dtype_for_device(torch, device)
|
|
1270
|
+
rapid_repo = self._maybe_apply_rapid_aio_transformer(pipe=pipe, extra=request.extra, torch_dtype=torch_dtype)
|
|
1271
|
+
lora_sig = self._apply_loras(kind="t2i", pipe=pipe, extra=request.extra)
|
|
1272
|
+
|
|
1273
|
+
kwargs: Dict[str, Any] = {
|
|
1274
|
+
"prompt": request.prompt,
|
|
1275
|
+
}
|
|
1276
|
+
if request.negative_prompt is not None:
|
|
1277
|
+
kwargs["negative_prompt"] = request.negative_prompt
|
|
1278
|
+
if request.width is not None:
|
|
1279
|
+
kwargs["width"] = int(request.width)
|
|
1280
|
+
if request.height is not None:
|
|
1281
|
+
kwargs["height"] = int(request.height)
|
|
1282
|
+
if request.steps is not None:
|
|
1283
|
+
kwargs["num_inference_steps"] = int(request.steps)
|
|
1284
|
+
if request.guidance_scale is not None:
|
|
1285
|
+
if call_params is not None and "true_cfg_scale" in call_params:
|
|
1286
|
+
kwargs["true_cfg_scale"] = float(request.guidance_scale)
|
|
1287
|
+
# Some pipelines (e.g. Qwen Image) only enable CFG when a `negative_prompt`
|
|
1288
|
+
# is provided (even an empty one). Make `guidance_scale` behave consistently.
|
|
1289
|
+
if request.negative_prompt is None and (call_params is None or "negative_prompt" in call_params):
|
|
1290
|
+
kwargs["negative_prompt"] = " "
|
|
1291
|
+
else:
|
|
1292
|
+
kwargs["guidance_scale"] = float(request.guidance_scale)
|
|
1293
|
+
gen = self._seed_generator(request.seed)
|
|
1294
|
+
if gen is not None:
|
|
1295
|
+
kwargs["generator"] = gen
|
|
1296
|
+
|
|
1297
|
+
if isinstance(request.extra, dict) and request.extra:
|
|
1298
|
+
kwargs.update(dict(request.extra))
|
|
1299
|
+
|
|
1300
|
+
try:
|
|
1301
|
+
call_kwargs = dict(kwargs)
|
|
1302
|
+
if progress_callback is not None:
|
|
1303
|
+
call_kwargs["__abstractvision_progress_callback"] = progress_callback
|
|
1304
|
+
call_kwargs["__abstractvision_progress_total_steps"] = total_steps
|
|
1305
|
+
out, had_invalid_cast = self._pipe_call(pipe, call_kwargs)
|
|
1306
|
+
except Exception as e:
|
|
1307
|
+
out2 = self._maybe_retry_on_dtype_mismatch(
|
|
1308
|
+
kind="t2i",
|
|
1309
|
+
pipe=pipe,
|
|
1310
|
+
kwargs=kwargs,
|
|
1311
|
+
error=e,
|
|
1312
|
+
progress_callback=progress_callback,
|
|
1313
|
+
total_steps=total_steps,
|
|
1314
|
+
)
|
|
1315
|
+
if out2 is None:
|
|
1316
|
+
raise
|
|
1317
|
+
out, had_invalid_cast = out2, False
|
|
1318
|
+
retried_fp32 = False
|
|
1319
|
+
images = getattr(out, "images", None)
|
|
1320
|
+
if not isinstance(images, list) or not images:
|
|
1321
|
+
raise ValueError("Diffusers pipeline returned no images")
|
|
1322
|
+
if self._is_probably_all_black_image(images[0]):
|
|
1323
|
+
out2 = self._maybe_retry_fp32_on_invalid_output(
|
|
1324
|
+
kind="t2i",
|
|
1325
|
+
pipe=pipe,
|
|
1326
|
+
kwargs=kwargs,
|
|
1327
|
+
progress_callback=progress_callback,
|
|
1328
|
+
total_steps=total_steps,
|
|
1329
|
+
)
|
|
1330
|
+
if out2 is not None:
|
|
1331
|
+
out = out2
|
|
1332
|
+
retried_fp32 = True
|
|
1333
|
+
images = getattr(out, "images", None)
|
|
1334
|
+
if not isinstance(images, list) or not images:
|
|
1335
|
+
raise ValueError("Diffusers pipeline returned no images")
|
|
1336
|
+
if self._is_probably_all_black_image(images[0]):
|
|
1337
|
+
raise ValueError(
|
|
1338
|
+
"Diffusers produced an all-black image output. "
|
|
1339
|
+
+ (
|
|
1340
|
+
"An automatic fp32 retry was attempted and it still produced an all-black image. "
|
|
1341
|
+
if retried_fp32
|
|
1342
|
+
else "Try setting torch_dtype=float32. "
|
|
1343
|
+
)
|
|
1344
|
+
+ "Try increasing steps, adjusting guidance_scale, or use the stable-diffusion.cpp backend."
|
|
1345
|
+
)
|
|
1346
|
+
png = self._png_bytes(images[0])
|
|
1347
|
+
meta = {"source": "diffusers", "model_id": self._cfg.model_id}
|
|
1348
|
+
if rapid_repo:
|
|
1349
|
+
meta["rapid_aio_repo"] = rapid_repo
|
|
1350
|
+
if lora_sig:
|
|
1351
|
+
meta["lora_signature"] = lora_sig
|
|
1352
|
+
if retried_fp32:
|
|
1353
|
+
meta["retried_fp32"] = True
|
|
1354
|
+
if had_invalid_cast:
|
|
1355
|
+
meta["had_invalid_cast_warning"] = True
|
|
1356
|
+
try:
|
|
1357
|
+
current_pipe = self._pipelines.get("t2i", pipe)
|
|
1358
|
+
dtype = getattr(current_pipe, "dtype", None)
|
|
1359
|
+
device = getattr(current_pipe, "device", None)
|
|
1360
|
+
if dtype is not None:
|
|
1361
|
+
meta["dtype"] = str(dtype)
|
|
1362
|
+
if device is not None:
|
|
1363
|
+
meta["device"] = str(device)
|
|
1364
|
+
except Exception:
|
|
1365
|
+
pass
|
|
1366
|
+
return GeneratedAsset(
|
|
1367
|
+
media_type="image",
|
|
1368
|
+
data=png,
|
|
1369
|
+
mime_type="image/png",
|
|
1370
|
+
metadata=meta,
|
|
1371
|
+
)
|
|
1372
|
+
|
|
1373
|
+
def edit_image(self, request: ImageEditRequest) -> GeneratedAsset:
|
|
1374
|
+
return self.edit_image_with_progress(request, progress_callback=None)
|
|
1375
|
+
|
|
1376
|
+
def edit_image_with_progress(
|
|
1377
|
+
self,
|
|
1378
|
+
request: ImageEditRequest,
|
|
1379
|
+
progress_callback: Optional[Callable[[int, Optional[int]], None]] = None,
|
|
1380
|
+
) -> GeneratedAsset:
|
|
1381
|
+
if request.mask is not None:
|
|
1382
|
+
pipe = self._get_or_load_pipeline("inpaint")
|
|
1383
|
+
call_params = self._call_params.get("inpaint")
|
|
1384
|
+
kind = "inpaint"
|
|
1385
|
+
else:
|
|
1386
|
+
pipe = self._get_or_load_pipeline("i2i")
|
|
1387
|
+
call_params = self._call_params.get("i2i")
|
|
1388
|
+
kind = "i2i"
|
|
1389
|
+
|
|
1390
|
+
total_steps = int(request.steps) if request.steps is not None else None
|
|
1391
|
+
|
|
1392
|
+
torch_dtype = getattr(pipe, "dtype", None)
|
|
1393
|
+
if torch_dtype is None:
|
|
1394
|
+
torch = _lazy_import_torch()
|
|
1395
|
+
device = self._effective_device(torch)
|
|
1396
|
+
torch_dtype = _torch_dtype_from_str(torch, self._cfg.torch_dtype) or _default_torch_dtype_for_device(torch, device)
|
|
1397
|
+
rapid_repo = self._maybe_apply_rapid_aio_transformer(pipe=pipe, extra=request.extra, torch_dtype=torch_dtype)
|
|
1398
|
+
lora_sig = self._apply_loras(kind=kind, pipe=pipe, extra=request.extra)
|
|
1399
|
+
|
|
1400
|
+
img = self._pil_from_bytes(request.image)
|
|
1401
|
+
kwargs: Dict[str, Any] = {"prompt": request.prompt, "image": img}
|
|
1402
|
+
if request.mask is not None:
|
|
1403
|
+
kwargs["mask_image"] = self._pil_from_bytes(request.mask)
|
|
1404
|
+
if request.negative_prompt is not None:
|
|
1405
|
+
kwargs["negative_prompt"] = request.negative_prompt
|
|
1406
|
+
if request.steps is not None:
|
|
1407
|
+
kwargs["num_inference_steps"] = int(request.steps)
|
|
1408
|
+
if request.guidance_scale is not None:
|
|
1409
|
+
if call_params is not None and "true_cfg_scale" in call_params:
|
|
1410
|
+
kwargs["true_cfg_scale"] = float(request.guidance_scale)
|
|
1411
|
+
if request.negative_prompt is None and (call_params is None or "negative_prompt" in call_params):
|
|
1412
|
+
kwargs["negative_prompt"] = " "
|
|
1413
|
+
else:
|
|
1414
|
+
kwargs["guidance_scale"] = float(request.guidance_scale)
|
|
1415
|
+
gen = self._seed_generator(request.seed)
|
|
1416
|
+
if gen is not None:
|
|
1417
|
+
kwargs["generator"] = gen
|
|
1418
|
+
|
|
1419
|
+
if isinstance(request.extra, dict) and request.extra:
|
|
1420
|
+
kwargs.update(dict(request.extra))
|
|
1421
|
+
|
|
1422
|
+
try:
|
|
1423
|
+
call_kwargs = dict(kwargs)
|
|
1424
|
+
if progress_callback is not None:
|
|
1425
|
+
call_kwargs["__abstractvision_progress_callback"] = progress_callback
|
|
1426
|
+
call_kwargs["__abstractvision_progress_total_steps"] = total_steps
|
|
1427
|
+
out, had_invalid_cast = self._pipe_call(pipe, call_kwargs)
|
|
1428
|
+
except Exception as e:
|
|
1429
|
+
out2 = self._maybe_retry_on_dtype_mismatch(
|
|
1430
|
+
kind=kind,
|
|
1431
|
+
pipe=pipe,
|
|
1432
|
+
kwargs=kwargs,
|
|
1433
|
+
error=e,
|
|
1434
|
+
progress_callback=progress_callback,
|
|
1435
|
+
total_steps=total_steps,
|
|
1436
|
+
)
|
|
1437
|
+
if out2 is None:
|
|
1438
|
+
raise
|
|
1439
|
+
out, had_invalid_cast = out2, False
|
|
1440
|
+
retried_fp32 = False
|
|
1441
|
+
images = getattr(out, "images", None)
|
|
1442
|
+
if not isinstance(images, list) or not images:
|
|
1443
|
+
raise ValueError("Diffusers pipeline returned no images")
|
|
1444
|
+
if self._is_probably_all_black_image(images[0]):
|
|
1445
|
+
kind = "inpaint" if request.mask is not None else "i2i"
|
|
1446
|
+
out2 = self._maybe_retry_fp32_on_invalid_output(
|
|
1447
|
+
kind=kind,
|
|
1448
|
+
pipe=pipe,
|
|
1449
|
+
kwargs=kwargs,
|
|
1450
|
+
progress_callback=progress_callback,
|
|
1451
|
+
total_steps=total_steps,
|
|
1452
|
+
)
|
|
1453
|
+
if out2 is not None:
|
|
1454
|
+
out = out2
|
|
1455
|
+
retried_fp32 = True
|
|
1456
|
+
images = getattr(out, "images", None)
|
|
1457
|
+
if not isinstance(images, list) or not images:
|
|
1458
|
+
raise ValueError("Diffusers pipeline returned no images")
|
|
1459
|
+
if self._is_probably_all_black_image(images[0]):
|
|
1460
|
+
raise ValueError(
|
|
1461
|
+
"Diffusers produced an all-black image output. "
|
|
1462
|
+
+ (
|
|
1463
|
+
"An automatic fp32 retry was attempted and it still produced an all-black image. "
|
|
1464
|
+
if retried_fp32
|
|
1465
|
+
else "Try setting torch_dtype=bfloat16 (recommended on MPS) or torch_dtype=float32. "
|
|
1466
|
+
)
|
|
1467
|
+
+ "Try increasing steps, adjusting guidance_scale, or use the stable-diffusion.cpp backend."
|
|
1468
|
+
)
|
|
1469
|
+
png = self._png_bytes(images[0])
|
|
1470
|
+
meta = {"source": "diffusers", "model_id": self._cfg.model_id}
|
|
1471
|
+
if rapid_repo:
|
|
1472
|
+
meta["rapid_aio_repo"] = rapid_repo
|
|
1473
|
+
if lora_sig:
|
|
1474
|
+
meta["lora_signature"] = lora_sig
|
|
1475
|
+
if retried_fp32:
|
|
1476
|
+
meta["retried_fp32"] = True
|
|
1477
|
+
if had_invalid_cast:
|
|
1478
|
+
meta["had_invalid_cast_warning"] = True
|
|
1479
|
+
try:
|
|
1480
|
+
current_pipe = self._pipelines.get(kind, pipe)
|
|
1481
|
+
dtype = getattr(current_pipe, "dtype", None)
|
|
1482
|
+
device = getattr(current_pipe, "device", None)
|
|
1483
|
+
if dtype is not None:
|
|
1484
|
+
meta["dtype"] = str(dtype)
|
|
1485
|
+
if device is not None:
|
|
1486
|
+
meta["device"] = str(device)
|
|
1487
|
+
except Exception:
|
|
1488
|
+
pass
|
|
1489
|
+
return GeneratedAsset(
|
|
1490
|
+
media_type="image",
|
|
1491
|
+
data=png,
|
|
1492
|
+
mime_type="image/png",
|
|
1493
|
+
metadata=meta,
|
|
1494
|
+
)
|
|
1495
|
+
|
|
1496
|
+
def generate_angles(self, request: MultiAngleRequest) -> list[GeneratedAsset]:
|
|
1497
|
+
raise CapabilityNotSupportedError("HuggingFaceDiffusersVisionBackend does not implement multi-view generation.")
|
|
1498
|
+
|
|
1499
|
+
def generate_video(self, request: VideoGenerationRequest) -> GeneratedAsset:
|
|
1500
|
+
raise CapabilityNotSupportedError("HuggingFaceDiffusersVisionBackend does not implement text_to_video (phase 2).")
|
|
1501
|
+
|
|
1502
|
+
def image_to_video(self, request: ImageToVideoRequest) -> GeneratedAsset:
|
|
1503
|
+
raise CapabilityNotSupportedError("HuggingFaceDiffusersVisionBackend does not implement image_to_video (phase 2).")
|