abstractvision 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1503 @@
1
+ from __future__ import annotations
2
+
3
+ import io
4
+ import inspect
5
+ import os
6
+ import hashlib
7
+ from contextlib import contextmanager
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
11
+
12
+ from ..errors import CapabilityNotSupportedError, OptionalDependencyMissingError
13
+ from ..types import (
14
+ GeneratedAsset,
15
+ ImageEditRequest,
16
+ ImageGenerationRequest,
17
+ ImageToVideoRequest,
18
+ MultiAngleRequest,
19
+ VideoGenerationRequest,
20
+ VisionBackendCapabilities,
21
+ )
22
+ from .base_backend import VisionBackend
23
+
24
+
25
+ def _require_optional_dep(name: str, install_hint: str) -> None:
26
+ import sys
27
+
28
+ raise OptionalDependencyMissingError(
29
+ f"Optional dependency missing: {name}. Install via: {install_hint} " f"(python={sys.executable})"
30
+ )
31
+
32
+
33
+ def _lazy_import_diffusers():
34
+ try:
35
+ import warnings
36
+
37
+ # Some Diffusers modules decorate functions with `torch.autocast(device_type="cuda", ...)`,
38
+ # which emits noisy warnings on non-CUDA machines (including Apple Silicon / MPS).
39
+ warnings.filterwarnings("ignore", message=r".*CUDA is not available.*Disabling autocast.*", category=UserWarning)
40
+ warnings.filterwarnings(
41
+ "ignore",
42
+ message=r".*device_type of 'cuda'.*CUDA is not available.*",
43
+ category=UserWarning,
44
+ )
45
+ import diffusers # type: ignore
46
+ from diffusers import DiffusionPipeline # type: ignore
47
+ except Exception as e: # pragma: no cover
48
+ raise OptionalDependencyMissingError(
49
+ "Optional dependency missing (or failed to import): diffusers. Install via: pip install 'diffusers'. "
50
+ f"(python={__import__('sys').executable})"
51
+ ) from e
52
+
53
+ # AutoPipeline classes are optional here. Some environments may have diffusers installed but fail to import
54
+ # AutoPipeline due to version mismatches with transformers/torch or other optional deps. We can still load
55
+ # many text-to-image models via `DiffusionPipeline` and only require AutoPipeline for i2i/inpaint.
56
+ AutoPipelineForText2Image = None
57
+ AutoPipelineForImage2Image = None
58
+ AutoPipelineForInpainting = None
59
+ try:
60
+ from diffusers import AutoPipelineForText2Image as _AutoPipelineForText2Image # type: ignore
61
+
62
+ AutoPipelineForText2Image = _AutoPipelineForText2Image
63
+ except Exception:
64
+ pass
65
+ try:
66
+ from diffusers import AutoPipelineForImage2Image as _AutoPipelineForImage2Image # type: ignore
67
+
68
+ AutoPipelineForImage2Image = _AutoPipelineForImage2Image
69
+ except Exception:
70
+ pass
71
+ try:
72
+ from diffusers import AutoPipelineForInpainting as _AutoPipelineForInpainting # type: ignore
73
+
74
+ AutoPipelineForInpainting = _AutoPipelineForInpainting
75
+ except Exception:
76
+ pass
77
+
78
+ return (
79
+ DiffusionPipeline,
80
+ AutoPipelineForText2Image,
81
+ AutoPipelineForImage2Image,
82
+ AutoPipelineForInpainting,
83
+ getattr(diffusers, "__version__", "unknown"),
84
+ )
85
+
86
+
87
+ def _lazy_import_torch():
88
+ try:
89
+ import torch # type: ignore
90
+ except Exception: # pragma: no cover
91
+ _require_optional_dep("torch", "pip install 'torch'")
92
+ return torch
93
+
94
+
95
+ def _lazy_import_pil():
96
+ try:
97
+ from PIL import Image # type: ignore
98
+ except Exception: # pragma: no cover
99
+ _require_optional_dep("pillow", "pip install 'pillow'")
100
+ return Image
101
+
102
+
103
+ _TRANSFORMERS_CLIP_POSITION_IDS_PATCHED = False
104
+
105
+
106
+ def _maybe_patch_transformers_clip_position_ids() -> None:
107
+ """Fix Transformers v5 noisy LOAD REPORTs for common CLIP checkpoints.
108
+
109
+ Transformers 5 logs a detailed load report when encountering unexpected keys like
110
+ `*.embeddings.position_ids` in older CLIP checkpoints (e.g. SD1.5 text encoder / safety checker).
111
+
112
+ The root cause is a small architecture/state-dict mismatch: those checkpoints include a persistent
113
+ `position_ids` buffer, while newer CLIP embedding classes may not. We re-add that buffer so the
114
+ checkpoint matches the instantiated model and no "UNEXPECTED" keys are reported.
115
+ """
116
+
117
+ global _TRANSFORMERS_CLIP_POSITION_IDS_PATCHED
118
+ if _TRANSFORMERS_CLIP_POSITION_IDS_PATCHED:
119
+ return
120
+
121
+ try:
122
+ import transformers # type: ignore
123
+ import torch as _torch # type: ignore
124
+ except Exception:
125
+ return
126
+
127
+ ver = str(getattr(transformers, "__version__", "0"))
128
+ try:
129
+ major = int(ver.split(".", 1)[0])
130
+ except Exception:
131
+ major = 0
132
+ if major < 5:
133
+ _TRANSFORMERS_CLIP_POSITION_IDS_PATCHED = True
134
+ return
135
+
136
+ try:
137
+ from transformers.models.clip.modeling_clip import CLIPTextEmbeddings, CLIPVisionEmbeddings # type: ignore
138
+ except Exception:
139
+ _TRANSFORMERS_CLIP_POSITION_IDS_PATCHED = True
140
+ return
141
+
142
+ def _patch(cls: Any) -> None:
143
+ if bool(getattr(cls, "_abstractvision_position_ids_patched", False)):
144
+ return
145
+ orig_init = getattr(cls, "__init__", None)
146
+ if not callable(orig_init):
147
+ return
148
+
149
+ def __init__(self, *args, **kwargs): # type: ignore[no-redef]
150
+ orig_init(self, *args, **kwargs)
151
+ if hasattr(self, "position_ids"):
152
+ # In Transformers 5, `position_ids` is sometimes registered as a non-persistent buffer
153
+ # (`persistent=False`), so it isn't part of the state dict and is reported as UNEXPECTED
154
+ # when loading older checkpoints that include it. Make it persistent.
155
+ try:
156
+ buffers = getattr(self, "_buffers", None)
157
+ if isinstance(buffers, dict) and "position_ids" in buffers:
158
+ non_persistent = getattr(self, "_non_persistent_buffers_set", None)
159
+ if isinstance(non_persistent, set):
160
+ non_persistent.discard("position_ids")
161
+ return
162
+ except Exception:
163
+ return
164
+ pos_emb = getattr(self, "position_embedding", None)
165
+ num = getattr(pos_emb, "num_embeddings", None) if pos_emb is not None else None
166
+ if num is None:
167
+ return
168
+ try:
169
+ position_ids = _torch.arange(int(num)).unsqueeze(0)
170
+ self.register_buffer("position_ids", position_ids, persistent=True)
171
+ except Exception:
172
+ return
173
+
174
+ setattr(cls, "__init__", __init__)
175
+ setattr(cls, "_abstractvision_position_ids_patched", True)
176
+
177
+ _patch(CLIPTextEmbeddings)
178
+ _patch(CLIPVisionEmbeddings)
179
+ _TRANSFORMERS_CLIP_POSITION_IDS_PATCHED = True
180
+
181
+
182
+ @contextmanager
183
+ def _hf_offline_env(enabled: bool):
184
+ """Control Hugging Face offline mode within a scope.
185
+
186
+ When `enabled=True`, we force offline mode (no network calls).
187
+ When `enabled=False`, we force online mode (overrides e.g. HF_HUB_OFFLINE=1 in the user's shell).
188
+ """
189
+
190
+ # These are respected by huggingface_hub / transformers / diffusers.
191
+ # We scope them to the load/call to avoid surprising other parts of the process.
192
+ vars_to_set = {
193
+ "HF_HUB_OFFLINE": "1" if enabled else "0",
194
+ "TRANSFORMERS_OFFLINE": "1" if enabled else "0",
195
+ "DIFFUSERS_OFFLINE": "1" if enabled else "0",
196
+ # Avoid any telemetry even in edge cases.
197
+ "HF_HUB_DISABLE_TELEMETRY": "1",
198
+ }
199
+ old = {k: os.environ.get(k) for k in vars_to_set.keys()}
200
+ try:
201
+ for k, v in vars_to_set.items():
202
+ os.environ[k] = v
203
+ yield
204
+ finally:
205
+ for k, prev in old.items():
206
+ if prev is None:
207
+ os.environ.pop(k, None)
208
+ else:
209
+ os.environ[k] = prev
210
+
211
+
212
+ def _torch_dtype_from_str(torch: Any, value: Optional[str]) -> Any:
213
+ if value is None:
214
+ return None
215
+ v = str(value).strip().lower()
216
+ if not v or v == "auto":
217
+ return None
218
+ if v in {"float16", "fp16"}:
219
+ return torch.float16
220
+ if v in {"bfloat16", "bf16"}:
221
+ return torch.bfloat16
222
+ if v in {"float32", "fp32"}:
223
+ return torch.float32
224
+ raise ValueError(f"Unsupported torch_dtype: {value!r}")
225
+
226
+
227
+ def _default_torch_dtype_for_device(torch: Any, device: str) -> Any:
228
+ d = str(device or "").strip().lower()
229
+ if not d:
230
+ return None
231
+ if d.startswith("cuda"):
232
+ return torch.float16
233
+ if d == "mps" or d.startswith("mps:"):
234
+ # Default to fp16 on Apple Silicon for broad model compatibility.
235
+ # (Some pipelines mix dtypes when using bf16, which can crash with matmul dtype mismatches.)
236
+ #
237
+ # You can still force bf16 explicitly via `torch_dtype=bfloat16`.
238
+ return torch.float16
239
+ return None
240
+
241
+
242
+ def _require_device_available(torch: Any, device: str) -> None:
243
+ d = str(device or "").strip().lower()
244
+ if not d:
245
+ return
246
+
247
+ if d.startswith("cuda"):
248
+ cuda = getattr(torch, "cuda", None)
249
+ is_available = getattr(cuda, "is_available", None) if cuda is not None else None
250
+ ok = bool(is_available()) if callable(is_available) else False
251
+ if not ok:
252
+ raise ValueError(
253
+ "Device 'cuda' was requested, but torch.cuda.is_available() is False. "
254
+ "Install a CUDA-enabled PyTorch build or use device='cpu'."
255
+ )
256
+
257
+ if d == "mps" or d.startswith("mps:"):
258
+ backends = getattr(torch, "backends", None)
259
+ mps = getattr(backends, "mps", None) if backends is not None else None
260
+ is_available = getattr(mps, "is_available", None) if mps is not None else None
261
+ ok = bool(is_available()) if callable(is_available) else False
262
+ if not ok:
263
+ raise ValueError(
264
+ "Device 'mps' was requested, but torch.backends.mps.is_available() is False. "
265
+ "On macOS this typically means you are not using an Apple Silicon + MPS-enabled PyTorch build. "
266
+ "Use device='cpu', or use the stable-diffusion.cpp (sd-cli) backend for GGUF models."
267
+ )
268
+
269
+
270
+ def _call_param_names(fn: Any) -> Optional[set[str]]:
271
+ try:
272
+ sig = inspect.signature(fn)
273
+ for p in sig.parameters.values():
274
+ if p.kind == p.VAR_KEYWORD:
275
+ return None
276
+ return {str(k) for k in sig.parameters.keys() if str(k) != "self"}
277
+ except Exception:
278
+ return None
279
+
280
+
281
+ def _looks_like_dtype_mismatch_error(e: Exception) -> bool:
282
+ msg = str(e or "")
283
+ m = msg.lower()
284
+ return (
285
+ "must have the same dtype" in m
286
+ or ("input type" in m and "bias type" in m and "should be the same" in m)
287
+ or ("expected scalar type" in m and "but found" in m)
288
+ )
289
+
290
+
291
+ def _maybe_upcast_vae_for_mps(torch: Any, pipe: Any, device: str) -> None:
292
+ d = str(device or "").strip().lower()
293
+ if d != "mps" and not d.startswith("mps:"):
294
+ return
295
+
296
+ # On Apple Silicon, some pipelines can produce NaNs/black images when decoding with a float16 VAE.
297
+ # A common fix is to keep the main model in fp16 but run VAE encode/decode in fp32.
298
+ #
299
+ # Diffusers pipelines do not consistently cast inputs to `vae.dtype` before calling `vae.encode/decode`.
300
+ # If we upcast only the VAE weights to fp32 while the pipeline still produces fp16 latents/images,
301
+ # PyTorch can raise dtype mismatch errors like:
302
+ # "Input type (c10::Half) and bias type (float) should be the same"
303
+ #
304
+ # To keep this backend robust across Diffusers versions, when we upcast the VAE we also wrap
305
+ # `vae.encode` and `vae.decode` to cast their tensor inputs to the VAE's dtype.
306
+ vae = getattr(pipe, "vae", None)
307
+ if vae is None:
308
+ return
309
+ to_fn = getattr(vae, "to", None)
310
+ if not callable(to_fn):
311
+ return
312
+ dtype = getattr(vae, "dtype", None)
313
+ if dtype == getattr(torch, "float16", None):
314
+ try:
315
+ vae.to(dtype=torch.float32)
316
+ _maybe_cast_vae_inputs_to_dtype(vae)
317
+ except Exception:
318
+ return
319
+
320
+
321
+ def _maybe_cast_pipe_modules_to_dtype(pipe: Any, *, dtype: Any) -> None:
322
+ if dtype is None:
323
+ return
324
+
325
+ def _to(module: Any) -> None:
326
+ if module is None:
327
+ return
328
+ to_fn = getattr(module, "to", None)
329
+ if not callable(to_fn):
330
+ return
331
+ try:
332
+ module.to(dtype=dtype)
333
+ except Exception:
334
+ return
335
+
336
+ # Best-effort: different pipelines use different component names (unet vs transformer, etc).
337
+ for attr in (
338
+ "model",
339
+ "transformer",
340
+ "unet",
341
+ "text_encoder",
342
+ "text_encoder_2",
343
+ "image_encoder",
344
+ "prior",
345
+ "vae",
346
+ "safety_checker",
347
+ ):
348
+ _to(getattr(pipe, attr, None))
349
+
350
+ vae = getattr(pipe, "vae", None)
351
+ if vae is not None:
352
+ _to(getattr(vae, "encoder", None))
353
+ _to(getattr(vae, "decoder", None))
354
+
355
+ # As a fallback, cast all registered components when available (covers pipelines that don't follow
356
+ # the common attribute naming patterns above).
357
+ comps = getattr(pipe, "components", None)
358
+ if isinstance(comps, dict):
359
+ for v in comps.values():
360
+ _to(v)
361
+
362
+
363
+ def _maybe_cast_vae_inputs_to_dtype(vae: Any) -> None:
364
+ if getattr(vae, "_abstractvision_casts_inputs_to_dtype", False):
365
+ return
366
+
367
+ try:
368
+ import types
369
+
370
+ def _wrap(name: str) -> None:
371
+ orig = getattr(vae, name, None)
372
+ if not callable(orig):
373
+ return
374
+
375
+ def wrapper(self: Any, x: Any, *args: Any, **kwargs: Any) -> Any:
376
+ try:
377
+ dtype = getattr(self, "dtype", None)
378
+ x_dtype = getattr(x, "dtype", None)
379
+ to_fn = getattr(x, "to", None)
380
+ if dtype is not None and x_dtype is not None and x_dtype != dtype and callable(to_fn):
381
+ x = x.to(dtype=dtype)
382
+ except Exception:
383
+ pass
384
+ return orig(x, *args, **kwargs)
385
+
386
+ setattr(vae, name, types.MethodType(wrapper, vae))
387
+
388
+ _wrap("encode")
389
+ _wrap("decode")
390
+ setattr(vae, "_abstractvision_casts_inputs_to_dtype", True)
391
+ except Exception:
392
+ return
393
+
394
+
395
+ @dataclass(frozen=True)
396
+ class HuggingFaceDiffusersBackendConfig:
397
+ """Config for a local Diffusers backend.
398
+
399
+ Notes:
400
+ - Downloads are enabled by default so a fresh environment can work after a `pip install`.
401
+ - To force offline mode (no network calls / cache-only), set `allow_download=False`.
402
+ """
403
+
404
+ model_id: str
405
+ device: str = "cpu" # "cpu" | "cuda" | "mps" | "auto" | ...
406
+ torch_dtype: Optional[str] = None # "float16" | "bfloat16" | "float32" | None
407
+ allow_download: bool = True
408
+ auto_retry_fp32: bool = True
409
+ cache_dir: Optional[str] = None
410
+ revision: Optional[str] = None
411
+ variant: Optional[str] = None
412
+ use_safetensors: bool = True
413
+ low_cpu_mem_usage: bool = True
414
+
415
+
416
+ class HuggingFaceDiffusersVisionBackend(VisionBackend):
417
+ """Local generative vision backend using HuggingFace Diffusers (images only, phase 1)."""
418
+
419
+ def __init__(self, *, config: HuggingFaceDiffusersBackendConfig):
420
+ self._cfg = config
421
+ self._pipelines: Dict[str, Any] = {}
422
+ self._call_params: Dict[str, Optional[set[str]]] = {}
423
+ self._fused_lora_signature: Dict[str, Optional[str]] = {}
424
+ self._rapid_transformer_key: Optional[str] = None
425
+ self._rapid_transformer: Any = None
426
+ self._resolved_device: Optional[str] = None
427
+
428
+ def _effective_device(self, torch: Any) -> str:
429
+ if self._resolved_device is not None:
430
+ return self._resolved_device
431
+
432
+ raw = str(getattr(self._cfg, "device", "") or "").strip()
433
+ d = raw.lower()
434
+ if not d or d in {"auto", "default"}:
435
+ cuda = getattr(torch, "cuda", None)
436
+ if cuda is not None and callable(getattr(cuda, "is_available", None)) and cuda.is_available():
437
+ self._resolved_device = "cuda"
438
+ return self._resolved_device
439
+
440
+ backends = getattr(torch, "backends", None)
441
+ mps = getattr(backends, "mps", None) if backends is not None else None
442
+ if mps is not None and callable(getattr(mps, "is_available", None)) and mps.is_available():
443
+ self._resolved_device = "mps"
444
+ return self._resolved_device
445
+
446
+ self._resolved_device = "cpu"
447
+ return self._resolved_device
448
+
449
+ # Normalize common spellings but preserve explicit device indexes (e.g. "cuda:0").
450
+ if d == "gpu":
451
+ self._resolved_device = "cuda"
452
+ else:
453
+ self._resolved_device = raw
454
+ return self._resolved_device
455
+
456
+ def preload(self) -> None:
457
+ # Best-effort: preload the most common pipeline.
458
+ self._get_or_load_pipeline("t2i")
459
+
460
+ def unload(self) -> None:
461
+ # Best-effort: release pipelines and GPU cache.
462
+ pipes = list(self._pipelines.values())
463
+ self._pipelines.clear()
464
+ self._call_params.clear()
465
+ self._fused_lora_signature.clear()
466
+ self._rapid_transformer_key = None
467
+ self._rapid_transformer = None
468
+
469
+ # Drop references and aggressively collect.
470
+ try:
471
+ for p in pipes:
472
+ try:
473
+ # Try to free adapter weights.
474
+ unfuse = getattr(p, "unfuse_lora", None)
475
+ if callable(unfuse):
476
+ unfuse()
477
+ except Exception:
478
+ pass
479
+ try:
480
+ unload = getattr(p, "unload_lora_weights", None)
481
+ if callable(unload):
482
+ unload()
483
+ except Exception:
484
+ pass
485
+ finally:
486
+ pipes = []
487
+
488
+ try:
489
+ import gc
490
+
491
+ gc.collect()
492
+ except Exception:
493
+ pass
494
+
495
+ try:
496
+ torch = _lazy_import_torch()
497
+ cuda = getattr(torch, "cuda", None)
498
+ if cuda is not None and callable(getattr(cuda, "is_available", None)) and cuda.is_available():
499
+ empty = getattr(cuda, "empty_cache", None)
500
+ if callable(empty):
501
+ empty()
502
+ ipc_collect = getattr(cuda, "ipc_collect", None)
503
+ if callable(ipc_collect):
504
+ ipc_collect()
505
+
506
+ mps = getattr(torch, "mps", None)
507
+ empty_mps = getattr(mps, "empty_cache", None) if mps is not None else None
508
+ if callable(empty_mps):
509
+ empty_mps()
510
+ except Exception:
511
+ pass
512
+
513
+ def _lora_signature(self, loras: List[Dict[str, Any]]) -> Optional[str]:
514
+ if not loras:
515
+ return None
516
+ parts: List[str] = []
517
+ for spec in sorted(loras, key=lambda x: str(x.get("source") or "")):
518
+ parts.append(
519
+ "|".join(
520
+ [
521
+ str(spec.get("source") or ""),
522
+ str(spec.get("subfolder") or ""),
523
+ str(spec.get("weight_name") or ""),
524
+ str(spec.get("scale") or 1.0),
525
+ ]
526
+ )
527
+ )
528
+ combined = "::".join(parts)
529
+ return hashlib.md5(combined.encode("utf-8")).hexdigest()[:12]
530
+
531
+ def _parse_loras(self, extra: Any) -> List[Dict[str, Any]]:
532
+ if not isinstance(extra, dict) or not extra:
533
+ return []
534
+
535
+ raw: Any = None
536
+ for k in ("loras", "loras_json", "lora", "lora_json"):
537
+ if k in extra and extra.get(k) is not None:
538
+ raw = extra.get(k)
539
+ break
540
+ if raw is None:
541
+ return []
542
+
543
+ import json
544
+
545
+ items: Any = raw
546
+ if isinstance(raw, str):
547
+ s = raw.strip()
548
+ if not s:
549
+ return []
550
+ # Prefer JSON, but allow a simple comma-separated list of sources.
551
+ if s.startswith("[") or s.startswith("{"):
552
+ try:
553
+ items = json.loads(s)
554
+ except Exception:
555
+ items = raw
556
+ if isinstance(items, str):
557
+ parts = [p.strip() for p in items.split(",") if p.strip()]
558
+ items = [{"source": p} for p in parts]
559
+
560
+ if isinstance(items, dict):
561
+ items = [items]
562
+ if isinstance(items, str):
563
+ return [{"source": items.strip()}] if items.strip() else []
564
+ if not isinstance(items, list):
565
+ return []
566
+
567
+ out: List[Dict[str, Any]] = []
568
+ for el in items:
569
+ if isinstance(el, str):
570
+ src = el.strip()
571
+ if src:
572
+ out.append({"source": src, "scale": 1.0})
573
+ continue
574
+ if not isinstance(el, dict):
575
+ continue
576
+ src = str(el.get("source") or "").strip()
577
+ if not src:
578
+ continue
579
+ spec: Dict[str, Any] = {"source": src}
580
+ if el.get("subfolder") is not None:
581
+ spec["subfolder"] = str(el.get("subfolder") or "").strip() or None
582
+ if el.get("weight_name") is not None:
583
+ spec["weight_name"] = str(el.get("weight_name") or "").strip() or None
584
+ if el.get("adapter_name") is not None:
585
+ spec["adapter_name"] = str(el.get("adapter_name") or "").strip() or None
586
+ try:
587
+ spec["scale"] = float(el.get("scale") if el.get("scale") is not None else 1.0)
588
+ except Exception:
589
+ spec["scale"] = 1.0
590
+ out.append(spec)
591
+ return out
592
+
593
+ def _resolved_adapter_name(self, spec: Dict[str, Any]) -> str:
594
+ name = str(spec.get("adapter_name") or "").strip()
595
+ if name:
596
+ return name
597
+ key = "|".join([str(spec.get("source") or ""), str(spec.get("subfolder") or ""), str(spec.get("weight_name") or "")])
598
+ return "lora_" + hashlib.md5(key.encode("utf-8")).hexdigest()[:12]
599
+
600
+ def _apply_loras(self, *, kind: str, pipe: Any, extra: Any) -> Optional[str]:
601
+ loras = self._parse_loras(extra)
602
+ new_sig = self._lora_signature(loras)
603
+ cur_sig = self._fused_lora_signature.get(kind)
604
+ if new_sig == cur_sig:
605
+ return cur_sig
606
+
607
+ # Always clear previous adapters before applying a new set.
608
+ if hasattr(pipe, "unfuse_lora"):
609
+ try:
610
+ pipe.unfuse_lora()
611
+ except Exception:
612
+ pass
613
+ if hasattr(pipe, "unload_lora_weights"):
614
+ try:
615
+ pipe.unload_lora_weights()
616
+ except Exception:
617
+ pass
618
+
619
+ if not loras:
620
+ self._fused_lora_signature[kind] = None
621
+ return None
622
+
623
+ adapter_names: List[str] = []
624
+ adapter_scales: List[float] = []
625
+
626
+ with _hf_offline_env(not bool(self._cfg.allow_download)):
627
+ for spec in loras:
628
+ adapter_name = self._resolved_adapter_name(spec)
629
+ adapter_names.append(adapter_name)
630
+ adapter_scales.append(float(spec.get("scale") or 1.0))
631
+
632
+ kwargs: Dict[str, Any] = {}
633
+ if spec.get("weight_name"):
634
+ kwargs["weight_name"] = spec["weight_name"]
635
+ if spec.get("subfolder"):
636
+ kwargs["subfolder"] = spec["subfolder"]
637
+ kwargs["local_files_only"] = not bool(self._cfg.allow_download)
638
+ if self._cfg.cache_dir:
639
+ kwargs["cache_dir"] = str(self._cfg.cache_dir)
640
+
641
+ load_fn = getattr(pipe, "load_lora_weights", None)
642
+ if not callable(load_fn):
643
+ raise ValueError("This diffusers pipeline does not support LoRA adapters (missing load_lora_weights).")
644
+ load_fn(spec["source"], adapter_name=adapter_name, **kwargs)
645
+
646
+ if hasattr(pipe, "set_adapters"):
647
+ try:
648
+ pipe.set_adapters(adapter_names, adapter_weights=adapter_scales)
649
+ except Exception:
650
+ pass
651
+
652
+ if hasattr(pipe, "fuse_lora"):
653
+ try:
654
+ pipe.fuse_lora()
655
+ except Exception:
656
+ pass
657
+
658
+ if hasattr(pipe, "unload_lora_weights"):
659
+ try:
660
+ pipe.unload_lora_weights()
661
+ except Exception:
662
+ pass
663
+
664
+ self._fused_lora_signature[kind] = new_sig
665
+ return new_sig
666
+
667
+ def _maybe_apply_rapid_aio_transformer(self, *, pipe: Any, extra: Any, torch_dtype: Any) -> Optional[str]:
668
+ """Optionally swap the pipeline's transformer with a Rapid-AIO distilled transformer.
669
+
670
+ This is primarily useful for Qwen Image Edit pipelines (very fast 4-step inference), but we keep it
671
+ generic: if a pipeline has a `.transformer` module and diffusers provides a compatible transformer
672
+ class, we can hot-swap it.
673
+
674
+ Downloads are enabled by default; set allow_download=False for cache-only/offline mode.
675
+ """
676
+
677
+ if not isinstance(extra, dict) or not extra:
678
+ return None
679
+
680
+ repo = None
681
+ if extra.get("rapid_aio_repo"):
682
+ repo = str(extra.get("rapid_aio_repo") or "").strip()
683
+ elif extra.get("rapid_aio") is True:
684
+ repo = "linoyts/Qwen-Image-Edit-Rapid-AIO"
685
+ elif isinstance(extra.get("rapid_aio"), str) and str(extra.get("rapid_aio")).strip():
686
+ repo = str(extra.get("rapid_aio")).strip()
687
+ if not repo:
688
+ return None
689
+
690
+ subfolder = str(extra.get("rapid_aio_subfolder") or "transformer").strip() or "transformer"
691
+ key = f"{repo}|{subfolder}|{torch_dtype}"
692
+ if key == self._rapid_transformer_key and self._rapid_transformer is not None:
693
+ tr = self._rapid_transformer
694
+ else:
695
+ try:
696
+ from diffusers.models import QwenImageTransformer2DModel # type: ignore
697
+ except Exception:
698
+ raise ValueError(
699
+ "Rapid-AIO transformer override requires diffusers.models.QwenImageTransformer2DModel, "
700
+ "which is not available in this diffusers build."
701
+ )
702
+ kwargs: Dict[str, Any] = {"subfolder": subfolder, "local_files_only": not bool(self._cfg.allow_download)}
703
+ if self._cfg.cache_dir:
704
+ kwargs["cache_dir"] = str(self._cfg.cache_dir)
705
+ with _hf_offline_env(not bool(self._cfg.allow_download)):
706
+ tr = QwenImageTransformer2DModel.from_pretrained(repo, torch_dtype=torch_dtype, **kwargs)
707
+ torch = _lazy_import_torch()
708
+ device = self._effective_device(torch)
709
+ try:
710
+ tr = tr.to(device=str(device), dtype=torch_dtype)
711
+ except Exception:
712
+ try:
713
+ tr = tr.to(dtype=torch_dtype)
714
+ tr = tr.to(str(device))
715
+ except Exception:
716
+ pass
717
+
718
+ self._rapid_transformer_key = key
719
+ self._rapid_transformer = tr
720
+
721
+ if hasattr(pipe, "register_modules"):
722
+ try:
723
+ pipe.register_modules(transformer=tr)
724
+ except Exception:
725
+ setattr(pipe, "transformer", tr)
726
+ else:
727
+ setattr(pipe, "transformer", tr)
728
+
729
+ _maybe_cast_pipe_modules_to_dtype(pipe, dtype=torch_dtype)
730
+ return repo
731
+
732
+ def get_capabilities(self) -> VisionBackendCapabilities:
733
+ return VisionBackendCapabilities(
734
+ supported_tasks=["text_to_image", "image_to_image"],
735
+ supports_mask=None, # depends on whether inpaint pipeline loads for the model
736
+ )
737
+
738
+ def _pipeline_common_kwargs(self) -> Dict[str, Any]:
739
+ kwargs: Dict[str, Any] = {
740
+ "local_files_only": not bool(self._cfg.allow_download),
741
+ "use_safetensors": bool(self._cfg.use_safetensors),
742
+ }
743
+ if self._cfg.cache_dir:
744
+ kwargs["cache_dir"] = str(self._cfg.cache_dir)
745
+ if self._cfg.revision:
746
+ kwargs["revision"] = str(self._cfg.revision)
747
+ if self._cfg.variant:
748
+ kwargs["variant"] = str(self._cfg.variant)
749
+ return kwargs
750
+
751
+ def _hf_cache_root(self) -> Path:
752
+ if self._cfg.cache_dir:
753
+ return Path(self._cfg.cache_dir).expanduser()
754
+ hub_cache = os.environ.get("HF_HUB_CACHE")
755
+ if hub_cache:
756
+ return Path(hub_cache).expanduser()
757
+ hf_home = os.environ.get("HF_HOME")
758
+ if hf_home:
759
+ return Path(hf_home).expanduser() / "hub"
760
+ return Path.home() / ".cache" / "huggingface" / "hub"
761
+
762
+ def _resolve_snapshot_dir(self) -> Optional[Path]:
763
+ model_id = str(self._cfg.model_id).strip()
764
+ if not model_id:
765
+ return None
766
+
767
+ p = Path(model_id).expanduser()
768
+ if p.exists():
769
+ return p
770
+
771
+ if "/" not in model_id:
772
+ return None
773
+
774
+ cache_root = self._hf_cache_root()
775
+ repo_dir = cache_root / ("models--" + model_id.replace("/", "--"))
776
+ snaps = repo_dir / "snapshots"
777
+ if not snaps.is_dir():
778
+ return None
779
+
780
+ rev = str(self._cfg.revision or "main").strip() or "main"
781
+ ref_file = repo_dir / "refs" / rev
782
+ if ref_file.is_file():
783
+ commit = ref_file.read_text(encoding="utf-8").strip()
784
+ snap_dir = snaps / commit
785
+ if snap_dir.is_dir():
786
+ return snap_dir
787
+
788
+ # Fallback: pick the most recently modified snapshot.
789
+ candidates = [d for d in snaps.iterdir() if d.is_dir()]
790
+ if not candidates:
791
+ return None
792
+ return max(candidates, key=lambda d: d.stat().st_mtime)
793
+
794
+ def _preflight_check_model_index(self) -> None:
795
+ snap = self._resolve_snapshot_dir()
796
+ if snap is None:
797
+ return
798
+ idx_path = snap / "model_index.json"
799
+ if not idx_path.is_file():
800
+ return
801
+
802
+ try:
803
+ import json
804
+
805
+ model_index = json.loads(idx_path.read_text(encoding="utf-8"))
806
+ except Exception:
807
+ return
808
+
809
+ class_name = str(model_index.get("_class_name") or "").strip()
810
+ if not class_name:
811
+ return
812
+
813
+ (
814
+ _DiffusionPipeline,
815
+ _AutoPipelineForText2Image,
816
+ _AutoPipelineForImage2Image,
817
+ _AutoPipelineForInpainting,
818
+ diffusers_version,
819
+ ) = _lazy_import_diffusers()
820
+
821
+ import diffusers as _diffusers # type: ignore
822
+
823
+ if not hasattr(_diffusers, class_name):
824
+ required = str(model_index.get("_diffusers_version") or "unknown")
825
+ install_hint = "pip install -U 'git+https://github.com/huggingface/diffusers@main'"
826
+ install_hint_alt = "pip install -e '.[huggingface-dev]'"
827
+ extra = ""
828
+ if class_name == "Flux2KleinPipeline":
829
+ extra = (
830
+ " Note: this model uses a different text encoder than the released Flux2Pipeline in diffusers 0.36 "
831
+ "(Klein uses Qwen3; Flux2Pipeline is built around Mistral3), so a newer diffusers is required."
832
+ )
833
+ raise ValueError(
834
+ f"Diffusers pipeline class {class_name!r} is required by this model, but is not available in your "
835
+ f"installed diffusers ({diffusers_version}). "
836
+ f"The model's model_index.json was authored for diffusers {required}. "
837
+ "This class is not available in the latest PyPI release at the time of writing. "
838
+ f"Install a newer diffusers (offline runtime is still supported): {install_hint}. "
839
+ f"If you're installing AbstractVision from a repo checkout, you can also use: {install_hint_alt}.{extra}"
840
+ )
841
+
842
+ # Optional: sanity-check that referenced Transformers classes exist to avoid late failures.
843
+ try:
844
+ import transformers # type: ignore
845
+
846
+ missing_tf: list[str] = []
847
+ for v in model_index.values():
848
+ if (
849
+ isinstance(v, list)
850
+ and len(v) == 2
851
+ and isinstance(v[0], str)
852
+ and isinstance(v[1], str)
853
+ and v[0].strip().lower() == "transformers"
854
+ ):
855
+ tf_cls = v[1].strip()
856
+ if tf_cls and not hasattr(transformers, tf_cls):
857
+ missing_tf.append(tf_cls)
858
+ if missing_tf:
859
+ tf_ver = getattr(transformers, "__version__", "unknown")
860
+ raise ValueError(
861
+ "This model references Transformers classes that are not available in your environment "
862
+ f"(transformers={tf_ver}): {', '.join(sorted(set(missing_tf)))}. "
863
+ "Upgrade transformers to a compatible version."
864
+ )
865
+ except ValueError:
866
+ raise
867
+ except Exception:
868
+ pass
869
+
870
+ def _get_or_load_pipeline(self, kind: str) -> Any:
871
+ existing = self._pipelines.get(kind)
872
+ if existing is not None:
873
+ return existing
874
+
875
+ (
876
+ DiffusionPipeline,
877
+ AutoPipelineForText2Image,
878
+ AutoPipelineForImage2Image,
879
+ AutoPipelineForInpainting,
880
+ diffusers_version,
881
+ ) = _lazy_import_diffusers()
882
+ torch = _lazy_import_torch()
883
+ device = self._effective_device(torch)
884
+ _require_device_available(torch, device)
885
+
886
+ self._preflight_check_model_index()
887
+ _maybe_patch_transformers_clip_position_ids()
888
+
889
+ torch_dtype = _torch_dtype_from_str(torch, self._cfg.torch_dtype)
890
+ if torch_dtype is None:
891
+ torch_dtype = _default_torch_dtype_for_device(torch, device)
892
+ common = self._pipeline_common_kwargs()
893
+ if bool(self._cfg.low_cpu_mem_usage):
894
+ common["low_cpu_mem_usage"] = True
895
+
896
+ # Auto-select checkpoint variants when appropriate (best-effort).
897
+ # Prefer fp16 on GPU backends (CUDA/MPS) to cut memory/disk use, but never on CPU.
898
+ auto_variant: Optional[str] = None
899
+ if not str(getattr(self._cfg, "variant", "") or "").strip() and str(device).strip().lower() != "cpu":
900
+ if torch_dtype == getattr(torch, "float16", object()):
901
+ auto_variant = "fp16"
902
+
903
+ def _looks_like_missing_variant_error(e: Exception, variant: str) -> bool:
904
+ msg = str(e or "")
905
+ m = msg.lower()
906
+ v = str(variant or "").strip().lower()
907
+ if not v:
908
+ return False
909
+ return (
910
+ (f".{v}." in m or f" {v} " in m or f"'{v}'" in m)
911
+ and (
912
+ "no such file" in m
913
+ or "does not exist" in m
914
+ or "not found" in m
915
+ or "is not present" in m
916
+ or "couldn't find" in m
917
+ or "cannot find" in m
918
+ )
919
+ )
920
+
921
+ def _from_pretrained(cls: Any) -> Any:
922
+ if auto_variant:
923
+ common2 = dict(common)
924
+ common2["variant"] = auto_variant
925
+ try:
926
+ return cls.from_pretrained(self._cfg.model_id, torch_dtype=torch_dtype, **common2)
927
+ except Exception as e:
928
+ # If the repo doesn't provide the fp16 variant, fall back to regular weights.
929
+ if _looks_like_missing_variant_error(e, auto_variant):
930
+ return cls.from_pretrained(self._cfg.model_id, torch_dtype=torch_dtype, **common)
931
+ raise
932
+ return cls.from_pretrained(self._cfg.model_id, torch_dtype=torch_dtype, **common)
933
+
934
+ def _maybe_raise_offline_missing_model(e: Exception) -> None:
935
+ model_id = str(self._cfg.model_id or "").strip()
936
+ if not model_id or "/" not in model_id:
937
+ return
938
+ # If it's not in cache, provide a clearer message than the upstream
939
+ # "does not appear to have a file named model_index.json" wording.
940
+ if self._resolve_snapshot_dir() is not None:
941
+ return
942
+ msg = str(e)
943
+ if "model_index.json" not in msg:
944
+ return
945
+ raise ValueError(
946
+ f"Model {model_id!r} is not available locally and downloads are disabled. "
947
+ "Either pre-download it (e.g. via `huggingface-cli download ...`) or enable downloads "
948
+ "(set allow_download=True; for AbstractCore Server: set ABSTRACTCORE_VISION_ALLOW_DOWNLOAD=1). "
949
+ "If the model is gated, accept its terms on Hugging Face and set `HF_TOKEN` before downloading."
950
+ ) from e
951
+
952
+ pipe = None
953
+ with _hf_offline_env(not bool(self._cfg.allow_download)):
954
+ if kind == "t2i":
955
+ # Prefer AutoPipeline when available, but fall back to DiffusionPipeline for robustness.
956
+ if AutoPipelineForText2Image is not None:
957
+ try:
958
+ pipe = _from_pretrained(AutoPipelineForText2Image)
959
+ except ValueError as e:
960
+ _maybe_raise_offline_missing_model(e)
961
+ pipe = None
962
+ if pipe is None:
963
+ try:
964
+ pipe = _from_pretrained(DiffusionPipeline)
965
+ except Exception as e:
966
+ _maybe_raise_offline_missing_model(e)
967
+ raise
968
+ elif kind == "i2i":
969
+ if AutoPipelineForImage2Image is not None:
970
+ try:
971
+ pipe = _from_pretrained(AutoPipelineForImage2Image)
972
+ except ValueError as e:
973
+ _maybe_raise_offline_missing_model(e)
974
+ pipe = None
975
+ if pipe is None:
976
+ try:
977
+ pipe = _from_pretrained(DiffusionPipeline)
978
+ except Exception as e:
979
+ _maybe_raise_offline_missing_model(e)
980
+ raise ValueError(
981
+ "Diffusers could not load an image-to-image pipeline for this model id. "
982
+ "Install/upgrade diffusers (and compatible transformers/torch), or use a model repo that "
983
+ "ships an image-to-image pipeline. "
984
+ f"(diffusers={diffusers_version})"
985
+ ) from e
986
+ elif kind == "inpaint":
987
+ if AutoPipelineForInpainting is None:
988
+ raise ValueError(
989
+ "Diffusers inpainting pipeline is not available in this environment. "
990
+ "Install/upgrade diffusers (and compatible transformers/torch). "
991
+ f"(diffusers={diffusers_version})"
992
+ )
993
+ pipe = _from_pretrained(AutoPipelineForInpainting)
994
+ else:
995
+ raise ValueError(f"Unknown pipeline kind: {kind!r}")
996
+
997
+ # Diffusers pipelines support `.to(<device>)` with a string.
998
+ pipe = pipe.to(str(device))
999
+ _maybe_cast_pipe_modules_to_dtype(pipe, dtype=torch_dtype)
1000
+ _maybe_upcast_vae_for_mps(torch, pipe, device)
1001
+ self._pipelines[kind] = pipe
1002
+ self._call_params[kind] = _call_param_names(getattr(pipe, "__call__", None))
1003
+ return pipe
1004
+
1005
+ def _pil_from_bytes(self, data: bytes):
1006
+ Image = _lazy_import_pil()
1007
+ img = Image.open(io.BytesIO(bytes(data)))
1008
+ # Many pipelines expect RGB.
1009
+ return img.convert("RGB")
1010
+
1011
+ def _png_bytes(self, img) -> bytes:
1012
+ buf = io.BytesIO()
1013
+ img.save(buf, format="PNG")
1014
+ return buf.getvalue()
1015
+
1016
+ def _seed_generator(self, seed: Optional[int]):
1017
+ if seed is None:
1018
+ return None
1019
+ torch = _lazy_import_torch()
1020
+ d = str(self._effective_device(torch) or "").strip().lower()
1021
+ gen_device = "cpu" if d == "mps" or d.startswith("mps:") else str(self._effective_device(torch))
1022
+ try:
1023
+ gen = torch.Generator(device=gen_device)
1024
+ except Exception:
1025
+ gen = torch.Generator()
1026
+ gen.manual_seed(int(seed))
1027
+ return gen
1028
+
1029
+ def _is_probably_all_black_image(self, img: Any) -> bool:
1030
+ try:
1031
+ rgb = img.convert("RGB")
1032
+ extrema = rgb.getextrema()
1033
+ if isinstance(extrema, tuple) and len(extrema) == 2 and all(isinstance(x, int) for x in extrema):
1034
+ _, mx = extrema
1035
+ return mx <= 1
1036
+ if isinstance(extrema, tuple):
1037
+ return all(isinstance(x, tuple) and len(x) == 2 and int(x[1]) <= 1 for x in extrema)
1038
+ except Exception:
1039
+ return False
1040
+ return False
1041
+
1042
+ def _pipe_call(self, pipe: Any, kwargs: Dict[str, Any]):
1043
+ import warnings
1044
+
1045
+ call_kwargs = dict(kwargs)
1046
+ if callable(kwargs.get("__abstractvision_progress_callback")):
1047
+ progress_cb = kwargs.get("__abstractvision_progress_callback")
1048
+ total_steps = kwargs.get("__abstractvision_progress_total_steps")
1049
+ try:
1050
+ call_kwargs.pop("__abstractvision_progress_callback", None)
1051
+ call_kwargs.pop("__abstractvision_progress_total_steps", None)
1052
+ except Exception:
1053
+ pass
1054
+ try:
1055
+ call_kwargs = self._inject_progress_kwargs(
1056
+ pipe=pipe,
1057
+ kwargs=call_kwargs,
1058
+ progress_callback=progress_cb,
1059
+ total_steps=int(total_steps) if total_steps is not None else None,
1060
+ )
1061
+ except Exception:
1062
+ # Best-effort: never break inference for progress reporting.
1063
+ pass
1064
+
1065
+ with warnings.catch_warnings(record=True) as w:
1066
+ warnings.simplefilter("always", RuntimeWarning)
1067
+ with _hf_offline_env(not bool(self._cfg.allow_download)):
1068
+ out = pipe(**call_kwargs)
1069
+ had_invalid_cast = any(
1070
+ issubclass(getattr(x, "category", Warning), RuntimeWarning)
1071
+ and "invalid value encountered in cast" in str(getattr(x, "message", ""))
1072
+ for x in w
1073
+ )
1074
+ return out, had_invalid_cast
1075
+
1076
+ def _pipe_progress_param_names(self, pipe: Any) -> set[str]:
1077
+ fn = getattr(pipe, "__call__", None)
1078
+ if not callable(fn):
1079
+ return set()
1080
+ try:
1081
+ sig = inspect.signature(fn)
1082
+ except Exception:
1083
+ return set()
1084
+ return {str(k) for k in sig.parameters.keys() if str(k) != "self"}
1085
+
1086
+ def _inject_progress_kwargs(
1087
+ self,
1088
+ *,
1089
+ pipe: Any,
1090
+ kwargs: Dict[str, Any],
1091
+ progress_callback: Callable[[int, Optional[int]], None],
1092
+ total_steps: Optional[int],
1093
+ ) -> Dict[str, Any]:
1094
+ names = self._pipe_progress_param_names(pipe)
1095
+ if not names:
1096
+ return kwargs
1097
+
1098
+ if "callback_on_step_end" in names:
1099
+
1100
+ def _on_step_end(*args: Any, **kw: Any) -> Any:
1101
+ # Expected signature: (pipe, step, timestep, callback_kwargs)
1102
+ step = None
1103
+ cb_kwargs = None
1104
+ try:
1105
+ if len(args) >= 2:
1106
+ step = args[1]
1107
+ if len(args) >= 4:
1108
+ cb_kwargs = args[3]
1109
+ if cb_kwargs is None:
1110
+ cb_kwargs = kw.get("callback_kwargs")
1111
+ except Exception:
1112
+ pass
1113
+ try:
1114
+ if step is not None:
1115
+ progress_callback(int(step) + 1, total_steps)
1116
+ except Exception:
1117
+ pass
1118
+ return cb_kwargs if cb_kwargs is not None else {}
1119
+
1120
+ kwargs["callback_on_step_end"] = _on_step_end
1121
+ # Avoid passing large tensors through callback_kwargs unless explicitly requested.
1122
+ if "callback_on_step_end_tensor_inputs" in names:
1123
+ kwargs.setdefault("callback_on_step_end_tensor_inputs", [])
1124
+ return kwargs
1125
+
1126
+ if "callback" in names:
1127
+
1128
+ def _callback(*args: Any, **_kw: Any) -> None:
1129
+ # Expected signature: (step, timestep, latents)
1130
+ try:
1131
+ if args:
1132
+ progress_callback(int(args[0]) + 1, total_steps)
1133
+ except Exception:
1134
+ pass
1135
+
1136
+ kwargs["callback"] = _callback
1137
+ if "callback_steps" in names:
1138
+ kwargs["callback_steps"] = 1
1139
+ return kwargs
1140
+
1141
+ return kwargs
1142
+
1143
+ def _maybe_retry_on_dtype_mismatch(
1144
+ self,
1145
+ *,
1146
+ kind: str,
1147
+ pipe: Any,
1148
+ kwargs: Dict[str, Any],
1149
+ error: Exception,
1150
+ progress_callback: Optional[Callable[[int, Optional[int]], None]] = None,
1151
+ total_steps: Optional[int] = None,
1152
+ ) -> Optional[Any]:
1153
+ if not bool(getattr(self._cfg, "auto_retry_fp32", False)):
1154
+ return None
1155
+ if not _looks_like_dtype_mismatch_error(error):
1156
+ return None
1157
+
1158
+ torch = _lazy_import_torch()
1159
+ device = self._effective_device(torch)
1160
+ d = str(device or "").strip().lower()
1161
+ if not (d == "mps" or d.startswith("mps:")):
1162
+ return None
1163
+
1164
+ current_dtype = getattr(pipe, "dtype", None)
1165
+ if current_dtype is None:
1166
+ current_dtype = _torch_dtype_from_str(torch, self._cfg.torch_dtype) or _default_torch_dtype_for_device(
1167
+ torch, device
1168
+ )
1169
+
1170
+ candidates: list[Any] = []
1171
+ if current_dtype == getattr(torch, "bfloat16", object()):
1172
+ candidates.append(torch.float16)
1173
+ if current_dtype != getattr(torch, "float32", object()):
1174
+ candidates.append(torch.float32)
1175
+
1176
+ for target in candidates:
1177
+ try:
1178
+ pipe2 = pipe.to(device=str(device), dtype=target)
1179
+ except Exception:
1180
+ try:
1181
+ pipe2 = pipe.to(dtype=target)
1182
+ pipe2 = pipe2.to(str(device))
1183
+ except Exception:
1184
+ continue
1185
+
1186
+ _maybe_upcast_vae_for_mps(torch, pipe2, device)
1187
+ self._pipelines[kind] = pipe2
1188
+ self._call_params[kind] = _call_param_names(getattr(pipe2, "__call__", None))
1189
+
1190
+ try:
1191
+ call_kwargs = dict(kwargs)
1192
+ if progress_callback is not None:
1193
+ call_kwargs["__abstractvision_progress_callback"] = progress_callback
1194
+ call_kwargs["__abstractvision_progress_total_steps"] = total_steps
1195
+ out2, _had_invalid_cast2 = self._pipe_call(pipe2, call_kwargs)
1196
+ return out2
1197
+ except Exception:
1198
+ continue
1199
+ return None
1200
+
1201
+ def _maybe_retry_fp32_on_invalid_output(
1202
+ self,
1203
+ *,
1204
+ kind: str,
1205
+ pipe: Any,
1206
+ kwargs: Dict[str, Any],
1207
+ progress_callback: Optional[Callable[[int, Optional[int]], None]] = None,
1208
+ total_steps: Optional[int] = None,
1209
+ ) -> Optional[Any]:
1210
+ if not bool(getattr(self._cfg, "auto_retry_fp32", False)):
1211
+ return None
1212
+ torch = _lazy_import_torch()
1213
+ device = self._effective_device(torch)
1214
+ d = str(device or "").strip().lower()
1215
+ cfg_dtype = _torch_dtype_from_str(torch, self._cfg.torch_dtype)
1216
+ if cfg_dtype is None:
1217
+ cfg_dtype = _default_torch_dtype_for_device(torch, device)
1218
+
1219
+ # Currently, we only auto-retry on Apple Silicon / MPS when running fp16,
1220
+ # because NaNs/black images are common for some models (e.g. Qwen Image).
1221
+ if not (d == "mps" or d.startswith("mps:")):
1222
+ return None
1223
+ if cfg_dtype != torch.float16:
1224
+ return None
1225
+
1226
+ try:
1227
+ pipe_fp32 = pipe.to(device=str(device), dtype=torch.float32)
1228
+ except Exception:
1229
+ try:
1230
+ pipe_fp32 = pipe.to(dtype=torch.float32)
1231
+ pipe_fp32 = pipe_fp32.to(str(device))
1232
+ except Exception:
1233
+ return None
1234
+
1235
+ _maybe_upcast_vae_for_mps(torch, pipe_fp32, device)
1236
+ self._pipelines[kind] = pipe_fp32
1237
+ self._call_params[kind] = _call_param_names(getattr(pipe_fp32, "__call__", None))
1238
+
1239
+ call_kwargs = dict(kwargs)
1240
+ if progress_callback is not None:
1241
+ call_kwargs["__abstractvision_progress_callback"] = progress_callback
1242
+ call_kwargs["__abstractvision_progress_total_steps"] = total_steps
1243
+ out2, had_invalid_cast2 = self._pipe_call(pipe_fp32, call_kwargs)
1244
+ if had_invalid_cast2:
1245
+ raise ValueError(
1246
+ "Diffusers produced invalid pixel values (NaNs) while decoding the image "
1247
+ "(resulting in an all-black output). "
1248
+ "Tried an automatic fp32 retry on MPS and it still failed. "
1249
+ "Try setting torch_dtype=float32 explicitly, increasing steps, or use the stable-diffusion.cpp backend."
1250
+ )
1251
+ return out2
1252
+
1253
+ def generate_image(self, request: ImageGenerationRequest) -> GeneratedAsset:
1254
+ return self.generate_image_with_progress(request, progress_callback=None)
1255
+
1256
+ def generate_image_with_progress(
1257
+ self,
1258
+ request: ImageGenerationRequest,
1259
+ progress_callback: Optional[Callable[[int, Optional[int]], None]] = None,
1260
+ ) -> GeneratedAsset:
1261
+ pipe = self._get_or_load_pipeline("t2i")
1262
+ call_params = self._call_params.get("t2i")
1263
+ total_steps = int(request.steps) if request.steps is not None else None
1264
+
1265
+ torch_dtype = getattr(pipe, "dtype", None)
1266
+ if torch_dtype is None:
1267
+ torch = _lazy_import_torch()
1268
+ device = self._effective_device(torch)
1269
+ torch_dtype = _torch_dtype_from_str(torch, self._cfg.torch_dtype) or _default_torch_dtype_for_device(torch, device)
1270
+ rapid_repo = self._maybe_apply_rapid_aio_transformer(pipe=pipe, extra=request.extra, torch_dtype=torch_dtype)
1271
+ lora_sig = self._apply_loras(kind="t2i", pipe=pipe, extra=request.extra)
1272
+
1273
+ kwargs: Dict[str, Any] = {
1274
+ "prompt": request.prompt,
1275
+ }
1276
+ if request.negative_prompt is not None:
1277
+ kwargs["negative_prompt"] = request.negative_prompt
1278
+ if request.width is not None:
1279
+ kwargs["width"] = int(request.width)
1280
+ if request.height is not None:
1281
+ kwargs["height"] = int(request.height)
1282
+ if request.steps is not None:
1283
+ kwargs["num_inference_steps"] = int(request.steps)
1284
+ if request.guidance_scale is not None:
1285
+ if call_params is not None and "true_cfg_scale" in call_params:
1286
+ kwargs["true_cfg_scale"] = float(request.guidance_scale)
1287
+ # Some pipelines (e.g. Qwen Image) only enable CFG when a `negative_prompt`
1288
+ # is provided (even an empty one). Make `guidance_scale` behave consistently.
1289
+ if request.negative_prompt is None and (call_params is None or "negative_prompt" in call_params):
1290
+ kwargs["negative_prompt"] = " "
1291
+ else:
1292
+ kwargs["guidance_scale"] = float(request.guidance_scale)
1293
+ gen = self._seed_generator(request.seed)
1294
+ if gen is not None:
1295
+ kwargs["generator"] = gen
1296
+
1297
+ if isinstance(request.extra, dict) and request.extra:
1298
+ kwargs.update(dict(request.extra))
1299
+
1300
+ try:
1301
+ call_kwargs = dict(kwargs)
1302
+ if progress_callback is not None:
1303
+ call_kwargs["__abstractvision_progress_callback"] = progress_callback
1304
+ call_kwargs["__abstractvision_progress_total_steps"] = total_steps
1305
+ out, had_invalid_cast = self._pipe_call(pipe, call_kwargs)
1306
+ except Exception as e:
1307
+ out2 = self._maybe_retry_on_dtype_mismatch(
1308
+ kind="t2i",
1309
+ pipe=pipe,
1310
+ kwargs=kwargs,
1311
+ error=e,
1312
+ progress_callback=progress_callback,
1313
+ total_steps=total_steps,
1314
+ )
1315
+ if out2 is None:
1316
+ raise
1317
+ out, had_invalid_cast = out2, False
1318
+ retried_fp32 = False
1319
+ images = getattr(out, "images", None)
1320
+ if not isinstance(images, list) or not images:
1321
+ raise ValueError("Diffusers pipeline returned no images")
1322
+ if self._is_probably_all_black_image(images[0]):
1323
+ out2 = self._maybe_retry_fp32_on_invalid_output(
1324
+ kind="t2i",
1325
+ pipe=pipe,
1326
+ kwargs=kwargs,
1327
+ progress_callback=progress_callback,
1328
+ total_steps=total_steps,
1329
+ )
1330
+ if out2 is not None:
1331
+ out = out2
1332
+ retried_fp32 = True
1333
+ images = getattr(out, "images", None)
1334
+ if not isinstance(images, list) or not images:
1335
+ raise ValueError("Diffusers pipeline returned no images")
1336
+ if self._is_probably_all_black_image(images[0]):
1337
+ raise ValueError(
1338
+ "Diffusers produced an all-black image output. "
1339
+ + (
1340
+ "An automatic fp32 retry was attempted and it still produced an all-black image. "
1341
+ if retried_fp32
1342
+ else "Try setting torch_dtype=float32. "
1343
+ )
1344
+ + "Try increasing steps, adjusting guidance_scale, or use the stable-diffusion.cpp backend."
1345
+ )
1346
+ png = self._png_bytes(images[0])
1347
+ meta = {"source": "diffusers", "model_id": self._cfg.model_id}
1348
+ if rapid_repo:
1349
+ meta["rapid_aio_repo"] = rapid_repo
1350
+ if lora_sig:
1351
+ meta["lora_signature"] = lora_sig
1352
+ if retried_fp32:
1353
+ meta["retried_fp32"] = True
1354
+ if had_invalid_cast:
1355
+ meta["had_invalid_cast_warning"] = True
1356
+ try:
1357
+ current_pipe = self._pipelines.get("t2i", pipe)
1358
+ dtype = getattr(current_pipe, "dtype", None)
1359
+ device = getattr(current_pipe, "device", None)
1360
+ if dtype is not None:
1361
+ meta["dtype"] = str(dtype)
1362
+ if device is not None:
1363
+ meta["device"] = str(device)
1364
+ except Exception:
1365
+ pass
1366
+ return GeneratedAsset(
1367
+ media_type="image",
1368
+ data=png,
1369
+ mime_type="image/png",
1370
+ metadata=meta,
1371
+ )
1372
+
1373
+ def edit_image(self, request: ImageEditRequest) -> GeneratedAsset:
1374
+ return self.edit_image_with_progress(request, progress_callback=None)
1375
+
1376
+ def edit_image_with_progress(
1377
+ self,
1378
+ request: ImageEditRequest,
1379
+ progress_callback: Optional[Callable[[int, Optional[int]], None]] = None,
1380
+ ) -> GeneratedAsset:
1381
+ if request.mask is not None:
1382
+ pipe = self._get_or_load_pipeline("inpaint")
1383
+ call_params = self._call_params.get("inpaint")
1384
+ kind = "inpaint"
1385
+ else:
1386
+ pipe = self._get_or_load_pipeline("i2i")
1387
+ call_params = self._call_params.get("i2i")
1388
+ kind = "i2i"
1389
+
1390
+ total_steps = int(request.steps) if request.steps is not None else None
1391
+
1392
+ torch_dtype = getattr(pipe, "dtype", None)
1393
+ if torch_dtype is None:
1394
+ torch = _lazy_import_torch()
1395
+ device = self._effective_device(torch)
1396
+ torch_dtype = _torch_dtype_from_str(torch, self._cfg.torch_dtype) or _default_torch_dtype_for_device(torch, device)
1397
+ rapid_repo = self._maybe_apply_rapid_aio_transformer(pipe=pipe, extra=request.extra, torch_dtype=torch_dtype)
1398
+ lora_sig = self._apply_loras(kind=kind, pipe=pipe, extra=request.extra)
1399
+
1400
+ img = self._pil_from_bytes(request.image)
1401
+ kwargs: Dict[str, Any] = {"prompt": request.prompt, "image": img}
1402
+ if request.mask is not None:
1403
+ kwargs["mask_image"] = self._pil_from_bytes(request.mask)
1404
+ if request.negative_prompt is not None:
1405
+ kwargs["negative_prompt"] = request.negative_prompt
1406
+ if request.steps is not None:
1407
+ kwargs["num_inference_steps"] = int(request.steps)
1408
+ if request.guidance_scale is not None:
1409
+ if call_params is not None and "true_cfg_scale" in call_params:
1410
+ kwargs["true_cfg_scale"] = float(request.guidance_scale)
1411
+ if request.negative_prompt is None and (call_params is None or "negative_prompt" in call_params):
1412
+ kwargs["negative_prompt"] = " "
1413
+ else:
1414
+ kwargs["guidance_scale"] = float(request.guidance_scale)
1415
+ gen = self._seed_generator(request.seed)
1416
+ if gen is not None:
1417
+ kwargs["generator"] = gen
1418
+
1419
+ if isinstance(request.extra, dict) and request.extra:
1420
+ kwargs.update(dict(request.extra))
1421
+
1422
+ try:
1423
+ call_kwargs = dict(kwargs)
1424
+ if progress_callback is not None:
1425
+ call_kwargs["__abstractvision_progress_callback"] = progress_callback
1426
+ call_kwargs["__abstractvision_progress_total_steps"] = total_steps
1427
+ out, had_invalid_cast = self._pipe_call(pipe, call_kwargs)
1428
+ except Exception as e:
1429
+ out2 = self._maybe_retry_on_dtype_mismatch(
1430
+ kind=kind,
1431
+ pipe=pipe,
1432
+ kwargs=kwargs,
1433
+ error=e,
1434
+ progress_callback=progress_callback,
1435
+ total_steps=total_steps,
1436
+ )
1437
+ if out2 is None:
1438
+ raise
1439
+ out, had_invalid_cast = out2, False
1440
+ retried_fp32 = False
1441
+ images = getattr(out, "images", None)
1442
+ if not isinstance(images, list) or not images:
1443
+ raise ValueError("Diffusers pipeline returned no images")
1444
+ if self._is_probably_all_black_image(images[0]):
1445
+ kind = "inpaint" if request.mask is not None else "i2i"
1446
+ out2 = self._maybe_retry_fp32_on_invalid_output(
1447
+ kind=kind,
1448
+ pipe=pipe,
1449
+ kwargs=kwargs,
1450
+ progress_callback=progress_callback,
1451
+ total_steps=total_steps,
1452
+ )
1453
+ if out2 is not None:
1454
+ out = out2
1455
+ retried_fp32 = True
1456
+ images = getattr(out, "images", None)
1457
+ if not isinstance(images, list) or not images:
1458
+ raise ValueError("Diffusers pipeline returned no images")
1459
+ if self._is_probably_all_black_image(images[0]):
1460
+ raise ValueError(
1461
+ "Diffusers produced an all-black image output. "
1462
+ + (
1463
+ "An automatic fp32 retry was attempted and it still produced an all-black image. "
1464
+ if retried_fp32
1465
+ else "Try setting torch_dtype=bfloat16 (recommended on MPS) or torch_dtype=float32. "
1466
+ )
1467
+ + "Try increasing steps, adjusting guidance_scale, or use the stable-diffusion.cpp backend."
1468
+ )
1469
+ png = self._png_bytes(images[0])
1470
+ meta = {"source": "diffusers", "model_id": self._cfg.model_id}
1471
+ if rapid_repo:
1472
+ meta["rapid_aio_repo"] = rapid_repo
1473
+ if lora_sig:
1474
+ meta["lora_signature"] = lora_sig
1475
+ if retried_fp32:
1476
+ meta["retried_fp32"] = True
1477
+ if had_invalid_cast:
1478
+ meta["had_invalid_cast_warning"] = True
1479
+ try:
1480
+ current_pipe = self._pipelines.get(kind, pipe)
1481
+ dtype = getattr(current_pipe, "dtype", None)
1482
+ device = getattr(current_pipe, "device", None)
1483
+ if dtype is not None:
1484
+ meta["dtype"] = str(dtype)
1485
+ if device is not None:
1486
+ meta["device"] = str(device)
1487
+ except Exception:
1488
+ pass
1489
+ return GeneratedAsset(
1490
+ media_type="image",
1491
+ data=png,
1492
+ mime_type="image/png",
1493
+ metadata=meta,
1494
+ )
1495
+
1496
+ def generate_angles(self, request: MultiAngleRequest) -> list[GeneratedAsset]:
1497
+ raise CapabilityNotSupportedError("HuggingFaceDiffusersVisionBackend does not implement multi-view generation.")
1498
+
1499
+ def generate_video(self, request: VideoGenerationRequest) -> GeneratedAsset:
1500
+ raise CapabilityNotSupportedError("HuggingFaceDiffusersVisionBackend does not implement text_to_video (phase 2).")
1501
+
1502
+ def image_to_video(self, request: ImageToVideoRequest) -> GeneratedAsset:
1503
+ raise CapabilityNotSupportedError("HuggingFaceDiffusersVisionBackend does not implement image_to_video (phase 2).")