coreml-diffusion 0.1.4__tar.gz → 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. coreml_diffusion-0.1.6/.release-please-manifest.json +3 -0
  2. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/CHANGELOG.md +14 -0
  3. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/PKG-INFO +3 -2
  4. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/coreml_diffusion/__init__.py +17 -0
  5. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/coreml_diffusion/cli.py +12 -5
  6. coreml_diffusion-0.1.6/coreml_diffusion/conversion/state_dict.py +99 -0
  7. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/coreml_diffusion/convert.py +23 -26
  8. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/pyproject.toml +3 -2
  9. coreml_diffusion-0.1.6/tests/unit/test_detect_model_version.py +77 -0
  10. coreml_diffusion-0.1.6/uv.lock +1217 -0
  11. coreml_diffusion-0.1.4/.release-please-manifest.json +0 -3
  12. coreml_diffusion-0.1.4/coreml_diffusion/conversion/state_dict.py +0 -20
  13. coreml_diffusion-0.1.4/uv.lock +0 -900
  14. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/.github/workflows/publish-pypi.yml +0 -0
  15. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/.github/workflows/release-please.yml +0 -0
  16. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/.github/workflows/tier0.yml +0 -0
  17. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/.github/workflows/tier1.yml +0 -0
  18. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/.github/workflows/tier2.yml +0 -0
  19. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/.gitignore +0 -0
  20. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/LICENSE +0 -0
  21. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/README.md +0 -0
  22. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/coreml_diffusion/attention.py +0 -0
  23. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/coreml_diffusion/component.py +0 -0
  24. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/coreml_diffusion/conversion/__init__.py +0 -0
  25. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/coreml_diffusion/conversion/attention.py +0 -0
  26. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/coreml_diffusion/conversion/shapes.py +0 -0
  27. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/coreml_diffusion/conversion/text_encoder.py +0 -0
  28. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/coreml_diffusion/conversion/trace.py +0 -0
  29. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/coreml_diffusion/conversion/unet.py +0 -0
  30. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/coreml_diffusion/conversion/vae.py +0 -0
  31. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/coreml_diffusion/inference.py +0 -0
  32. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/coreml_diffusion/logger.py +0 -0
  33. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/coreml_diffusion/model_version.py +0 -0
  34. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/coreml_diffusion/naming.py +0 -0
  35. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/coreml_diffusion/sources.py +0 -0
  36. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/release-please-config.json +0 -0
  37. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/tests/conftest.py +0 -0
  38. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/tests/m2/goldens/sd15_astronaut.png +0 -0
  39. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/tests/m2/goldens/sd15_astronaut.sha256 +0 -0
  40. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/tests/m2/goldens/sd15_astronaut_full_coreml.png +0 -0
  41. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/tests/m2/goldens/sd15_astronaut_full_coreml.sha256 +0 -0
  42. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/tests/m2/test_inference_golden.py +0 -0
  43. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/tests/m2/test_original_gpu.py +0 -0
  44. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/tests/smoke/test_coreml_adapters.py +0 -0
  45. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/tests/smoke/test_lcm_conversion.py +0 -0
  46. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/tests/smoke/test_original_attention.py +0 -0
  47. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/tests/smoke/test_split_einsum_attention.py +0 -0
  48. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/tests/smoke/test_synthetic_text_encoder.py +0 -0
  49. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/tests/smoke/test_synthetic_unet.py +0 -0
  50. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/tests/smoke/test_synthetic_vae.py +0 -0
  51. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/tests/unit/test_characterization_component_name.py +0 -0
  52. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/tests/unit/test_characterization_out_name.py +0 -0
  53. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/tests/unit/test_cli.py +0 -0
  54. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/tests/unit/test_conversion_helpers.py +0 -0
  55. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/tests/unit/test_discovery_api.py +0 -0
  56. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/tests/unit/test_inference_output_contract.py +0 -0
  57. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/tests/unit/test_sources.py +0 -0
  58. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/tests/unit/test_state_dict_layout.py +0 -0
  59. {coreml_diffusion-0.1.4 → coreml_diffusion-0.1.6}/tests/unit/test_tier0_purity.py +0 -0
@@ -0,0 +1,3 @@
1
+ {
2
+ ".": "0.1.6"
3
+ }
@@ -1,5 +1,19 @@
1
1
  # Changelog
2
2
 
3
+ ## [0.1.6](https://github.com/aszc-dev/coreml-diffusion/compare/v0.1.5...v0.1.6) (2026-06-13)
4
+
5
+
6
+ ### 🐛 Bug Fixes
7
+
8
+ * **deps:** drop the stale <3.13 Python cap (allow >=3.12) ([9b44a5a](https://github.com/aszc-dev/coreml-diffusion/commit/9b44a5a4a4118a99c924ccb8301d4e1a300c3b01))
9
+
10
+ ## [0.1.5](https://github.com/aszc-dev/coreml-diffusion/compare/v0.1.4...v0.1.5) (2026-06-13)
11
+
12
+
13
+ ### ✨ Features
14
+
15
+ * **convert:** auto-detect model version from the checkpoint ([2a24d4e](https://github.com/aszc-dev/coreml-diffusion/commit/2a24d4efd196100dbdd0bf9d5dd61c6cce31d2ac))
16
+
3
17
  ## [0.1.4](https://github.com/aszc-dev/coreml-diffusion/compare/v0.1.3...v0.1.4) (2026-06-13)
4
18
 
5
19
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: coreml-diffusion
3
- Version: 0.1.4
3
+ Version: 0.1.6
4
4
  Summary: Convert diffusion-model checkpoints (SD1.5/SDXL) to Core ML for Apple Neural Engine — framework-free, ComfyUI-independent.
5
5
  Project-URL: Homepage, https://github.com/aszc-dev/coreml-diffusion
6
6
  Project-URL: Repository, https://github.com/aszc-dev/coreml-diffusion
@@ -14,10 +14,11 @@ Classifier: Intended Audience :: Developers
14
14
  Classifier: Operating System :: MacOS
15
15
  Classifier: Programming Language :: Python :: 3
16
16
  Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Programming Language :: Python :: 3.13
17
18
  Classifier: Topic :: Multimedia :: Graphics :: Graphics Conversion
18
19
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
20
  Classifier: Typing :: Typed
20
- Requires-Python: <3.13,>=3.12
21
+ Requires-Python: >=3.12
21
22
  Requires-Dist: coremltools<10,>=9
22
23
  Requires-Dist: diffusers>=0.30
23
24
  Requires-Dist: numpy<3,>=2
@@ -44,6 +44,7 @@ __all__ = [
44
44
  "compose_component_name",
45
45
  "lora_names_from_params",
46
46
  "convert",
47
+ "detect_model_version",
47
48
  "build_pipeline",
48
49
  "CoreMLUNet",
49
50
  "CoreMLVAE",
@@ -123,7 +124,23 @@ def __getattr__(name):
123
124
  if name == "convert":
124
125
  from coreml_diffusion.convert import convert as _convert
125
126
 
127
+ # Importing the submodule binds ``coreml_diffusion.convert`` to the
128
+ # MODULE as a side effect, which shadows this function on every later
129
+ # access (a module object isn't callable). Rebind the package attribute
130
+ # to the function so ``coreml_diffusion.convert(...)`` stays callable in
131
+ # long-lived processes (e.g. a ComfyUI server doing >1 conversion).
132
+ globals()["convert"] = _convert
126
133
  return _convert
134
+ if name == "detect_model_version":
135
+ # Lives in the framework-free state_dict module (reads only the
136
+ # safetensors header), so exposing it never drags coremltools/diffusers
137
+ # into the import path.
138
+ from coreml_diffusion.conversion.state_dict import (
139
+ detect_model_version as _detect,
140
+ )
141
+
142
+ globals()["detect_model_version"] = _detect
143
+ return _detect
127
144
  if name in ("build_pipeline", "CoreMLUNet", "CoreMLVAE", "CoreMLTextEncoder"):
128
145
  from coreml_diffusion import inference
129
146
 
@@ -33,9 +33,14 @@ def _convert_cmd(args):
33
33
  sample_size = (args.height // 8, args.width // 8)
34
34
  lora_weights = [_parse_lora(spec) for spec in (args.lora or [])]
35
35
  ckpt = sources.resolve_checkpoint(args.ckpt, args.source)
36
+ model_version = (
37
+ coreml_diffusion.ModelVersion[args.model_version]
38
+ if args.model_version
39
+ else None
40
+ )
36
41
  coreml_diffusion.convert(
37
42
  ckpt,
38
- coreml_diffusion.ModelVersion[args.model_version],
43
+ model_version,
39
44
  args.out,
40
45
  component=args.component,
41
46
  batch_size=args.batch_size,
@@ -94,11 +99,13 @@ def build_parser():
94
99
  )
95
100
  conv.add_argument(
96
101
  "--model-version",
97
- required=True,
98
- # include experimental: the CLI is the power-user path. Experimental
99
- # versions (LCM, SDXL_REFINER) convert but are not golden-verified.
102
+ default=None,
103
+ # Auto-detected from the checkpoint when omitted. Choices stay available
104
+ # as an explicit override (the CLI is the power-user path; experimental
105
+ # versions convert but are not golden-verified).
100
106
  choices=coreml_diffusion.list_model_versions(include_experimental=True),
101
- help="Model architecture (verified: SD15, SDXL; experimental otherwise)",
107
+ help="Model architecture; auto-detected from the checkpoint when omitted "
108
+ "(verified: SD15, SDXL; experimental otherwise)",
102
109
  )
103
110
  conv.add_argument(
104
111
  "--component",
@@ -0,0 +1,99 @@
1
+ """State-dict layout predicates — framework-free (no coremltools/diffusers).
2
+
3
+ Single-file checkpoints come in two layouts: original LDM (UNet keys under
4
+ ``model.diffusion_model.``) and diffusers-native UNet-only dumps (block keys at
5
+ the top level — e.g. ``LCM_Dreamshaper_v7_4k.safetensors``, the canonical
6
+ full-distill LCM artifact). diffusers' ``from_single_file`` only understands
7
+ the former and raises ``SingleFileComponentError`` on the latter;
8
+ ``convert.load_unet`` routes on this predicate.
9
+
10
+ Also home to ``detect_model_version``: it reads only the safetensors header (no
11
+ coremltools/diffusers), so the conversion entrypoint can auto-pick the model
12
+ version without dragging the heavy stack into the discovery path.
13
+ """
14
+
15
+ from coreml_diffusion.model_version import ModelVersion
16
+
17
+ DIFFUSERS_UNET_KEY_PREFIXES = ("down_blocks.", "up_blocks.", "mid_block.")
18
+ LDM_UNET_KEY_PREFIX = "model.diffusion_model."
19
+
20
+ # cross_attention_dim -> model version. The context (key/value) dim of the UNet's
21
+ # cross-attention is the architecture fingerprint: SD1.5 conditions on a single
22
+ # 768-dim CLIP, SDXL on the 2048-dim concat of both encoders, the SDXL refiner on
23
+ # the 1280-dim OpenCLIP-bigG alone. A guidance embedding (time_embedding.cond_proj)
24
+ # on top of the 768-dim SD1.5 stack marks a full-distill LCM.
25
+ _CROSS_ATTENTION_DIM_TO_VERSION = {
26
+ 768: ModelVersion.SD15,
27
+ 2048: ModelVersion.SDXL,
28
+ 1280: ModelVersion.SDXL_REFINER,
29
+ }
30
+
31
+
32
+ def is_diffusers_unet_layout(keys) -> bool:
33
+ """True when ``keys`` form a diffusers-format UNet-only state dict."""
34
+ keys = list(keys)
35
+ has_diffusers_blocks = any(k.startswith(DIFFUSERS_UNET_KEY_PREFIXES) for k in keys)
36
+ has_ldm_prefix = any(k.startswith(LDM_UNET_KEY_PREFIX) for k in keys)
37
+ return has_diffusers_blocks and not has_ldm_prefix
38
+
39
+
40
+ def safetensors_keys(ckpt_path):
41
+ """The file's key list when it is safetensors, else None.
42
+
43
+ Probes by content, not filename — a resolved checkpoint path may point at
44
+ an extension-less blob (e.g. inside the Hugging Face cache).
45
+ """
46
+ from safetensors import SafetensorError, safe_open
47
+
48
+ try:
49
+ with safe_open(ckpt_path, framework="pt") as f:
50
+ return list(f.keys())
51
+ except SafetensorError:
52
+ return None
53
+
54
+
55
+ def detect_model_version(ckpt_path):
56
+ """Infer the ``ModelVersion`` from a checkpoint's UNet weights.
57
+
58
+ Reads two architecture fingerprints straight from the safetensors header
59
+ (no full model load): the cross-attention context dim (``attn2.to_k``) and
60
+ whether a guidance embedding (``cond_proj``) is present. Works for both LDM
61
+ and diffusers key layouts. Raises ``ValueError`` carrying the observed
62
+ evidence when the architecture is unrecognised, so a bad guess is debuggable
63
+ rather than silent.
64
+ """
65
+ keys = safetensors_keys(ckpt_path)
66
+ if keys is None:
67
+ raise ValueError(
68
+ f"Cannot auto-detect model version from {ckpt_path!r}: not a readable "
69
+ "safetensors file. Pass model_version explicitly."
70
+ )
71
+ cross_attn_key = next((k for k in keys if k.endswith("attn2.to_k.weight")), None)
72
+ if cross_attn_key is None:
73
+ raise ValueError(
74
+ f"Cannot auto-detect model version from {ckpt_path!r}: no cross-attention "
75
+ "(attn2.to_k) weights found. Pass model_version explicitly."
76
+ )
77
+ from safetensors import safe_open
78
+
79
+ with safe_open(ckpt_path, framework="pt") as f:
80
+ cross_attention_dim = f.get_slice(cross_attn_key).get_shape()[1]
81
+ has_guidance_embedding = any(k.endswith("cond_proj.weight") for k in keys)
82
+
83
+ if has_guidance_embedding:
84
+ if cross_attention_dim == 768:
85
+ return ModelVersion.LCM
86
+ raise ValueError(
87
+ f"Cannot auto-detect model version from {ckpt_path!r}: found a guidance "
88
+ f"embedding (LCM) but cross_attention_dim={cross_attention_dim}; only "
89
+ "SD1.5-class LCM (cross_attention_dim=768) is supported. Pass "
90
+ "model_version explicitly."
91
+ )
92
+ version = _CROSS_ATTENTION_DIM_TO_VERSION.get(cross_attention_dim)
93
+ if version is None:
94
+ raise ValueError(
95
+ f"Cannot auto-detect model version from {ckpt_path!r}: unrecognised "
96
+ f"cross_attention_dim={cross_attention_dim}. Supported: 768 (SD15/LCM), "
97
+ "2048 (SDXL), 1280 (SDXL_REFINER). Pass model_version explicitly."
98
+ )
99
+ return version
@@ -22,7 +22,11 @@ from coreml_diffusion.attention import ATTENTION_IMPLEMENTATIONS
22
22
  from coreml_diffusion.component import Component
23
23
  from coreml_diffusion.conversion.attention import apply_attention_implementation
24
24
  from coreml_diffusion.conversion.shapes import conv2d_output_shape
25
- from coreml_diffusion.conversion.state_dict import is_diffusers_unet_layout
25
+ from coreml_diffusion.conversion.state_dict import (
26
+ detect_model_version,
27
+ is_diffusers_unet_layout,
28
+ safetensors_keys,
29
+ )
26
30
  from coreml_diffusion.conversion.text_encoder import (
27
31
  CoreMLTextEncoderWrapper,
28
32
  static_causal_mask,
@@ -460,8 +464,8 @@ def convert_text_encoder(
460
464
 
461
465
  def convert(
462
466
  ckpt_path: str,
463
- model_version: ModelVersion,
464
- out_path: str,
467
+ model_version: ModelVersion = None,
468
+ out_path: str = None,
465
469
  *,
466
470
  component: str = Component.UNET.value,
467
471
  batch_size: int = 1,
@@ -474,20 +478,28 @@ def convert(
474
478
  ):
475
479
  """Convert a single-file checkpoint component to a Core ML ``.mlpackage``.
476
480
 
477
- ``component`` selects what to convert (default ``"unet"`` — historical
478
- behaviour, all UNet-only kwargs apply). ``vae_decoder`` / ``vae_encoder`` /
479
- ``text_encoder`` / ``text_encoder_2`` convert the corresponding sub-model and
480
- ignore the UNet-only kwargs (``controlnet_support``, ``lora_weights``,
481
- ``attn_impl``). Keyword-only past the three required positionals so the package
482
- can add capabilities without breaking an older caller. Writes ``out_path``;
483
- returns None.
481
+ ``model_version`` is auto-detected from the checkpoint when left ``None``
482
+ (the architecture fully determines the conversion); pass it explicitly only
483
+ to override a misdetection. ``component`` selects what to convert (default
484
+ ``"unet"`` historical behaviour, all UNet-only kwargs apply).
485
+ ``vae_decoder`` / ``vae_encoder`` / ``text_encoder`` / ``text_encoder_2``
486
+ convert the corresponding sub-model and ignore the UNet-only kwargs
487
+ (``controlnet_support``, ``lora_weights``, ``attn_impl``). Keyword-only past
488
+ the leading positionals so the package can add capabilities without breaking
489
+ an older caller. Writes ``out_path``; returns None.
484
490
  """
491
+ if out_path is None:
492
+ raise TypeError("convert() requires out_path")
485
493
  if os.path.exists(out_path):
486
494
  logger.info(f"Found existing model at {out_path}! Skipping..")
487
495
  return
488
496
 
489
497
  comp = Component(component)
490
498
 
499
+ if model_version is None:
500
+ model_version = detect_model_version(ckpt_path)
501
+ logger.info(f"Auto-detected model version: {model_version.name}")
502
+
491
503
  if comp is Component.VAE_DECODER:
492
504
  convert_vae_decoder(
493
505
  load_vae(ckpt_path),
@@ -553,7 +565,7 @@ def load_unet(ckpt_path, config_path):
553
565
  e.g. ``LCM_Dreamshaper_v7_4k.safetensors``) are rejected by
554
566
  ``from_single_file`` outright, so they get a direct state-dict load.
555
567
  """
556
- keys = _safetensors_keys(ckpt_path)
568
+ keys = safetensors_keys(ckpt_path)
557
569
  if keys is not None and is_diffusers_unet_layout(keys):
558
570
  return load_unet_from_diffusers_state_dict(ckpt_path)
559
571
  return UNet2DConditionModel.from_single_file(
@@ -562,21 +574,6 @@ def load_unet(ckpt_path, config_path):
562
574
  )
563
575
 
564
576
 
565
- def _safetensors_keys(ckpt_path):
566
- """The file's key list when it is safetensors, else None.
567
-
568
- Probes by content, not filename — a resolved checkpoint path may point at
569
- an extension-less blob (e.g. inside the Hugging Face cache).
570
- """
571
- from safetensors import SafetensorError, safe_open
572
-
573
- try:
574
- with safe_open(ckpt_path, framework="pt") as f:
575
- return list(f.keys())
576
- except SafetensorError:
577
- return None
578
-
579
-
580
577
  def load_unet_from_diffusers_state_dict(ckpt_path, **config_overrides):
581
578
  """Load a diffusers-layout UNet-only safetensors dump (SD1.5-class).
582
579
 
@@ -1,10 +1,10 @@
1
1
  [project]
2
2
  name = "coreml-diffusion"
3
3
  description = "Convert diffusion-model checkpoints (SD1.5/SDXL) to Core ML for Apple Neural Engine — framework-free, ComfyUI-independent."
4
- version = "0.1.4"
4
+ version = "0.1.6"
5
5
  license = "MIT"
6
6
  license-files = ["LICENSE"]
7
- requires-python = ">=3.12,<3.13"
7
+ requires-python = ">=3.12"
8
8
  readme = "README.md"
9
9
  authors = [{ name = "Adrian Szczepański", email = "hi@aszc.dev" }]
10
10
  keywords = [
@@ -24,6 +24,7 @@ classifiers = [
24
24
  "Operating System :: MacOS",
25
25
  "Programming Language :: Python :: 3",
26
26
  "Programming Language :: Python :: 3.12",
27
+ "Programming Language :: Python :: 3.13",
27
28
  "Topic :: Scientific/Engineering :: Artificial Intelligence",
28
29
  "Topic :: Multimedia :: Graphics :: Graphics Conversion",
29
30
  "Typing :: Typed",
@@ -0,0 +1,77 @@
1
+ """Tier 0: model-version auto-detection from checkpoint weights.
2
+
3
+ Locks the architecture fingerprinting that lets ``convert(model_version=None)``
4
+ pick the right conversion path: cross-attention context dim (attn2.to_k) plus
5
+ the presence of a guidance embedding (cond_proj). Synthetic safetensors files
6
+ carry only the two keys the detector reads, so the test stays framework-free.
7
+ """
8
+
9
+ import pytest
10
+ import torch
11
+ from safetensors.torch import save_file
12
+
13
+ from coreml_diffusion.conversion.state_dict import detect_model_version
14
+ from coreml_diffusion.model_version import ModelVersion
15
+
16
+
17
+ def _write_ckpt(path, cross_attention_dim, *, guidance=False, with_cross_attn=True):
18
+ tensors = {}
19
+ if with_cross_attn:
20
+ # attn2.to_k maps the context dim -> inner dim; shape[1] is what we read.
21
+ key = "down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k.weight"
22
+ tensors[key] = torch.zeros(320, cross_attention_dim)
23
+ if guidance:
24
+ tensors["time_embedding.cond_proj.weight"] = torch.zeros(320, 256)
25
+ if not tensors:
26
+ tensors["dummy"] = torch.zeros(1)
27
+ save_file(tensors, str(path))
28
+ return str(path)
29
+
30
+
31
+ @pytest.mark.parametrize(
32
+ "cross_attention_dim, guidance, expected",
33
+ [
34
+ (768, False, ModelVersion.SD15),
35
+ (768, True, ModelVersion.LCM),
36
+ (2048, False, ModelVersion.SDXL),
37
+ (1280, False, ModelVersion.SDXL_REFINER),
38
+ ],
39
+ )
40
+ def test_detects_known_architectures(tmp_path, cross_attention_dim, guidance, expected):
41
+ ckpt = _write_ckpt(
42
+ tmp_path / "m.safetensors", cross_attention_dim, guidance=guidance
43
+ )
44
+ assert detect_model_version(ckpt) is expected
45
+
46
+
47
+ def test_lcm_lora_merge_detects_as_sd15(tmp_path):
48
+ # An LCM-LoRA merge is plain SD1.5 architecture (no guidance embedding); it
49
+ # must NOT be mistaken for a full-distill LCM.
50
+ ckpt = _write_ckpt(tmp_path / "merge.safetensors", 768, guidance=False)
51
+ assert detect_model_version(ckpt) is ModelVersion.SD15
52
+
53
+
54
+ def test_guidance_with_non_sd15_dim_is_rejected(tmp_path):
55
+ ckpt = _write_ckpt(tmp_path / "sdxl_lcm.safetensors", 2048, guidance=True)
56
+ with pytest.raises(ValueError, match="only SD1.5-class LCM"):
57
+ detect_model_version(ckpt)
58
+
59
+
60
+ def test_unknown_cross_attention_dim_is_rejected(tmp_path):
61
+ # 1024 is SD2.x — unsupported; the error must name the observed dim.
62
+ ckpt = _write_ckpt(tmp_path / "sd2.safetensors", 1024)
63
+ with pytest.raises(ValueError, match="cross_attention_dim=1024"):
64
+ detect_model_version(ckpt)
65
+
66
+
67
+ def test_no_cross_attention_weights_is_rejected(tmp_path):
68
+ ckpt = _write_ckpt(tmp_path / "weird.safetensors", 768, with_cross_attn=False)
69
+ with pytest.raises(ValueError, match="no cross-attention"):
70
+ detect_model_version(ckpt)
71
+
72
+
73
+ def test_non_safetensors_is_rejected(tmp_path):
74
+ bogus = tmp_path / "model.ckpt"
75
+ bogus.write_bytes(b"not safetensors")
76
+ with pytest.raises(ValueError, match="not a readable safetensors"):
77
+ detect_model_version(str(bogus))