coreml-diffusion 0.1.2__tar.gz → 0.1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. coreml_diffusion-0.1.3/.release-please-manifest.json +3 -0
  2. coreml_diffusion-0.1.3/CHANGELOG.md +26 -0
  3. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/PKG-INFO +7 -1
  4. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/README.md +6 -0
  5. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/__init__.py +19 -2
  6. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/cli.py +8 -0
  7. coreml_diffusion-0.1.3/coreml_diffusion/component.py +32 -0
  8. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/conversion/attention.py +7 -3
  9. coreml_diffusion-0.1.3/coreml_diffusion/conversion/text_encoder.py +85 -0
  10. coreml_diffusion-0.1.3/coreml_diffusion/conversion/vae.py +49 -0
  11. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/convert.py +233 -11
  12. coreml_diffusion-0.1.3/coreml_diffusion/inference.py +380 -0
  13. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/naming.py +56 -0
  14. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/pyproject.toml +1 -1
  15. coreml_diffusion-0.1.3/tests/m2/goldens/sd15_astronaut_full_coreml.png +0 -0
  16. coreml_diffusion-0.1.3/tests/m2/goldens/sd15_astronaut_full_coreml.sha256 +1 -0
  17. coreml_diffusion-0.1.3/tests/m2/test_inference_golden.py +152 -0
  18. coreml_diffusion-0.1.3/tests/smoke/test_coreml_adapters.py +139 -0
  19. coreml_diffusion-0.1.3/tests/smoke/test_synthetic_text_encoder.py +98 -0
  20. coreml_diffusion-0.1.3/tests/smoke/test_synthetic_vae.py +77 -0
  21. coreml_diffusion-0.1.3/tests/unit/test_characterization_component_name.py +154 -0
  22. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/unit/test_cli.py +36 -0
  23. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/unit/test_discovery_api.py +28 -1
  24. coreml_diffusion-0.1.3/tests/unit/test_inference_output_contract.py +43 -0
  25. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/uv.lock +1 -1
  26. coreml_diffusion-0.1.2/.release-please-manifest.json +0 -3
  27. coreml_diffusion-0.1.2/CHANGELOG.md +0 -8
  28. coreml_diffusion-0.1.2/coreml_diffusion/inference.py +0 -176
  29. coreml_diffusion-0.1.2/tests/m2/test_inference_golden.py +0 -111
  30. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/.github/workflows/publish-pypi.yml +0 -0
  31. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/.github/workflows/release-please.yml +0 -0
  32. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/.github/workflows/tier0.yml +0 -0
  33. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/.github/workflows/tier1.yml +0 -0
  34. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/.github/workflows/tier2.yml +0 -0
  35. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/.gitignore +0 -0
  36. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/LICENSE +0 -0
  37. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/attention.py +0 -0
  38. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/conversion/__init__.py +0 -0
  39. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/conversion/shapes.py +0 -0
  40. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/conversion/trace.py +0 -0
  41. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/conversion/unet.py +0 -0
  42. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/logger.py +0 -0
  43. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/model_version.py +0 -0
  44. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/sources.py +0 -0
  45. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/release-please-config.json +0 -0
  46. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/conftest.py +0 -0
  47. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/m2/goldens/sd15_astronaut.png +0 -0
  48. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/m2/goldens/sd15_astronaut.sha256 +0 -0
  49. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/m2/test_original_gpu.py +0 -0
  50. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/smoke/test_original_attention.py +0 -0
  51. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/smoke/test_split_einsum_attention.py +0 -0
  52. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/smoke/test_synthetic_unet.py +0 -0
  53. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/unit/test_characterization_out_name.py +0 -0
  54. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/unit/test_conversion_helpers.py +0 -0
  55. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/unit/test_sources.py +0 -0
  56. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/unit/test_tier0_purity.py +0 -0
@@ -0,0 +1,3 @@
1
+ {
2
+ ".": "0.1.3"
3
+ }
@@ -0,0 +1,26 @@
1
+ # Changelog
2
+
3
+ ## [0.1.3](https://github.com/aszc-dev/coreml-diffusion/compare/v0.1.2...v0.1.3) (2026-06-04)
4
+
5
+
6
+ ### ✨ Features
7
+
8
+ * **convert:** add VAE and CLIP text-encoder conversion ([dc1f85b](https://github.com/aszc-dev/coreml-diffusion/commit/dc1f85bafe50d36655ff7ece0c052a30fd77bb81))
9
+ * **inference:** end-to-end Core ML pipeline (VAE + text-encoder swap) ([ca08b16](https://github.com/aszc-dev/coreml-diffusion/commit/ca08b16729529afbdf610d0e8ec2d09b849080c6))
10
+
11
+
12
+ ### 🐛 Bug Fixes
13
+
14
+ * **inference:** expose .device on the Core ML adapters ([30a673e](https://github.com/aszc-dev/coreml-diffusion/commit/30a673eebe3927722214d0ab6a44fbc344d18f3a))
15
+
16
+
17
+ ### 📚 Documentation
18
+
19
+ * **readme:** link the log.aszc.dev energy benchmark writeup ([77927b5](https://github.com/aszc-dev/coreml-diffusion/commit/77927b5dd5f1311a3b3c317692f3a347e3976a54))
20
+
21
+ ## [0.1.2](https://github.com/aszc-dev/coreml-diffusion/compare/v0.1.1...v0.1.2) (2026-05-27)
22
+
23
+
24
+ ### 🐛 Bug Fixes
25
+
26
+ * **attention:** convertible fp32 ORIGINAL attention for the Core ML GPU path ([#2](https://github.com/aszc-dev/coreml-diffusion/issues/2)) ([28e56fc](https://github.com/aszc-dev/coreml-diffusion/commit/28e56fcf8c2242ebbe4c05abd05f7e796069d7d1))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: coreml-diffusion
3
- Version: 0.1.2
3
+ Version: 0.1.3
4
4
  Summary: Convert diffusion-model checkpoints (SD1.5/SDXL) to Core ML for Apple Neural Engine — framework-free, ComfyUI-independent.
5
5
  Project-URL: Homepage, https://github.com/aszc-dev/coreml-diffusion
6
6
  Project-URL: Repository, https://github.com/aszc-dev/coreml-diffusion
@@ -44,6 +44,12 @@ GPU-free, embeddable in a Swift/iOS app. ANE is the differentiator — this is a
44
44
  feasibility and power efficiency for SD1.5/SDXL on ANE, not a raw-throughput claim
45
45
  against desktop GPUs.
46
46
 
47
+ The power-efficiency claim is measured: in a cross-backend benchmark the ct9
48
+ converter here runs the SD1.5 UNet on the ANE at **6-7x lower energy** than
49
+ GPU/MPS, at the same speed — see the writeup,
50
+ [The ANE runs the SD1.5 UNet at 6-7x lower energy than GPU/MPS](https://log.aszc.dev/ane-vs-gpu-mps-sd15-unet-energy/),
51
+ for the methodology and the numerical-divergence tradeoff.
52
+
47
53
  The scope is diffusion architectures generally, not Stable Diffusion specifically.
48
54
  The project aims to gather, in one place: the conversion path, a reproducible
49
55
  benchmarking suite for objective comparison, a per-model catalogue documenting the
@@ -15,6 +15,12 @@ GPU-free, embeddable in a Swift/iOS app. ANE is the differentiator — this is a
15
15
  feasibility and power efficiency for SD1.5/SDXL on ANE, not a raw-throughput claim
16
16
  against desktop GPUs.
17
17
 
18
+ The power-efficiency claim is measured: in a cross-backend benchmark the ct9
19
+ converter here runs the SD1.5 UNet on the ANE at **6-7x lower energy** than
20
+ GPU/MPS, at the same speed — see the writeup,
21
+ [The ANE runs the SD1.5 UNet at 6-7x lower energy than GPU/MPS](https://log.aszc.dev/ane-vs-gpu-mps-sd15-unet-energy/),
22
+ for the methodology and the numerical-divergence tradeoff.
23
+
18
24
  The scope is diffusion architectures generally, not Stable Diffusion specifically.
19
25
  The project aims to gather, in one place: the conversion path, a reproducible
20
26
  benchmarking suite for objective comparison, a per-model catalogue documenting the
@@ -23,9 +23,11 @@ because a saved workflow JSON references these strings verbatim.
23
23
  from enum import Enum
24
24
 
25
25
  from coreml_diffusion.attention import ATTENTION_IMPLEMENTATIONS
26
+ from coreml_diffusion.component import CONVERTIBLE_COMPONENTS
26
27
  from coreml_diffusion.model_version import ModelVersion
27
28
  from coreml_diffusion.naming import (
28
29
  QUANT_NBITS_VALUES,
30
+ compose_component_name,
29
31
  compose_out_name,
30
32
  lora_names_from_params,
31
33
  )
@@ -36,12 +38,16 @@ __all__ = [
36
38
  "list_model_versions",
37
39
  "list_attention_impls",
38
40
  "list_quant_modes",
41
+ "list_convertible_components",
39
42
  "CONTRACT_VERSION",
40
43
  "compose_out_name",
44
+ "compose_component_name",
41
45
  "lora_names_from_params",
42
46
  "convert",
43
47
  "build_pipeline",
44
48
  "CoreMLUNet",
49
+ "CoreMLVAE",
50
+ "CoreMLTextEncoder",
45
51
  ]
46
52
 
47
53
 
@@ -91,9 +97,20 @@ def list_quant_modes() -> list[str]:
91
97
  return list(QUANT_NBITS_VALUES)
92
98
 
93
99
 
100
+ def list_convertible_components() -> list[str]:
101
+ """Convertible components, e.g. ``["unet", "vae_decoder", ...]``.
102
+
103
+ ``"unet"`` is the historical default; the rest are the VAE / text-encoder
104
+ extension. ``"text_encoder_2"`` is only meaningful for SDXL — validity per
105
+ model version is enforced at convert time, not advertised here.
106
+ """
107
+ return list(CONVERTIBLE_COMPONENTS)
108
+
109
+
94
110
  # Discovery-contract version. Bump per the additive-only rules in this module's
95
111
  # docstring and CONVERTER_EXTRACTION_SPEC.md "Interface contract".
96
- CONTRACT_VERSION = "1.0"
112
+ # 1.1: added list_convertible_components (VAE + text-encoder conversion).
113
+ CONTRACT_VERSION = "1.1"
97
114
 
98
115
 
99
116
  def __getattr__(name):
@@ -107,7 +124,7 @@ def __getattr__(name):
107
124
  from coreml_diffusion.convert import convert as _convert
108
125
 
109
126
  return _convert
110
- if name in ("build_pipeline", "CoreMLUNet"):
127
+ if name in ("build_pipeline", "CoreMLUNet", "CoreMLVAE", "CoreMLTextEncoder"):
111
128
  from coreml_diffusion import inference
112
129
 
113
130
  return getattr(inference, name)
@@ -37,6 +37,7 @@ def _convert_cmd(args):
37
37
  ckpt,
38
38
  coreml_diffusion.ModelVersion[args.model_version],
39
39
  args.out,
40
+ component=args.component,
40
41
  batch_size=args.batch_size,
41
42
  sample_size=sample_size,
42
43
  controlnet_support=args.controlnet,
@@ -99,6 +100,13 @@ def build_parser():
99
100
  choices=coreml_diffusion.list_model_versions(include_experimental=True),
100
101
  help="Model architecture (verified: SD15, SDXL; experimental otherwise)",
101
102
  )
103
+ conv.add_argument(
104
+ "--component",
105
+ choices=coreml_diffusion.list_convertible_components(),
106
+ default="unet",
107
+ help="Checkpoint component to convert (default unet). VAE/text-encoder "
108
+ "components ignore --attn-impl/--controlnet/--lora. text_encoder_2 is SDXL-only",
109
+ )
102
110
  conv.add_argument("--out", required=True, help="Output .mlpackage path to write")
103
111
  conv.add_argument(
104
112
  "--height", type=int, default=512, help="Target image height (default 512)"
@@ -0,0 +1,32 @@
1
+ """Convertible model components — framework-free leaf.
2
+
3
+ A single-file checkpoint bundles several sub-models; ``coreml_diffusion`` can
4
+ convert each into its own ``.mlpackage``. This enum is the canonical identifier
5
+ set, mirrored into the discovery contract (``list_convertible_components``) and
6
+ the naming contract (``compose_component_name``).
7
+
8
+ ``UNET`` keeps its historical conversion path and filename (``compose_out_name``);
9
+ the other components are the additive VAE / text-encoder extension. ``.value`` is
10
+ the wire string used by the CLI ``--component`` flag and the discovery list — kept
11
+ lowercase and stable, since a saved workflow / benchmark manifest references it
12
+ verbatim (same additive-only rule as ``ModelVersion``).
13
+
14
+ ``TEXT_ENCODER_2`` is SDXL-only (its second CLIP, ``CLIPTextModelWithProjection``);
15
+ on SD1.5 only ``TEXT_ENCODER`` exists. Validity per model version is enforced at
16
+ convert time, not encoded here.
17
+ """
18
+
19
+ from enum import Enum
20
+
21
+
22
+ class Component(Enum):
23
+ UNET = "unet"
24
+ VAE_DECODER = "vae_decoder"
25
+ VAE_ENCODER = "vae_encoder"
26
+ TEXT_ENCODER = "text_encoder"
27
+ TEXT_ENCODER_2 = "text_encoder_2"
28
+
29
+
30
+ # Declaration order is the discovery-list order; UNET leads to match the
31
+ # historical, primary conversion path.
32
+ CONVERTIBLE_COMPONENTS = tuple(c.value for c in Component)
@@ -120,9 +120,13 @@ def _attention_forward(
120
120
  input_ndim = hidden_states.ndim
121
121
  if input_ndim == 4:
122
122
  batch_size, channel, height, width = hidden_states.shape
123
- hidden_states = hidden_states.view(
124
- batch_size, channel, height * width
125
- ).transpose(1, 2)
123
+ # flatten(2) instead of view(B, C, height * width): the explicit
124
+ # height * width multiplies two traced size ints, emitting an aten::mul ->
125
+ # aten::Int that coremltools 9 cannot fold to a const (it fails the
126
+ # conversion). flatten collapses the spatial dims with a single reshape and
127
+ # no symbolic product. Only the 4D path (VAE self-attention) hits this; the
128
+ # UNet routes attention through a 3D tensor, so ORIGINAL there is untouched.
129
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
126
130
  else:
127
131
  batch_size, _, channel = hidden_states.shape
128
132
  height = None
@@ -0,0 +1,85 @@
1
+ """CLIP text-encoder wrapper adapting it to a flat Core ML tensor contract.
2
+
3
+ Wraps a transformers ``CLIPTextModel`` (SD1.5, SDXL encoder 1) or
4
+ ``CLIPTextModelWithProjection`` (SDXL encoder 2) so the traced graph takes a
5
+ single ``input_ids`` tensor ``(B, 77)`` and returns the embeddings the diffusion
6
+ pipeline consumes — nothing else.
7
+
8
+ Which hidden state and whether a pooled vector is emitted depends on the model:
9
+
10
+ SD1.5 : final ``last_hidden_state`` (768), no pooled
11
+ SDXL encoder 1 : penultimate ``hidden_states[-2]`` (768), no pooled
12
+ SDXL encoder 2 : penultimate ``hidden_states[-2]`` (1280) + projected pooled (1280)
13
+
14
+ The penultimate selection is SDXL's documented behaviour (it concatenates the
15
+ two encoders' penultimate states and uses encoder 2's projected pooled output as
16
+ the ``add_embeds`` conditioning). ``hidden_states_index=None`` selects the final
17
+ ``last_hidden_state``; an int indexes ``hidden_states`` directly.
18
+
19
+ The pooled vector prefers ``text_embeds`` (the projection head, encoder 2) and
20
+ falls back to ``pooler_output`` — so the same wrapper serves both CLIP variants.
21
+ ``input_ids`` is fed as int32 at the Core ML boundary (see ``convert``).
22
+ """
23
+
24
+ import contextlib
25
+
26
+ import torch
27
+
28
+
29
+ @contextlib.contextmanager
30
+ def static_causal_mask(seq_len):
31
+ """Patch CLIP's causal-mask builder to a constant for the trace duration.
32
+
33
+ transformers builds the causal mask from ``query_length + past_key_values_length``,
34
+ both traced 0-dim size tensors; the resulting ``aten::Int`` cannot be folded to a
35
+ const under coremltools 9 (conversion fails). Since the converted sequence length
36
+ is fixed at trace time, swap ``_create_4d_causal_attention_mask`` for a closure
37
+ that materialises the upper-triangular ``-inf`` mask from a Python-int ``seq_len``
38
+ — a pure constant the frontend folds away. Mirrors ``prepare_unet_for_coreml_trace``:
39
+ a trace-only enablement patch, restored on exit. Shape ``(1, 1, seq, seq)``
40
+ broadcasts over batch/heads, so no symbolic batch leaks back in.
41
+ """
42
+ from transformers.models.clip import modeling_clip
43
+
44
+ original = modeling_clip._create_4d_causal_attention_mask
45
+
46
+ def _const_mask(input_shape, dtype, device=None, *args, **kwargs):
47
+ mask = torch.full(
48
+ (seq_len, seq_len), torch.finfo(dtype).min, dtype=dtype, device=device
49
+ )
50
+ return torch.triu(mask, diagonal=1)[None, None]
51
+
52
+ modeling_clip._create_4d_causal_attention_mask = _const_mask
53
+ try:
54
+ yield
55
+ finally:
56
+ modeling_clip._create_4d_causal_attention_mask = original
57
+
58
+
59
+ class CoreMLTextEncoderWrapper(torch.nn.Module):
60
+ """token ids ``(B, 77)`` -> embeddings (+ optional pooled)."""
61
+
62
+ def __init__(self, text_encoder, *, hidden_states_index=None, output_pooled=False):
63
+ super().__init__()
64
+ self.text_encoder = text_encoder
65
+ self.hidden_states_index = hidden_states_index
66
+ self.output_pooled = output_pooled
67
+
68
+ def forward(self, input_ids):
69
+ out = self.text_encoder(
70
+ input_ids,
71
+ output_hidden_states=self.hidden_states_index is not None,
72
+ return_dict=True,
73
+ )
74
+ if self.hidden_states_index is None:
75
+ embeds = out.last_hidden_state
76
+ else:
77
+ embeds = out.hidden_states[self.hidden_states_index]
78
+
79
+ if not self.output_pooled:
80
+ return embeds
81
+
82
+ pooled = getattr(out, "text_embeds", None)
83
+ if pooled is None:
84
+ pooled = out.pooler_output
85
+ return embeds, pooled
@@ -0,0 +1,49 @@
1
+ """VAE wrappers adapting ``AutoencoderKL`` to a flat Core ML tensor contract.
2
+
3
+ Two thin modules, one per direction, mirroring ``CoreMLUNetWrapper``: they expose
4
+ a single positional tensor in / single tensor out so the traced graph has a clean,
5
+ named Core ML I/O signature. They call the VAE *submodules* directly
6
+ (``post_quant_conv``/``decoder``, ``encoder``/``quant_conv``) rather than
7
+ ``decode``/``encode`` — the same op graph, but free of the ``return_dict``
8
+ plumbing and the ``DiagonalGaussianDistribution`` wrapper that complicate tracing.
9
+
10
+ Scaling is intentionally NOT baked in. The pipeline owns ``scaling_factor``
11
+ (``latent = latent / scaling_factor`` before decode, ``moments`` -> distribution
12
+ -> sample -> ``* scaling_factor`` after encode), keeping these artifacts 1:1 with
13
+ the reference VAE.
14
+
15
+ The mid-block self-attention is converted via the ORIGINAL (full, fp32-score)
16
+ processor; see ``convert.convert_vae_*``.
17
+ """
18
+
19
+ import torch
20
+
21
+
22
+ class CoreMLVAEDecoderWrapper(torch.nn.Module):
23
+ """latent ``(B, latent_channels, h, w)`` -> image ``(B, 3, h*8, w*8)``."""
24
+
25
+ def __init__(self, vae):
26
+ super().__init__()
27
+ self.vae = vae
28
+
29
+ def forward(self, latent):
30
+ z = self.vae.post_quant_conv(latent)
31
+ return self.vae.decoder(z)
32
+
33
+
34
+ class CoreMLVAEEncoderWrapper(torch.nn.Module):
35
+ """image ``(B, 3, h*8, w*8)`` -> latent moments ``(B, 2*latent_channels, h, w)``.
36
+
37
+ Outputs the raw moments (mean ‖ logvar) exactly as ``AutoencoderKL.encode``
38
+ produces them before wrapping in a ``DiagonalGaussianDistribution``. Sampling
39
+ (mean + std·noise) is deferred to the pipeline so the converted encoder stays
40
+ deterministic and noise-source agnostic.
41
+ """
42
+
43
+ def __init__(self, vae):
44
+ super().__init__()
45
+ self.vae = vae
46
+
47
+ def forward(self, image):
48
+ h = self.vae.encoder(image)
49
+ return self.vae.quant_conv(h)
@@ -19,10 +19,19 @@ import torch
19
19
  from diffusers import UNet2DConditionModel
20
20
 
21
21
  from coreml_diffusion.attention import ATTENTION_IMPLEMENTATIONS
22
+ from coreml_diffusion.component import Component
22
23
  from coreml_diffusion.conversion.attention import apply_attention_implementation
23
24
  from coreml_diffusion.conversion.shapes import conv2d_output_shape
25
+ from coreml_diffusion.conversion.text_encoder import (
26
+ CoreMLTextEncoderWrapper,
27
+ static_causal_mask,
28
+ )
24
29
  from coreml_diffusion.conversion.trace import prepare_unet_for_coreml_trace
25
30
  from coreml_diffusion.conversion.unet import CoreMLUNetWrapper
31
+ from coreml_diffusion.conversion.vae import (
32
+ CoreMLVAEDecoderWrapper,
33
+ CoreMLVAEEncoderWrapper,
34
+ )
26
35
  from coreml_diffusion.logger import logger
27
36
  from coreml_diffusion.model_version import ModelVersion
28
37
 
@@ -274,9 +283,17 @@ def convert_unet(
274
283
  del traced_unet
275
284
  gc.collect()
276
285
 
286
+ _palettize_and_save(coreml_unet, unet_out_path, quantize_nbits, "unet")
287
+
288
+
289
+ def _palettize_and_save(coreml_model, out_path, quantize_nbits, label):
290
+ """Optionally k-means palettize the weights, then save the ``.mlpackage``.
291
+
292
+ The default path (``quantize_nbits="none"``) saves the model untouched. Shared
293
+ by the UNet and the VAE / text-encoder conversions so every component gets the
294
+ same opt-in palettization behaviour and filename-driven cache semantics.
295
+ """
277
296
  if quantize_nbits != "none":
278
- # Opt-in k-means weight palettization. The default path
279
- # (quantize_nbits="none") leaves the traced UNet untouched.
280
297
  from coremltools.optimize.coreml import (
281
298
  OpPalettizerConfig,
282
299
  OptimizationConfig,
@@ -284,16 +301,144 @@ def convert_unet(
284
301
  )
285
302
 
286
303
  nbits = int(quantize_nbits)
287
- logger.info(f"Palettizing UNet weights to {nbits}-bit (kmeans)..")
304
+ logger.info(f"Palettizing {label} weights to {nbits}-bit (kmeans)..")
288
305
  t0 = time.time()
289
306
  cfg = OptimizationConfig(
290
307
  global_config=OpPalettizerConfig(mode="kmeans", nbits=nbits)
291
308
  )
292
- coreml_unet = palettize_weights(coreml_unet, config=cfg)
309
+ coreml_model = palettize_weights(coreml_model, config=cfg)
293
310
  logger.info(f"Palettization took {time.time() - t0:.1f}s")
294
311
 
295
- coreml_unet.save(unet_out_path)
296
- logger.info(f"Saved unet into {unet_out_path}")
312
+ coreml_model.save(out_path)
313
+ logger.info(f"Saved {label} into {out_path}")
314
+
315
+
316
+ def convert_vae_decoder(
317
+ ref_vae,
318
+ out_path: str,
319
+ *,
320
+ batch_size: int = 1,
321
+ sample_size: tuple[int, int] = (64, 64),
322
+ quantize_nbits: str = "none",
323
+ ):
324
+ """Convert ``AutoencoderKL``'s decoder (latent -> image) to a ``.mlpackage``.
325
+
326
+ ``sample_size`` is the *latent* spatial size (H/8, W/8); the image output is
327
+ 8x that. The mid-block self-attention is routed through the ORIGINAL
328
+ (full, fp32-score) processor — diffusers' stock sdpa attention fails to
329
+ convert under coremltools 9, the same reason ``ORIGINAL`` exists for the UNet.
330
+ """
331
+ apply_attention_implementation(ref_vae, "ORIGINAL")
332
+ wrapper = CoreMLVAEDecoderWrapper(ref_vae.eval()).eval()
333
+
334
+ latent_shape = (
335
+ batch_size,
336
+ ref_vae.config.latent_channels,
337
+ sample_size[0],
338
+ sample_size[1],
339
+ )
340
+ example = torch.rand(*latent_shape)
341
+
342
+ logger.info(f"JIT tracing VAE decoder (latent {latent_shape})..")
343
+ traced = torch.jit.trace(wrapper, example)
344
+
345
+ inputs = [ct.TensorType(name="latent", shape=latent_shape, dtype=np.float16)]
346
+ coreml_model = convert_to_coreml("vae_decoder", traced, inputs, ["image"], out_path)
347
+ del traced
348
+ gc.collect()
349
+ _palettize_and_save(coreml_model, out_path, quantize_nbits, "vae_decoder")
350
+
351
+
352
+ def convert_vae_encoder(
353
+ ref_vae,
354
+ out_path: str,
355
+ *,
356
+ batch_size: int = 1,
357
+ sample_size: tuple[int, int] = (64, 64),
358
+ quantize_nbits: str = "none",
359
+ ):
360
+ """Convert ``AutoencoderKL``'s encoder (image -> latent moments) to a ``.mlpackage``.
361
+
362
+ ``sample_size`` is the latent spatial size; the image input is 8x that. The
363
+ output ``latent_moments`` is the raw mean‖logvar tensor (2*latent_channels
364
+ channels) — the pipeline samples from it. Attention handling matches the
365
+ decoder.
366
+ """
367
+ apply_attention_implementation(ref_vae, "ORIGINAL")
368
+ wrapper = CoreMLVAEEncoderWrapper(ref_vae.eval()).eval()
369
+
370
+ image_shape = (batch_size, 3, sample_size[0] * 8, sample_size[1] * 8)
371
+ example = torch.rand(*image_shape)
372
+
373
+ logger.info(f"JIT tracing VAE encoder (image {image_shape})..")
374
+ traced = torch.jit.trace(wrapper, example)
375
+
376
+ inputs = [ct.TensorType(name="image", shape=image_shape, dtype=np.float16)]
377
+ coreml_model = convert_to_coreml(
378
+ "vae_encoder", traced, inputs, ["latent_moments"], out_path
379
+ )
380
+ del traced
381
+ gc.collect()
382
+ _palettize_and_save(coreml_model, out_path, quantize_nbits, "vae_encoder")
383
+
384
+
385
+ def convert_text_encoder(
386
+ ckpt_path: str,
387
+ model_version: ModelVersion,
388
+ out_path: str,
389
+ *,
390
+ which: int = 1,
391
+ batch_size: int = 1,
392
+ quantize_nbits: str = "none",
393
+ ):
394
+ """Convert a CLIP text encoder (token ids -> embeddings) to a ``.mlpackage``.
395
+
396
+ ``which`` selects the encoder: 1 = primary (SD1.5/SDXL), 2 = SDXL's second
397
+ (``CLIPTextModelWithProjection``). SDXL uses the penultimate hidden state from
398
+ both encoders plus encoder 2's projected pooled output; SD1.5 uses encoder 1's
399
+ final ``last_hidden_state``. ``input_ids`` crosses the Core ML boundary as
400
+ int32 over the fixed 77-token sequence.
401
+ """
402
+ encoders = load_text_encoders(ckpt_path, model_version)
403
+ if which == 2:
404
+ if len(encoders) < 2:
405
+ raise ValueError(
406
+ f"{model_version.name} has no second text encoder "
407
+ "(text_encoder_2 is SDXL-only)."
408
+ )
409
+ encoder = encoders[1]
410
+ hidden_states_index = -2
411
+ output_pooled = True
412
+ output_names = ["hidden_states", "pooled_embeds"]
413
+ else:
414
+ encoder = encoders[0]
415
+ is_sdxl = model_version in {ModelVersion.SDXL, ModelVersion.SDXL_REFINER}
416
+ hidden_states_index = -2 if is_sdxl else None
417
+ output_pooled = False
418
+ output_names = ["hidden_states"]
419
+
420
+ wrapper = CoreMLTextEncoderWrapper(
421
+ encoder.eval(),
422
+ hidden_states_index=hidden_states_index,
423
+ output_pooled=output_pooled,
424
+ ).eval()
425
+
426
+ ids_shape = (batch_size, TEXT_TOKEN_SEQUENCE_LENGTH)
427
+ # Trace with int64 (torch embedding gather needs long); declare int32 at the
428
+ # Core ML boundary (coremltools casts the gather index op).
429
+ example = torch.zeros(ids_shape, dtype=torch.int64)
430
+
431
+ logger.info(f"JIT tracing text encoder {which} (input_ids {ids_shape})..")
432
+ with static_causal_mask(TEXT_TOKEN_SEQUENCE_LENGTH):
433
+ traced = torch.jit.trace(wrapper, example)
434
+
435
+ inputs = [ct.TensorType(name="input_ids", shape=ids_shape, dtype=np.int32)]
436
+ coreml_model = convert_to_coreml(
437
+ f"text_encoder_{which}", traced, inputs, output_names, out_path
438
+ )
439
+ del traced
440
+ gc.collect()
441
+ _palettize_and_save(coreml_model, out_path, quantize_nbits, f"text_encoder_{which}")
297
442
 
298
443
 
299
444
  def convert(
@@ -301,6 +446,7 @@ def convert(
301
446
  model_version: ModelVersion,
302
447
  out_path: str,
303
448
  *,
449
+ component: str = Component.UNET.value,
304
450
  batch_size: int = 1,
305
451
  sample_size: tuple[int, int] = (64, 64),
306
452
  controlnet_support: bool = False,
@@ -309,16 +455,53 @@ def convert(
309
455
  config_path: str = None,
310
456
  quantize_nbits: str = "none",
311
457
  ):
312
- """Convert a single-file checkpoint's UNet to a Core ML ``.mlpackage``.
313
-
314
- Keyword-only past the three required positionals so the package can add
315
- capabilities (new keyword args) without breaking an older caller — the
316
- versioned interface contract. Writes ``out_path``; returns None.
458
+ """Convert a single-file checkpoint component to a Core ML ``.mlpackage``.
459
+
460
+ ``component`` selects what to convert (default ``"unet"`` historical
461
+ behaviour, all UNet-only kwargs apply). ``vae_decoder`` / ``vae_encoder`` /
462
+ ``text_encoder`` / ``text_encoder_2`` convert the corresponding sub-model and
463
+ ignore the UNet-only kwargs (``controlnet_support``, ``lora_weights``,
464
+ ``attn_impl``). Keyword-only past the three required positionals so the package
465
+ can add capabilities without breaking an older caller. Writes ``out_path``;
466
+ returns None.
317
467
  """
318
468
  if os.path.exists(out_path):
319
469
  logger.info(f"Found existing model at {out_path}! Skipping..")
320
470
  return
321
471
 
472
+ comp = Component(component)
473
+
474
+ if comp is Component.VAE_DECODER:
475
+ convert_vae_decoder(
476
+ load_vae(ckpt_path),
477
+ out_path,
478
+ batch_size=batch_size,
479
+ sample_size=sample_size,
480
+ quantize_nbits=quantize_nbits,
481
+ )
482
+ return
483
+
484
+ if comp is Component.VAE_ENCODER:
485
+ convert_vae_encoder(
486
+ load_vae(ckpt_path),
487
+ out_path,
488
+ batch_size=batch_size,
489
+ sample_size=sample_size,
490
+ quantize_nbits=quantize_nbits,
491
+ )
492
+ return
493
+
494
+ if comp in {Component.TEXT_ENCODER, Component.TEXT_ENCODER_2}:
495
+ convert_text_encoder(
496
+ ckpt_path,
497
+ model_version,
498
+ out_path,
499
+ which=2 if comp is Component.TEXT_ENCODER_2 else 1,
500
+ batch_size=batch_size,
501
+ quantize_nbits=quantize_nbits,
502
+ )
503
+ return
504
+
322
505
  if attn_impl not in ATTENTION_IMPLEMENTATIONS:
323
506
  raise ValueError(
324
507
  f"Unsupported attention implementation {attn_impl!r}. "
@@ -350,3 +533,42 @@ def load_unet(ckpt_path, config_path):
350
533
  ckpt_path,
351
534
  original_config=config_path,
352
535
  )
536
+
537
+
538
+ def load_vae(ckpt_path):
539
+ """Load ``AutoencoderKL`` from a single-file checkpoint."""
540
+ from diffusers import AutoencoderKL
541
+
542
+ return AutoencoderKL.from_single_file(ckpt_path)
543
+
544
+
545
+ # Pipeline class whose ``from_single_file`` extracts the text encoder(s) for a
546
+ # given model version. SDXL pipelines expose ``text_encoder_2``; SD15/LCM do not.
547
+ _TEXT_ENCODER_PIPELINE = {
548
+ ModelVersion.SD15: ("diffusers", "StableDiffusionPipeline"),
549
+ ModelVersion.LCM: ("diffusers", "StableDiffusionPipeline"),
550
+ ModelVersion.SDXL: ("diffusers", "StableDiffusionXLPipeline"),
551
+ ModelVersion.SDXL_REFINER: ("diffusers", "StableDiffusionXLPipeline"),
552
+ }
553
+
554
+
555
+ def load_text_encoders(ckpt_path, model_version):
556
+ """Return the checkpoint's CLIP text encoder(s) as a list (1 for SD1.5, 2 for SDXL).
557
+
558
+ Loads the matching diffusers pipeline from the single-file checkpoint and
559
+ extracts its text encoder(s). This pulls the whole pipeline (UNet/VAE
560
+ included) — acceptable for an offline, one-shot conversion and the most robust
561
+ way to recover correctly-configured CLIP weights from a single file.
562
+ """
563
+ import importlib
564
+
565
+ if model_version not in _TEXT_ENCODER_PIPELINE:
566
+ raise ValueError(f"No text-encoder pipeline mapped for {model_version!r}.")
567
+ module_name, class_name = _TEXT_ENCODER_PIPELINE[model_version]
568
+ pipeline_cls = getattr(importlib.import_module(module_name), class_name)
569
+
570
+ pipe = pipeline_cls.from_single_file(ckpt_path)
571
+ encoders = [pipe.text_encoder]
572
+ if getattr(pipe, "text_encoder_2", None) is not None:
573
+ encoders.append(pipe.text_encoder_2)
574
+ return encoders