coreml-diffusion 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- coreml_diffusion/__init__.py +108 -0
- coreml_diffusion/attention.py +5 -0
- coreml_diffusion/cli.py +114 -0
- coreml_diffusion/conversion/__init__.py +9 -0
- coreml_diffusion/conversion/attention.py +245 -0
- coreml_diffusion/conversion/shapes.py +20 -0
- coreml_diffusion/conversion/trace.py +61 -0
- coreml_diffusion/conversion/unet.py +54 -0
- coreml_diffusion/convert.py +348 -0
- coreml_diffusion/logger.py +5 -0
- coreml_diffusion/model_version.py +8 -0
- coreml_diffusion/naming.py +73 -0
- coreml_diffusion-0.1.0.dist-info/METADATA +98 -0
- coreml_diffusion-0.1.0.dist-info/RECORD +17 -0
- coreml_diffusion-0.1.0.dist-info/WHEEL +4 -0
- coreml_diffusion-0.1.0.dist-info/entry_points.txt +2 -0
- coreml_diffusion-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""coreml_diffusion — framework-free Core ML diffusion conversion.
|
|
2
|
+
|
|
3
|
+
Converts diffusion-model checkpoints (SD1.5/SDXL today) into Core ML
|
|
4
|
+
``.mlpackage`` artifacts for Apple Neural Engine, with no ComfyUI dependency.
|
|
5
|
+
Usable as a library, via the ``coreml-diffusion`` CLI, or embedded in on-device
|
|
6
|
+
(iOS) tooling. The public surface is the discovery API below plus ``convert``,
|
|
7
|
+
``compose_out_name`` and ``ModelVersion``.
|
|
8
|
+
|
|
9
|
+
This package MUST stay free of ``comfy`` / ``folder_paths`` / ``comfy_extras``;
|
|
10
|
+
``import coreml_diffusion`` works in a comfy-free environment.
|
|
11
|
+
|
|
12
|
+
Discovery contract (consumed by ComfyUI-CoreMLSuite): the node populates its
|
|
13
|
+
dropdowns by calling ``list_*`` here, so installing a newer ``coreml_diffusion``
|
|
14
|
+
surfaces new conversion types in the old node with no Suite change and no Suite
|
|
15
|
+
version bump. The identifiers returned here are an ADDITIVE-ONLY contract:
|
|
16
|
+
|
|
17
|
+
- adding an identifier, or promoting EXPERIMENTAL -> VERIFIED => minor bump
|
|
18
|
+
- removing/renaming an identifier, or demoting VERIFIED => MAJOR bump + note
|
|
19
|
+
|
|
20
|
+
because a saved workflow JSON references these strings verbatim.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from enum import Enum
|
|
24
|
+
|
|
25
|
+
from coreml_diffusion.attention import ATTENTION_IMPLEMENTATIONS
|
|
26
|
+
from coreml_diffusion.model_version import ModelVersion
|
|
27
|
+
from coreml_diffusion.naming import (
|
|
28
|
+
QUANT_NBITS_VALUES,
|
|
29
|
+
compose_out_name,
|
|
30
|
+
lora_names_from_params,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
"ModelVersion",
|
|
35
|
+
"Status",
|
|
36
|
+
"list_model_versions",
|
|
37
|
+
"list_attention_impls",
|
|
38
|
+
"list_quant_modes",
|
|
39
|
+
"CONTRACT_VERSION",
|
|
40
|
+
"compose_out_name",
|
|
41
|
+
"lora_names_from_params",
|
|
42
|
+
"convert",
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class Status(Enum):
|
|
47
|
+
VERIFIED = "verified" # has a golden anchor + passing [M2-ANE] check
|
|
48
|
+
EXPERIMENTAL = "experimental" # convertible, not yet anchored/verified
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# Single source of truth for which conversions the Suite may surface. The Suite
|
|
52
|
+
# gates on this status, NOT on a hardcoded node list: promoting a model to
|
|
53
|
+
# VERIFIED expands the node's dropdown with no Suite change.
|
|
54
|
+
#
|
|
55
|
+
# Keyed by the ModelVersion MEMBER so ``list_model_versions`` can emit ``.name``
|
|
56
|
+
# ("SD15", "SDXL"). The node reverses the dropdown string via ``ModelVersion[...]``
|
|
57
|
+
# (name lookup, nodes.py), so emitting ``.value`` ("sd15") would raise KeyError on
|
|
58
|
+
# every saved workflow. See seam.md §5.
|
|
59
|
+
_MODEL_STATUS = {
|
|
60
|
+
ModelVersion.SD15: Status.VERIFIED,
|
|
61
|
+
ModelVersion.SDXL: Status.VERIFIED,
|
|
62
|
+
ModelVersion.SDXL_REFINER: Status.EXPERIMENTAL, # -> VERIFIED after a refiner golden anchor
|
|
63
|
+
ModelVersion.LCM: Status.EXPERIMENTAL, # -> VERIFIED after E-LCM golden anchor
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def list_model_versions(include_experimental: bool = False) -> list[str]:
|
|
68
|
+
"""Model versions by ``.name`` (e.g. ``["SD15", "SDXL"]``).
|
|
69
|
+
|
|
70
|
+
Returns VERIFIED versions only by default — the converter node calls this
|
|
71
|
+
plainly. A power-user/CLI path may pass ``include_experimental=True`` to also
|
|
72
|
+
list convertible-but-unanchored versions.
|
|
73
|
+
"""
|
|
74
|
+
return [
|
|
75
|
+
version.name
|
|
76
|
+
for version, status in _MODEL_STATUS.items()
|
|
77
|
+
if status is Status.VERIFIED
|
|
78
|
+
or (include_experimental and status is Status.EXPERIMENTAL)
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def list_attention_impls() -> list[str]:
|
|
83
|
+
"""Supported attention implementations, e.g. ``["SPLIT_EINSUM", ...]``."""
|
|
84
|
+
return list(ATTENTION_IMPLEMENTATIONS)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def list_quant_modes() -> list[str]:
|
|
88
|
+
"""Palettization modes, e.g. ``["none", "8", "6", "4"]`` ("none" = unquantized)."""
|
|
89
|
+
return list(QUANT_NBITS_VALUES)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
# Discovery-contract version. Bump per the additive-only rules in this module's
|
|
93
|
+
# docstring and CONVERTER_EXTRACTION_SPEC.md "Interface contract".
|
|
94
|
+
CONTRACT_VERSION = "1.0"
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def __getattr__(name):
|
|
98
|
+
"""Lazily expose the heavy conversion entrypoint.
|
|
99
|
+
|
|
100
|
+
``convert`` pulls coremltools/diffusers, so importing it eagerly would drag
|
|
101
|
+
the Mac/heavy stack into every ``import coreml_diffusion`` and break the
|
|
102
|
+
Tier-0 (Linux, framework-free) lane. Resolve it only on first access.
|
|
103
|
+
"""
|
|
104
|
+
if name == "convert":
|
|
105
|
+
from coreml_diffusion.convert import convert as _convert
|
|
106
|
+
|
|
107
|
+
return _convert
|
|
108
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
coreml_diffusion/cli.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
"""Command-line entry point for coreml_diffusion.
|
|
2
|
+
|
|
3
|
+
Mirrors ``coreml_diffusion.convert`` so the package can produce a Core ML
|
|
4
|
+
``.mlpackage`` with no ComfyUI involved — for headless and on-device (iOS)
|
|
5
|
+
conversion workflows. The heavy import (coremltools/diffusers, pulled by
|
|
6
|
+
``convert``) is deferred into the handler, so ``--help`` and argument parsing
|
|
7
|
+
stay light and the arg→call mapping is testable on plain Linux.
|
|
8
|
+
|
|
9
|
+
Example:
|
|
10
|
+
coreml-diffusion convert --ckpt model.safetensors --model-version SD15 \\
|
|
11
|
+
--out unet.mlpackage --height 512 --width 512 --attn-impl SPLIT_EINSUM
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import argparse
|
|
15
|
+
|
|
16
|
+
import coreml_diffusion
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _parse_lora(spec):
|
|
20
|
+
"""Parse a ``PATH:STRENGTH`` lora spec into ``(path, float_strength)``.
|
|
21
|
+
|
|
22
|
+
Strength defaults to 1.0 when omitted. ``rsplit`` on the last ':' so Windows
|
|
23
|
+
drive letters / colons in the path survive.
|
|
24
|
+
"""
|
|
25
|
+
path, sep, strength = spec.rpartition(":")
|
|
26
|
+
if not sep:
|
|
27
|
+
return spec, 1.0
|
|
28
|
+
return path, float(strength)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _convert_cmd(args):
|
|
32
|
+
sample_size = (args.height // 8, args.width // 8)
|
|
33
|
+
lora_weights = [_parse_lora(spec) for spec in (args.lora or [])]
|
|
34
|
+
coreml_diffusion.convert(
|
|
35
|
+
args.ckpt,
|
|
36
|
+
coreml_diffusion.ModelVersion[args.model_version],
|
|
37
|
+
args.out,
|
|
38
|
+
batch_size=args.batch_size,
|
|
39
|
+
sample_size=sample_size,
|
|
40
|
+
controlnet_support=args.controlnet,
|
|
41
|
+
lora_weights=lora_weights or None,
|
|
42
|
+
attn_impl=args.attn_impl,
|
|
43
|
+
config_path=args.config,
|
|
44
|
+
quantize_nbits=args.quantize,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def build_parser():
|
|
49
|
+
parser = argparse.ArgumentParser(
|
|
50
|
+
prog="coreml-diffusion",
|
|
51
|
+
description="Convert diffusion checkpoints to Core ML for Apple Neural Engine.",
|
|
52
|
+
)
|
|
53
|
+
sub = parser.add_subparsers(dest="command", required=True)
|
|
54
|
+
|
|
55
|
+
conv = sub.add_parser("convert", help="Convert a checkpoint's UNet to a .mlpackage")
|
|
56
|
+
conv.add_argument(
|
|
57
|
+
"--ckpt", required=True, help="Path to the source .safetensors checkpoint"
|
|
58
|
+
)
|
|
59
|
+
conv.add_argument(
|
|
60
|
+
"--model-version",
|
|
61
|
+
required=True,
|
|
62
|
+
# include experimental: the CLI is the power-user path. Experimental
|
|
63
|
+
# versions (LCM, SDXL_REFINER) convert but are not golden-verified.
|
|
64
|
+
choices=coreml_diffusion.list_model_versions(include_experimental=True),
|
|
65
|
+
help="Model architecture (verified: SD15, SDXL; experimental otherwise)",
|
|
66
|
+
)
|
|
67
|
+
conv.add_argument("--out", required=True, help="Output .mlpackage path to write")
|
|
68
|
+
conv.add_argument(
|
|
69
|
+
"--height", type=int, default=512, help="Target image height (default 512)"
|
|
70
|
+
)
|
|
71
|
+
conv.add_argument(
|
|
72
|
+
"--width", type=int, default=512, help="Target image width (default 512)"
|
|
73
|
+
)
|
|
74
|
+
conv.add_argument(
|
|
75
|
+
"--batch-size", type=int, default=1, help="Batch size (default 1)"
|
|
76
|
+
)
|
|
77
|
+
conv.add_argument(
|
|
78
|
+
"--attn-impl",
|
|
79
|
+
choices=coreml_diffusion.list_attention_impls(),
|
|
80
|
+
default=coreml_diffusion.list_attention_impls()[0],
|
|
81
|
+
help="Attention implementation (default SPLIT_EINSUM)",
|
|
82
|
+
)
|
|
83
|
+
conv.add_argument(
|
|
84
|
+
"--controlnet",
|
|
85
|
+
action="store_true",
|
|
86
|
+
help="Add ControlNet residual inputs to the converted UNet",
|
|
87
|
+
)
|
|
88
|
+
conv.add_argument(
|
|
89
|
+
"--lora",
|
|
90
|
+
action="append",
|
|
91
|
+
metavar="PATH[:STRENGTH]",
|
|
92
|
+
help="LoRA to fuse before conversion; repeatable. STRENGTH defaults to 1.0",
|
|
93
|
+
)
|
|
94
|
+
conv.add_argument(
|
|
95
|
+
"--config", default=None, help="Optional original-config YAML path"
|
|
96
|
+
)
|
|
97
|
+
conv.add_argument(
|
|
98
|
+
"--quantize",
|
|
99
|
+
choices=coreml_diffusion.list_quant_modes(),
|
|
100
|
+
default="none",
|
|
101
|
+
help="K-means weight palettization bits (default none = unquantized)",
|
|
102
|
+
)
|
|
103
|
+
conv.set_defaults(func=_convert_cmd)
|
|
104
|
+
return parser
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def main(argv=None):
|
|
108
|
+
parser = build_parser()
|
|
109
|
+
args = parser.parse_args(argv)
|
|
110
|
+
return args.func(args)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
if __name__ == "__main__":
|
|
114
|
+
main()
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""Core ML conversion helpers.
|
|
2
|
+
|
|
3
|
+
The conversion approach originates from Apple's ml-stable-diffusion
|
|
4
|
+
(https://github.com/apple/ml-stable-diffusion). This implementation has since
|
|
5
|
+
diverged: it runs natively on diffusers' UNet2DConditionModel with its own
|
|
6
|
+
SPLIT_EINSUM / SPLIT_EINSUM_V2 attention processors and no longer depends on
|
|
7
|
+
that package. The intent is to keep iterating on these methods independently
|
|
8
|
+
while tracking current tooling.
|
|
9
|
+
"""
|
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
logger = logging.getLogger(__name__)
|
|
6
|
+
|
|
7
|
+
CHUNK_SIZE = 512
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def apply_attention_implementation(unet, attention_implementation):
|
|
11
|
+
if attention_implementation == "ORIGINAL":
|
|
12
|
+
return unet
|
|
13
|
+
|
|
14
|
+
if attention_implementation == "SPLIT_EINSUM":
|
|
15
|
+
unet.set_attn_processor(SplitEinsumAttnProcessor())
|
|
16
|
+
return unet
|
|
17
|
+
|
|
18
|
+
if attention_implementation == "SPLIT_EINSUM_V2":
|
|
19
|
+
unet.set_attn_processor(SplitEinsumV2AttnProcessor())
|
|
20
|
+
return unet
|
|
21
|
+
|
|
22
|
+
raise ValueError(
|
|
23
|
+
f"Unsupported attention implementation: {attention_implementation}"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class SplitEinsumAttnProcessor:
|
|
28
|
+
def __call__(
|
|
29
|
+
self,
|
|
30
|
+
attn,
|
|
31
|
+
hidden_states,
|
|
32
|
+
encoder_hidden_states=None,
|
|
33
|
+
attention_mask=None,
|
|
34
|
+
temb=None,
|
|
35
|
+
*args,
|
|
36
|
+
**kwargs,
|
|
37
|
+
):
|
|
38
|
+
return _attention_forward(
|
|
39
|
+
attn,
|
|
40
|
+
hidden_states,
|
|
41
|
+
encoder_hidden_states,
|
|
42
|
+
attention_mask,
|
|
43
|
+
temb,
|
|
44
|
+
split_einsum,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class SplitEinsumV2AttnProcessor:
|
|
49
|
+
def __call__(
|
|
50
|
+
self,
|
|
51
|
+
attn,
|
|
52
|
+
hidden_states,
|
|
53
|
+
encoder_hidden_states=None,
|
|
54
|
+
attention_mask=None,
|
|
55
|
+
temb=None,
|
|
56
|
+
*args,
|
|
57
|
+
**kwargs,
|
|
58
|
+
):
|
|
59
|
+
return _attention_forward(
|
|
60
|
+
attn,
|
|
61
|
+
hidden_states,
|
|
62
|
+
encoder_hidden_states,
|
|
63
|
+
attention_mask,
|
|
64
|
+
temb,
|
|
65
|
+
split_einsum_v2,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _attention_forward(
|
|
70
|
+
attn,
|
|
71
|
+
hidden_states,
|
|
72
|
+
encoder_hidden_states,
|
|
73
|
+
attention_mask,
|
|
74
|
+
temb,
|
|
75
|
+
attention_fn,
|
|
76
|
+
):
|
|
77
|
+
residual = hidden_states
|
|
78
|
+
|
|
79
|
+
if attn.spatial_norm is not None:
|
|
80
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
81
|
+
|
|
82
|
+
input_ndim = hidden_states.ndim
|
|
83
|
+
if input_ndim == 4:
|
|
84
|
+
batch_size, channel, height, width = hidden_states.shape
|
|
85
|
+
hidden_states = hidden_states.view(
|
|
86
|
+
batch_size, channel, height * width
|
|
87
|
+
).transpose(1, 2)
|
|
88
|
+
else:
|
|
89
|
+
batch_size, _, channel = hidden_states.shape
|
|
90
|
+
height = None
|
|
91
|
+
width = None
|
|
92
|
+
|
|
93
|
+
batch_size, key_sequence_length, _ = (
|
|
94
|
+
hidden_states.shape
|
|
95
|
+
if encoder_hidden_states is None
|
|
96
|
+
else encoder_hidden_states.shape
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
if attention_mask is not None:
|
|
100
|
+
attention_mask = attn.prepare_attention_mask(
|
|
101
|
+
attention_mask,
|
|
102
|
+
key_sequence_length,
|
|
103
|
+
batch_size,
|
|
104
|
+
)
|
|
105
|
+
attention_mask = _prepare_split_einsum_mask(
|
|
106
|
+
attention_mask,
|
|
107
|
+
batch_size,
|
|
108
|
+
attn.heads,
|
|
109
|
+
key_sequence_length,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
if attn.group_norm is not None:
|
|
113
|
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
114
|
+
|
|
115
|
+
query = attn.to_q(hidden_states)
|
|
116
|
+
|
|
117
|
+
if encoder_hidden_states is None:
|
|
118
|
+
encoder_hidden_states = hidden_states
|
|
119
|
+
elif attn.norm_cross:
|
|
120
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
121
|
+
|
|
122
|
+
key = attn.to_k(encoder_hidden_states)
|
|
123
|
+
value = attn.to_v(encoder_hidden_states)
|
|
124
|
+
|
|
125
|
+
batch_size = query.shape[0]
|
|
126
|
+
dim_head = attn.inner_kv_dim // attn.heads
|
|
127
|
+
|
|
128
|
+
query = _linear_projection_to_bchw(query)
|
|
129
|
+
key = _linear_projection_to_bchw(key)
|
|
130
|
+
value = _linear_projection_to_bchw(value)
|
|
131
|
+
|
|
132
|
+
hidden_states = attention_fn(
|
|
133
|
+
query,
|
|
134
|
+
key,
|
|
135
|
+
value,
|
|
136
|
+
attention_mask,
|
|
137
|
+
attn.heads,
|
|
138
|
+
dim_head,
|
|
139
|
+
)
|
|
140
|
+
hidden_states = hidden_states.squeeze(2).transpose(1, 2)
|
|
141
|
+
hidden_states = hidden_states.reshape(batch_size, -1, attn.inner_dim)
|
|
142
|
+
|
|
143
|
+
hidden_states = attn.to_out[0](hidden_states)
|
|
144
|
+
hidden_states = attn.to_out[1](hidden_states)
|
|
145
|
+
|
|
146
|
+
if input_ndim == 4:
|
|
147
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
|
148
|
+
batch_size,
|
|
149
|
+
channel,
|
|
150
|
+
height,
|
|
151
|
+
width,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
if attn.residual_connection:
|
|
155
|
+
hidden_states = hidden_states + residual
|
|
156
|
+
|
|
157
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
|
158
|
+
return hidden_states
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def split_einsum(q, k, v, mask, heads, dim_head):
|
|
162
|
+
q_heads = _split_heads(q, heads, dim_head)
|
|
163
|
+
k = k.transpose(1, 3)
|
|
164
|
+
k_heads = [
|
|
165
|
+
k[:, :, :, head_idx * dim_head : (head_idx + 1) * dim_head]
|
|
166
|
+
for head_idx in range(heads)
|
|
167
|
+
]
|
|
168
|
+
v_heads = _split_heads(v, heads, dim_head)
|
|
169
|
+
|
|
170
|
+
weights = [
|
|
171
|
+
torch.einsum("bchq,bkhc->bkhq", query, key) * (dim_head**-0.5)
|
|
172
|
+
for query, key in zip(q_heads, k_heads)
|
|
173
|
+
]
|
|
174
|
+
if mask is not None:
|
|
175
|
+
weights = [weight + mask for weight in weights]
|
|
176
|
+
|
|
177
|
+
weights = [weight.softmax(dim=1) for weight in weights]
|
|
178
|
+
outputs = [
|
|
179
|
+
torch.einsum("bkhq,bchk->bchq", weight, value)
|
|
180
|
+
for weight, value in zip(weights, v_heads)
|
|
181
|
+
]
|
|
182
|
+
return torch.cat(outputs, dim=1)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def split_einsum_v2(q, k, v, mask, heads, dim_head):
|
|
186
|
+
query_length = q.size(3)
|
|
187
|
+
num_chunks = query_length // CHUNK_SIZE
|
|
188
|
+
if num_chunks == 0:
|
|
189
|
+
logger.info(
|
|
190
|
+
"SPLIT_EINSUM_V2 query sequence is shorter than %s; using SPLIT_EINSUM.",
|
|
191
|
+
CHUNK_SIZE,
|
|
192
|
+
)
|
|
193
|
+
return split_einsum(q, k, v, mask, heads, dim_head)
|
|
194
|
+
|
|
195
|
+
q_heads = _split_heads(q, heads, dim_head)
|
|
196
|
+
q_chunks = [
|
|
197
|
+
[
|
|
198
|
+
head[..., chunk_idx * CHUNK_SIZE : (chunk_idx + 1) * CHUNK_SIZE]
|
|
199
|
+
for chunk_idx in range(num_chunks)
|
|
200
|
+
]
|
|
201
|
+
for head in q_heads
|
|
202
|
+
]
|
|
203
|
+
|
|
204
|
+
k = k.transpose(1, 3)
|
|
205
|
+
k_heads = [
|
|
206
|
+
k[:, :, :, head_idx * dim_head : (head_idx + 1) * dim_head]
|
|
207
|
+
for head_idx in range(heads)
|
|
208
|
+
]
|
|
209
|
+
v_heads = _split_heads(v, heads, dim_head)
|
|
210
|
+
|
|
211
|
+
head_outputs = []
|
|
212
|
+
for query_chunks, key, value in zip(q_chunks, k_heads, v_heads):
|
|
213
|
+
chunk_outputs = []
|
|
214
|
+
for query_chunk in query_chunks:
|
|
215
|
+
weights = torch.einsum("bchq,bkhc->bkhq", query_chunk, key)
|
|
216
|
+
weights = weights * (dim_head**-0.5)
|
|
217
|
+
if mask is not None:
|
|
218
|
+
weights = weights + mask
|
|
219
|
+
weights = weights.softmax(dim=1)
|
|
220
|
+
chunk_outputs.append(torch.einsum("bkhq,bchk->bchq", weights, value))
|
|
221
|
+
head_outputs.append(torch.cat(chunk_outputs, dim=3))
|
|
222
|
+
|
|
223
|
+
return torch.cat(head_outputs, dim=1)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _split_heads(x, heads, dim_head):
|
|
227
|
+
return [
|
|
228
|
+
x[:, head_idx * dim_head : (head_idx + 1) * dim_head, :, :]
|
|
229
|
+
for head_idx in range(heads)
|
|
230
|
+
]
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _linear_projection_to_bchw(x):
|
|
234
|
+
return x.transpose(1, 2).unsqueeze(2)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def _prepare_split_einsum_mask(mask, batch_size, heads, key_sequence_length):
|
|
238
|
+
if mask.ndim == 2:
|
|
239
|
+
mask = mask[:, None, :]
|
|
240
|
+
if mask.shape[0] == batch_size * heads:
|
|
241
|
+
mask = mask.reshape(batch_size, heads, -1, key_sequence_length)
|
|
242
|
+
mask = mask[:, 0]
|
|
243
|
+
if mask.ndim == 3:
|
|
244
|
+
mask = mask[:, :, None, None]
|
|
245
|
+
return mask
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
def conv2d_output_shape(height, width, conv):
|
|
2
|
+
"""Return the spatial output shape for a torch.nn.Conv2d-like module."""
|
|
3
|
+
kernel_h, kernel_w = _pair(conv.kernel_size)
|
|
4
|
+
stride_h, stride_w = _pair(conv.stride)
|
|
5
|
+
pad_h, pad_w = _pair(conv.padding)
|
|
6
|
+
dilation_h, dilation_w = _pair(conv.dilation)
|
|
7
|
+
|
|
8
|
+
out_h = _conv_output_dim(height, kernel_h, stride_h, pad_h, dilation_h)
|
|
9
|
+
out_w = _conv_output_dim(width, kernel_w, stride_w, pad_w, dilation_w)
|
|
10
|
+
return out_h, out_w
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _conv_output_dim(size, kernel, stride, padding, dilation):
|
|
14
|
+
return ((size + (2 * padding) - (dilation * (kernel - 1)) - 1) // stride) + 1
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _pair(value):
|
|
18
|
+
if isinstance(value, tuple):
|
|
19
|
+
return value
|
|
20
|
+
return value, value
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from types import MethodType
|
|
2
|
+
|
|
3
|
+
from diffusers.models.transformers.transformer_2d import Transformer2DModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def prepare_unet_for_coreml_trace(unet):
|
|
7
|
+
for module in unet.modules():
|
|
8
|
+
if isinstance(module, Transformer2DModel):
|
|
9
|
+
module._operate_on_continuous_inputs = MethodType(
|
|
10
|
+
_operate_on_continuous_inputs,
|
|
11
|
+
module,
|
|
12
|
+
)
|
|
13
|
+
module._get_output_for_continuous_inputs = MethodType(
|
|
14
|
+
_get_output_for_continuous_inputs,
|
|
15
|
+
module,
|
|
16
|
+
)
|
|
17
|
+
return unet
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _operate_on_continuous_inputs(self, hidden_states):
|
|
21
|
+
hidden_states = self.norm(hidden_states)
|
|
22
|
+
|
|
23
|
+
if not self.use_linear_projection:
|
|
24
|
+
hidden_states = self.proj_in(hidden_states)
|
|
25
|
+
inner_dim = self.inner_dim
|
|
26
|
+
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
|
27
|
+
else:
|
|
28
|
+
inner_dim = hidden_states.shape[1]
|
|
29
|
+
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
|
30
|
+
hidden_states = self.proj_in(hidden_states)
|
|
31
|
+
|
|
32
|
+
return hidden_states, inner_dim
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _get_output_for_continuous_inputs(
|
|
36
|
+
self,
|
|
37
|
+
hidden_states,
|
|
38
|
+
residual,
|
|
39
|
+
batch_size,
|
|
40
|
+
height,
|
|
41
|
+
width,
|
|
42
|
+
inner_dim,
|
|
43
|
+
):
|
|
44
|
+
if not self.use_linear_projection:
|
|
45
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
|
46
|
+
batch_size,
|
|
47
|
+
inner_dim,
|
|
48
|
+
height,
|
|
49
|
+
width,
|
|
50
|
+
)
|
|
51
|
+
hidden_states = self.proj_out(hidden_states)
|
|
52
|
+
else:
|
|
53
|
+
hidden_states = self.proj_out(hidden_states)
|
|
54
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
|
55
|
+
batch_size,
|
|
56
|
+
inner_dim,
|
|
57
|
+
height,
|
|
58
|
+
width,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
return hidden_states + residual
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class CoreMLUNetWrapper(torch.nn.Module):
|
|
5
|
+
"""Adapt diffusers UNet inputs to CoreMLSuite's stable Core ML contract."""
|
|
6
|
+
|
|
7
|
+
def __init__(self, unet, model_version):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.unet = unet
|
|
10
|
+
self.model_version = model_version
|
|
11
|
+
|
|
12
|
+
def forward(self, sample, timestep, encoder_hidden_states, *extra_inputs):
|
|
13
|
+
input_index = 0
|
|
14
|
+
timestep_cond = None
|
|
15
|
+
if self._is_lcm:
|
|
16
|
+
timestep_cond = extra_inputs[input_index]
|
|
17
|
+
input_index += 1
|
|
18
|
+
|
|
19
|
+
added_cond_kwargs = None
|
|
20
|
+
if self._is_sdxl:
|
|
21
|
+
time_ids = extra_inputs[input_index]
|
|
22
|
+
text_embeds = extra_inputs[input_index + 1]
|
|
23
|
+
input_index += 2
|
|
24
|
+
added_cond_kwargs = {
|
|
25
|
+
"time_ids": time_ids,
|
|
26
|
+
"text_embeds": text_embeds,
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
additional_residuals = extra_inputs[input_index:]
|
|
30
|
+
down_residuals = None
|
|
31
|
+
mid_residual = None
|
|
32
|
+
if additional_residuals:
|
|
33
|
+
down_residuals = tuple(additional_residuals[:-1])
|
|
34
|
+
mid_residual = additional_residuals[-1]
|
|
35
|
+
|
|
36
|
+
outputs = self.unet(
|
|
37
|
+
sample,
|
|
38
|
+
timestep,
|
|
39
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
40
|
+
timestep_cond=timestep_cond,
|
|
41
|
+
added_cond_kwargs=added_cond_kwargs,
|
|
42
|
+
down_block_additional_residuals=down_residuals,
|
|
43
|
+
mid_block_additional_residual=mid_residual,
|
|
44
|
+
return_dict=False,
|
|
45
|
+
)
|
|
46
|
+
return outputs[0]
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def _is_lcm(self):
|
|
50
|
+
return self.model_version.name == "LCM"
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def _is_sdxl(self):
|
|
54
|
+
return self.model_version.name in {"SDXL", "SDXL_REFINER"}
|
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
"""Core ML UNet conversion mechanics — framework-free.
|
|
2
|
+
|
|
3
|
+
Moved from ``coreml_suite/converter.py`` in extraction phase E2. This module
|
|
4
|
+
produces a ``.mlpackage`` on disk and stops there: it must NOT import ``comfy``,
|
|
5
|
+
``folder_paths``, or ``comfy_extras``. Output paths are inputs, not resolved here.
|
|
6
|
+
|
|
7
|
+
``get_sample_input`` carries an optional ``scheduler`` so the LCM path (which
|
|
8
|
+
derives the trace timestep from an LCM scheduler) shares this single
|
|
9
|
+
implementation instead of keeping a near-duplicate copy.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import gc
|
|
13
|
+
import os
|
|
14
|
+
import time
|
|
15
|
+
|
|
16
|
+
import coremltools as ct
|
|
17
|
+
import numpy as np
|
|
18
|
+
import torch
|
|
19
|
+
from diffusers import UNet2DConditionModel
|
|
20
|
+
|
|
21
|
+
from coreml_diffusion.attention import ATTENTION_IMPLEMENTATIONS
|
|
22
|
+
from coreml_diffusion.conversion.attention import apply_attention_implementation
|
|
23
|
+
from coreml_diffusion.conversion.shapes import conv2d_output_shape
|
|
24
|
+
from coreml_diffusion.conversion.trace import prepare_unet_for_coreml_trace
|
|
25
|
+
from coreml_diffusion.conversion.unet import CoreMLUNetWrapper
|
|
26
|
+
from coreml_diffusion.logger import logger
|
|
27
|
+
from coreml_diffusion.model_version import ModelVersion
|
|
28
|
+
|
|
29
|
+
DEFAULT_TRACE_TIMESTEP = 999.0
|
|
30
|
+
TEXT_TOKEN_SEQUENCE_LENGTH = 77
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_unet(model_version: ModelVersion, ref_unet, attention_implementation):
|
|
34
|
+
ref_unet = prepare_unet_for_coreml_trace(ref_unet)
|
|
35
|
+
unet = apply_attention_implementation(
|
|
36
|
+
ref_unet.eval(),
|
|
37
|
+
attention_implementation,
|
|
38
|
+
)
|
|
39
|
+
return CoreMLUNetWrapper(unet, model_version)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_encoder_hidden_states_shape(ref_unet, batch_size):
|
|
43
|
+
encoder_hidden_states_shape = (
|
|
44
|
+
batch_size,
|
|
45
|
+
TEXT_TOKEN_SEQUENCE_LENGTH,
|
|
46
|
+
ref_unet.config.cross_attention_dim,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
return encoder_hidden_states_shape
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_coreml_inputs(sample_inputs):
|
|
53
|
+
coreml_sample_unet_inputs = {
|
|
54
|
+
k: v.numpy().astype(np.float16) for k, v in sample_inputs.items()
|
|
55
|
+
}
|
|
56
|
+
return [
|
|
57
|
+
ct.TensorType(
|
|
58
|
+
name=k,
|
|
59
|
+
shape=v.shape,
|
|
60
|
+
dtype=v.numpy().dtype if isinstance(v, torch.Tensor) else v.dtype,
|
|
61
|
+
)
|
|
62
|
+
for k, v in coreml_sample_unet_inputs.items()
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def load_coreml_model(out_path):
|
|
67
|
+
logger.info(f"Loading model from {out_path}")
|
|
68
|
+
|
|
69
|
+
start = time.time()
|
|
70
|
+
coreml_model = ct.models.MLModel(out_path)
|
|
71
|
+
logger.info(f"Loading {out_path} took {time.time() - start:.1f} seconds")
|
|
72
|
+
|
|
73
|
+
return coreml_model
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def convert_to_coreml(
|
|
77
|
+
submodule_name, torchscript_module, sample_inputs, output_names, out_path
|
|
78
|
+
):
|
|
79
|
+
if os.path.exists(out_path):
|
|
80
|
+
logger.info(f"Skipping export because {out_path} already exists")
|
|
81
|
+
coreml_model = load_coreml_model(out_path)
|
|
82
|
+
else:
|
|
83
|
+
logger.info(f"Converting {submodule_name} to CoreML..")
|
|
84
|
+
coreml_model = ct.convert(
|
|
85
|
+
torchscript_module,
|
|
86
|
+
convert_to="mlprogram",
|
|
87
|
+
minimum_deployment_target=ct.target.macOS13,
|
|
88
|
+
inputs=sample_inputs,
|
|
89
|
+
outputs=[
|
|
90
|
+
ct.TensorType(name=name, dtype=np.float32) for name in output_names
|
|
91
|
+
],
|
|
92
|
+
skip_model_load=True,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
del torchscript_module
|
|
96
|
+
gc.collect()
|
|
97
|
+
|
|
98
|
+
return coreml_model
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_sample_input(
|
|
102
|
+
batch_size, encoder_hidden_states_shape, sample_shape, scheduler=None
|
|
103
|
+
):
|
|
104
|
+
"""Build the example inputs used to JIT-trace the UNet.
|
|
105
|
+
|
|
106
|
+
When ``scheduler`` is provided (the LCM path) the trace timestep is taken
|
|
107
|
+
from ``scheduler.timesteps[0]``; otherwise the fixed ``DEFAULT_TRACE_TIMESTEP``
|
|
108
|
+
is used. Only the shapes/dtypes/order of these tensors matter to the traced
|
|
109
|
+
graph — the random values are placeholders.
|
|
110
|
+
"""
|
|
111
|
+
timestep_value = (
|
|
112
|
+
scheduler.timesteps[0].item()
|
|
113
|
+
if scheduler is not None
|
|
114
|
+
else DEFAULT_TRACE_TIMESTEP
|
|
115
|
+
)
|
|
116
|
+
sample_unet_inputs = dict(
|
|
117
|
+
[
|
|
118
|
+
("sample", torch.rand(*sample_shape)),
|
|
119
|
+
(
|
|
120
|
+
"timestep",
|
|
121
|
+
torch.tensor([timestep_value] * batch_size).to(torch.float32),
|
|
122
|
+
),
|
|
123
|
+
("encoder_hidden_states", torch.rand(*encoder_hidden_states_shape)),
|
|
124
|
+
]
|
|
125
|
+
)
|
|
126
|
+
return sample_unet_inputs
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def lcm_inputs(sample_unet_inputs):
|
|
130
|
+
batch_size = sample_unet_inputs["sample"].shape[0]
|
|
131
|
+
return {"timestep_cond": torch.randn(batch_size, 256).to(torch.float32)}
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def sdxl_inputs(sample_unet_inputs, ref_unet, model_version):
|
|
135
|
+
sample_shape = sample_unet_inputs["sample"].shape
|
|
136
|
+
batch_size = sample_shape[0]
|
|
137
|
+
h = sample_shape[2] * 8
|
|
138
|
+
w = sample_shape[3] * 8
|
|
139
|
+
original_size = (h, w)
|
|
140
|
+
crops_coords_top_left = (0, 0)
|
|
141
|
+
|
|
142
|
+
is_refiner = model_version == ModelVersion.SDXL_REFINER
|
|
143
|
+
|
|
144
|
+
if is_refiner:
|
|
145
|
+
aesthetic_score = (6.0,)
|
|
146
|
+
time_ids_list = list(original_size + crops_coords_top_left + aesthetic_score)
|
|
147
|
+
else:
|
|
148
|
+
target_size = (h, w)
|
|
149
|
+
time_ids_list = list(original_size + crops_coords_top_left + target_size)
|
|
150
|
+
|
|
151
|
+
time_ids = torch.tensor(time_ids_list).repeat(batch_size, 1).to(torch.int64)
|
|
152
|
+
text_embeds_shape = (
|
|
153
|
+
batch_size,
|
|
154
|
+
get_sdxl_text_embeds_dim(ref_unet, len(time_ids_list)),
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
return {
|
|
158
|
+
"time_ids": time_ids,
|
|
159
|
+
"text_embeds": torch.randn(*text_embeds_shape).to(torch.float32),
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def get_sdxl_text_embeds_dim(ref_unet, time_ids_dim):
|
|
164
|
+
projection_dim = ref_unet.config.projection_class_embeddings_input_dim
|
|
165
|
+
time_embed_dim = ref_unet.config.addition_time_embed_dim
|
|
166
|
+
return projection_dim - (time_ids_dim * time_embed_dim)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def get_inputs_spec(inputs):
|
|
170
|
+
inputs_spec = {k: (v.shape, v.dtype) for k, v in inputs.items()}
|
|
171
|
+
return inputs_spec
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def add_cnet_support(sample_shape, reference_unet):
|
|
175
|
+
additional_residuals_shapes = []
|
|
176
|
+
|
|
177
|
+
batch_size = sample_shape[0]
|
|
178
|
+
h, w = sample_shape[2:]
|
|
179
|
+
|
|
180
|
+
# conv_in
|
|
181
|
+
out_h, out_w = conv2d_output_shape(
|
|
182
|
+
h,
|
|
183
|
+
w,
|
|
184
|
+
reference_unet.conv_in,
|
|
185
|
+
)
|
|
186
|
+
additional_residuals_shapes.append(
|
|
187
|
+
(batch_size, reference_unet.conv_in.out_channels, out_h, out_w)
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# down_blocks
|
|
191
|
+
for down_block in reference_unet.down_blocks:
|
|
192
|
+
additional_residuals_shapes += [
|
|
193
|
+
(batch_size, resnet.out_channels, out_h, out_w)
|
|
194
|
+
for resnet in down_block.resnets
|
|
195
|
+
]
|
|
196
|
+
if hasattr(down_block, "downsamplers") and down_block.downsamplers is not None:
|
|
197
|
+
for downsampler in down_block.downsamplers:
|
|
198
|
+
out_h, out_w = conv2d_output_shape(out_h, out_w, downsampler.conv)
|
|
199
|
+
additional_residuals_shapes.append(
|
|
200
|
+
(
|
|
201
|
+
batch_size,
|
|
202
|
+
down_block.downsamplers[-1].conv.out_channels,
|
|
203
|
+
out_h,
|
|
204
|
+
out_w,
|
|
205
|
+
)
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# mid_block
|
|
209
|
+
additional_residuals_shapes.append(
|
|
210
|
+
(batch_size, reference_unet.mid_block.resnets[-1].out_channels, out_h, out_w)
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
additional_inputs = {}
|
|
214
|
+
for i, shape in enumerate(additional_residuals_shapes):
|
|
215
|
+
sample_residual_input = torch.rand(*shape)
|
|
216
|
+
additional_inputs[f"additional_residual_{i}"] = sample_residual_input
|
|
217
|
+
|
|
218
|
+
return additional_inputs
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def convert_unet(
|
|
222
|
+
ref_unet,
|
|
223
|
+
model_version: ModelVersion,
|
|
224
|
+
unet_out_path: str,
|
|
225
|
+
batch_size: int = 1,
|
|
226
|
+
sample_size: tuple[int, int] = (64, 64),
|
|
227
|
+
controlnet_support: bool = False,
|
|
228
|
+
attention_implementation: str = ATTENTION_IMPLEMENTATIONS[0],
|
|
229
|
+
quantize_nbits: str = "none",
|
|
230
|
+
):
|
|
231
|
+
coreml_unet = get_unet(model_version, ref_unet, attention_implementation)
|
|
232
|
+
|
|
233
|
+
sample_shape = (
|
|
234
|
+
batch_size, # B
|
|
235
|
+
ref_unet.config.in_channels, # C
|
|
236
|
+
sample_size[0], # H
|
|
237
|
+
sample_size[1], # W
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
encoder_hidden_states_shape = get_encoder_hidden_states_shape(ref_unet, batch_size)
|
|
241
|
+
|
|
242
|
+
sample_inputs = get_sample_input(
|
|
243
|
+
batch_size, encoder_hidden_states_shape, sample_shape
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
if model_version == ModelVersion.LCM:
|
|
247
|
+
sample_inputs |= lcm_inputs(sample_inputs)
|
|
248
|
+
|
|
249
|
+
if model_version in {ModelVersion.SDXL, ModelVersion.SDXL_REFINER}:
|
|
250
|
+
sample_inputs |= sdxl_inputs(sample_inputs, ref_unet, model_version)
|
|
251
|
+
|
|
252
|
+
if controlnet_support:
|
|
253
|
+
sample_inputs |= add_cnet_support(sample_shape, ref_unet)
|
|
254
|
+
|
|
255
|
+
sample_inputs_spec = get_inputs_spec(sample_inputs)
|
|
256
|
+
|
|
257
|
+
logger.info(f"Sample UNet inputs spec: {sample_inputs_spec}")
|
|
258
|
+
logger.info("JIT tracing..")
|
|
259
|
+
traced_unet = torch.jit.trace(
|
|
260
|
+
coreml_unet, example_inputs=list(sample_inputs.values())
|
|
261
|
+
)
|
|
262
|
+
logger.info("Done.")
|
|
263
|
+
|
|
264
|
+
coreml_sample_inputs = get_coreml_inputs(sample_inputs)
|
|
265
|
+
|
|
266
|
+
coreml_unet = convert_to_coreml(
|
|
267
|
+
"unet", traced_unet, coreml_sample_inputs, ["noise_pred"], unet_out_path
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
del traced_unet
|
|
271
|
+
gc.collect()
|
|
272
|
+
|
|
273
|
+
if quantize_nbits != "none":
|
|
274
|
+
# Opt-in k-means weight palettization. The default path
|
|
275
|
+
# (quantize_nbits="none") leaves the traced UNet untouched.
|
|
276
|
+
from coremltools.optimize.coreml import (
|
|
277
|
+
OpPalettizerConfig,
|
|
278
|
+
OptimizationConfig,
|
|
279
|
+
palettize_weights,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
nbits = int(quantize_nbits)
|
|
283
|
+
logger.info(f"Palettizing UNet weights to {nbits}-bit (kmeans)..")
|
|
284
|
+
t0 = time.time()
|
|
285
|
+
cfg = OptimizationConfig(
|
|
286
|
+
global_config=OpPalettizerConfig(mode="kmeans", nbits=nbits)
|
|
287
|
+
)
|
|
288
|
+
coreml_unet = palettize_weights(coreml_unet, config=cfg)
|
|
289
|
+
logger.info(f"Palettization took {time.time() - t0:.1f}s")
|
|
290
|
+
|
|
291
|
+
coreml_unet.save(unet_out_path)
|
|
292
|
+
logger.info(f"Saved unet into {unet_out_path}")
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def convert(
|
|
296
|
+
ckpt_path: str,
|
|
297
|
+
model_version: ModelVersion,
|
|
298
|
+
out_path: str,
|
|
299
|
+
*,
|
|
300
|
+
batch_size: int = 1,
|
|
301
|
+
sample_size: tuple[int, int] = (64, 64),
|
|
302
|
+
controlnet_support: bool = False,
|
|
303
|
+
lora_weights: list[tuple[str | os.PathLike, float]] = None,
|
|
304
|
+
attn_impl: str = ATTENTION_IMPLEMENTATIONS[0],
|
|
305
|
+
config_path: str = None,
|
|
306
|
+
quantize_nbits: str = "none",
|
|
307
|
+
):
|
|
308
|
+
"""Convert a single-file checkpoint's UNet to a Core ML ``.mlpackage``.
|
|
309
|
+
|
|
310
|
+
Keyword-only past the three required positionals so the package can add
|
|
311
|
+
capabilities (new keyword args) without breaking an older caller — the
|
|
312
|
+
versioned interface contract. Writes ``out_path``; returns None.
|
|
313
|
+
"""
|
|
314
|
+
if os.path.exists(out_path):
|
|
315
|
+
logger.info(f"Found existing model at {out_path}! Skipping..")
|
|
316
|
+
return
|
|
317
|
+
|
|
318
|
+
if attn_impl not in ATTENTION_IMPLEMENTATIONS:
|
|
319
|
+
raise ValueError(
|
|
320
|
+
f"Unsupported attention implementation {attn_impl!r}. "
|
|
321
|
+
f"Expected one of {ATTENTION_IMPLEMENTATIONS}."
|
|
322
|
+
)
|
|
323
|
+
ref_unet = load_unet(ckpt_path, config_path)
|
|
324
|
+
|
|
325
|
+
for i, lora_weight in enumerate(lora_weights or []):
|
|
326
|
+
lora_path, strength = lora_weight
|
|
327
|
+
adapter_name = f"lora_{i}"
|
|
328
|
+
ref_unet.load_lora_adapter(lora_path, adapter_name=adapter_name)
|
|
329
|
+
ref_unet.set_adapters([adapter_name], weights=[strength])
|
|
330
|
+
ref_unet.fuse_lora()
|
|
331
|
+
|
|
332
|
+
convert_unet(
|
|
333
|
+
ref_unet,
|
|
334
|
+
model_version,
|
|
335
|
+
out_path,
|
|
336
|
+
batch_size,
|
|
337
|
+
sample_size,
|
|
338
|
+
controlnet_support,
|
|
339
|
+
attention_implementation=attn_impl,
|
|
340
|
+
quantize_nbits=quantize_nbits,
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def load_unet(ckpt_path, config_path):
|
|
345
|
+
return UNet2DConditionModel.from_single_file(
|
|
346
|
+
ckpt_path,
|
|
347
|
+
original_config=config_path,
|
|
348
|
+
)
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""Pure out_name composition for the Core ML UNet artifact.
|
|
2
|
+
|
|
3
|
+
Extracted from CoreMLConverter.convert so the filename contract
|
|
4
|
+
can be tested + reused without instantiating the node. The string is the
|
|
5
|
+
cache key: every workflow that references a converted .mlpackage depends
|
|
6
|
+
on it staying byte-for-byte identical.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Iterable, Tuple
|
|
10
|
+
|
|
11
|
+
ATTN_SUFFIX = {
|
|
12
|
+
"SPLIT_EINSUM": "se",
|
|
13
|
+
"SPLIT_EINSUM_V2": "se2",
|
|
14
|
+
"ORIGINAL": "orig",
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
# Palettization bits. "none" = no quantization (default; keeps the
|
|
18
|
+
# unquantized filename intact so existing workflows still resolve their
|
|
19
|
+
# cached .mlpackage). Numeric values append a `_q<bits>` suffix.
|
|
20
|
+
QUANT_NBITS_VALUES = ("none", "8", "6", "4")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def compose_out_name(
|
|
24
|
+
*,
|
|
25
|
+
ckpt_name: str,
|
|
26
|
+
batch_size: int,
|
|
27
|
+
width: int,
|
|
28
|
+
height: int,
|
|
29
|
+
controlnet_support: bool,
|
|
30
|
+
attention_implementation: str,
|
|
31
|
+
lora_names: Iterable[str] = (),
|
|
32
|
+
quantize_nbits: str = "none",
|
|
33
|
+
) -> str:
|
|
34
|
+
"""Build the .mlpackage stem from convert() parameters.
|
|
35
|
+
|
|
36
|
+
Locked behaviour (characterization tests):
|
|
37
|
+
- first '.' in ckpt_name wins (`a.b.c.safetensors` -> `a`)
|
|
38
|
+
- spaces collapse to underscores
|
|
39
|
+
- LoRA names are taken stem-only, sorted, joined with '_' and
|
|
40
|
+
prefixed with '_' when present (caller is expected to pass a
|
|
41
|
+
sorted list; we sort defensively)
|
|
42
|
+
- controlnet adds `_cn`
|
|
43
|
+
- attn suffix is `_se` | `_se2` | `_orig`
|
|
44
|
+
|
|
45
|
+
Quantization:
|
|
46
|
+
- quantize_nbits "none" (default) appends nothing — existing
|
|
47
|
+
unquantized .mlpackages keep the old filename
|
|
48
|
+
- "4" / "6" / "8" appends `_q<bits>` after the attn suffix
|
|
49
|
+
"""
|
|
50
|
+
if quantize_nbits not in QUANT_NBITS_VALUES:
|
|
51
|
+
raise ValueError(
|
|
52
|
+
f"quantize_nbits={quantize_nbits!r} not in {QUANT_NBITS_VALUES}"
|
|
53
|
+
)
|
|
54
|
+
stem = ckpt_name.split(".")[0]
|
|
55
|
+
sorted_names = sorted(lora_names)
|
|
56
|
+
lora_str = (
|
|
57
|
+
"_" + "_".join(name.split(".")[0] for name in sorted_names)
|
|
58
|
+
if sorted_names
|
|
59
|
+
else ""
|
|
60
|
+
)
|
|
61
|
+
cn_suffix = "_cn" if controlnet_support else ""
|
|
62
|
+
attn_suffix = "_" + ATTN_SUFFIX[attention_implementation]
|
|
63
|
+
quant_suffix = f"_q{quantize_nbits}" if quantize_nbits != "none" else ""
|
|
64
|
+
out_name = (
|
|
65
|
+
f"{stem}{lora_str}_{batch_size}x{width}x{height}"
|
|
66
|
+
f"{cn_suffix}{attn_suffix}{quant_suffix}"
|
|
67
|
+
)
|
|
68
|
+
return out_name.replace(" ", "_")
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def lora_names_from_params(lora_params: Iterable[Tuple[str, float]]) -> list[str]:
|
|
72
|
+
"""Mirror the sort applied inside CoreMLConverter.convert."""
|
|
73
|
+
return [name for name, _ in sorted(lora_params, key=lambda pair: pair[0])]
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: coreml-diffusion
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Convert diffusion-model checkpoints (SD1.5/SDXL) to Core ML for Apple Neural Engine — framework-free, ComfyUI-independent.
|
|
5
|
+
Project-URL: Homepage, https://github.com/aszc-dev/coreml-diffusion
|
|
6
|
+
Project-URL: Repository, https://github.com/aszc-dev/coreml-diffusion
|
|
7
|
+
Project-URL: Issues, https://github.com/aszc-dev/coreml-diffusion/issues
|
|
8
|
+
Author-email: Adrian Szczepański <hi@aszc.dev>
|
|
9
|
+
License-Expression: MIT
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Keywords: ane,apple-neural-engine,comfyui,core-ml,coreml,diffusers,diffusion,sdxl,stable-diffusion
|
|
12
|
+
Classifier: Development Status :: 4 - Beta
|
|
13
|
+
Classifier: Intended Audience :: Developers
|
|
14
|
+
Classifier: Operating System :: MacOS
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Topic :: Multimedia :: Graphics :: Graphics Conversion
|
|
18
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
19
|
+
Classifier: Typing :: Typed
|
|
20
|
+
Requires-Python: <3.13,>=3.12
|
|
21
|
+
Requires-Dist: coremltools<10,>=9
|
|
22
|
+
Requires-Dist: diffusers>=0.30
|
|
23
|
+
Requires-Dist: numpy<3,>=2
|
|
24
|
+
Requires-Dist: omegaconf>=2.3
|
|
25
|
+
Requires-Dist: peft>=0.13
|
|
26
|
+
Requires-Dist: torch<2.8,>=2.7
|
|
27
|
+
Requires-Dist: transformers>=4.44
|
|
28
|
+
Description-Content-Type: text/markdown
|
|
29
|
+
|
|
30
|
+
# coreml-diffusion
|
|
31
|
+
|
|
32
|
+
Convert diffusion-model checkpoints into Core ML `.mlpackage` artifacts for the
|
|
33
|
+
Apple Neural Engine (ANE) — framework-free and independent of ComfyUI.
|
|
34
|
+
|
|
35
|
+
`coreml-diffusion` takes a single-file Stable Diffusion checkpoint and produces a
|
|
36
|
+
Core ML UNet you can run on-device (macOS/iOS) or load back into ComfyUI via
|
|
37
|
+
[ComfyUI-CoreMLSuite](https://github.com/aszc-dev/ComfyUI-CoreMLSuite), which
|
|
38
|
+
depends on this package for its conversion path.
|
|
39
|
+
|
|
40
|
+
## Positioning
|
|
41
|
+
|
|
42
|
+
The niche is **diffusion models on the Apple Neural Engine via Core ML** — inside
|
|
43
|
+
ComfyUI and on-device. ANE is the differentiator: low-power, GPU-free, embeddable
|
|
44
|
+
in a Swift/iOS app. This is about feasibility and power efficiency for SD1.5/SDXL
|
|
45
|
+
on ANE, not a raw-throughput claim against desktop GPUs.
|
|
46
|
+
|
|
47
|
+
Supported today: SD1.5 and SDXL (verified). SDXL refiner and LCM convert but are
|
|
48
|
+
not yet golden-verified (experimental). The scope is diffusion architectures
|
|
49
|
+
generally, not Stable Diffusion specifically.
|
|
50
|
+
|
|
51
|
+
## Install
|
|
52
|
+
|
|
53
|
+
```sh
|
|
54
|
+
uv pip install coreml-diffusion # from PyPI (planned)
|
|
55
|
+
uv pip install -e . # from a checkout
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
Requires Python 3.12 and (for conversion) `coremltools` 9 — conversion runs on
|
|
59
|
+
macOS; the package imports and its CLI parse on any platform.
|
|
60
|
+
|
|
61
|
+
## CLI
|
|
62
|
+
|
|
63
|
+
```sh
|
|
64
|
+
coreml-diffusion convert \
|
|
65
|
+
--ckpt path/to/model.safetensors \
|
|
66
|
+
--model-version SD15 \
|
|
67
|
+
--out unet.mlpackage \
|
|
68
|
+
--height 512 --width 512 \
|
|
69
|
+
--attn-impl SPLIT_EINSUM \
|
|
70
|
+
--quantize none
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
Options: `--batch-size`, `--controlnet`, `--lora PATH[:STRENGTH]` (repeatable),
|
|
74
|
+
`--config` (original-config YAML). `--quantize {none,8,6,4}` applies k-means
|
|
75
|
+
weight palettization. Run `coreml-diffusion convert --help` for the full list.
|
|
76
|
+
|
|
77
|
+
The output `.mlpackage` is the deliverable: load it natively in Swift/Core ML, or
|
|
78
|
+
through ComfyUI-CoreMLSuite.
|
|
79
|
+
|
|
80
|
+
## Library
|
|
81
|
+
|
|
82
|
+
```python
|
|
83
|
+
import coreml_diffusion
|
|
84
|
+
from coreml_diffusion import ModelVersion
|
|
85
|
+
|
|
86
|
+
coreml_diffusion.convert(
|
|
87
|
+
"model.safetensors", ModelVersion.SD15, "unet.mlpackage",
|
|
88
|
+
height=512, width=512, attn_impl="SPLIT_EINSUM",
|
|
89
|
+
)
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
Discovery API (`list_model_versions`, `list_attention_impls`, `list_quant_modes`,
|
|
93
|
+
`CONTRACT_VERSION`) reports what this build can convert; the identifiers are an
|
|
94
|
+
additive-only contract (removing/renaming one is a major bump).
|
|
95
|
+
|
|
96
|
+
## License
|
|
97
|
+
|
|
98
|
+
MIT
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
coreml_diffusion/__init__.py,sha256=IVtKAGpIHR25M-9r2kF4RIGZf599PGzuSdBykA_7jD8,4118
|
|
2
|
+
coreml_diffusion/attention.py,sha256=adbE3AmV7uR-FRyILYGy-_3tBtFnSny4ZGutIVfOnPc,91
|
|
3
|
+
coreml_diffusion/cli.py,sha256=TIoClej1JMdlNh9BXtmSls2NcCTAXTZogiFnrWfEt9s,3896
|
|
4
|
+
coreml_diffusion/convert.py,sha256=8Q4IcXbaTQ9AGD89J_eeCI1h1f55CsypsL5OHYg6c1Y,11153
|
|
5
|
+
coreml_diffusion/logger.py,sha256=PE9S6WFmT-3UxG4830IqAeCDOdw0laUNwwu_Q65pYPw,105
|
|
6
|
+
coreml_diffusion/model_version.py,sha256=wjjfLqMRU8LITzylY7k1o0r9L5gEtgg117DC50fBcDY,136
|
|
7
|
+
coreml_diffusion/naming.py,sha256=SDLAI2PnBCyPnZ0ufjUR1oagRJbwur2M1PYNFYulI2w,2589
|
|
8
|
+
coreml_diffusion/conversion/__init__.py,sha256=veWEFzP7tsjSIukFeDIL1H1H6BRMF38rzyd7XE5E2TQ,443
|
|
9
|
+
coreml_diffusion/conversion/attention.py,sha256=VxoO2-unK8iXYMdj8oFNereZUHfk9AGH2xmj5RMLhBA,6710
|
|
10
|
+
coreml_diffusion/conversion/shapes.py,sha256=kJP0lIh5ty2JwLc70va67Neovu-wtP6LXMQDycTPhDM,726
|
|
11
|
+
coreml_diffusion/conversion/trace.py,sha256=iiIh0ZzaULyz5PP8EUN14rlogZ5a9jmAMBJ9LKxWUug,1707
|
|
12
|
+
coreml_diffusion/conversion/unet.py,sha256=nljZgNMY667vbAxDZPC_dS2fF861fzgVSsDygXwEPpU,1701
|
|
13
|
+
coreml_diffusion-0.1.0.dist-info/METADATA,sha256=zjlX9MaUEoShe7Ks-Z6sagvM3EPOpMzOjNZn7a3Exio,3600
|
|
14
|
+
coreml_diffusion-0.1.0.dist-info/WHEEL,sha256=QccIxa26bgl1E6uMy58deGWi-0aeIkkangHcxk2kWfw,87
|
|
15
|
+
coreml_diffusion-0.1.0.dist-info/entry_points.txt,sha256=oYMr6Rre4ErwzBzfgxeFQ1isiwKOeGHswipc0IDB38o,63
|
|
16
|
+
coreml_diffusion-0.1.0.dist-info/licenses/LICENSE,sha256=0L46frKmxey5OMCRWgckyvNBwVT1t4YXMNLs0ZUh5bI,1081
|
|
17
|
+
coreml_diffusion-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2023-2026 Adrian Szczepański
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|