mlx-taef 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_taef/__init__.py +9 -0
- mlx_taef/_version.py +24 -0
- mlx_taef/api.py +204 -0
- mlx_taef/cli.py +92 -0
- mlx_taef/convert.py +201 -0
- mlx_taef/download.py +45 -0
- mlx_taef/integrations/__init__.py +1 -0
- mlx_taef/integrations/mflux.py +168 -0
- mlx_taef/model.py +197 -0
- mlx_taef/py.typed +0 -0
- mlx_taef/variants.py +115 -0
- mlx_taef-0.1.0.dist-info/METADATA +125 -0
- mlx_taef-0.1.0.dist-info/RECORD +15 -0
- mlx_taef-0.1.0.dist-info/WHEEL +4 -0
- mlx_taef-0.1.0.dist-info/entry_points.txt +2 -0
mlx_taef/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""mlx-taef: Tiny AutoEncoder family ported to Apple MLX."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
from mlx_taef.api import TAEF1, TAEF2, TAESD, TAESDXL, Taef
|
|
6
|
+
|
|
7
|
+
__all__ = ["TAEF1", "TAEF2", "TAESD", "TAESDXL", "Taef"]
|
|
8
|
+
|
|
9
|
+
logging.getLogger("mlx_taef").addHandler(logging.NullHandler())
|
mlx_taef/_version.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# file generated by vcs-versioning
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"__version__",
|
|
7
|
+
"__version_tuple__",
|
|
8
|
+
"version",
|
|
9
|
+
"version_tuple",
|
|
10
|
+
"__commit_id__",
|
|
11
|
+
"commit_id",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
version: str
|
|
15
|
+
__version__: str
|
|
16
|
+
__version_tuple__: tuple[int | str, ...]
|
|
17
|
+
version_tuple: tuple[int | str, ...]
|
|
18
|
+
commit_id: str | None
|
|
19
|
+
__commit_id__: str | None
|
|
20
|
+
|
|
21
|
+
__version__ = version = '0.1.0'
|
|
22
|
+
__version_tuple__ = version_tuple = (0, 1, 0)
|
|
23
|
+
|
|
24
|
+
__commit_id__ = commit_id = None
|
mlx_taef/api.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
"""User-facing Taef family API: load weights and decode/encode latents.
|
|
2
|
+
|
|
3
|
+
Tensor layout: all public methods use NHWC mx.array.
|
|
4
|
+
Value space: decode() outputs [0, 1] float; encode() expects [0, 1] float.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import cast
|
|
10
|
+
|
|
11
|
+
import mlx.core as mx
|
|
12
|
+
import mlx.nn as nn
|
|
13
|
+
|
|
14
|
+
from mlx_taef.model import make_decoder, make_encoder
|
|
15
|
+
from mlx_taef.variants import (
|
|
16
|
+
TAEF1_CONFIG,
|
|
17
|
+
TAEF2_CONFIG,
|
|
18
|
+
TAESD_CONFIG,
|
|
19
|
+
TAESDXL_CONFIG,
|
|
20
|
+
TaesdVariantConfig,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Taef(nn.Module): # type: ignore[misc,name-defined]
|
|
27
|
+
"""Base class for TAESD-family models.
|
|
28
|
+
|
|
29
|
+
Subclasses set the `_config` class variable to a `TaesdVariantConfig`.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
None — instantiated via `from_pretrained_local(path)`.
|
|
33
|
+
|
|
34
|
+
Tensor layout:
|
|
35
|
+
All public methods use NHWC `mx.array`:
|
|
36
|
+
- latent shape: (B, H, W, latent_channels)
|
|
37
|
+
- image shape: (B, H*8, W*8, 3)
|
|
38
|
+
|
|
39
|
+
Value space:
|
|
40
|
+
- `decode()` returns float in [0, 1] (clipped on output).
|
|
41
|
+
- `decode_image()` returns uint8 in [0, 255].
|
|
42
|
+
- Latents are RAW (post-Clamp scale), values roughly in [-3, 3].
|
|
43
|
+
- `scale_latents` / `unscale_latents` convert between raw and [0, 1].
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
_config: TaesdVariantConfig
|
|
47
|
+
|
|
48
|
+
def __init__(self) -> None:
|
|
49
|
+
"""Initialise decoder and encoder from the subclass `_config`."""
|
|
50
|
+
super().__init__()
|
|
51
|
+
self.decoder = make_decoder(self._config)
|
|
52
|
+
self.encoder = make_encoder(self._config)
|
|
53
|
+
|
|
54
|
+
@classmethod
|
|
55
|
+
def from_pretrained_local(
|
|
56
|
+
cls,
|
|
57
|
+
decoder_path: Path | str,
|
|
58
|
+
encoder_path: Path | str | None = None,
|
|
59
|
+
*,
|
|
60
|
+
dtype: mx.Dtype = mx.float32,
|
|
61
|
+
) -> "Taef":
|
|
62
|
+
"""Instantiate from already-converted MLX safetensors on disk.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
decoder_path: path to converted MLX decoder safetensors.
|
|
66
|
+
encoder_path: optional path to converted MLX encoder safetensors (Phase 2).
|
|
67
|
+
dtype: target dtype for model weights. Default fp32 to keep parity tests sharp.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
A loaded `Taef` instance ready for `decode()`.
|
|
71
|
+
"""
|
|
72
|
+
instance = cls()
|
|
73
|
+
d_weights = cast("dict[str, mx.array]", mx.load(str(decoder_path)))
|
|
74
|
+
weights_list: list[tuple[str, mx.array]] = [
|
|
75
|
+
(f"decoder.{k}", v) for k, v in d_weights.items()
|
|
76
|
+
]
|
|
77
|
+
if encoder_path is not None and hasattr(instance, "encoder"):
|
|
78
|
+
e_weights = cast("dict[str, mx.array]", mx.load(str(encoder_path)))
|
|
79
|
+
weights_list.extend((f"encoder.{k}", v) for k, v in e_weights.items())
|
|
80
|
+
instance.load_weights(weights_list, strict=False)
|
|
81
|
+
if dtype is not mx.float32:
|
|
82
|
+
instance.set_dtype(dtype)
|
|
83
|
+
instance.eval()
|
|
84
|
+
return instance
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def from_pretrained( # pragma: no cover
|
|
88
|
+
cls,
|
|
89
|
+
repo_id: str | None = None,
|
|
90
|
+
*,
|
|
91
|
+
dtype: mx.Dtype = mx.float32,
|
|
92
|
+
include_encoder: bool = True,
|
|
93
|
+
) -> "Taef":
|
|
94
|
+
"""Auto-download weights from HF Hub, convert to MLX, and load.
|
|
95
|
+
|
|
96
|
+
On first call, downloads upstream weights from `cls._config.hf_repo` and
|
|
97
|
+
caches converted output at ~/.cache/mlx-taef/. Subsequent calls return
|
|
98
|
+
instantly from cache.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
repo_id: optional HF repo override. If set, must match `cls._config.hf_repo`.
|
|
102
|
+
dtype: target dtype for weights. Default fp32 to keep parity strict.
|
|
103
|
+
include_encoder: whether to also load the encoder side (default True).
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
A loaded `Taef` instance.
|
|
107
|
+
"""
|
|
108
|
+
from mlx_taef.download import get_or_convert
|
|
109
|
+
|
|
110
|
+
config = cls._config
|
|
111
|
+
if repo_id is not None and repo_id != config.hf_repo:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
f"repo_id mismatch: requested {repo_id!r} but variant {config.name!r} "
|
|
114
|
+
f"uses {config.hf_repo!r}"
|
|
115
|
+
)
|
|
116
|
+
decoder_path = get_or_convert(config, role="decoder")
|
|
117
|
+
encoder_path = get_or_convert(config, role="encoder") if include_encoder else None
|
|
118
|
+
return cls.from_pretrained_local(decoder_path, encoder_path=encoder_path, dtype=dtype)
|
|
119
|
+
|
|
120
|
+
def decode(self, latents: mx.array) -> mx.array:
|
|
121
|
+
"""Decode raw latents (NHWC) to image (NHWC, [0, 1] float).
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
latents: NHWC array with shape (B, H, W, latent_channels).
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
NHWC float array with shape (B, H*8, W*8, 3), values in [0, 1].
|
|
128
|
+
"""
|
|
129
|
+
out = self.decoder(latents)
|
|
130
|
+
return mx.clip(out, 0.0, 1.0)
|
|
131
|
+
|
|
132
|
+
def decode_image(self, latents: mx.array) -> mx.array:
|
|
133
|
+
"""Decode raw latents to a uint8 NHWC image suitable for PIL/PNG.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
latents: NHWC array with shape (B, H, W, latent_channels).
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
NHWC uint8 array with shape (B, H*8, W*8, 3), values in [0, 255].
|
|
140
|
+
"""
|
|
141
|
+
img_float = self.decode(latents)
|
|
142
|
+
return (img_float * 255.0).astype(mx.uint8)
|
|
143
|
+
|
|
144
|
+
def encode(self, image: mx.array) -> mx.array:
|
|
145
|
+
"""Encode an NHWC RGB image (B, H, W, 3) in [0, 1] to a latent (B, H/8, W/8, latent_channels).
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
image: NHWC float array with shape (B, H, W, 3), values in [0, 1].
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
NHWC latent array with shape (B, H/8, W/8, latent_channels).
|
|
152
|
+
"""
|
|
153
|
+
return cast("mx.array", self.encoder(image))
|
|
154
|
+
|
|
155
|
+
def scale_latents(self, raw: mx.array) -> mx.array:
|
|
156
|
+
"""Map raw latents to [0, 1] using config.latent_magnitude/shift.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
raw: raw latent array, values roughly in [-3, 3].
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
Scaled latent array clipped to [0, 1].
|
|
163
|
+
"""
|
|
164
|
+
c = self._config
|
|
165
|
+
return mx.clip(raw / (2.0 * c.latent_magnitude) + c.latent_shift, 0.0, 1.0)
|
|
166
|
+
|
|
167
|
+
def unscale_latents(self, scaled: mx.array) -> mx.array:
|
|
168
|
+
"""Inverse of scale_latents: [0, 1] -> raw.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
scaled: latent array in [0, 1].
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
Raw latent array.
|
|
175
|
+
"""
|
|
176
|
+
c = self._config
|
|
177
|
+
return (scaled - c.latent_shift) * (2.0 * c.latent_magnitude)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class TAESD(Taef):
|
|
181
|
+
"""TAESD for Stable Diffusion 1.x."""
|
|
182
|
+
|
|
183
|
+
_config = TAESD_CONFIG
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class TAESDXL(Taef):
|
|
187
|
+
"""TAESD for SDXL."""
|
|
188
|
+
|
|
189
|
+
_config = TAESDXL_CONFIG
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class TAEF1(Taef):
|
|
193
|
+
"""TAEF1 for FLUX.1."""
|
|
194
|
+
|
|
195
|
+
_config = TAEF1_CONFIG
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class TAEF2(Taef):
|
|
199
|
+
"""TAEF2 for FLUX.2 Klein."""
|
|
200
|
+
|
|
201
|
+
_config = TAEF2_CONFIG
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
__all__ = ["TAEF1", "TAEF2", "TAESD", "TAESDXL", "Taef"]
|
mlx_taef/cli.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""mlx-taef command-line interface: convert / info / bench."""
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import logging
|
|
5
|
+
import time
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import mlx.core as mx
|
|
9
|
+
|
|
10
|
+
from mlx_taef.variants import ALL_VARIANTS
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def main(argv: list[str] | None = None) -> int:
|
|
16
|
+
"""CLI entry point. Returns process exit code."""
|
|
17
|
+
parser = argparse.ArgumentParser(
|
|
18
|
+
prog="mlx-taef",
|
|
19
|
+
description="Tiny AutoEncoder family for diffusion on Apple MLX.",
|
|
20
|
+
)
|
|
21
|
+
sub = parser.add_subparsers(dest="cmd", required=True)
|
|
22
|
+
|
|
23
|
+
variant_names = [v.name for v in ALL_VARIANTS]
|
|
24
|
+
|
|
25
|
+
p_convert = sub.add_parser("convert", help="Download upstream weights and convert to MLX")
|
|
26
|
+
p_convert.add_argument("--variant", required=True, choices=variant_names)
|
|
27
|
+
p_convert.add_argument("--role", default="decoder", choices=["decoder", "encoder"])
|
|
28
|
+
p_convert.add_argument("--dst", required=True, type=Path, help="Output .safetensors path")
|
|
29
|
+
|
|
30
|
+
p_info = sub.add_parser("info", help="Print info about a converted MLX safetensors file")
|
|
31
|
+
p_info.add_argument("path", type=Path)
|
|
32
|
+
|
|
33
|
+
p_bench = sub.add_parser("bench", help="Decode benchmark on the current Mac")
|
|
34
|
+
p_bench.add_argument("--variant", default="taef2", choices=variant_names)
|
|
35
|
+
|
|
36
|
+
args = parser.parse_args(argv)
|
|
37
|
+
if args.cmd == "convert": # pragma: no cover
|
|
38
|
+
return _cmd_convert(args)
|
|
39
|
+
if args.cmd == "info":
|
|
40
|
+
return _cmd_info(args)
|
|
41
|
+
if args.cmd == "bench": # pragma: no cover
|
|
42
|
+
return _cmd_bench(args)
|
|
43
|
+
return 1 # pragma: no cover
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _cmd_convert(args: argparse.Namespace) -> int: # pragma: no cover
|
|
47
|
+
from mlx_taef.convert import convert_hf_decoder_to_mlx, convert_hf_encoder_to_mlx
|
|
48
|
+
|
|
49
|
+
config = next(v for v in ALL_VARIANTS if v.name == args.variant)
|
|
50
|
+
if args.role == "encoder":
|
|
51
|
+
convert_hf_encoder_to_mlx(out_path=args.dst, config=config)
|
|
52
|
+
else:
|
|
53
|
+
convert_hf_decoder_to_mlx(out_path=args.dst, config=config)
|
|
54
|
+
print(f"Wrote {args.dst}")
|
|
55
|
+
return 0
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _cmd_info(args: argparse.Namespace) -> int:
|
|
59
|
+
weights = mx.load(str(args.path))
|
|
60
|
+
print(f"File: {args.path}")
|
|
61
|
+
print(f"Total tensors: {len(weights)}")
|
|
62
|
+
total_params = sum(int(w.size) for w in weights.values()) # type: ignore[misc,union-attr]
|
|
63
|
+
print(f"Total params: {total_params:,}")
|
|
64
|
+
return 0
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _cmd_bench(args: argparse.Namespace) -> int: # pragma: no cover
|
|
68
|
+
from mlx_taef.api import TAEF1, TAEF2, TAESD, TAESDXL
|
|
69
|
+
|
|
70
|
+
cls_by_name = {"taesd": TAESD, "taesdxl": TAESDXL, "taef1": TAEF1, "taef2": TAEF2}
|
|
71
|
+
cls = cls_by_name[args.variant]
|
|
72
|
+
config = next(v for v in ALL_VARIANTS if v.name == args.variant)
|
|
73
|
+
|
|
74
|
+
model = cls.from_pretrained(include_encoder=False)
|
|
75
|
+
# 1024x1024 image with 8x downsample = 128x128 latent
|
|
76
|
+
latent = mx.random.normal((1, 128, 128, config.latent_channels)).astype(mx.float16)
|
|
77
|
+
mx.eval(latent)
|
|
78
|
+
|
|
79
|
+
# Warm-up
|
|
80
|
+
mx.eval(model.decode(latent))
|
|
81
|
+
|
|
82
|
+
times: list[float] = []
|
|
83
|
+
for _ in range(5):
|
|
84
|
+
start = time.perf_counter()
|
|
85
|
+
mx.eval(model.decode(latent))
|
|
86
|
+
times.append(time.perf_counter() - start)
|
|
87
|
+
median_ms = sorted(times)[len(times) // 2] * 1000
|
|
88
|
+
print(f"{args.variant} decode median: {median_ms:.1f} ms over {len(times)} runs")
|
|
89
|
+
return 0
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
__all__ = ["main"]
|
mlx_taef/convert.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
"""HF safetensors -> MLX safetensors conversion.
|
|
2
|
+
|
|
3
|
+
Zero PyTorch dependency: reads source files with `safetensors.numpy.load_file`
|
|
4
|
+
and writes MLX safetensors directly. Runtime users never need torch.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import mlx.core as mx
|
|
12
|
+
import numpy as np
|
|
13
|
+
from huggingface_hub import hf_hub_download
|
|
14
|
+
from safetensors.numpy import load_file as safetensors_load_numpy
|
|
15
|
+
|
|
16
|
+
from mlx_taef.model import make_decoder
|
|
17
|
+
from mlx_taef.variants import TaesdVariantConfig
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def convert_diffusers_to_sequential(
|
|
23
|
+
sd: dict[str, Any],
|
|
24
|
+
*,
|
|
25
|
+
role: str,
|
|
26
|
+
) -> dict[str, np.ndarray]:
|
|
27
|
+
"""Map Diffusers-VAE keys to upstream Sequential-key format.
|
|
28
|
+
|
|
29
|
+
Per the TAEF2 model card, the decoder gets a +1 index shift because the
|
|
30
|
+
Diffusers VAE prepends one layer that the upstream Sequential decoder
|
|
31
|
+
doesn't have. Encoder keys have no offset.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
sd: source state dict with Diffusers keys like 'decoder.layers.0.weight'.
|
|
35
|
+
role: 'decoder' (apply +1 offset) or 'encoder' (no offset).
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
State dict with upstream-Sequential keys like '0.weight', '1.weight'.
|
|
39
|
+
Keys not matching role-prefix are filtered out.
|
|
40
|
+
"""
|
|
41
|
+
out: dict[str, np.ndarray] = {}
|
|
42
|
+
prefix = f"{role}."
|
|
43
|
+
for k, v in sd.items():
|
|
44
|
+
if not k.startswith(prefix):
|
|
45
|
+
continue
|
|
46
|
+
suffix = k[len(prefix) :]
|
|
47
|
+
if suffix.startswith("layers."):
|
|
48
|
+
parts = suffix.split(".")
|
|
49
|
+
idx = int(parts[1])
|
|
50
|
+
if role == "decoder":
|
|
51
|
+
idx += 1
|
|
52
|
+
new_key = f"{idx}." + ".".join(parts[2:])
|
|
53
|
+
else: # pragma: no cover
|
|
54
|
+
new_key = suffix
|
|
55
|
+
out[new_key] = v
|
|
56
|
+
return out
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _load_role_state_dict( # pragma: no cover
|
|
60
|
+
config: TaesdVariantConfig,
|
|
61
|
+
role: str,
|
|
62
|
+
) -> dict[str, np.ndarray]:
|
|
63
|
+
"""Download and load weights for (variant, role) into a Sequential-keyed dict."""
|
|
64
|
+
if config.key_format == "diffusers":
|
|
65
|
+
if config.hf_filename is None:
|
|
66
|
+
raise ValueError(f"Diffusers variant {config.name!r} has no hf_filename")
|
|
67
|
+
path = hf_hub_download(repo_id=config.hf_repo, filename=config.hf_filename)
|
|
68
|
+
full_sd = safetensors_load_numpy(path)
|
|
69
|
+
return convert_diffusers_to_sequential(full_sd, role=role)
|
|
70
|
+
if config.key_format == "upstream":
|
|
71
|
+
filename = config.hf_decoder_filename if role == "decoder" else config.hf_encoder_filename
|
|
72
|
+
if filename is None:
|
|
73
|
+
raise ValueError(f"Upstream variant {config.name!r} has no {role} filename")
|
|
74
|
+
path = hf_hub_download(repo_id=config.hf_repo, filename=filename)
|
|
75
|
+
return safetensors_load_numpy(path)
|
|
76
|
+
raise ValueError(f"Unknown key_format: {config.key_format!r}")
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def convert_hf_decoder_to_mlx( # pragma: no cover
|
|
80
|
+
*,
|
|
81
|
+
out_path: Path | str,
|
|
82
|
+
config: TaesdVariantConfig,
|
|
83
|
+
) -> None:
|
|
84
|
+
"""Download upstream decoder weights, convert to MLX safetensors at `out_path`.
|
|
85
|
+
|
|
86
|
+
Handles both upstream-Sequential and Diffusers key formats. Transposes
|
|
87
|
+
Conv2d weights from NCHW to NHWC. Writes the result with MLX-flat keys
|
|
88
|
+
like 'layers.0.weight', 'layers.1.weight', ...
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
out_path: where to write the MLX safetensors file.
|
|
92
|
+
config: variant configuration.
|
|
93
|
+
"""
|
|
94
|
+
sd = _load_role_state_dict(config, role="decoder")
|
|
95
|
+
decoder = make_decoder(config)
|
|
96
|
+
expected = _flatten_module_param_shapes(decoder)
|
|
97
|
+
converted = _build_mlx_state_dict(sd, expected_shapes=expected)
|
|
98
|
+
mx.save_safetensors(str(out_path), converted)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _sequential_key_to_mlx(src_key: str) -> str:
|
|
102
|
+
"""Convert an upstream-Sequential key to an MLX-flat dotted key.
|
|
103
|
+
|
|
104
|
+
MLX's `nn.Sequential` stores its children under `.layers`, so every
|
|
105
|
+
integer path segment (after the first top-level layer index) must be
|
|
106
|
+
wrapped as `layers.<N>` rather than a bare `<N>`.
|
|
107
|
+
|
|
108
|
+
Examples::
|
|
109
|
+
|
|
110
|
+
"1.weight" -> "layers.1.weight"
|
|
111
|
+
"3.conv.0.weight" -> "layers.3.conv.layers.0.weight"
|
|
112
|
+
"3.pool.1.bias" -> "layers.3.pool.layers.1.bias"
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
src_key: upstream-Sequential key like '3.conv.0.weight'.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
MLX-flat dotted key like 'layers.3.conv.layers.0.weight'.
|
|
119
|
+
"""
|
|
120
|
+
parts = src_key.split(".")
|
|
121
|
+
out = ["layers", parts[0]]
|
|
122
|
+
for part in parts[1:]:
|
|
123
|
+
if part.isdigit():
|
|
124
|
+
out.extend(["layers", part])
|
|
125
|
+
else:
|
|
126
|
+
out.append(part)
|
|
127
|
+
return ".".join(out)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _build_mlx_state_dict(
|
|
131
|
+
sd: dict[str, np.ndarray],
|
|
132
|
+
*,
|
|
133
|
+
expected_shapes: dict[str, tuple[int, ...]],
|
|
134
|
+
) -> dict[str, mx.array]:
|
|
135
|
+
"""Apply NCHW->NHWC transpose for Conv weights and prefix keys with 'layers.'."""
|
|
136
|
+
converted: dict[str, mx.array] = {}
|
|
137
|
+
for src_key, arr in sd.items():
|
|
138
|
+
dst_key = _sequential_key_to_mlx(src_key)
|
|
139
|
+
if dst_key not in expected_shapes:
|
|
140
|
+
# Skip keys that don't map to the MLX module structure
|
|
141
|
+
# (e.g., extra Diffusers-specific keys we don't need)
|
|
142
|
+
continue
|
|
143
|
+
# Conv2d weight transpose NCHW (out, in, kH, kW) -> NHWC (out, kH, kW, in)
|
|
144
|
+
# Detected when source is 4D and expected MLX shape matches the transposed shape.
|
|
145
|
+
if arr.ndim == 4 and expected_shapes[dst_key] == (
|
|
146
|
+
arr.shape[0],
|
|
147
|
+
arr.shape[2],
|
|
148
|
+
arr.shape[3],
|
|
149
|
+
arr.shape[1],
|
|
150
|
+
):
|
|
151
|
+
arr = np.transpose(arr, (0, 2, 3, 1)).copy()
|
|
152
|
+
converted[dst_key] = mx.array(arr)
|
|
153
|
+
return converted
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _flatten_module_param_shapes(module: Any, prefix: str = "") -> dict[str, tuple[int, ...]]:
|
|
157
|
+
"""Walk module.parameters() and return a flat dict of dotted-key -> shape."""
|
|
158
|
+
out: dict[str, tuple[int, ...]] = {}
|
|
159
|
+
|
|
160
|
+
def _walk(obj: Any, p: str) -> None:
|
|
161
|
+
if isinstance(obj, dict):
|
|
162
|
+
for k, v in obj.items():
|
|
163
|
+
_walk(v, f"{p}.{k}" if p else k)
|
|
164
|
+
elif isinstance(obj, list):
|
|
165
|
+
for i, item in enumerate(obj):
|
|
166
|
+
_walk(item, f"{p}.{i}")
|
|
167
|
+
elif hasattr(obj, "shape"):
|
|
168
|
+
out[p] = tuple(obj.shape)
|
|
169
|
+
|
|
170
|
+
_walk(module.parameters(), prefix)
|
|
171
|
+
return out
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def convert_hf_encoder_to_mlx( # pragma: no cover
|
|
175
|
+
*,
|
|
176
|
+
out_path: Path | str,
|
|
177
|
+
config: TaesdVariantConfig,
|
|
178
|
+
) -> None:
|
|
179
|
+
"""Download upstream encoder weights, convert to MLX safetensors at `out_path`.
|
|
180
|
+
|
|
181
|
+
Mirrors `convert_hf_decoder_to_mlx` but introspects via `make_encoder` so
|
|
182
|
+
Conv weights are transposed against the correct shapes.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
out_path: where to write the MLX safetensors file.
|
|
186
|
+
config: variant configuration.
|
|
187
|
+
"""
|
|
188
|
+
from mlx_taef.model import make_encoder
|
|
189
|
+
|
|
190
|
+
sd = _load_role_state_dict(config, role="encoder")
|
|
191
|
+
encoder = make_encoder(config)
|
|
192
|
+
expected = _flatten_module_param_shapes(encoder)
|
|
193
|
+
converted = _build_mlx_state_dict(sd, expected_shapes=expected)
|
|
194
|
+
mx.save_safetensors(str(out_path), converted)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
__all__ = [
|
|
198
|
+
"convert_diffusers_to_sequential",
|
|
199
|
+
"convert_hf_decoder_to_mlx",
|
|
200
|
+
"convert_hf_encoder_to_mlx",
|
|
201
|
+
]
|
mlx_taef/download.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""HF Hub auto-download + cache. Zero PyTorch dependency at runtime."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from mlx_taef.convert import convert_hf_decoder_to_mlx, convert_hf_encoder_to_mlx
|
|
7
|
+
from mlx_taef.variants import TaesdVariantConfig
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
CACHE_ROOT = Path.home() / ".cache" / "mlx-taef"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_or_convert(config: TaesdVariantConfig, *, role: str = "decoder") -> Path:
|
|
15
|
+
"""Return the local path to converted MLX weights for (variant, role).
|
|
16
|
+
|
|
17
|
+
On cache miss, triggers the full conversion pipeline (HF download + key
|
|
18
|
+
remap + NHWC transpose + safetensors write). Subsequent calls return the
|
|
19
|
+
cached path without any network access.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
config: variant configuration (selects HF repo + filename).
|
|
23
|
+
role: 'decoder' (default) or 'encoder'.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Local filesystem path to the MLX safetensors file.
|
|
27
|
+
"""
|
|
28
|
+
cache_dir = CACHE_ROOT / config.name
|
|
29
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
30
|
+
out_path = cache_dir / f"{config.name}_{role}.safetensors"
|
|
31
|
+
if out_path.exists():
|
|
32
|
+
logger.debug("Using cached weights at %s", out_path)
|
|
33
|
+
return out_path
|
|
34
|
+
|
|
35
|
+
logger.info("Downloading + converting %s %s weights from %s", config.name, role, config.hf_repo)
|
|
36
|
+
if role == "decoder":
|
|
37
|
+
convert_hf_decoder_to_mlx(out_path=out_path, config=config)
|
|
38
|
+
elif role == "encoder":
|
|
39
|
+
convert_hf_encoder_to_mlx(out_path=out_path, config=config)
|
|
40
|
+
else:
|
|
41
|
+
raise ValueError(f"role must be 'decoder' or 'encoder', got {role!r}")
|
|
42
|
+
return out_path
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
__all__ = ["CACHE_ROOT", "get_or_convert"]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Optional integrations for mlx-taef (e.g. mflux)."""
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""mflux integration for mlx-taef.
|
|
2
|
+
|
|
3
|
+
Provides:
|
|
4
|
+
- `unpack_flux2_latent`: convert mflux's packed FLUX.2 latents to TAEF2-compatible NHWC.
|
|
5
|
+
- `LivePreviewCallback`: drop-in mflux callback that writes preview PNGs every N steps.
|
|
6
|
+
|
|
7
|
+
Install with: `pip install "mlx-taef[mflux]"`.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import mlx.core as mx
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
from mflux.callbacks.callback import InLoopCallback # type: ignore[import-untyped]
|
|
18
|
+
except ImportError as e: # pragma: no cover
|
|
19
|
+
raise ImportError(
|
|
20
|
+
"mflux is required for mlx_taef.integrations.mflux. "
|
|
21
|
+
"Install with: pip install 'mlx-taef[mflux]'"
|
|
22
|
+
) from e
|
|
23
|
+
|
|
24
|
+
from mlx_taef.api import TAEF1, TAEF2, Taef
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def unpack_flux2_latent(
|
|
30
|
+
packed: mx.array,
|
|
31
|
+
*,
|
|
32
|
+
latent_height: int,
|
|
33
|
+
latent_width: int,
|
|
34
|
+
bn_mean: mx.array | None = None,
|
|
35
|
+
bn_var: mx.array | None = None,
|
|
36
|
+
bn_eps: float = 1e-4,
|
|
37
|
+
) -> mx.array:
|
|
38
|
+
"""Unpack mflux's packed FLUX.2 latent into NHWC `(B, lH*2, lW*2, 32)` for TAEF2.
|
|
39
|
+
|
|
40
|
+
Mirrors mflux's `Flux2VAE.decode_packed_latents` preprocessing:
|
|
41
|
+
1. BN denormalize: latents = packed * sqrt(var + eps) + mean (128-channel stats)
|
|
42
|
+
2. Unpatchify: (B, 128, lH, lW) -> (B, 32, lH*2, lW*2)
|
|
43
|
+
3. NCHW -> NHWC transpose.
|
|
44
|
+
|
|
45
|
+
Without `bn_mean`/`bn_var`, the helper assumes identity BN (correct shape,
|
|
46
|
+
slight value offset). Preview structure will still be visible but colors
|
|
47
|
+
may shift. For exact preview, pass the BN stats from the loaded Flux2VAE.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
packed: in-loop latent from mflux, shape (B, lH*lW, 128).
|
|
51
|
+
latent_height: latent spatial height (image_height // 16).
|
|
52
|
+
latent_width: latent spatial width.
|
|
53
|
+
bn_mean: optional BN running_mean for exact value recovery (128 elements).
|
|
54
|
+
bn_var: optional BN running_var for exact value recovery (128 elements).
|
|
55
|
+
bn_eps: BN epsilon. Matches mflux's Flux2BatchNormStats default of 1e-4,
|
|
56
|
+
verified at mflux 0.17.5.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
NHWC tensor of shape (B, latent_height*2, latent_width*2, 32) — ready
|
|
60
|
+
for `TAEF2.decode()`.
|
|
61
|
+
|
|
62
|
+
See `notes/mflux-latent-layout.md` for the analysis behind this transform.
|
|
63
|
+
"""
|
|
64
|
+
b, _, c = packed.shape
|
|
65
|
+
if c != 128:
|
|
66
|
+
raise ValueError(f"Expected 128-channel packed latent, got {c}")
|
|
67
|
+
|
|
68
|
+
# Step 1: reshape and transpose to NCHW: (B, lH*lW, 128) -> (B, lH, lW, 128) -> (B, 128, lH, lW)
|
|
69
|
+
latents = packed.reshape(b, latent_height, latent_width, c).transpose(0, 3, 1, 2)
|
|
70
|
+
|
|
71
|
+
# Step 2: BN denormalize
|
|
72
|
+
if bn_mean is not None and bn_var is not None:
|
|
73
|
+
mean = bn_mean.reshape(1, -1, 1, 1)
|
|
74
|
+
std = mx.sqrt(bn_var.reshape(1, -1, 1, 1) + bn_eps)
|
|
75
|
+
latents = latents * std + mean
|
|
76
|
+
# else: identity BN (mean=0, var=1)
|
|
77
|
+
|
|
78
|
+
# Step 3: Unpatchify (B, 128, lH, lW) -> (B, 32, lH*2, lW*2)
|
|
79
|
+
batch, _, h, w = latents.shape
|
|
80
|
+
latents = latents.reshape(batch, 32, 2, 2, h, w)
|
|
81
|
+
latents = latents.transpose(0, 1, 4, 2, 5, 3)
|
|
82
|
+
latents = latents.reshape(batch, 32, h * 2, w * 2)
|
|
83
|
+
|
|
84
|
+
# Step 4: NCHW -> NHWC for TAEF2
|
|
85
|
+
return latents.transpose(0, 2, 3, 1)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class LivePreviewCallback(InLoopCallback): # type: ignore[misc]
|
|
89
|
+
"""mflux callback that writes a low-quality preview PNG every N iterations.
|
|
90
|
+
|
|
91
|
+
Drops into `mflux.Flux2Klein.generate_image(callbacks=[...])`. On the
|
|
92
|
+
iteration indices that match `every`, unpacks the in-flight latent, runs
|
|
93
|
+
TAEF2 decode (via `TAEF1` for FLUX.1), and writes a PIL PNG to disk.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
variant: 'taef1' (for FLUX.1 latents) or 'taef2' (for FLUX.2 Klein).
|
|
97
|
+
every: emit a preview every Nth iteration. Default 5.
|
|
98
|
+
save_to: filesystem path to write previews. Overwritten each emission.
|
|
99
|
+
latent_height: latent spatial height; for 512x512 image with FLUX.2 Klein, this is 32.
|
|
100
|
+
latent_width: latent spatial width.
|
|
101
|
+
bn_mean: optional BN running_mean for TAEF2 (see `unpack_flux2_latent`).
|
|
102
|
+
bn_var: optional BN running_var for TAEF2.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
*,
|
|
108
|
+
variant: str = "taef2",
|
|
109
|
+
every: int = 5,
|
|
110
|
+
save_to: str | Path = "preview.png",
|
|
111
|
+
latent_height: int = 32,
|
|
112
|
+
latent_width: int = 32,
|
|
113
|
+
bn_mean: mx.array | None = None,
|
|
114
|
+
bn_var: mx.array | None = None,
|
|
115
|
+
) -> None:
|
|
116
|
+
"""Initialise LivePreviewCallback. See class docstring for argument descriptions."""
|
|
117
|
+
if variant == "taef1": # pragma: no cover
|
|
118
|
+
self.model: Taef = TAEF1.from_pretrained(include_encoder=False)
|
|
119
|
+
elif variant == "taef2":
|
|
120
|
+
self.model = TAEF2.from_pretrained(include_encoder=False)
|
|
121
|
+
else: # pragma: no cover
|
|
122
|
+
raise ValueError(f"variant must be 'taef1' or 'taef2', got {variant!r}")
|
|
123
|
+
self.every = every
|
|
124
|
+
self.save_to = Path(save_to)
|
|
125
|
+
self.latent_height = latent_height
|
|
126
|
+
self.latent_width = latent_width
|
|
127
|
+
self.bn_mean = bn_mean
|
|
128
|
+
self.bn_var = bn_var
|
|
129
|
+
self._iter = 0
|
|
130
|
+
|
|
131
|
+
def call_in_loop(
|
|
132
|
+
self,
|
|
133
|
+
t: object,
|
|
134
|
+
seed: object,
|
|
135
|
+
prompt: object,
|
|
136
|
+
latents: mx.array,
|
|
137
|
+
config: object,
|
|
138
|
+
time_steps: object,
|
|
139
|
+
) -> None:
|
|
140
|
+
"""Decode latent + save PNG every Nth iteration (counted by our own iter, not `t`)."""
|
|
141
|
+
idx = self._iter
|
|
142
|
+
self._iter += 1
|
|
143
|
+
if idx % self.every != 0:
|
|
144
|
+
return # pragma: no cover
|
|
145
|
+
|
|
146
|
+
if self.model._config.name == "taef2":
|
|
147
|
+
unpacked = unpack_flux2_latent(
|
|
148
|
+
latents,
|
|
149
|
+
latent_height=self.latent_height,
|
|
150
|
+
latent_width=self.latent_width,
|
|
151
|
+
bn_mean=self.bn_mean,
|
|
152
|
+
bn_var=self.bn_var,
|
|
153
|
+
)
|
|
154
|
+
else: # pragma: no cover
|
|
155
|
+
# TAEF1 / FLUX.1: latents come in NCHW (B, 16, H, W). Transpose to NHWC.
|
|
156
|
+
unpacked = mx.transpose(latents, (0, 2, 3, 1)) if latents.shape[1] == 16 else latents
|
|
157
|
+
|
|
158
|
+
img = self.model.decode_image(unpacked)
|
|
159
|
+
self._save_png(img[0])
|
|
160
|
+
|
|
161
|
+
def _save_png(self, img_nhwc_uint8: mx.array) -> None:
|
|
162
|
+
from PIL import Image
|
|
163
|
+
|
|
164
|
+
arr = np.array(img_nhwc_uint8)
|
|
165
|
+
Image.fromarray(arr).save(self.save_to)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
__all__ = ["LivePreviewCallback", "unpack_flux2_latent"]
|
mlx_taef/model.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
"""MLX layer modules for the TAESD-family of tiny autoencoders.
|
|
2
|
+
|
|
3
|
+
Layers are direct ports of the upstream PyTorch implementations at
|
|
4
|
+
https://github.com/madebyollin/taesd (MIT licensed).
|
|
5
|
+
|
|
6
|
+
All public layers accept and return NHWC `mx.array` tensors.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import mlx.core as mx
|
|
10
|
+
import mlx.nn as nn
|
|
11
|
+
|
|
12
|
+
from mlx_taef.variants import TaesdVariantConfig
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Clamp(nn.Module): # type: ignore[misc,name-defined]
|
|
16
|
+
"""Soft clamp via `tanh(x / 3) * 3` — bounds output to [-3, 3].
|
|
17
|
+
|
|
18
|
+
Direct port of upstream `taesd.Clamp`. Used at the start of the decoder
|
|
19
|
+
to bound latent magnitudes before the first conv.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
23
|
+
"""Apply the soft clamp.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
x: input tensor of any shape.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Element-wise `tanh(x / 3) * 3` of the same shape.
|
|
30
|
+
"""
|
|
31
|
+
return mx.tanh(x / 3.0) * 3.0
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def make_conv(n_in: int, n_out: int, *, stride: int = 1, bias: bool = True) -> "nn.Conv2d": # type: ignore[name-defined]
|
|
35
|
+
"""3x3 Conv2d with padding=1 — matches upstream `taesd.conv(n_in, n_out, **kwargs)`.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
n_in: input channel count.
|
|
39
|
+
n_out: output channel count.
|
|
40
|
+
stride: spatial stride for the convolution (default 1).
|
|
41
|
+
bias: whether to include a learned bias term (default True).
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Configured `mlx.nn.Conv2d` with kernel_size=3, padding=1.
|
|
45
|
+
"""
|
|
46
|
+
return nn.Conv2d(n_in, n_out, kernel_size=3, stride=stride, padding=1, bias=bias) # type: ignore[attr-defined]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class _Identity(nn.Module): # type: ignore[misc,name-defined]
|
|
50
|
+
"""No-op layer used as the skip connection when in/out channels match."""
|
|
51
|
+
|
|
52
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
53
|
+
return x
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class Block(nn.Module): # type: ignore[misc,name-defined]
|
|
57
|
+
"""Residual block: `ReLU(Conv -> ReLU -> Conv -> ReLU -> Conv + skip)`.
|
|
58
|
+
|
|
59
|
+
Direct port of upstream `taesd.Block`. Optionally adds a parallel
|
|
60
|
+
`midblock_gn` pool branch (used by the flux_2 TAEF2 variant).
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
n_in: input channel count.
|
|
64
|
+
n_out: output channel count.
|
|
65
|
+
use_midblock_gn: when True, prepend `x = x + pool(x)` where `pool` is
|
|
66
|
+
a 1x1 conv expand -> GroupNorm -> ReLU -> 1x1 conv contract chain.
|
|
67
|
+
Used by TAEF2's flux_2 arch_variant. Default False.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(self, n_in: int, n_out: int, *, use_midblock_gn: bool = False) -> None:
|
|
71
|
+
"""Initialise Block layers.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
n_in: input channel count.
|
|
75
|
+
n_out: output channel count.
|
|
76
|
+
use_midblock_gn: when True, prepend the midblock GroupNorm pool
|
|
77
|
+
branch before the main conv path. Default False.
|
|
78
|
+
"""
|
|
79
|
+
super().__init__()
|
|
80
|
+
self.conv = nn.Sequential( # type: ignore[attr-defined]
|
|
81
|
+
make_conv(n_in, n_out),
|
|
82
|
+
nn.ReLU(), # type: ignore[attr-defined]
|
|
83
|
+
make_conv(n_out, n_out),
|
|
84
|
+
nn.ReLU(), # type: ignore[attr-defined]
|
|
85
|
+
make_conv(n_out, n_out),
|
|
86
|
+
)
|
|
87
|
+
self.skip = (
|
|
88
|
+
nn.Conv2d(n_in, n_out, kernel_size=1, bias=False) if n_in != n_out else _Identity() # type: ignore[attr-defined]
|
|
89
|
+
)
|
|
90
|
+
self.fuse = nn.ReLU() # type: ignore[attr-defined]
|
|
91
|
+
self.pool = None
|
|
92
|
+
if use_midblock_gn:
|
|
93
|
+
n_gn = n_in * 4
|
|
94
|
+
# CRITICAL: pytorch_compatible=True is required for parity. MLX
|
|
95
|
+
# defaults to a different group-channel ordering than PyTorch's
|
|
96
|
+
# nn.GroupNorm.
|
|
97
|
+
self.pool = nn.Sequential( # type: ignore[attr-defined]
|
|
98
|
+
nn.Conv2d(n_in, n_gn, kernel_size=1, bias=False), # type: ignore[attr-defined]
|
|
99
|
+
nn.GroupNorm(num_groups=4, dims=n_gn, pytorch_compatible=True), # type: ignore[attr-defined]
|
|
100
|
+
nn.ReLU(), # type: ignore[attr-defined]
|
|
101
|
+
nn.Conv2d(n_gn, n_in, kernel_size=1, bias=False), # type: ignore[attr-defined]
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
105
|
+
"""Apply the residual block.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
x: NHWC input tensor with `n_in` channels.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
NHWC output tensor with `n_out` channels.
|
|
112
|
+
"""
|
|
113
|
+
if self.pool is not None:
|
|
114
|
+
x = x + self.pool(x)
|
|
115
|
+
return self.fuse(self.conv(x) + self.skip(x)) # type: ignore[no-any-return]
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def make_decoder(config: TaesdVariantConfig) -> nn.Sequential: # type: ignore[name-defined]
|
|
119
|
+
"""Build the decoder for a TAESD-family variant.
|
|
120
|
+
|
|
121
|
+
Mirrors upstream `taesd.Decoder(latent_channels, use_midblock_gn)`:
|
|
122
|
+
Clamp -> Conv -> ReLU -> Block x3 -> Up -> Conv -> Block x3 -> Up -> Conv ->
|
|
123
|
+
Block x3 -> Up -> Conv -> Block -> Conv.
|
|
124
|
+
|
|
125
|
+
For arch_variant="flux_2" (TAEF2), the first three Blocks have
|
|
126
|
+
`use_midblock_gn=True`; the rest are vanilla.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
config: variant configuration that selects `latent_channels` and
|
|
130
|
+
whether the first three Blocks use midblock_gn.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
`nn.Sequential` decoder ready for weight loading.
|
|
134
|
+
"""
|
|
135
|
+
mb = config.use_midblock_gn
|
|
136
|
+
layers: list[nn.Module] = [ # type: ignore[name-defined]
|
|
137
|
+
Clamp(),
|
|
138
|
+
make_conv(config.latent_channels, 64),
|
|
139
|
+
nn.ReLU(), # type: ignore[attr-defined]
|
|
140
|
+
Block(64, 64, use_midblock_gn=mb),
|
|
141
|
+
Block(64, 64, use_midblock_gn=mb),
|
|
142
|
+
Block(64, 64, use_midblock_gn=mb),
|
|
143
|
+
nn.Upsample(scale_factor=2, mode="nearest"), # type: ignore[attr-defined]
|
|
144
|
+
make_conv(64, 64, bias=False),
|
|
145
|
+
Block(64, 64),
|
|
146
|
+
Block(64, 64),
|
|
147
|
+
Block(64, 64),
|
|
148
|
+
nn.Upsample(scale_factor=2, mode="nearest"), # type: ignore[attr-defined]
|
|
149
|
+
make_conv(64, 64, bias=False),
|
|
150
|
+
Block(64, 64),
|
|
151
|
+
Block(64, 64),
|
|
152
|
+
Block(64, 64),
|
|
153
|
+
nn.Upsample(scale_factor=2, mode="nearest"), # type: ignore[attr-defined]
|
|
154
|
+
make_conv(64, 64, bias=False),
|
|
155
|
+
Block(64, 64),
|
|
156
|
+
make_conv(64, 3),
|
|
157
|
+
]
|
|
158
|
+
return nn.Sequential(*layers) # type: ignore[attr-defined]
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def make_encoder(config: TaesdVariantConfig) -> nn.Sequential: # type: ignore[name-defined]
|
|
162
|
+
"""Build the encoder for a TAESD-family variant.
|
|
163
|
+
|
|
164
|
+
Mirrors upstream `taesd.Encoder(latent_channels, use_midblock_gn)`:
|
|
165
|
+
Conv -> Block -> (StridedConv -> Block x3) x3 -> Conv.
|
|
166
|
+
Three strided convs each halve spatial dims (8x downsample total).
|
|
167
|
+
|
|
168
|
+
For arch_variant="flux_2" (TAEF2), the LAST three Blocks use `use_midblock_gn=True`.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
config: variant configuration that selects `latent_channels` and `arch_variant`.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
`nn.Sequential` encoder ready for weight loading.
|
|
175
|
+
"""
|
|
176
|
+
mb = config.use_midblock_gn
|
|
177
|
+
layers: list[nn.Module] = [ # type: ignore[name-defined]
|
|
178
|
+
make_conv(3, 64),
|
|
179
|
+
Block(64, 64),
|
|
180
|
+
make_conv(64, 64, stride=2, bias=False),
|
|
181
|
+
Block(64, 64),
|
|
182
|
+
Block(64, 64),
|
|
183
|
+
Block(64, 64),
|
|
184
|
+
make_conv(64, 64, stride=2, bias=False),
|
|
185
|
+
Block(64, 64),
|
|
186
|
+
Block(64, 64),
|
|
187
|
+
Block(64, 64),
|
|
188
|
+
make_conv(64, 64, stride=2, bias=False),
|
|
189
|
+
Block(64, 64, use_midblock_gn=mb),
|
|
190
|
+
Block(64, 64, use_midblock_gn=mb),
|
|
191
|
+
Block(64, 64, use_midblock_gn=mb),
|
|
192
|
+
make_conv(64, config.latent_channels),
|
|
193
|
+
]
|
|
194
|
+
return nn.Sequential(*layers) # type: ignore[attr-defined]
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
__all__ = ["Block", "Clamp", "make_conv", "make_decoder", "make_encoder"]
|
mlx_taef/py.typed
ADDED
|
File without changes
|
mlx_taef/variants.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""Variant configurations for the TAESD-family of tiny autoencoders.
|
|
2
|
+
|
|
3
|
+
Field reference:
|
|
4
|
+
- `key_format`: "upstream" means the HF safetensors uses upstream Sequential keys
|
|
5
|
+
like "0.weight", "1.weight"; "diffusers" means the HF safetensors uses Diffusers
|
|
6
|
+
VAE keys like "encoder.layers.0.weight" and requires a +1 decoder index offset.
|
|
7
|
+
- `arch_variant`: None | "flux_2" | "f32". "flux_2" enables midblock_gn pool
|
|
8
|
+
branches in three blocks of encoder and decoder. "f32" is reserved for
|
|
9
|
+
TAESANA (future).
|
|
10
|
+
- `latent_magnitude`, `latent_shift`: from upstream TAESD.scale_latents /
|
|
11
|
+
unscale_latents.
|
|
12
|
+
- `hf_filename` (Diffusers single-file format) vs `hf_decoder_filename` /
|
|
13
|
+
`hf_encoder_filename` (upstream two-file format) — depends on the actual
|
|
14
|
+
repo layout for each variant.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass(frozen=True, slots=True, kw_only=True)
|
|
24
|
+
class TaesdVariantConfig:
|
|
25
|
+
"""Configuration for a single TAESD-family variant.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
name: short identifier like "taef2".
|
|
29
|
+
latent_channels: latent channel count (4 for SD1/SDXL, 16 for FLUX.1, 32 for FLUX.2 Klein).
|
|
30
|
+
arch_variant: None for the standard arch, "flux_2" for TAEF2's midblock_gn, "f32" for TAESANA.
|
|
31
|
+
key_format: "upstream" (Sequential keys) or "diffusers" (requires remap).
|
|
32
|
+
hf_repo: e.g. "madebyollin/taef2".
|
|
33
|
+
hf_filename: filename within hf_repo when key_format == "diffusers" (single file).
|
|
34
|
+
hf_decoder_filename: decoder filename when key_format == "upstream".
|
|
35
|
+
hf_encoder_filename: encoder filename when key_format == "upstream".
|
|
36
|
+
latent_magnitude: scale factor for scale_latents/unscale_latents.
|
|
37
|
+
latent_shift: shift factor for scale_latents/unscale_latents.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
name: str
|
|
41
|
+
latent_channels: int
|
|
42
|
+
arch_variant: str | None
|
|
43
|
+
key_format: str
|
|
44
|
+
hf_repo: str
|
|
45
|
+
hf_filename: str | None
|
|
46
|
+
hf_decoder_filename: str | None
|
|
47
|
+
hf_encoder_filename: str | None
|
|
48
|
+
latent_magnitude: float = 3.0
|
|
49
|
+
latent_shift: float = 0.5
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def use_midblock_gn(self) -> bool:
|
|
53
|
+
"""Whether the variant's Block layers use the midblock GroupNorm pool branch."""
|
|
54
|
+
return self.arch_variant == "flux_2"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
TAESD_CONFIG = TaesdVariantConfig(
|
|
58
|
+
name="taesd",
|
|
59
|
+
latent_channels=4,
|
|
60
|
+
arch_variant=None,
|
|
61
|
+
key_format="upstream",
|
|
62
|
+
hf_repo="madebyollin/taesd",
|
|
63
|
+
hf_filename=None,
|
|
64
|
+
hf_decoder_filename="taesd_decoder.safetensors",
|
|
65
|
+
hf_encoder_filename="taesd_encoder.safetensors",
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
TAESDXL_CONFIG = TaesdVariantConfig(
|
|
69
|
+
name="taesdxl",
|
|
70
|
+
latent_channels=4,
|
|
71
|
+
arch_variant=None,
|
|
72
|
+
key_format="upstream",
|
|
73
|
+
hf_repo="madebyollin/taesdxl",
|
|
74
|
+
hf_filename=None,
|
|
75
|
+
hf_decoder_filename="taesdxl_decoder.safetensors",
|
|
76
|
+
hf_encoder_filename="taesdxl_encoder.safetensors",
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
TAEF1_CONFIG = TaesdVariantConfig(
|
|
80
|
+
name="taef1",
|
|
81
|
+
latent_channels=16,
|
|
82
|
+
arch_variant=None,
|
|
83
|
+
key_format="diffusers",
|
|
84
|
+
hf_repo="madebyollin/taef1",
|
|
85
|
+
hf_filename="diffusion_pytorch_model.safetensors",
|
|
86
|
+
hf_decoder_filename=None,
|
|
87
|
+
hf_encoder_filename=None,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
TAEF2_CONFIG = TaesdVariantConfig(
|
|
91
|
+
name="taef2",
|
|
92
|
+
latent_channels=32,
|
|
93
|
+
arch_variant="flux_2",
|
|
94
|
+
key_format="diffusers",
|
|
95
|
+
hf_repo="madebyollin/taef2",
|
|
96
|
+
hf_filename="taef2.safetensors",
|
|
97
|
+
hf_decoder_filename=None,
|
|
98
|
+
hf_encoder_filename=None,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
ALL_VARIANTS: tuple[TaesdVariantConfig, ...] = (
|
|
102
|
+
TAESD_CONFIG,
|
|
103
|
+
TAESDXL_CONFIG,
|
|
104
|
+
TAEF1_CONFIG,
|
|
105
|
+
TAEF2_CONFIG,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
__all__ = [
|
|
109
|
+
"ALL_VARIANTS",
|
|
110
|
+
"TAEF1_CONFIG",
|
|
111
|
+
"TAEF2_CONFIG",
|
|
112
|
+
"TAESDXL_CONFIG",
|
|
113
|
+
"TAESD_CONFIG",
|
|
114
|
+
"TaesdVariantConfig",
|
|
115
|
+
]
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mlx-taef
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Tiny AutoEncoders for diffusion (TAESD family) on Apple MLX.
|
|
5
|
+
Project-URL: Homepage, https://github.com/ionden/mlx-taef
|
|
6
|
+
Project-URL: Documentation, https://ionden.github.io/mlx-taef
|
|
7
|
+
Project-URL: Issues, https://github.com/ionden/mlx-taef/issues
|
|
8
|
+
Project-URL: Changelog, https://github.com/ionden/mlx-taef/blob/main/CHANGELOG.md
|
|
9
|
+
Author-email: Denis Ineshin <denis.ineshin@gmail.com>
|
|
10
|
+
License-Expression: MIT
|
|
11
|
+
Keywords: apple-silicon,autoencoder,flux,mlx,stable-diffusion,taesd
|
|
12
|
+
Classifier: Development Status :: 3 - Alpha
|
|
13
|
+
Classifier: Intended Audience :: Developers
|
|
14
|
+
Classifier: Intended Audience :: Science/Research
|
|
15
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
16
|
+
Classifier: Operating System :: MacOS :: MacOS X
|
|
17
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
21
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
22
|
+
Classifier: Typing :: Typed
|
|
23
|
+
Requires-Python: >=3.11
|
|
24
|
+
Requires-Dist: huggingface-hub>=0.24
|
|
25
|
+
Requires-Dist: mlx>=0.20
|
|
26
|
+
Requires-Dist: numpy>=1.26
|
|
27
|
+
Requires-Dist: safetensors>=0.4
|
|
28
|
+
Provides-Extra: image
|
|
29
|
+
Requires-Dist: pillow>=10.0; extra == 'image'
|
|
30
|
+
Provides-Extra: mflux
|
|
31
|
+
Requires-Dist: mflux>=0.17; extra == 'mflux'
|
|
32
|
+
Description-Content-Type: text/markdown
|
|
33
|
+
|
|
34
|
+
# mlx-taef
|
|
35
|
+
|
|
36
|
+
Tiny AutoEncoders for diffusion latents on Apple Silicon, in pure MLX.
|
|
37
|
+
|
|
38
|
+
`mlx-taef` is the first MLX port of the TAESD family — TAESD (SD1.x), TAESDXL (SDXL), TAEF1 (FLUX.1), TAEF2 (FLUX.2 Klein) — distilled mini-autoencoders that decode diffusion latents to RGB in milliseconds using a few-MB model instead of multi-GB full VAEs.
|
|
39
|
+
|
|
40
|
+
Use it for:
|
|
41
|
+
- **Live previews** during long generations on Mac — see each step refresh in <100 ms instead of waiting 30 s for the full VAE.
|
|
42
|
+
- **Low-memory fallbacks** when the full VAE OOMs on 16 GB Macs (TAEF2 peaks at ~1 GB for 1024×1024 vs ~9.6 GB for the full Flux VAE).
|
|
43
|
+
- **Quick latent inspection** in notebooks and ML research.
|
|
44
|
+
|
|
45
|
+
```python
|
|
46
|
+
import mlx.core as mx
|
|
47
|
+
from mlx_taef import TAEF2
|
|
48
|
+
|
|
49
|
+
taef = TAEF2.from_pretrained() # downloads + converts on first call
|
|
50
|
+
img = taef.decode(latents) # NHWC float in [0, 1]
|
|
51
|
+
img_uint8 = taef.decode_image(latents) # uint8 NHWC ready for PIL
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
## Install
|
|
55
|
+
|
|
56
|
+
```bash
|
|
57
|
+
pip install mlx-taef
|
|
58
|
+
# With mflux preview callback:
|
|
59
|
+
pip install "mlx-taef[mflux]"
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
Requires Python ≥ 3.11 and Apple Silicon. Runtime install has **zero PyTorch dependency**.
|
|
63
|
+
|
|
64
|
+
## Variants
|
|
65
|
+
|
|
66
|
+
| Variant | latent_channels | For | HF source |
|
|
67
|
+
|---|---|---|---|
|
|
68
|
+
| `TAESD` | 4 | Stable Diffusion 1.x | [madebyollin/taesd](https://huggingface.co/madebyollin/taesd) |
|
|
69
|
+
| `TAESDXL` | 4 | Stable Diffusion XL | [madebyollin/taesdxl](https://huggingface.co/madebyollin/taesdxl) |
|
|
70
|
+
| `TAEF1` | 16 | FLUX.1 | [madebyollin/taef1](https://huggingface.co/madebyollin/taef1) |
|
|
71
|
+
| `TAEF2` | 32 | FLUX.2 Klein | [madebyollin/taef2](https://huggingface.co/madebyollin/taef2) |
|
|
72
|
+
|
|
73
|
+
All four share one API.
|
|
74
|
+
|
|
75
|
+
## Benchmarks (M1 Max, fp16)
|
|
76
|
+
|
|
77
|
+
| Metric | TAEF2 (this library) | Full Flux VAE (reference) | Win |
|
|
78
|
+
|---|---|---|---|
|
|
79
|
+
| Decode latency 1024×1024 | **~100 ms** | seconds | 50–100× |
|
|
80
|
+
| Peak unified memory 1024×1024 | **~1 GB** | ~9.6 GB | **9.4×** |
|
|
81
|
+
| Output cosine sim vs PyTorch reference | > 0.999 | — | (parity verified) |
|
|
82
|
+
|
|
83
|
+
Numbers from `tests/test_perf.py` on M1 Max 32 GB. See `notes/phase1-benchmarks.md` for details.
|
|
84
|
+
|
|
85
|
+
## mflux live previews
|
|
86
|
+
|
|
87
|
+
```python
|
|
88
|
+
from mflux.models.flux2 import Flux2Klein
|
|
89
|
+
from mlx_taef.integrations.mflux import LivePreviewCallback
|
|
90
|
+
|
|
91
|
+
model = Flux2Klein.from_pretrained("4bit")
|
|
92
|
+
preview = LivePreviewCallback(
|
|
93
|
+
variant="taef2",
|
|
94
|
+
every=5,
|
|
95
|
+
save_to="preview.png",
|
|
96
|
+
latent_height=32, # 512 / 16
|
|
97
|
+
latent_width=32,
|
|
98
|
+
)
|
|
99
|
+
model.callbacks.register(preview)
|
|
100
|
+
model.generate_image(
|
|
101
|
+
prompt="a red apple on a wooden table",
|
|
102
|
+
num_inference_steps=25,
|
|
103
|
+
width=512,
|
|
104
|
+
height=512,
|
|
105
|
+
seed=42,
|
|
106
|
+
)
|
|
107
|
+
```
|
|
108
|
+
|
|
109
|
+
For exact value-space recovery, also pass `bn_mean=flux2_vae.bn.running_mean, bn_var=flux2_vae.bn.running_var` to the callback. Without them, previews show correct structure but colors may shift.
|
|
110
|
+
|
|
111
|
+
See `docs/manual-verification.md` for the full verification recipe.
|
|
112
|
+
|
|
113
|
+
## Status
|
|
114
|
+
|
|
115
|
+
- v0.1.0 — initial public release. All four variants, encoder + decoder, mflux integration, CI, 100% test coverage.
|
|
116
|
+
|
|
117
|
+
## License
|
|
118
|
+
|
|
119
|
+
MIT. Mirrors upstream [madebyollin/taesd](https://github.com/madebyollin/taesd) license. Pretrained weights belong to their respective authors (madebyollin).
|
|
120
|
+
|
|
121
|
+
## Acknowledgements
|
|
122
|
+
|
|
123
|
+
- [madebyollin](https://github.com/madebyollin) for the upstream TAESD-family models and weights.
|
|
124
|
+
- [Apple ML Explore](https://github.com/ml-explore/mlx) for MLX.
|
|
125
|
+
- [filipstrand/mflux](https://github.com/filipstrand/mflux) for the MLX-native FLUX runner this library integrates with.
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
mlx_taef/__init__.py,sha256=hn12r3kNI6DnIKnzNpx6EjU6zfYXpHT034qtk3OQ8ic,261
|
|
2
|
+
mlx_taef/_version.py,sha256=n_5vdJsPNu7wZ57LGuRL585uvll-hiuvZUBWzdG0RQU,520
|
|
3
|
+
mlx_taef/api.py,sha256=_Si8Mi0uVkjbr5n9Ni_kxES-CVhyW3M0nNGedwmtaQo,6502
|
|
4
|
+
mlx_taef/cli.py,sha256=-8nvCxga9tf4YWU6CJADuv0ch7qi2fa9MkwneKJrg0Q,3253
|
|
5
|
+
mlx_taef/convert.py,sha256=5mLO-nfat6HuUpUoQFbIVu2Vi6wbEW7qlGetybL7syo,7002
|
|
6
|
+
mlx_taef/download.py,sha256=QDrOz54aRtByHj_x4idkcAV3k1X-LjgNsFwl3f-OSQw,1608
|
|
7
|
+
mlx_taef/model.py,sha256=FtI8PeDUQi_n01YGvOz_hlJoNkirprQPawPZ2lnSBQM,7340
|
|
8
|
+
mlx_taef/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
|
+
mlx_taef/variants.py,sha256=jAsiU0C2jG38fwazE8XfKf-6gVkh1A0AaBz6n8jnoNI,3655
|
|
10
|
+
mlx_taef/integrations/__init__.py,sha256=XS1-Ml6Mj8iPVF9E3CvFbK18mg-o6dBDiYvg4-MIo08,55
|
|
11
|
+
mlx_taef/integrations/mflux.py,sha256=PyDtgZp1Bgsb3ZYzKYQWg3DvJz_MUxnSsSXaWv2hs2k,6214
|
|
12
|
+
mlx_taef-0.1.0.dist-info/METADATA,sha256=cNI3F-seg2WdGCcdvp5VVRkrhH1AyStHSN3RppctkPA,4883
|
|
13
|
+
mlx_taef-0.1.0.dist-info/WHEEL,sha256=QccIxa26bgl1E6uMy58deGWi-0aeIkkangHcxk2kWfw,87
|
|
14
|
+
mlx_taef-0.1.0.dist-info/entry_points.txt,sha256=x8owUzlm7qAoK-CsSc0ScgAnTzu0Ipkn9evR_kTso1I,47
|
|
15
|
+
mlx_taef-0.1.0.dist-info/RECORD,,
|