coreml-diffusion 0.1.2__tar.gz → 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. coreml_diffusion-0.1.4/.release-please-manifest.json +3 -0
  2. coreml_diffusion-0.1.4/CHANGELOG.md +34 -0
  3. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/PKG-INFO +7 -1
  4. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/README.md +6 -0
  5. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/coreml_diffusion/__init__.py +19 -2
  6. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/coreml_diffusion/cli.py +8 -0
  7. coreml_diffusion-0.1.4/coreml_diffusion/component.py +32 -0
  8. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/coreml_diffusion/conversion/attention.py +7 -3
  9. coreml_diffusion-0.1.4/coreml_diffusion/conversion/state_dict.py +20 -0
  10. coreml_diffusion-0.1.4/coreml_diffusion/conversion/text_encoder.py +85 -0
  11. coreml_diffusion-0.1.4/coreml_diffusion/conversion/vae.py +49 -0
  12. coreml_diffusion-0.1.4/coreml_diffusion/convert.py +642 -0
  13. coreml_diffusion-0.1.4/coreml_diffusion/inference.py +395 -0
  14. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/coreml_diffusion/naming.py +56 -0
  15. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/pyproject.toml +1 -1
  16. coreml_diffusion-0.1.4/tests/m2/goldens/sd15_astronaut_full_coreml.png +0 -0
  17. coreml_diffusion-0.1.4/tests/m2/goldens/sd15_astronaut_full_coreml.sha256 +1 -0
  18. coreml_diffusion-0.1.4/tests/m2/test_inference_golden.py +152 -0
  19. coreml_diffusion-0.1.4/tests/smoke/test_coreml_adapters.py +163 -0
  20. coreml_diffusion-0.1.4/tests/smoke/test_lcm_conversion.py +108 -0
  21. coreml_diffusion-0.1.4/tests/smoke/test_synthetic_text_encoder.py +98 -0
  22. coreml_diffusion-0.1.4/tests/smoke/test_synthetic_vae.py +77 -0
  23. coreml_diffusion-0.1.4/tests/unit/test_characterization_component_name.py +154 -0
  24. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/tests/unit/test_cli.py +36 -0
  25. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/tests/unit/test_discovery_api.py +28 -1
  26. coreml_diffusion-0.1.4/tests/unit/test_inference_output_contract.py +43 -0
  27. coreml_diffusion-0.1.4/tests/unit/test_state_dict_layout.py +43 -0
  28. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/uv.lock +1 -1
  29. coreml_diffusion-0.1.2/.release-please-manifest.json +0 -3
  30. coreml_diffusion-0.1.2/CHANGELOG.md +0 -8
  31. coreml_diffusion-0.1.2/coreml_diffusion/convert.py +0 -352
  32. coreml_diffusion-0.1.2/coreml_diffusion/inference.py +0 -176
  33. coreml_diffusion-0.1.2/tests/m2/test_inference_golden.py +0 -111
  34. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/.github/workflows/publish-pypi.yml +0 -0
  35. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/.github/workflows/release-please.yml +0 -0
  36. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/.github/workflows/tier0.yml +0 -0
  37. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/.github/workflows/tier1.yml +0 -0
  38. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/.github/workflows/tier2.yml +0 -0
  39. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/.gitignore +0 -0
  40. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/LICENSE +0 -0
  41. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/coreml_diffusion/attention.py +0 -0
  42. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/coreml_diffusion/conversion/__init__.py +0 -0
  43. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/coreml_diffusion/conversion/shapes.py +0 -0
  44. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/coreml_diffusion/conversion/trace.py +0 -0
  45. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/coreml_diffusion/conversion/unet.py +0 -0
  46. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/coreml_diffusion/logger.py +0 -0
  47. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/coreml_diffusion/model_version.py +0 -0
  48. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/coreml_diffusion/sources.py +0 -0
  49. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/release-please-config.json +0 -0
  50. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/tests/conftest.py +0 -0
  51. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/tests/m2/goldens/sd15_astronaut.png +0 -0
  52. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/tests/m2/goldens/sd15_astronaut.sha256 +0 -0
  53. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/tests/m2/test_original_gpu.py +0 -0
  54. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/tests/smoke/test_original_attention.py +0 -0
  55. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/tests/smoke/test_split_einsum_attention.py +0 -0
  56. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/tests/smoke/test_synthetic_unet.py +0 -0
  57. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/tests/unit/test_characterization_out_name.py +0 -0
  58. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/tests/unit/test_conversion_helpers.py +0 -0
  59. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/tests/unit/test_sources.py +0 -0
  60. {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.4}/tests/unit/test_tier0_purity.py +0 -0
@@ -0,0 +1,3 @@
1
+ {
2
+ ".": "0.1.4"
3
+ }
@@ -0,0 +1,34 @@
1
+ # Changelog
2
+
3
+ ## [0.1.4](https://github.com/aszc-dev/coreml-diffusion/compare/v0.1.3...v0.1.4) (2026-06-13)
4
+
5
+
6
+ ### 🐛 Bug Fixes
7
+
8
+ * **convert:** generic LCM conversion for arbitrary checkpoints ([bedeb49](https://github.com/aszc-dev/coreml-diffusion/commit/bedeb49a84a33b4530c6c7d4ed4343a914a400e2))
9
+ * **inference:** add output_dtype to CoreMLTextEncoder ([eb6d3b5](https://github.com/aszc-dev/coreml-diffusion/commit/eb6d3b5b08b6ff34a4eb8683de3d1223fb447186))
10
+
11
+ ## [0.1.3](https://github.com/aszc-dev/coreml-diffusion/compare/v0.1.2...v0.1.3) (2026-06-04)
12
+
13
+
14
+ ### ✨ Features
15
+
16
+ * **convert:** add VAE and CLIP text-encoder conversion ([dc1f85b](https://github.com/aszc-dev/coreml-diffusion/commit/dc1f85bafe50d36655ff7ece0c052a30fd77bb81))
17
+ * **inference:** end-to-end Core ML pipeline (VAE + text-encoder swap) ([ca08b16](https://github.com/aszc-dev/coreml-diffusion/commit/ca08b16729529afbdf610d0e8ec2d09b849080c6))
18
+
19
+
20
+ ### 🐛 Bug Fixes
21
+
22
+ * **inference:** expose .device on the Core ML adapters ([30a673e](https://github.com/aszc-dev/coreml-diffusion/commit/30a673eebe3927722214d0ab6a44fbc344d18f3a))
23
+
24
+
25
+ ### 📚 Documentation
26
+
27
+ * **readme:** link the log.aszc.dev energy benchmark writeup ([77927b5](https://github.com/aszc-dev/coreml-diffusion/commit/77927b5dd5f1311a3b3c317692f3a347e3976a54))
28
+
29
+ ## [0.1.2](https://github.com/aszc-dev/coreml-diffusion/compare/v0.1.1...v0.1.2) (2026-05-27)
30
+
31
+
32
+ ### 🐛 Bug Fixes
33
+
34
+ * **attention:** convertible fp32 ORIGINAL attention for the Core ML GPU path ([#2](https://github.com/aszc-dev/coreml-diffusion/issues/2)) ([28e56fc](https://github.com/aszc-dev/coreml-diffusion/commit/28e56fcf8c2242ebbe4c05abd05f7e796069d7d1))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: coreml-diffusion
3
- Version: 0.1.2
3
+ Version: 0.1.4
4
4
  Summary: Convert diffusion-model checkpoints (SD1.5/SDXL) to Core ML for Apple Neural Engine — framework-free, ComfyUI-independent.
5
5
  Project-URL: Homepage, https://github.com/aszc-dev/coreml-diffusion
6
6
  Project-URL: Repository, https://github.com/aszc-dev/coreml-diffusion
@@ -44,6 +44,12 @@ GPU-free, embeddable in a Swift/iOS app. ANE is the differentiator — this is a
44
44
  feasibility and power efficiency for SD1.5/SDXL on ANE, not a raw-throughput claim
45
45
  against desktop GPUs.
46
46
 
47
+ The power-efficiency claim is measured: in a cross-backend benchmark the ct9
48
+ converter here runs the SD1.5 UNet on the ANE at **6-7x lower energy** than
49
+ GPU/MPS, at the same speed — see the writeup,
50
+ [The ANE runs the SD1.5 UNet at 6-7x lower energy than GPU/MPS](https://log.aszc.dev/ane-vs-gpu-mps-sd15-unet-energy/),
51
+ for the methodology and the numerical-divergence tradeoff.
52
+
47
53
  The scope is diffusion architectures generally, not Stable Diffusion specifically.
48
54
  The project aims to gather, in one place: the conversion path, a reproducible
49
55
  benchmarking suite for objective comparison, a per-model catalogue documenting the
@@ -15,6 +15,12 @@ GPU-free, embeddable in a Swift/iOS app. ANE is the differentiator — this is a
15
15
  feasibility and power efficiency for SD1.5/SDXL on ANE, not a raw-throughput claim
16
16
  against desktop GPUs.
17
17
 
18
+ The power-efficiency claim is measured: in a cross-backend benchmark the ct9
19
+ converter here runs the SD1.5 UNet on the ANE at **6-7x lower energy** than
20
+ GPU/MPS, at the same speed — see the writeup,
21
+ [The ANE runs the SD1.5 UNet at 6-7x lower energy than GPU/MPS](https://log.aszc.dev/ane-vs-gpu-mps-sd15-unet-energy/),
22
+ for the methodology and the numerical-divergence tradeoff.
23
+
18
24
  The scope is diffusion architectures generally, not Stable Diffusion specifically.
19
25
  The project aims to gather, in one place: the conversion path, a reproducible
20
26
  benchmarking suite for objective comparison, a per-model catalogue documenting the
@@ -23,9 +23,11 @@ because a saved workflow JSON references these strings verbatim.
23
23
  from enum import Enum
24
24
 
25
25
  from coreml_diffusion.attention import ATTENTION_IMPLEMENTATIONS
26
+ from coreml_diffusion.component import CONVERTIBLE_COMPONENTS
26
27
  from coreml_diffusion.model_version import ModelVersion
27
28
  from coreml_diffusion.naming import (
28
29
  QUANT_NBITS_VALUES,
30
+ compose_component_name,
29
31
  compose_out_name,
30
32
  lora_names_from_params,
31
33
  )
@@ -36,12 +38,16 @@ __all__ = [
36
38
  "list_model_versions",
37
39
  "list_attention_impls",
38
40
  "list_quant_modes",
41
+ "list_convertible_components",
39
42
  "CONTRACT_VERSION",
40
43
  "compose_out_name",
44
+ "compose_component_name",
41
45
  "lora_names_from_params",
42
46
  "convert",
43
47
  "build_pipeline",
44
48
  "CoreMLUNet",
49
+ "CoreMLVAE",
50
+ "CoreMLTextEncoder",
45
51
  ]
46
52
 
47
53
 
@@ -91,9 +97,20 @@ def list_quant_modes() -> list[str]:
91
97
  return list(QUANT_NBITS_VALUES)
92
98
 
93
99
 
100
+ def list_convertible_components() -> list[str]:
101
+ """Convertible components, e.g. ``["unet", "vae_decoder", ...]``.
102
+
103
+ ``"unet"`` is the historical default; the rest are the VAE / text-encoder
104
+ extension. ``"text_encoder_2"`` is only meaningful for SDXL — validity per
105
+ model version is enforced at convert time, not advertised here.
106
+ """
107
+ return list(CONVERTIBLE_COMPONENTS)
108
+
109
+
94
110
  # Discovery-contract version. Bump per the additive-only rules in this module's
95
111
  # docstring and CONVERTER_EXTRACTION_SPEC.md "Interface contract".
96
- CONTRACT_VERSION = "1.0"
112
+ # 1.1: added list_convertible_components (VAE + text-encoder conversion).
113
+ CONTRACT_VERSION = "1.1"
97
114
 
98
115
 
99
116
  def __getattr__(name):
@@ -107,7 +124,7 @@ def __getattr__(name):
107
124
  from coreml_diffusion.convert import convert as _convert
108
125
 
109
126
  return _convert
110
- if name in ("build_pipeline", "CoreMLUNet"):
127
+ if name in ("build_pipeline", "CoreMLUNet", "CoreMLVAE", "CoreMLTextEncoder"):
111
128
  from coreml_diffusion import inference
112
129
 
113
130
  return getattr(inference, name)
@@ -37,6 +37,7 @@ def _convert_cmd(args):
37
37
  ckpt,
38
38
  coreml_diffusion.ModelVersion[args.model_version],
39
39
  args.out,
40
+ component=args.component,
40
41
  batch_size=args.batch_size,
41
42
  sample_size=sample_size,
42
43
  controlnet_support=args.controlnet,
@@ -99,6 +100,13 @@ def build_parser():
99
100
  choices=coreml_diffusion.list_model_versions(include_experimental=True),
100
101
  help="Model architecture (verified: SD15, SDXL; experimental otherwise)",
101
102
  )
103
+ conv.add_argument(
104
+ "--component",
105
+ choices=coreml_diffusion.list_convertible_components(),
106
+ default="unet",
107
+ help="Checkpoint component to convert (default unet). VAE/text-encoder "
108
+ "components ignore --attn-impl/--controlnet/--lora. text_encoder_2 is SDXL-only",
109
+ )
102
110
  conv.add_argument("--out", required=True, help="Output .mlpackage path to write")
103
111
  conv.add_argument(
104
112
  "--height", type=int, default=512, help="Target image height (default 512)"
@@ -0,0 +1,32 @@
1
+ """Convertible model components — framework-free leaf.
2
+
3
+ A single-file checkpoint bundles several sub-models; ``coreml_diffusion`` can
4
+ convert each into its own ``.mlpackage``. This enum is the canonical identifier
5
+ set, mirrored into the discovery contract (``list_convertible_components``) and
6
+ the naming contract (``compose_component_name``).
7
+
8
+ ``UNET`` keeps its historical conversion path and filename (``compose_out_name``);
9
+ the other components are the additive VAE / text-encoder extension. ``.value`` is
10
+ the wire string used by the CLI ``--component`` flag and the discovery list — kept
11
+ lowercase and stable, since a saved workflow / benchmark manifest references it
12
+ verbatim (same additive-only rule as ``ModelVersion``).
13
+
14
+ ``TEXT_ENCODER_2`` is SDXL-only (its second CLIP, ``CLIPTextModelWithProjection``);
15
+ on SD1.5 only ``TEXT_ENCODER`` exists. Validity per model version is enforced at
16
+ convert time, not encoded here.
17
+ """
18
+
19
+ from enum import Enum
20
+
21
+
22
+ class Component(Enum):
23
+ UNET = "unet"
24
+ VAE_DECODER = "vae_decoder"
25
+ VAE_ENCODER = "vae_encoder"
26
+ TEXT_ENCODER = "text_encoder"
27
+ TEXT_ENCODER_2 = "text_encoder_2"
28
+
29
+
30
+ # Declaration order is the discovery-list order; UNET leads to match the
31
+ # historical, primary conversion path.
32
+ CONVERTIBLE_COMPONENTS = tuple(c.value for c in Component)
@@ -120,9 +120,13 @@ def _attention_forward(
120
120
  input_ndim = hidden_states.ndim
121
121
  if input_ndim == 4:
122
122
  batch_size, channel, height, width = hidden_states.shape
123
- hidden_states = hidden_states.view(
124
- batch_size, channel, height * width
125
- ).transpose(1, 2)
123
+ # flatten(2) instead of view(B, C, height * width): the explicit
124
+ # height * width multiplies two traced size ints, emitting an aten::mul ->
125
+ # aten::Int that coremltools 9 cannot fold to a const (it fails the
126
+ # conversion). flatten collapses the spatial dims with a single reshape and
127
+ # no symbolic product. Only the 4D path (VAE self-attention) hits this; the
128
+ # UNet routes attention through a 3D tensor, so ORIGINAL there is untouched.
129
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
126
130
  else:
127
131
  batch_size, _, channel = hidden_states.shape
128
132
  height = None
@@ -0,0 +1,20 @@
1
+ """State-dict layout predicates — framework-free (no coremltools/diffusers).
2
+
3
+ Single-file checkpoints come in two layouts: original LDM (UNet keys under
4
+ ``model.diffusion_model.``) and diffusers-native UNet-only dumps (block keys at
5
+ the top level — e.g. ``LCM_Dreamshaper_v7_4k.safetensors``, the canonical
6
+ full-distill LCM artifact). diffusers' ``from_single_file`` only understands
7
+ the former and raises ``SingleFileComponentError`` on the latter;
8
+ ``convert.load_unet`` routes on this predicate.
9
+ """
10
+
11
+ DIFFUSERS_UNET_KEY_PREFIXES = ("down_blocks.", "up_blocks.", "mid_block.")
12
+ LDM_UNET_KEY_PREFIX = "model.diffusion_model."
13
+
14
+
15
+ def is_diffusers_unet_layout(keys) -> bool:
16
+ """True when ``keys`` form a diffusers-format UNet-only state dict."""
17
+ keys = list(keys)
18
+ has_diffusers_blocks = any(k.startswith(DIFFUSERS_UNET_KEY_PREFIXES) for k in keys)
19
+ has_ldm_prefix = any(k.startswith(LDM_UNET_KEY_PREFIX) for k in keys)
20
+ return has_diffusers_blocks and not has_ldm_prefix
@@ -0,0 +1,85 @@
1
+ """CLIP text-encoder wrapper adapting it to a flat Core ML tensor contract.
2
+
3
+ Wraps a transformers ``CLIPTextModel`` (SD1.5, SDXL encoder 1) or
4
+ ``CLIPTextModelWithProjection`` (SDXL encoder 2) so the traced graph takes a
5
+ single ``input_ids`` tensor ``(B, 77)`` and returns the embeddings the diffusion
6
+ pipeline consumes — nothing else.
7
+
8
+ Which hidden state and whether a pooled vector is emitted depends on the model:
9
+
10
+ SD1.5 : final ``last_hidden_state`` (768), no pooled
11
+ SDXL encoder 1 : penultimate ``hidden_states[-2]`` (768), no pooled
12
+ SDXL encoder 2 : penultimate ``hidden_states[-2]`` (1280) + projected pooled (1280)
13
+
14
+ The penultimate selection is SDXL's documented behaviour (it concatenates the
15
+ two encoders' penultimate states and uses encoder 2's projected pooled output as
16
+ the ``add_embeds`` conditioning). ``hidden_states_index=None`` selects the final
17
+ ``last_hidden_state``; an int indexes ``hidden_states`` directly.
18
+
19
+ The pooled vector prefers ``text_embeds`` (the projection head, encoder 2) and
20
+ falls back to ``pooler_output`` — so the same wrapper serves both CLIP variants.
21
+ ``input_ids`` is fed as int32 at the Core ML boundary (see ``convert``).
22
+ """
23
+
24
+ import contextlib
25
+
26
+ import torch
27
+
28
+
29
+ @contextlib.contextmanager
30
+ def static_causal_mask(seq_len):
31
+ """Patch CLIP's causal-mask builder to a constant for the trace duration.
32
+
33
+ transformers builds the causal mask from ``query_length + past_key_values_length``,
34
+ both traced 0-dim size tensors; the resulting ``aten::Int`` cannot be folded to a
35
+ const under coremltools 9 (conversion fails). Since the converted sequence length
36
+ is fixed at trace time, swap ``_create_4d_causal_attention_mask`` for a closure
37
+ that materialises the upper-triangular ``-inf`` mask from a Python-int ``seq_len``
38
+ — a pure constant the frontend folds away. Mirrors ``prepare_unet_for_coreml_trace``:
39
+ a trace-only enablement patch, restored on exit. Shape ``(1, 1, seq, seq)``
40
+ broadcasts over batch/heads, so no symbolic batch leaks back in.
41
+ """
42
+ from transformers.models.clip import modeling_clip
43
+
44
+ original = modeling_clip._create_4d_causal_attention_mask
45
+
46
+ def _const_mask(input_shape, dtype, device=None, *args, **kwargs):
47
+ mask = torch.full(
48
+ (seq_len, seq_len), torch.finfo(dtype).min, dtype=dtype, device=device
49
+ )
50
+ return torch.triu(mask, diagonal=1)[None, None]
51
+
52
+ modeling_clip._create_4d_causal_attention_mask = _const_mask
53
+ try:
54
+ yield
55
+ finally:
56
+ modeling_clip._create_4d_causal_attention_mask = original
57
+
58
+
59
+ class CoreMLTextEncoderWrapper(torch.nn.Module):
60
+ """token ids ``(B, 77)`` -> embeddings (+ optional pooled)."""
61
+
62
+ def __init__(self, text_encoder, *, hidden_states_index=None, output_pooled=False):
63
+ super().__init__()
64
+ self.text_encoder = text_encoder
65
+ self.hidden_states_index = hidden_states_index
66
+ self.output_pooled = output_pooled
67
+
68
+ def forward(self, input_ids):
69
+ out = self.text_encoder(
70
+ input_ids,
71
+ output_hidden_states=self.hidden_states_index is not None,
72
+ return_dict=True,
73
+ )
74
+ if self.hidden_states_index is None:
75
+ embeds = out.last_hidden_state
76
+ else:
77
+ embeds = out.hidden_states[self.hidden_states_index]
78
+
79
+ if not self.output_pooled:
80
+ return embeds
81
+
82
+ pooled = getattr(out, "text_embeds", None)
83
+ if pooled is None:
84
+ pooled = out.pooler_output
85
+ return embeds, pooled
@@ -0,0 +1,49 @@
1
+ """VAE wrappers adapting ``AutoencoderKL`` to a flat Core ML tensor contract.
2
+
3
+ Two thin modules, one per direction, mirroring ``CoreMLUNetWrapper``: they expose
4
+ a single positional tensor in / single tensor out so the traced graph has a clean,
5
+ named Core ML I/O signature. They call the VAE *submodules* directly
6
+ (``post_quant_conv``/``decoder``, ``encoder``/``quant_conv``) rather than
7
+ ``decode``/``encode`` — the same op graph, but free of the ``return_dict``
8
+ plumbing and the ``DiagonalGaussianDistribution`` wrapper that complicate tracing.
9
+
10
+ Scaling is intentionally NOT baked in. The pipeline owns ``scaling_factor``
11
+ (``latent = latent / scaling_factor`` before decode, ``moments`` -> distribution
12
+ -> sample -> ``* scaling_factor`` after encode), keeping these artifacts 1:1 with
13
+ the reference VAE.
14
+
15
+ The mid-block self-attention is converted via the ORIGINAL (full, fp32-score)
16
+ processor; see ``convert.convert_vae_*``.
17
+ """
18
+
19
+ import torch
20
+
21
+
22
+ class CoreMLVAEDecoderWrapper(torch.nn.Module):
23
+ """latent ``(B, latent_channels, h, w)`` -> image ``(B, 3, h*8, w*8)``."""
24
+
25
+ def __init__(self, vae):
26
+ super().__init__()
27
+ self.vae = vae
28
+
29
+ def forward(self, latent):
30
+ z = self.vae.post_quant_conv(latent)
31
+ return self.vae.decoder(z)
32
+
33
+
34
+ class CoreMLVAEEncoderWrapper(torch.nn.Module):
35
+ """image ``(B, 3, h*8, w*8)`` -> latent moments ``(B, 2*latent_channels, h, w)``.
36
+
37
+ Outputs the raw moments (mean ‖ logvar) exactly as ``AutoencoderKL.encode``
38
+ produces them before wrapping in a ``DiagonalGaussianDistribution``. Sampling
39
+ (mean + std·noise) is deferred to the pipeline so the converted encoder stays
40
+ deterministic and noise-source agnostic.
41
+ """
42
+
43
+ def __init__(self, vae):
44
+ super().__init__()
45
+ self.vae = vae
46
+
47
+ def forward(self, image):
48
+ h = self.vae.encoder(image)
49
+ return self.vae.quant_conv(h)