coreml-diffusion 0.1.2__tar.gz → 0.1.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- coreml_diffusion-0.1.3/.release-please-manifest.json +3 -0
- coreml_diffusion-0.1.3/CHANGELOG.md +26 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/PKG-INFO +7 -1
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/README.md +6 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/__init__.py +19 -2
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/cli.py +8 -0
- coreml_diffusion-0.1.3/coreml_diffusion/component.py +32 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/conversion/attention.py +7 -3
- coreml_diffusion-0.1.3/coreml_diffusion/conversion/text_encoder.py +85 -0
- coreml_diffusion-0.1.3/coreml_diffusion/conversion/vae.py +49 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/convert.py +233 -11
- coreml_diffusion-0.1.3/coreml_diffusion/inference.py +380 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/naming.py +56 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/pyproject.toml +1 -1
- coreml_diffusion-0.1.3/tests/m2/goldens/sd15_astronaut_full_coreml.png +0 -0
- coreml_diffusion-0.1.3/tests/m2/goldens/sd15_astronaut_full_coreml.sha256 +1 -0
- coreml_diffusion-0.1.3/tests/m2/test_inference_golden.py +152 -0
- coreml_diffusion-0.1.3/tests/smoke/test_coreml_adapters.py +139 -0
- coreml_diffusion-0.1.3/tests/smoke/test_synthetic_text_encoder.py +98 -0
- coreml_diffusion-0.1.3/tests/smoke/test_synthetic_vae.py +77 -0
- coreml_diffusion-0.1.3/tests/unit/test_characterization_component_name.py +154 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/unit/test_cli.py +36 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/unit/test_discovery_api.py +28 -1
- coreml_diffusion-0.1.3/tests/unit/test_inference_output_contract.py +43 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/uv.lock +1 -1
- coreml_diffusion-0.1.2/.release-please-manifest.json +0 -3
- coreml_diffusion-0.1.2/CHANGELOG.md +0 -8
- coreml_diffusion-0.1.2/coreml_diffusion/inference.py +0 -176
- coreml_diffusion-0.1.2/tests/m2/test_inference_golden.py +0 -111
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/.github/workflows/publish-pypi.yml +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/.github/workflows/release-please.yml +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/.github/workflows/tier0.yml +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/.github/workflows/tier1.yml +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/.github/workflows/tier2.yml +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/.gitignore +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/LICENSE +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/attention.py +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/conversion/__init__.py +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/conversion/shapes.py +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/conversion/trace.py +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/conversion/unet.py +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/logger.py +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/model_version.py +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/coreml_diffusion/sources.py +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/release-please-config.json +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/conftest.py +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/m2/goldens/sd15_astronaut.png +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/m2/goldens/sd15_astronaut.sha256 +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/m2/test_original_gpu.py +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/smoke/test_original_attention.py +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/smoke/test_split_einsum_attention.py +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/smoke/test_synthetic_unet.py +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/unit/test_characterization_out_name.py +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/unit/test_conversion_helpers.py +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/unit/test_sources.py +0 -0
- {coreml_diffusion-0.1.2 → coreml_diffusion-0.1.3}/tests/unit/test_tier0_purity.py +0 -0
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Changelog
|
|
2
|
+
|
|
3
|
+
## [0.1.3](https://github.com/aszc-dev/coreml-diffusion/compare/v0.1.2...v0.1.3) (2026-06-04)
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
### ✨ Features
|
|
7
|
+
|
|
8
|
+
* **convert:** add VAE and CLIP text-encoder conversion ([dc1f85b](https://github.com/aszc-dev/coreml-diffusion/commit/dc1f85bafe50d36655ff7ece0c052a30fd77bb81))
|
|
9
|
+
* **inference:** end-to-end Core ML pipeline (VAE + text-encoder swap) ([ca08b16](https://github.com/aszc-dev/coreml-diffusion/commit/ca08b16729529afbdf610d0e8ec2d09b849080c6))
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
### 🐛 Bug Fixes
|
|
13
|
+
|
|
14
|
+
* **inference:** expose .device on the Core ML adapters ([30a673e](https://github.com/aszc-dev/coreml-diffusion/commit/30a673eebe3927722214d0ab6a44fbc344d18f3a))
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
### 📚 Documentation
|
|
18
|
+
|
|
19
|
+
* **readme:** link the log.aszc.dev energy benchmark writeup ([77927b5](https://github.com/aszc-dev/coreml-diffusion/commit/77927b5dd5f1311a3b3c317692f3a347e3976a54))
|
|
20
|
+
|
|
21
|
+
## [0.1.2](https://github.com/aszc-dev/coreml-diffusion/compare/v0.1.1...v0.1.2) (2026-05-27)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
### 🐛 Bug Fixes
|
|
25
|
+
|
|
26
|
+
* **attention:** convertible fp32 ORIGINAL attention for the Core ML GPU path ([#2](https://github.com/aszc-dev/coreml-diffusion/issues/2)) ([28e56fc](https://github.com/aszc-dev/coreml-diffusion/commit/28e56fcf8c2242ebbe4c05abd05f7e796069d7d1))
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: coreml-diffusion
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.3
|
|
4
4
|
Summary: Convert diffusion-model checkpoints (SD1.5/SDXL) to Core ML for Apple Neural Engine — framework-free, ComfyUI-independent.
|
|
5
5
|
Project-URL: Homepage, https://github.com/aszc-dev/coreml-diffusion
|
|
6
6
|
Project-URL: Repository, https://github.com/aszc-dev/coreml-diffusion
|
|
@@ -44,6 +44,12 @@ GPU-free, embeddable in a Swift/iOS app. ANE is the differentiator — this is a
|
|
|
44
44
|
feasibility and power efficiency for SD1.5/SDXL on ANE, not a raw-throughput claim
|
|
45
45
|
against desktop GPUs.
|
|
46
46
|
|
|
47
|
+
The power-efficiency claim is measured: in a cross-backend benchmark the ct9
|
|
48
|
+
converter here runs the SD1.5 UNet on the ANE at **6-7x lower energy** than
|
|
49
|
+
GPU/MPS, at the same speed — see the writeup,
|
|
50
|
+
[The ANE runs the SD1.5 UNet at 6-7x lower energy than GPU/MPS](https://log.aszc.dev/ane-vs-gpu-mps-sd15-unet-energy/),
|
|
51
|
+
for the methodology and the numerical-divergence tradeoff.
|
|
52
|
+
|
|
47
53
|
The scope is diffusion architectures generally, not Stable Diffusion specifically.
|
|
48
54
|
The project aims to gather, in one place: the conversion path, a reproducible
|
|
49
55
|
benchmarking suite for objective comparison, a per-model catalogue documenting the
|
|
@@ -15,6 +15,12 @@ GPU-free, embeddable in a Swift/iOS app. ANE is the differentiator — this is a
|
|
|
15
15
|
feasibility and power efficiency for SD1.5/SDXL on ANE, not a raw-throughput claim
|
|
16
16
|
against desktop GPUs.
|
|
17
17
|
|
|
18
|
+
The power-efficiency claim is measured: in a cross-backend benchmark the ct9
|
|
19
|
+
converter here runs the SD1.5 UNet on the ANE at **6-7x lower energy** than
|
|
20
|
+
GPU/MPS, at the same speed — see the writeup,
|
|
21
|
+
[The ANE runs the SD1.5 UNet at 6-7x lower energy than GPU/MPS](https://log.aszc.dev/ane-vs-gpu-mps-sd15-unet-energy/),
|
|
22
|
+
for the methodology and the numerical-divergence tradeoff.
|
|
23
|
+
|
|
18
24
|
The scope is diffusion architectures generally, not Stable Diffusion specifically.
|
|
19
25
|
The project aims to gather, in one place: the conversion path, a reproducible
|
|
20
26
|
benchmarking suite for objective comparison, a per-model catalogue documenting the
|
|
@@ -23,9 +23,11 @@ because a saved workflow JSON references these strings verbatim.
|
|
|
23
23
|
from enum import Enum
|
|
24
24
|
|
|
25
25
|
from coreml_diffusion.attention import ATTENTION_IMPLEMENTATIONS
|
|
26
|
+
from coreml_diffusion.component import CONVERTIBLE_COMPONENTS
|
|
26
27
|
from coreml_diffusion.model_version import ModelVersion
|
|
27
28
|
from coreml_diffusion.naming import (
|
|
28
29
|
QUANT_NBITS_VALUES,
|
|
30
|
+
compose_component_name,
|
|
29
31
|
compose_out_name,
|
|
30
32
|
lora_names_from_params,
|
|
31
33
|
)
|
|
@@ -36,12 +38,16 @@ __all__ = [
|
|
|
36
38
|
"list_model_versions",
|
|
37
39
|
"list_attention_impls",
|
|
38
40
|
"list_quant_modes",
|
|
41
|
+
"list_convertible_components",
|
|
39
42
|
"CONTRACT_VERSION",
|
|
40
43
|
"compose_out_name",
|
|
44
|
+
"compose_component_name",
|
|
41
45
|
"lora_names_from_params",
|
|
42
46
|
"convert",
|
|
43
47
|
"build_pipeline",
|
|
44
48
|
"CoreMLUNet",
|
|
49
|
+
"CoreMLVAE",
|
|
50
|
+
"CoreMLTextEncoder",
|
|
45
51
|
]
|
|
46
52
|
|
|
47
53
|
|
|
@@ -91,9 +97,20 @@ def list_quant_modes() -> list[str]:
|
|
|
91
97
|
return list(QUANT_NBITS_VALUES)
|
|
92
98
|
|
|
93
99
|
|
|
100
|
+
def list_convertible_components() -> list[str]:
|
|
101
|
+
"""Convertible components, e.g. ``["unet", "vae_decoder", ...]``.
|
|
102
|
+
|
|
103
|
+
``"unet"`` is the historical default; the rest are the VAE / text-encoder
|
|
104
|
+
extension. ``"text_encoder_2"`` is only meaningful for SDXL — validity per
|
|
105
|
+
model version is enforced at convert time, not advertised here.
|
|
106
|
+
"""
|
|
107
|
+
return list(CONVERTIBLE_COMPONENTS)
|
|
108
|
+
|
|
109
|
+
|
|
94
110
|
# Discovery-contract version. Bump per the additive-only rules in this module's
|
|
95
111
|
# docstring and CONVERTER_EXTRACTION_SPEC.md "Interface contract".
|
|
96
|
-
|
|
112
|
+
# 1.1: added list_convertible_components (VAE + text-encoder conversion).
|
|
113
|
+
CONTRACT_VERSION = "1.1"
|
|
97
114
|
|
|
98
115
|
|
|
99
116
|
def __getattr__(name):
|
|
@@ -107,7 +124,7 @@ def __getattr__(name):
|
|
|
107
124
|
from coreml_diffusion.convert import convert as _convert
|
|
108
125
|
|
|
109
126
|
return _convert
|
|
110
|
-
if name in ("build_pipeline", "CoreMLUNet"):
|
|
127
|
+
if name in ("build_pipeline", "CoreMLUNet", "CoreMLVAE", "CoreMLTextEncoder"):
|
|
111
128
|
from coreml_diffusion import inference
|
|
112
129
|
|
|
113
130
|
return getattr(inference, name)
|
|
@@ -37,6 +37,7 @@ def _convert_cmd(args):
|
|
|
37
37
|
ckpt,
|
|
38
38
|
coreml_diffusion.ModelVersion[args.model_version],
|
|
39
39
|
args.out,
|
|
40
|
+
component=args.component,
|
|
40
41
|
batch_size=args.batch_size,
|
|
41
42
|
sample_size=sample_size,
|
|
42
43
|
controlnet_support=args.controlnet,
|
|
@@ -99,6 +100,13 @@ def build_parser():
|
|
|
99
100
|
choices=coreml_diffusion.list_model_versions(include_experimental=True),
|
|
100
101
|
help="Model architecture (verified: SD15, SDXL; experimental otherwise)",
|
|
101
102
|
)
|
|
103
|
+
conv.add_argument(
|
|
104
|
+
"--component",
|
|
105
|
+
choices=coreml_diffusion.list_convertible_components(),
|
|
106
|
+
default="unet",
|
|
107
|
+
help="Checkpoint component to convert (default unet). VAE/text-encoder "
|
|
108
|
+
"components ignore --attn-impl/--controlnet/--lora. text_encoder_2 is SDXL-only",
|
|
109
|
+
)
|
|
102
110
|
conv.add_argument("--out", required=True, help="Output .mlpackage path to write")
|
|
103
111
|
conv.add_argument(
|
|
104
112
|
"--height", type=int, default=512, help="Target image height (default 512)"
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""Convertible model components — framework-free leaf.
|
|
2
|
+
|
|
3
|
+
A single-file checkpoint bundles several sub-models; ``coreml_diffusion`` can
|
|
4
|
+
convert each into its own ``.mlpackage``. This enum is the canonical identifier
|
|
5
|
+
set, mirrored into the discovery contract (``list_convertible_components``) and
|
|
6
|
+
the naming contract (``compose_component_name``).
|
|
7
|
+
|
|
8
|
+
``UNET`` keeps its historical conversion path and filename (``compose_out_name``);
|
|
9
|
+
the other components are the additive VAE / text-encoder extension. ``.value`` is
|
|
10
|
+
the wire string used by the CLI ``--component`` flag and the discovery list — kept
|
|
11
|
+
lowercase and stable, since a saved workflow / benchmark manifest references it
|
|
12
|
+
verbatim (same additive-only rule as ``ModelVersion``).
|
|
13
|
+
|
|
14
|
+
``TEXT_ENCODER_2`` is SDXL-only (its second CLIP, ``CLIPTextModelWithProjection``);
|
|
15
|
+
on SD1.5 only ``TEXT_ENCODER`` exists. Validity per model version is enforced at
|
|
16
|
+
convert time, not encoded here.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from enum import Enum
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Component(Enum):
|
|
23
|
+
UNET = "unet"
|
|
24
|
+
VAE_DECODER = "vae_decoder"
|
|
25
|
+
VAE_ENCODER = "vae_encoder"
|
|
26
|
+
TEXT_ENCODER = "text_encoder"
|
|
27
|
+
TEXT_ENCODER_2 = "text_encoder_2"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# Declaration order is the discovery-list order; UNET leads to match the
|
|
31
|
+
# historical, primary conversion path.
|
|
32
|
+
CONVERTIBLE_COMPONENTS = tuple(c.value for c in Component)
|
|
@@ -120,9 +120,13 @@ def _attention_forward(
|
|
|
120
120
|
input_ndim = hidden_states.ndim
|
|
121
121
|
if input_ndim == 4:
|
|
122
122
|
batch_size, channel, height, width = hidden_states.shape
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
123
|
+
# flatten(2) instead of view(B, C, height * width): the explicit
|
|
124
|
+
# height * width multiplies two traced size ints, emitting an aten::mul ->
|
|
125
|
+
# aten::Int that coremltools 9 cannot fold to a const (it fails the
|
|
126
|
+
# conversion). flatten collapses the spatial dims with a single reshape and
|
|
127
|
+
# no symbolic product. Only the 4D path (VAE self-attention) hits this; the
|
|
128
|
+
# UNet routes attention through a 3D tensor, so ORIGINAL there is untouched.
|
|
129
|
+
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
|
126
130
|
else:
|
|
127
131
|
batch_size, _, channel = hidden_states.shape
|
|
128
132
|
height = None
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""CLIP text-encoder wrapper adapting it to a flat Core ML tensor contract.
|
|
2
|
+
|
|
3
|
+
Wraps a transformers ``CLIPTextModel`` (SD1.5, SDXL encoder 1) or
|
|
4
|
+
``CLIPTextModelWithProjection`` (SDXL encoder 2) so the traced graph takes a
|
|
5
|
+
single ``input_ids`` tensor ``(B, 77)`` and returns the embeddings the diffusion
|
|
6
|
+
pipeline consumes — nothing else.
|
|
7
|
+
|
|
8
|
+
Which hidden state and whether a pooled vector is emitted depends on the model:
|
|
9
|
+
|
|
10
|
+
SD1.5 : final ``last_hidden_state`` (768), no pooled
|
|
11
|
+
SDXL encoder 1 : penultimate ``hidden_states[-2]`` (768), no pooled
|
|
12
|
+
SDXL encoder 2 : penultimate ``hidden_states[-2]`` (1280) + projected pooled (1280)
|
|
13
|
+
|
|
14
|
+
The penultimate selection is SDXL's documented behaviour (it concatenates the
|
|
15
|
+
two encoders' penultimate states and uses encoder 2's projected pooled output as
|
|
16
|
+
the ``add_embeds`` conditioning). ``hidden_states_index=None`` selects the final
|
|
17
|
+
``last_hidden_state``; an int indexes ``hidden_states`` directly.
|
|
18
|
+
|
|
19
|
+
The pooled vector prefers ``text_embeds`` (the projection head, encoder 2) and
|
|
20
|
+
falls back to ``pooler_output`` — so the same wrapper serves both CLIP variants.
|
|
21
|
+
``input_ids`` is fed as int32 at the Core ML boundary (see ``convert``).
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
import contextlib
|
|
25
|
+
|
|
26
|
+
import torch
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@contextlib.contextmanager
|
|
30
|
+
def static_causal_mask(seq_len):
|
|
31
|
+
"""Patch CLIP's causal-mask builder to a constant for the trace duration.
|
|
32
|
+
|
|
33
|
+
transformers builds the causal mask from ``query_length + past_key_values_length``,
|
|
34
|
+
both traced 0-dim size tensors; the resulting ``aten::Int`` cannot be folded to a
|
|
35
|
+
const under coremltools 9 (conversion fails). Since the converted sequence length
|
|
36
|
+
is fixed at trace time, swap ``_create_4d_causal_attention_mask`` for a closure
|
|
37
|
+
that materialises the upper-triangular ``-inf`` mask from a Python-int ``seq_len``
|
|
38
|
+
— a pure constant the frontend folds away. Mirrors ``prepare_unet_for_coreml_trace``:
|
|
39
|
+
a trace-only enablement patch, restored on exit. Shape ``(1, 1, seq, seq)``
|
|
40
|
+
broadcasts over batch/heads, so no symbolic batch leaks back in.
|
|
41
|
+
"""
|
|
42
|
+
from transformers.models.clip import modeling_clip
|
|
43
|
+
|
|
44
|
+
original = modeling_clip._create_4d_causal_attention_mask
|
|
45
|
+
|
|
46
|
+
def _const_mask(input_shape, dtype, device=None, *args, **kwargs):
|
|
47
|
+
mask = torch.full(
|
|
48
|
+
(seq_len, seq_len), torch.finfo(dtype).min, dtype=dtype, device=device
|
|
49
|
+
)
|
|
50
|
+
return torch.triu(mask, diagonal=1)[None, None]
|
|
51
|
+
|
|
52
|
+
modeling_clip._create_4d_causal_attention_mask = _const_mask
|
|
53
|
+
try:
|
|
54
|
+
yield
|
|
55
|
+
finally:
|
|
56
|
+
modeling_clip._create_4d_causal_attention_mask = original
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class CoreMLTextEncoderWrapper(torch.nn.Module):
|
|
60
|
+
"""token ids ``(B, 77)`` -> embeddings (+ optional pooled)."""
|
|
61
|
+
|
|
62
|
+
def __init__(self, text_encoder, *, hidden_states_index=None, output_pooled=False):
|
|
63
|
+
super().__init__()
|
|
64
|
+
self.text_encoder = text_encoder
|
|
65
|
+
self.hidden_states_index = hidden_states_index
|
|
66
|
+
self.output_pooled = output_pooled
|
|
67
|
+
|
|
68
|
+
def forward(self, input_ids):
|
|
69
|
+
out = self.text_encoder(
|
|
70
|
+
input_ids,
|
|
71
|
+
output_hidden_states=self.hidden_states_index is not None,
|
|
72
|
+
return_dict=True,
|
|
73
|
+
)
|
|
74
|
+
if self.hidden_states_index is None:
|
|
75
|
+
embeds = out.last_hidden_state
|
|
76
|
+
else:
|
|
77
|
+
embeds = out.hidden_states[self.hidden_states_index]
|
|
78
|
+
|
|
79
|
+
if not self.output_pooled:
|
|
80
|
+
return embeds
|
|
81
|
+
|
|
82
|
+
pooled = getattr(out, "text_embeds", None)
|
|
83
|
+
if pooled is None:
|
|
84
|
+
pooled = out.pooler_output
|
|
85
|
+
return embeds, pooled
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""VAE wrappers adapting ``AutoencoderKL`` to a flat Core ML tensor contract.
|
|
2
|
+
|
|
3
|
+
Two thin modules, one per direction, mirroring ``CoreMLUNetWrapper``: they expose
|
|
4
|
+
a single positional tensor in / single tensor out so the traced graph has a clean,
|
|
5
|
+
named Core ML I/O signature. They call the VAE *submodules* directly
|
|
6
|
+
(``post_quant_conv``/``decoder``, ``encoder``/``quant_conv``) rather than
|
|
7
|
+
``decode``/``encode`` — the same op graph, but free of the ``return_dict``
|
|
8
|
+
plumbing and the ``DiagonalGaussianDistribution`` wrapper that complicate tracing.
|
|
9
|
+
|
|
10
|
+
Scaling is intentionally NOT baked in. The pipeline owns ``scaling_factor``
|
|
11
|
+
(``latent = latent / scaling_factor`` before decode, ``moments`` -> distribution
|
|
12
|
+
-> sample -> ``* scaling_factor`` after encode), keeping these artifacts 1:1 with
|
|
13
|
+
the reference VAE.
|
|
14
|
+
|
|
15
|
+
The mid-block self-attention is converted via the ORIGINAL (full, fp32-score)
|
|
16
|
+
processor; see ``convert.convert_vae_*``.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class CoreMLVAEDecoderWrapper(torch.nn.Module):
|
|
23
|
+
"""latent ``(B, latent_channels, h, w)`` -> image ``(B, 3, h*8, w*8)``."""
|
|
24
|
+
|
|
25
|
+
def __init__(self, vae):
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.vae = vae
|
|
28
|
+
|
|
29
|
+
def forward(self, latent):
|
|
30
|
+
z = self.vae.post_quant_conv(latent)
|
|
31
|
+
return self.vae.decoder(z)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class CoreMLVAEEncoderWrapper(torch.nn.Module):
|
|
35
|
+
"""image ``(B, 3, h*8, w*8)`` -> latent moments ``(B, 2*latent_channels, h, w)``.
|
|
36
|
+
|
|
37
|
+
Outputs the raw moments (mean ‖ logvar) exactly as ``AutoencoderKL.encode``
|
|
38
|
+
produces them before wrapping in a ``DiagonalGaussianDistribution``. Sampling
|
|
39
|
+
(mean + std·noise) is deferred to the pipeline so the converted encoder stays
|
|
40
|
+
deterministic and noise-source agnostic.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(self, vae):
|
|
44
|
+
super().__init__()
|
|
45
|
+
self.vae = vae
|
|
46
|
+
|
|
47
|
+
def forward(self, image):
|
|
48
|
+
h = self.vae.encoder(image)
|
|
49
|
+
return self.vae.quant_conv(h)
|
|
@@ -19,10 +19,19 @@ import torch
|
|
|
19
19
|
from diffusers import UNet2DConditionModel
|
|
20
20
|
|
|
21
21
|
from coreml_diffusion.attention import ATTENTION_IMPLEMENTATIONS
|
|
22
|
+
from coreml_diffusion.component import Component
|
|
22
23
|
from coreml_diffusion.conversion.attention import apply_attention_implementation
|
|
23
24
|
from coreml_diffusion.conversion.shapes import conv2d_output_shape
|
|
25
|
+
from coreml_diffusion.conversion.text_encoder import (
|
|
26
|
+
CoreMLTextEncoderWrapper,
|
|
27
|
+
static_causal_mask,
|
|
28
|
+
)
|
|
24
29
|
from coreml_diffusion.conversion.trace import prepare_unet_for_coreml_trace
|
|
25
30
|
from coreml_diffusion.conversion.unet import CoreMLUNetWrapper
|
|
31
|
+
from coreml_diffusion.conversion.vae import (
|
|
32
|
+
CoreMLVAEDecoderWrapper,
|
|
33
|
+
CoreMLVAEEncoderWrapper,
|
|
34
|
+
)
|
|
26
35
|
from coreml_diffusion.logger import logger
|
|
27
36
|
from coreml_diffusion.model_version import ModelVersion
|
|
28
37
|
|
|
@@ -274,9 +283,17 @@ def convert_unet(
|
|
|
274
283
|
del traced_unet
|
|
275
284
|
gc.collect()
|
|
276
285
|
|
|
286
|
+
_palettize_and_save(coreml_unet, unet_out_path, quantize_nbits, "unet")
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def _palettize_and_save(coreml_model, out_path, quantize_nbits, label):
|
|
290
|
+
"""Optionally k-means palettize the weights, then save the ``.mlpackage``.
|
|
291
|
+
|
|
292
|
+
The default path (``quantize_nbits="none"``) saves the model untouched. Shared
|
|
293
|
+
by the UNet and the VAE / text-encoder conversions so every component gets the
|
|
294
|
+
same opt-in palettization behaviour and filename-driven cache semantics.
|
|
295
|
+
"""
|
|
277
296
|
if quantize_nbits != "none":
|
|
278
|
-
# Opt-in k-means weight palettization. The default path
|
|
279
|
-
# (quantize_nbits="none") leaves the traced UNet untouched.
|
|
280
297
|
from coremltools.optimize.coreml import (
|
|
281
298
|
OpPalettizerConfig,
|
|
282
299
|
OptimizationConfig,
|
|
@@ -284,16 +301,144 @@ def convert_unet(
|
|
|
284
301
|
)
|
|
285
302
|
|
|
286
303
|
nbits = int(quantize_nbits)
|
|
287
|
-
logger.info(f"Palettizing
|
|
304
|
+
logger.info(f"Palettizing {label} weights to {nbits}-bit (kmeans)..")
|
|
288
305
|
t0 = time.time()
|
|
289
306
|
cfg = OptimizationConfig(
|
|
290
307
|
global_config=OpPalettizerConfig(mode="kmeans", nbits=nbits)
|
|
291
308
|
)
|
|
292
|
-
|
|
309
|
+
coreml_model = palettize_weights(coreml_model, config=cfg)
|
|
293
310
|
logger.info(f"Palettization took {time.time() - t0:.1f}s")
|
|
294
311
|
|
|
295
|
-
|
|
296
|
-
logger.info(f"Saved
|
|
312
|
+
coreml_model.save(out_path)
|
|
313
|
+
logger.info(f"Saved {label} into {out_path}")
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def convert_vae_decoder(
|
|
317
|
+
ref_vae,
|
|
318
|
+
out_path: str,
|
|
319
|
+
*,
|
|
320
|
+
batch_size: int = 1,
|
|
321
|
+
sample_size: tuple[int, int] = (64, 64),
|
|
322
|
+
quantize_nbits: str = "none",
|
|
323
|
+
):
|
|
324
|
+
"""Convert ``AutoencoderKL``'s decoder (latent -> image) to a ``.mlpackage``.
|
|
325
|
+
|
|
326
|
+
``sample_size`` is the *latent* spatial size (H/8, W/8); the image output is
|
|
327
|
+
8x that. The mid-block self-attention is routed through the ORIGINAL
|
|
328
|
+
(full, fp32-score) processor — diffusers' stock sdpa attention fails to
|
|
329
|
+
convert under coremltools 9, the same reason ``ORIGINAL`` exists for the UNet.
|
|
330
|
+
"""
|
|
331
|
+
apply_attention_implementation(ref_vae, "ORIGINAL")
|
|
332
|
+
wrapper = CoreMLVAEDecoderWrapper(ref_vae.eval()).eval()
|
|
333
|
+
|
|
334
|
+
latent_shape = (
|
|
335
|
+
batch_size,
|
|
336
|
+
ref_vae.config.latent_channels,
|
|
337
|
+
sample_size[0],
|
|
338
|
+
sample_size[1],
|
|
339
|
+
)
|
|
340
|
+
example = torch.rand(*latent_shape)
|
|
341
|
+
|
|
342
|
+
logger.info(f"JIT tracing VAE decoder (latent {latent_shape})..")
|
|
343
|
+
traced = torch.jit.trace(wrapper, example)
|
|
344
|
+
|
|
345
|
+
inputs = [ct.TensorType(name="latent", shape=latent_shape, dtype=np.float16)]
|
|
346
|
+
coreml_model = convert_to_coreml("vae_decoder", traced, inputs, ["image"], out_path)
|
|
347
|
+
del traced
|
|
348
|
+
gc.collect()
|
|
349
|
+
_palettize_and_save(coreml_model, out_path, quantize_nbits, "vae_decoder")
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def convert_vae_encoder(
|
|
353
|
+
ref_vae,
|
|
354
|
+
out_path: str,
|
|
355
|
+
*,
|
|
356
|
+
batch_size: int = 1,
|
|
357
|
+
sample_size: tuple[int, int] = (64, 64),
|
|
358
|
+
quantize_nbits: str = "none",
|
|
359
|
+
):
|
|
360
|
+
"""Convert ``AutoencoderKL``'s encoder (image -> latent moments) to a ``.mlpackage``.
|
|
361
|
+
|
|
362
|
+
``sample_size`` is the latent spatial size; the image input is 8x that. The
|
|
363
|
+
output ``latent_moments`` is the raw mean‖logvar tensor (2*latent_channels
|
|
364
|
+
channels) — the pipeline samples from it. Attention handling matches the
|
|
365
|
+
decoder.
|
|
366
|
+
"""
|
|
367
|
+
apply_attention_implementation(ref_vae, "ORIGINAL")
|
|
368
|
+
wrapper = CoreMLVAEEncoderWrapper(ref_vae.eval()).eval()
|
|
369
|
+
|
|
370
|
+
image_shape = (batch_size, 3, sample_size[0] * 8, sample_size[1] * 8)
|
|
371
|
+
example = torch.rand(*image_shape)
|
|
372
|
+
|
|
373
|
+
logger.info(f"JIT tracing VAE encoder (image {image_shape})..")
|
|
374
|
+
traced = torch.jit.trace(wrapper, example)
|
|
375
|
+
|
|
376
|
+
inputs = [ct.TensorType(name="image", shape=image_shape, dtype=np.float16)]
|
|
377
|
+
coreml_model = convert_to_coreml(
|
|
378
|
+
"vae_encoder", traced, inputs, ["latent_moments"], out_path
|
|
379
|
+
)
|
|
380
|
+
del traced
|
|
381
|
+
gc.collect()
|
|
382
|
+
_palettize_and_save(coreml_model, out_path, quantize_nbits, "vae_encoder")
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def convert_text_encoder(
|
|
386
|
+
ckpt_path: str,
|
|
387
|
+
model_version: ModelVersion,
|
|
388
|
+
out_path: str,
|
|
389
|
+
*,
|
|
390
|
+
which: int = 1,
|
|
391
|
+
batch_size: int = 1,
|
|
392
|
+
quantize_nbits: str = "none",
|
|
393
|
+
):
|
|
394
|
+
"""Convert a CLIP text encoder (token ids -> embeddings) to a ``.mlpackage``.
|
|
395
|
+
|
|
396
|
+
``which`` selects the encoder: 1 = primary (SD1.5/SDXL), 2 = SDXL's second
|
|
397
|
+
(``CLIPTextModelWithProjection``). SDXL uses the penultimate hidden state from
|
|
398
|
+
both encoders plus encoder 2's projected pooled output; SD1.5 uses encoder 1's
|
|
399
|
+
final ``last_hidden_state``. ``input_ids`` crosses the Core ML boundary as
|
|
400
|
+
int32 over the fixed 77-token sequence.
|
|
401
|
+
"""
|
|
402
|
+
encoders = load_text_encoders(ckpt_path, model_version)
|
|
403
|
+
if which == 2:
|
|
404
|
+
if len(encoders) < 2:
|
|
405
|
+
raise ValueError(
|
|
406
|
+
f"{model_version.name} has no second text encoder "
|
|
407
|
+
"(text_encoder_2 is SDXL-only)."
|
|
408
|
+
)
|
|
409
|
+
encoder = encoders[1]
|
|
410
|
+
hidden_states_index = -2
|
|
411
|
+
output_pooled = True
|
|
412
|
+
output_names = ["hidden_states", "pooled_embeds"]
|
|
413
|
+
else:
|
|
414
|
+
encoder = encoders[0]
|
|
415
|
+
is_sdxl = model_version in {ModelVersion.SDXL, ModelVersion.SDXL_REFINER}
|
|
416
|
+
hidden_states_index = -2 if is_sdxl else None
|
|
417
|
+
output_pooled = False
|
|
418
|
+
output_names = ["hidden_states"]
|
|
419
|
+
|
|
420
|
+
wrapper = CoreMLTextEncoderWrapper(
|
|
421
|
+
encoder.eval(),
|
|
422
|
+
hidden_states_index=hidden_states_index,
|
|
423
|
+
output_pooled=output_pooled,
|
|
424
|
+
).eval()
|
|
425
|
+
|
|
426
|
+
ids_shape = (batch_size, TEXT_TOKEN_SEQUENCE_LENGTH)
|
|
427
|
+
# Trace with int64 (torch embedding gather needs long); declare int32 at the
|
|
428
|
+
# Core ML boundary (coremltools casts the gather index op).
|
|
429
|
+
example = torch.zeros(ids_shape, dtype=torch.int64)
|
|
430
|
+
|
|
431
|
+
logger.info(f"JIT tracing text encoder {which} (input_ids {ids_shape})..")
|
|
432
|
+
with static_causal_mask(TEXT_TOKEN_SEQUENCE_LENGTH):
|
|
433
|
+
traced = torch.jit.trace(wrapper, example)
|
|
434
|
+
|
|
435
|
+
inputs = [ct.TensorType(name="input_ids", shape=ids_shape, dtype=np.int32)]
|
|
436
|
+
coreml_model = convert_to_coreml(
|
|
437
|
+
f"text_encoder_{which}", traced, inputs, output_names, out_path
|
|
438
|
+
)
|
|
439
|
+
del traced
|
|
440
|
+
gc.collect()
|
|
441
|
+
_palettize_and_save(coreml_model, out_path, quantize_nbits, f"text_encoder_{which}")
|
|
297
442
|
|
|
298
443
|
|
|
299
444
|
def convert(
|
|
@@ -301,6 +446,7 @@ def convert(
|
|
|
301
446
|
model_version: ModelVersion,
|
|
302
447
|
out_path: str,
|
|
303
448
|
*,
|
|
449
|
+
component: str = Component.UNET.value,
|
|
304
450
|
batch_size: int = 1,
|
|
305
451
|
sample_size: tuple[int, int] = (64, 64),
|
|
306
452
|
controlnet_support: bool = False,
|
|
@@ -309,16 +455,53 @@ def convert(
|
|
|
309
455
|
config_path: str = None,
|
|
310
456
|
quantize_nbits: str = "none",
|
|
311
457
|
):
|
|
312
|
-
"""Convert a single-file checkpoint
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
458
|
+
"""Convert a single-file checkpoint component to a Core ML ``.mlpackage``.
|
|
459
|
+
|
|
460
|
+
``component`` selects what to convert (default ``"unet"`` — historical
|
|
461
|
+
behaviour, all UNet-only kwargs apply). ``vae_decoder`` / ``vae_encoder`` /
|
|
462
|
+
``text_encoder`` / ``text_encoder_2`` convert the corresponding sub-model and
|
|
463
|
+
ignore the UNet-only kwargs (``controlnet_support``, ``lora_weights``,
|
|
464
|
+
``attn_impl``). Keyword-only past the three required positionals so the package
|
|
465
|
+
can add capabilities without breaking an older caller. Writes ``out_path``;
|
|
466
|
+
returns None.
|
|
317
467
|
"""
|
|
318
468
|
if os.path.exists(out_path):
|
|
319
469
|
logger.info(f"Found existing model at {out_path}! Skipping..")
|
|
320
470
|
return
|
|
321
471
|
|
|
472
|
+
comp = Component(component)
|
|
473
|
+
|
|
474
|
+
if comp is Component.VAE_DECODER:
|
|
475
|
+
convert_vae_decoder(
|
|
476
|
+
load_vae(ckpt_path),
|
|
477
|
+
out_path,
|
|
478
|
+
batch_size=batch_size,
|
|
479
|
+
sample_size=sample_size,
|
|
480
|
+
quantize_nbits=quantize_nbits,
|
|
481
|
+
)
|
|
482
|
+
return
|
|
483
|
+
|
|
484
|
+
if comp is Component.VAE_ENCODER:
|
|
485
|
+
convert_vae_encoder(
|
|
486
|
+
load_vae(ckpt_path),
|
|
487
|
+
out_path,
|
|
488
|
+
batch_size=batch_size,
|
|
489
|
+
sample_size=sample_size,
|
|
490
|
+
quantize_nbits=quantize_nbits,
|
|
491
|
+
)
|
|
492
|
+
return
|
|
493
|
+
|
|
494
|
+
if comp in {Component.TEXT_ENCODER, Component.TEXT_ENCODER_2}:
|
|
495
|
+
convert_text_encoder(
|
|
496
|
+
ckpt_path,
|
|
497
|
+
model_version,
|
|
498
|
+
out_path,
|
|
499
|
+
which=2 if comp is Component.TEXT_ENCODER_2 else 1,
|
|
500
|
+
batch_size=batch_size,
|
|
501
|
+
quantize_nbits=quantize_nbits,
|
|
502
|
+
)
|
|
503
|
+
return
|
|
504
|
+
|
|
322
505
|
if attn_impl not in ATTENTION_IMPLEMENTATIONS:
|
|
323
506
|
raise ValueError(
|
|
324
507
|
f"Unsupported attention implementation {attn_impl!r}. "
|
|
@@ -350,3 +533,42 @@ def load_unet(ckpt_path, config_path):
|
|
|
350
533
|
ckpt_path,
|
|
351
534
|
original_config=config_path,
|
|
352
535
|
)
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
def load_vae(ckpt_path):
|
|
539
|
+
"""Load ``AutoencoderKL`` from a single-file checkpoint."""
|
|
540
|
+
from diffusers import AutoencoderKL
|
|
541
|
+
|
|
542
|
+
return AutoencoderKL.from_single_file(ckpt_path)
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
# Pipeline class whose ``from_single_file`` extracts the text encoder(s) for a
|
|
546
|
+
# given model version. SDXL pipelines expose ``text_encoder_2``; SD15/LCM do not.
|
|
547
|
+
_TEXT_ENCODER_PIPELINE = {
|
|
548
|
+
ModelVersion.SD15: ("diffusers", "StableDiffusionPipeline"),
|
|
549
|
+
ModelVersion.LCM: ("diffusers", "StableDiffusionPipeline"),
|
|
550
|
+
ModelVersion.SDXL: ("diffusers", "StableDiffusionXLPipeline"),
|
|
551
|
+
ModelVersion.SDXL_REFINER: ("diffusers", "StableDiffusionXLPipeline"),
|
|
552
|
+
}
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
def load_text_encoders(ckpt_path, model_version):
|
|
556
|
+
"""Return the checkpoint's CLIP text encoder(s) as a list (1 for SD1.5, 2 for SDXL).
|
|
557
|
+
|
|
558
|
+
Loads the matching diffusers pipeline from the single-file checkpoint and
|
|
559
|
+
extracts its text encoder(s). This pulls the whole pipeline (UNet/VAE
|
|
560
|
+
included) — acceptable for an offline, one-shot conversion and the most robust
|
|
561
|
+
way to recover correctly-configured CLIP weights from a single file.
|
|
562
|
+
"""
|
|
563
|
+
import importlib
|
|
564
|
+
|
|
565
|
+
if model_version not in _TEXT_ENCODER_PIPELINE:
|
|
566
|
+
raise ValueError(f"No text-encoder pipeline mapped for {model_version!r}.")
|
|
567
|
+
module_name, class_name = _TEXT_ENCODER_PIPELINE[model_version]
|
|
568
|
+
pipeline_cls = getattr(importlib.import_module(module_name), class_name)
|
|
569
|
+
|
|
570
|
+
pipe = pipeline_cls.from_single_file(ckpt_path)
|
|
571
|
+
encoders = [pipe.text_encoder]
|
|
572
|
+
if getattr(pipe, "text_encoder_2", None) is not None:
|
|
573
|
+
encoders.append(pipe.text_encoder_2)
|
|
574
|
+
return encoders
|