coreml-diffusion 0.1.3__tar.gz → 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. coreml_diffusion-0.1.4/.release-please-manifest.json +3 -0
  2. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/CHANGELOG.md +8 -0
  3. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/PKG-INFO +1 -1
  4. coreml_diffusion-0.1.4/coreml_diffusion/conversion/state_dict.py +20 -0
  5. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/coreml_diffusion/convert.py +71 -3
  6. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/coreml_diffusion/inference.py +22 -7
  7. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/pyproject.toml +1 -1
  8. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/tests/smoke/test_coreml_adapters.py +24 -0
  9. coreml_diffusion-0.1.4/tests/smoke/test_lcm_conversion.py +108 -0
  10. coreml_diffusion-0.1.4/tests/unit/test_state_dict_layout.py +43 -0
  11. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/uv.lock +1 -1
  12. coreml_diffusion-0.1.3/.release-please-manifest.json +0 -3
  13. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/.github/workflows/publish-pypi.yml +0 -0
  14. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/.github/workflows/release-please.yml +0 -0
  15. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/.github/workflows/tier0.yml +0 -0
  16. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/.github/workflows/tier1.yml +0 -0
  17. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/.github/workflows/tier2.yml +0 -0
  18. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/.gitignore +0 -0
  19. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/LICENSE +0 -0
  20. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/README.md +0 -0
  21. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/coreml_diffusion/__init__.py +0 -0
  22. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/coreml_diffusion/attention.py +0 -0
  23. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/coreml_diffusion/cli.py +0 -0
  24. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/coreml_diffusion/component.py +0 -0
  25. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/coreml_diffusion/conversion/__init__.py +0 -0
  26. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/coreml_diffusion/conversion/attention.py +0 -0
  27. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/coreml_diffusion/conversion/shapes.py +0 -0
  28. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/coreml_diffusion/conversion/text_encoder.py +0 -0
  29. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/coreml_diffusion/conversion/trace.py +0 -0
  30. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/coreml_diffusion/conversion/unet.py +0 -0
  31. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/coreml_diffusion/conversion/vae.py +0 -0
  32. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/coreml_diffusion/logger.py +0 -0
  33. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/coreml_diffusion/model_version.py +0 -0
  34. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/coreml_diffusion/naming.py +0 -0
  35. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/coreml_diffusion/sources.py +0 -0
  36. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/release-please-config.json +0 -0
  37. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/tests/conftest.py +0 -0
  38. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/tests/m2/goldens/sd15_astronaut.png +0 -0
  39. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/tests/m2/goldens/sd15_astronaut.sha256 +0 -0
  40. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/tests/m2/goldens/sd15_astronaut_full_coreml.png +0 -0
  41. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/tests/m2/goldens/sd15_astronaut_full_coreml.sha256 +0 -0
  42. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/tests/m2/test_inference_golden.py +0 -0
  43. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/tests/m2/test_original_gpu.py +0 -0
  44. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/tests/smoke/test_original_attention.py +0 -0
  45. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/tests/smoke/test_split_einsum_attention.py +0 -0
  46. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/tests/smoke/test_synthetic_text_encoder.py +0 -0
  47. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/tests/smoke/test_synthetic_unet.py +0 -0
  48. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/tests/smoke/test_synthetic_vae.py +0 -0
  49. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/tests/unit/test_characterization_component_name.py +0 -0
  50. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/tests/unit/test_characterization_out_name.py +0 -0
  51. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/tests/unit/test_cli.py +0 -0
  52. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/tests/unit/test_conversion_helpers.py +0 -0
  53. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/tests/unit/test_discovery_api.py +0 -0
  54. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/tests/unit/test_inference_output_contract.py +0 -0
  55. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/tests/unit/test_sources.py +0 -0
  56. {coreml_diffusion-0.1.3 → coreml_diffusion-0.1.4}/tests/unit/test_tier0_purity.py +0 -0
@@ -0,0 +1,3 @@
1
+ {
2
+ ".": "0.1.4"
3
+ }
@@ -1,5 +1,13 @@
1
1
  # Changelog
2
2
 
3
+ ## [0.1.4](https://github.com/aszc-dev/coreml-diffusion/compare/v0.1.3...v0.1.4) (2026-06-13)
4
+
5
+
6
+ ### 🐛 Bug Fixes
7
+
8
+ * **convert:** generic LCM conversion for arbitrary checkpoints ([bedeb49](https://github.com/aszc-dev/coreml-diffusion/commit/bedeb49a84a33b4530c6c7d4ed4343a914a400e2))
9
+ * **inference:** add output_dtype to CoreMLTextEncoder ([eb6d3b5](https://github.com/aszc-dev/coreml-diffusion/commit/eb6d3b5b08b6ff34a4eb8683de3d1223fb447186))
10
+
3
11
  ## [0.1.3](https://github.com/aszc-dev/coreml-diffusion/compare/v0.1.2...v0.1.3) (2026-06-04)
4
12
 
5
13
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: coreml-diffusion
3
- Version: 0.1.3
3
+ Version: 0.1.4
4
4
  Summary: Convert diffusion-model checkpoints (SD1.5/SDXL) to Core ML for Apple Neural Engine — framework-free, ComfyUI-independent.
5
5
  Project-URL: Homepage, https://github.com/aszc-dev/coreml-diffusion
6
6
  Project-URL: Repository, https://github.com/aszc-dev/coreml-diffusion
@@ -0,0 +1,20 @@
1
+ """State-dict layout predicates — framework-free (no coremltools/diffusers).
2
+
3
+ Single-file checkpoints come in two layouts: original LDM (UNet keys under
4
+ ``model.diffusion_model.``) and diffusers-native UNet-only dumps (block keys at
5
+ the top level — e.g. ``LCM_Dreamshaper_v7_4k.safetensors``, the canonical
6
+ full-distill LCM artifact). diffusers' ``from_single_file`` only understands
7
+ the former and raises ``SingleFileComponentError`` on the latter;
8
+ ``convert.load_unet`` routes on this predicate.
9
+ """
10
+
11
+ DIFFUSERS_UNET_KEY_PREFIXES = ("down_blocks.", "up_blocks.", "mid_block.")
12
+ LDM_UNET_KEY_PREFIX = "model.diffusion_model."
13
+
14
+
15
+ def is_diffusers_unet_layout(keys) -> bool:
16
+ """True when ``keys`` form a diffusers-format UNet-only state dict."""
17
+ keys = list(keys)
18
+ has_diffusers_blocks = any(k.startswith(DIFFUSERS_UNET_KEY_PREFIXES) for k in keys)
19
+ has_ldm_prefix = any(k.startswith(LDM_UNET_KEY_PREFIX) for k in keys)
20
+ return has_diffusers_blocks and not has_ldm_prefix
@@ -22,6 +22,7 @@ from coreml_diffusion.attention import ATTENTION_IMPLEMENTATIONS
22
22
  from coreml_diffusion.component import Component
23
23
  from coreml_diffusion.conversion.attention import apply_attention_implementation
24
24
  from coreml_diffusion.conversion.shapes import conv2d_output_shape
25
+ from coreml_diffusion.conversion.state_dict import is_diffusers_unet_layout
25
26
  from coreml_diffusion.conversion.text_encoder import (
26
27
  CoreMLTextEncoderWrapper,
27
28
  static_causal_mask,
@@ -139,9 +140,17 @@ def get_sample_input(
139
140
  return sample_unet_inputs
140
141
 
141
142
 
142
- def lcm_inputs(sample_unet_inputs):
143
+ def lcm_inputs(sample_unet_inputs, ref_unet=None):
144
+ """Build the LCM guidance-embedding (``timestep_cond``) trace input.
145
+
146
+ The embedding dim comes from ``ref_unet.config.time_cond_proj_dim``. The
147
+ 256 fallback preserves the legacy ref-unet-less call shape (the Suite's LCM
148
+ converter calls this with one argument); 256 is correct for every known
149
+ SD1.5 full-distill LCM.
150
+ """
143
151
  batch_size = sample_unet_inputs["sample"].shape[0]
144
- return {"timestep_cond": torch.randn(batch_size, 256).to(torch.float32)}
152
+ dim = 256 if ref_unet is None else ref_unet.config.time_cond_proj_dim
153
+ return {"timestep_cond": torch.randn(batch_size, dim).to(torch.float32)}
145
154
 
146
155
 
147
156
  def sdxl_inputs(sample_unet_inputs, ref_unet, model_version):
@@ -257,7 +266,15 @@ def convert_unet(
257
266
  )
258
267
 
259
268
  if model_version == ModelVersion.LCM:
260
- sample_inputs |= lcm_inputs(sample_inputs)
269
+ if ref_unet.config.time_cond_proj_dim is None:
270
+ raise ValueError(
271
+ "model_version=LCM requires a UNet with a guidance embedding "
272
+ "(config.time_cond_proj_dim), but this checkpoint has none — "
273
+ "it is an LCM-LoRA merge with plain SD1.5 architecture. "
274
+ "Convert it with model_version=SD15 and use an LCM scheduler "
275
+ "at sampling time."
276
+ )
277
+ sample_inputs |= lcm_inputs(sample_inputs, ref_unet)
261
278
 
262
279
  if model_version in {ModelVersion.SDXL, ModelVersion.SDXL_REFINER}:
263
280
  sample_inputs |= sdxl_inputs(sample_inputs, ref_unet, model_version)
@@ -529,12 +546,63 @@ def convert(
529
546
 
530
547
 
531
548
  def load_unet(ckpt_path, config_path):
549
+ """Load the UNet from a single-file checkpoint, routing on state-dict layout.
550
+
551
+ LDM-layout files (the Civitai norm) go through ``from_single_file``.
552
+ Diffusers-layout UNet-only dumps (the canonical full-distill LCM artifacts,
553
+ e.g. ``LCM_Dreamshaper_v7_4k.safetensors``) are rejected by
554
+ ``from_single_file`` outright, so they get a direct state-dict load.
555
+ """
556
+ keys = _safetensors_keys(ckpt_path)
557
+ if keys is not None and is_diffusers_unet_layout(keys):
558
+ return load_unet_from_diffusers_state_dict(ckpt_path)
532
559
  return UNet2DConditionModel.from_single_file(
533
560
  ckpt_path,
534
561
  original_config=config_path,
535
562
  )
536
563
 
537
564
 
565
+ def _safetensors_keys(ckpt_path):
566
+ """The file's key list when it is safetensors, else None.
567
+
568
+ Probes by content, not filename — a resolved checkpoint path may point at
569
+ an extension-less blob (e.g. inside the Hugging Face cache).
570
+ """
571
+ from safetensors import SafetensorError, safe_open
572
+
573
+ try:
574
+ with safe_open(ckpt_path, framework="pt") as f:
575
+ return list(f.keys())
576
+ except SafetensorError:
577
+ return None
578
+
579
+
580
+ def load_unet_from_diffusers_state_dict(ckpt_path, **config_overrides):
581
+ """Load a diffusers-layout UNet-only safetensors dump (SD1.5-class).
582
+
583
+ SD1.5 architecture is assumed (``UNet2DConditionModel`` defaults); the
584
+ cross-attention and guidance-embedding dims are read from the weights, so
585
+ both full-distill LCM dumps (``time_embedding.cond_proj`` present) and
586
+ plain SD1.5 UNet dumps load correctly. A non-SD1.5-class dump fails the
587
+ strict ``load_state_dict`` with an explicit shape/key error.
588
+ ``config_overrides`` exists for tests exercising miniature architectures.
589
+ """
590
+ from safetensors.torch import load_file
591
+
592
+ state_dict = load_file(ckpt_path)
593
+ cond_proj = state_dict.get("time_embedding.cond_proj.weight")
594
+ config_kwargs = {
595
+ "sample_size": 64,
596
+ "cross_attention_dim": state_dict[
597
+ "down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k.weight"
598
+ ].shape[1],
599
+ "time_cond_proj_dim": None if cond_proj is None else cond_proj.shape[1],
600
+ } | config_overrides
601
+ unet = UNet2DConditionModel(**config_kwargs)
602
+ unet.load_state_dict(state_dict)
603
+ return unet
604
+
605
+
538
606
  def load_vae(ckpt_path):
539
607
  """Load ``AutoencoderKL`` from a single-file checkpoint."""
540
608
  from diffusers import AutoencoderKL
@@ -268,13 +268,24 @@ class CoreMLTextEncoder(torch.nn.Module):
268
268
  """
269
269
 
270
270
  def __init__(
271
- self, mlpackage_path, ref_text_encoder, *, compute_unit=DEFAULT_COMPUTE_UNIT
271
+ self,
272
+ mlpackage_path,
273
+ ref_text_encoder,
274
+ *,
275
+ compute_unit=DEFAULT_COMPUTE_UNIT,
276
+ output_dtype=torch.float16,
272
277
  ):
273
278
  super().__init__()
274
279
  import coremltools as ct
275
280
 
276
281
  self.config = ref_text_encoder.config
277
- self.dtype = torch.float16
282
+ # The package runs fp16 internally. ``output_dtype`` is the dtype of the
283
+ # embeddings handed back to the pipeline; keep fp16 (default) for an
284
+ # all-Core ML pipeline, or set fp32 when the embeddings feed a torch fp32
285
+ # component (e.g. an OAT config with a Core ML text encoder + torch UNet),
286
+ # where diffusers propagates ``prompt_embeds.dtype`` to the latents and time
287
+ # embedding and a Half/Float mismatch would otherwise crash the UNet.
288
+ self.dtype = output_dtype
278
289
  unit = _resolve_unit(compute_unit)
279
290
  logger.info(f"Loading text encoder {mlpackage_path} to {unit.name}")
280
291
  self.model = ct.models.MLModel(mlpackage_path, compute_units=unit)
@@ -294,14 +305,18 @@ class CoreMLTextEncoder(torch.nn.Module):
294
305
  **_ignored,
295
306
  ):
296
307
  prediction = self.model.predict({"input_ids": _i32(input_ids)})
297
- embeds = torch.from_numpy(np.ascontiguousarray(prediction["hidden_states"])).to(
298
- input_ids.device
308
+ embeds = (
309
+ torch.from_numpy(np.ascontiguousarray(prediction["hidden_states"]))
310
+ .to(input_ids.device)
311
+ .to(self.dtype)
299
312
  )
300
313
  pooled = None
301
314
  if self._pooled_name is not None:
302
- pooled = torch.from_numpy(
303
- np.ascontiguousarray(prediction[self._pooled_name])
304
- ).to(input_ids.device)
315
+ pooled = (
316
+ torch.from_numpy(np.ascontiguousarray(prediction[self._pooled_name]))
317
+ .to(input_ids.device)
318
+ .to(self.dtype)
319
+ )
305
320
  return _CoreMLTextEncoderOutput(embeds, pooled)
306
321
 
307
322
 
@@ -1,7 +1,7 @@
1
1
  [project]
2
2
  name = "coreml-diffusion"
3
3
  description = "Convert diffusion-model checkpoints (SD1.5/SDXL) to Core ML for Apple Neural Engine — framework-free, ComfyUI-independent."
4
- version = "0.1.3"
4
+ version = "0.1.4"
5
5
  license = "MIT"
6
6
  license-files = ["LICENSE"]
7
7
  requires-python = ">=3.12,<3.13"
@@ -118,6 +118,30 @@ def test_coreml_text_encoder_sd15_shape(tmp_path):
118
118
 
119
119
  assert out[0].shape == (1, SEQ_LEN, HIDDEN) # no pooled -> [0] is embeds
120
120
  assert out.hidden_states[-2].shape == (1, SEQ_LEN, HIDDEN)
121
+ assert out[0].dtype == torch.float16 # default: fp16 for an all-Core ML pipeline
122
+
123
+
124
+ def test_coreml_text_encoder_output_dtype_fp32(tmp_path):
125
+ # output_dtype=fp32 bridges a Core ML text encoder feeding a torch fp32 UNet
126
+ # (an OAT config): the embeddings must come back fp32 so diffusers does not
127
+ # propagate fp16 into the fp32 UNet and crash on a Half/Float mismatch.
128
+ from transformers import CLIPTextModel
129
+
130
+ from coreml_diffusion.inference import CoreMLTextEncoder
131
+
132
+ torch.manual_seed(0)
133
+ encoder = CLIPTextModel(_tiny_clip_config())
134
+ pkg = tmp_path / "text_encoder.mlpackage"
135
+ _convert_text_encoder(encoder, ["hidden_states"], pkg, index=None, pooled=False)
136
+
137
+ coreml_te = CoreMLTextEncoder(
138
+ str(pkg), encoder, compute_unit="CPU_ONLY", output_dtype=torch.float32
139
+ )
140
+ out = coreml_te(torch.zeros(1, SEQ_LEN, dtype=torch.long))
141
+
142
+ assert coreml_te.dtype == torch.float32
143
+ assert out[0].dtype == torch.float32
144
+ assert out.last_hidden_state.dtype == torch.float32
121
145
 
122
146
 
123
147
  def test_coreml_text_encoder_sdxl_pooled(tmp_path):
@@ -0,0 +1,108 @@
1
+ """Tier 1 smoke: the generic LCM conversion path on a synthetic micro-UNet.
2
+
3
+ Locks the E-LCM fixes: the ``timestep_cond`` trace input takes its dim from
4
+ ``config.time_cond_proj_dim`` (not a hardcoded 256), an LCM-LoRA merge (no
5
+ guidance embedding) is rejected with a clear error, and diffusers-layout
6
+ UNet-only dumps (the canonical full-distill LCM artifact shape) load through
7
+ ``load_unet_from_diffusers_state_dict``.
8
+
9
+ Auto-skips on non-Apple-Silicon hosts so Tier 0 CI on Linux ignores it.
10
+ """
11
+
12
+ import platform
13
+
14
+ import pytest
15
+ import torch
16
+
17
+ pytestmark = pytest.mark.skipif(
18
+ platform.system() != "Darwin" or platform.machine() != "arm64",
19
+ reason="Tier 1 requires macOS on Apple Silicon",
20
+ )
21
+
22
+ GUIDANCE_DIM = 16 # deliberately != 256 to prove the dim comes from the config
23
+
24
+ TINY_ARCH = dict(
25
+ sample_size=8,
26
+ layers_per_block=1,
27
+ block_out_channels=(32, 64),
28
+ down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
29
+ up_block_types=("UpBlock2D", "CrossAttnUpBlock2D"),
30
+ cross_attention_dim=32,
31
+ attention_head_dim=8,
32
+ norm_num_groups=8,
33
+ )
34
+
35
+
36
+ def _tiny_unet(time_cond_proj_dim=None):
37
+ from diffusers import UNet2DConditionModel
38
+
39
+ return UNet2DConditionModel(**TINY_ARCH, time_cond_proj_dim=time_cond_proj_dim)
40
+
41
+
42
+ def test_lcm_unet_converts_with_config_guidance_dim(tmp_path):
43
+ import coremltools as ct
44
+
45
+ from coreml_diffusion.convert import convert_unet
46
+ from coreml_diffusion.model_version import ModelVersion
47
+
48
+ torch.manual_seed(0)
49
+ out_path = str(tmp_path / "lcm_unet.mlpackage")
50
+ convert_unet(
51
+ _tiny_unet(time_cond_proj_dim=GUIDANCE_DIM),
52
+ ModelVersion.LCM,
53
+ out_path,
54
+ sample_size=(8, 8),
55
+ )
56
+
57
+ spec = ct.models.MLModel(out_path, skip_model_load=True).get_spec()
58
+ inputs = {
59
+ i.name: tuple(i.type.multiArrayType.shape) for i in spec.description.input
60
+ }
61
+ assert "timestep_cond" in inputs
62
+ assert inputs["timestep_cond"] == (1, GUIDANCE_DIM)
63
+
64
+
65
+ def test_lcm_merge_checkpoint_is_rejected(tmp_path):
66
+ from coreml_diffusion.convert import convert_unet
67
+ from coreml_diffusion.model_version import ModelVersion
68
+
69
+ torch.manual_seed(0)
70
+ with pytest.raises(ValueError, match="LCM-LoRA merge"):
71
+ convert_unet(
72
+ _tiny_unet(time_cond_proj_dim=None),
73
+ ModelVersion.LCM,
74
+ str(tmp_path / "merge.mlpackage"),
75
+ sample_size=(8, 8),
76
+ )
77
+
78
+
79
+ def test_diffusers_layout_dump_loads_with_guidance_embedding(tmp_path):
80
+ from safetensors.torch import save_file
81
+
82
+ from coreml_diffusion.convert import load_unet_from_diffusers_state_dict
83
+
84
+ torch.manual_seed(0)
85
+ source = _tiny_unet(time_cond_proj_dim=GUIDANCE_DIM)
86
+ dump = tmp_path / "tiny_lcm_dump.safetensors"
87
+ save_file(source.state_dict(), str(dump))
88
+
89
+ # cross_attention_dim and time_cond_proj_dim must be read from the weights;
90
+ # only the miniature architecture is supplied as overrides.
91
+ loaded = load_unet_from_diffusers_state_dict(
92
+ str(dump), **{k: v for k, v in TINY_ARCH.items() if k != "cross_attention_dim"}
93
+ )
94
+
95
+ assert loaded.config.cross_attention_dim == TINY_ARCH["cross_attention_dim"]
96
+ assert loaded.config.time_cond_proj_dim == GUIDANCE_DIM
97
+ assert torch.equal(
98
+ loaded.time_embedding.cond_proj.weight, source.time_embedding.cond_proj.weight
99
+ )
100
+
101
+
102
+ def test_lcm_inputs_legacy_call_keeps_256():
103
+ # The Suite's LCM converter calls lcm_inputs without a ref UNet; the legacy
104
+ # 256 fallback is part of the import contract until the Suite migrates.
105
+ from coreml_diffusion.convert import lcm_inputs
106
+
107
+ inputs = lcm_inputs({"sample": torch.rand(2, 4, 8, 8)})
108
+ assert inputs["timestep_cond"].shape == (2, 256)
@@ -0,0 +1,43 @@
1
+ """Tier 0: state-dict layout detection routing ``load_unet``.
2
+
3
+ Locks the predicate that decides whether a single-file checkpoint is an LDM
4
+ checkpoint (``from_single_file`` path) or a diffusers-layout UNet-only dump
5
+ (direct state-dict load — the canonical full-distill LCM artifact shape).
6
+ """
7
+
8
+ from coreml_diffusion.conversion.state_dict import is_diffusers_unet_layout
9
+
10
+ DIFFUSERS_UNET_KEYS = [
11
+ "conv_in.weight",
12
+ "time_embedding.linear_1.weight",
13
+ "time_embedding.cond_proj.weight",
14
+ "down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k.weight",
15
+ "mid_block.resnets.0.conv1.weight",
16
+ "up_blocks.3.resnets.2.conv2.weight",
17
+ ]
18
+
19
+ LDM_CHECKPOINT_KEYS = [
20
+ "model.diffusion_model.time_embed.0.weight",
21
+ "model.diffusion_model.input_blocks.0.0.weight",
22
+ "first_stage_model.decoder.conv_in.weight",
23
+ "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight",
24
+ ]
25
+
26
+
27
+ def test_diffusers_unet_dump_is_detected():
28
+ assert is_diffusers_unet_layout(DIFFUSERS_UNET_KEYS)
29
+
30
+
31
+ def test_ldm_checkpoint_is_not_diffusers_layout():
32
+ assert not is_diffusers_unet_layout(LDM_CHECKPOINT_KEYS)
33
+
34
+
35
+ def test_mixed_prefixes_resolve_to_ldm():
36
+ # An LDM file whose extras coincidentally include diffusers-looking keys
37
+ # must still go through from_single_file.
38
+ assert not is_diffusers_unet_layout(LDM_CHECKPOINT_KEYS + DIFFUSERS_UNET_KEYS)
39
+
40
+
41
+ def test_unrelated_keys_are_not_diffusers_layout():
42
+ assert not is_diffusers_unet_layout(["text_model.encoder.layers.0.mlp.fc1.weight"])
43
+ assert not is_diffusers_unet_layout([])
@@ -106,7 +106,7 @@ wheels = [
106
106
 
107
107
  [[package]]
108
108
  name = "coreml-diffusion"
109
- version = "0.1.2"
109
+ version = "0.1.3"
110
110
  source = { editable = "." }
111
111
  dependencies = [
112
112
  { name = "coremltools" },
@@ -1,3 +0,0 @@
1
- {
2
- ".": "0.1.3"
3
- }