mlx-taef 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mlx_taef/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ """mlx-taef: Tiny AutoEncoder family ported to Apple MLX."""
2
+
3
+ import logging
4
+
5
+ from mlx_taef.api import TAEF1, TAEF2, TAESD, TAESDXL, Taef
6
+
7
+ __all__ = ["TAEF1", "TAEF2", "TAESD", "TAESDXL", "Taef"]
8
+
9
+ logging.getLogger("mlx_taef").addHandler(logging.NullHandler())
mlx_taef/_version.py ADDED
@@ -0,0 +1,24 @@
1
+ # file generated by vcs-versioning
2
+ # don't change, don't track in version control
3
+ from __future__ import annotations
4
+
5
+ __all__ = [
6
+ "__version__",
7
+ "__version_tuple__",
8
+ "version",
9
+ "version_tuple",
10
+ "__commit_id__",
11
+ "commit_id",
12
+ ]
13
+
14
+ version: str
15
+ __version__: str
16
+ __version_tuple__: tuple[int | str, ...]
17
+ version_tuple: tuple[int | str, ...]
18
+ commit_id: str | None
19
+ __commit_id__: str | None
20
+
21
+ __version__ = version = '0.1.0'
22
+ __version_tuple__ = version_tuple = (0, 1, 0)
23
+
24
+ __commit_id__ = commit_id = None
mlx_taef/api.py ADDED
@@ -0,0 +1,204 @@
1
+ """User-facing Taef family API: load weights and decode/encode latents.
2
+
3
+ Tensor layout: all public methods use NHWC mx.array.
4
+ Value space: decode() outputs [0, 1] float; encode() expects [0, 1] float.
5
+ """
6
+
7
+ import logging
8
+ from pathlib import Path
9
+ from typing import cast
10
+
11
+ import mlx.core as mx
12
+ import mlx.nn as nn
13
+
14
+ from mlx_taef.model import make_decoder, make_encoder
15
+ from mlx_taef.variants import (
16
+ TAEF1_CONFIG,
17
+ TAEF2_CONFIG,
18
+ TAESD_CONFIG,
19
+ TAESDXL_CONFIG,
20
+ TaesdVariantConfig,
21
+ )
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class Taef(nn.Module): # type: ignore[misc,name-defined]
27
+ """Base class for TAESD-family models.
28
+
29
+ Subclasses set the `_config` class variable to a `TaesdVariantConfig`.
30
+
31
+ Args:
32
+ None — instantiated via `from_pretrained_local(path)`.
33
+
34
+ Tensor layout:
35
+ All public methods use NHWC `mx.array`:
36
+ - latent shape: (B, H, W, latent_channels)
37
+ - image shape: (B, H*8, W*8, 3)
38
+
39
+ Value space:
40
+ - `decode()` returns float in [0, 1] (clipped on output).
41
+ - `decode_image()` returns uint8 in [0, 255].
42
+ - Latents are RAW (post-Clamp scale), values roughly in [-3, 3].
43
+ - `scale_latents` / `unscale_latents` convert between raw and [0, 1].
44
+ """
45
+
46
+ _config: TaesdVariantConfig
47
+
48
+ def __init__(self) -> None:
49
+ """Initialise decoder and encoder from the subclass `_config`."""
50
+ super().__init__()
51
+ self.decoder = make_decoder(self._config)
52
+ self.encoder = make_encoder(self._config)
53
+
54
+ @classmethod
55
+ def from_pretrained_local(
56
+ cls,
57
+ decoder_path: Path | str,
58
+ encoder_path: Path | str | None = None,
59
+ *,
60
+ dtype: mx.Dtype = mx.float32,
61
+ ) -> "Taef":
62
+ """Instantiate from already-converted MLX safetensors on disk.
63
+
64
+ Args:
65
+ decoder_path: path to converted MLX decoder safetensors.
66
+ encoder_path: optional path to converted MLX encoder safetensors (Phase 2).
67
+ dtype: target dtype for model weights. Default fp32 to keep parity tests sharp.
68
+
69
+ Returns:
70
+ A loaded `Taef` instance ready for `decode()`.
71
+ """
72
+ instance = cls()
73
+ d_weights = cast("dict[str, mx.array]", mx.load(str(decoder_path)))
74
+ weights_list: list[tuple[str, mx.array]] = [
75
+ (f"decoder.{k}", v) for k, v in d_weights.items()
76
+ ]
77
+ if encoder_path is not None and hasattr(instance, "encoder"):
78
+ e_weights = cast("dict[str, mx.array]", mx.load(str(encoder_path)))
79
+ weights_list.extend((f"encoder.{k}", v) for k, v in e_weights.items())
80
+ instance.load_weights(weights_list, strict=False)
81
+ if dtype is not mx.float32:
82
+ instance.set_dtype(dtype)
83
+ instance.eval()
84
+ return instance
85
+
86
+ @classmethod
87
+ def from_pretrained( # pragma: no cover
88
+ cls,
89
+ repo_id: str | None = None,
90
+ *,
91
+ dtype: mx.Dtype = mx.float32,
92
+ include_encoder: bool = True,
93
+ ) -> "Taef":
94
+ """Auto-download weights from HF Hub, convert to MLX, and load.
95
+
96
+ On first call, downloads upstream weights from `cls._config.hf_repo` and
97
+ caches converted output at ~/.cache/mlx-taef/. Subsequent calls return
98
+ instantly from cache.
99
+
100
+ Args:
101
+ repo_id: optional HF repo override. If set, must match `cls._config.hf_repo`.
102
+ dtype: target dtype for weights. Default fp32 to keep parity strict.
103
+ include_encoder: whether to also load the encoder side (default True).
104
+
105
+ Returns:
106
+ A loaded `Taef` instance.
107
+ """
108
+ from mlx_taef.download import get_or_convert
109
+
110
+ config = cls._config
111
+ if repo_id is not None and repo_id != config.hf_repo:
112
+ raise ValueError(
113
+ f"repo_id mismatch: requested {repo_id!r} but variant {config.name!r} "
114
+ f"uses {config.hf_repo!r}"
115
+ )
116
+ decoder_path = get_or_convert(config, role="decoder")
117
+ encoder_path = get_or_convert(config, role="encoder") if include_encoder else None
118
+ return cls.from_pretrained_local(decoder_path, encoder_path=encoder_path, dtype=dtype)
119
+
120
+ def decode(self, latents: mx.array) -> mx.array:
121
+ """Decode raw latents (NHWC) to image (NHWC, [0, 1] float).
122
+
123
+ Args:
124
+ latents: NHWC array with shape (B, H, W, latent_channels).
125
+
126
+ Returns:
127
+ NHWC float array with shape (B, H*8, W*8, 3), values in [0, 1].
128
+ """
129
+ out = self.decoder(latents)
130
+ return mx.clip(out, 0.0, 1.0)
131
+
132
+ def decode_image(self, latents: mx.array) -> mx.array:
133
+ """Decode raw latents to a uint8 NHWC image suitable for PIL/PNG.
134
+
135
+ Args:
136
+ latents: NHWC array with shape (B, H, W, latent_channels).
137
+
138
+ Returns:
139
+ NHWC uint8 array with shape (B, H*8, W*8, 3), values in [0, 255].
140
+ """
141
+ img_float = self.decode(latents)
142
+ return (img_float * 255.0).astype(mx.uint8)
143
+
144
+ def encode(self, image: mx.array) -> mx.array:
145
+ """Encode an NHWC RGB image (B, H, W, 3) in [0, 1] to a latent (B, H/8, W/8, latent_channels).
146
+
147
+ Args:
148
+ image: NHWC float array with shape (B, H, W, 3), values in [0, 1].
149
+
150
+ Returns:
151
+ NHWC latent array with shape (B, H/8, W/8, latent_channels).
152
+ """
153
+ return cast("mx.array", self.encoder(image))
154
+
155
+ def scale_latents(self, raw: mx.array) -> mx.array:
156
+ """Map raw latents to [0, 1] using config.latent_magnitude/shift.
157
+
158
+ Args:
159
+ raw: raw latent array, values roughly in [-3, 3].
160
+
161
+ Returns:
162
+ Scaled latent array clipped to [0, 1].
163
+ """
164
+ c = self._config
165
+ return mx.clip(raw / (2.0 * c.latent_magnitude) + c.latent_shift, 0.0, 1.0)
166
+
167
+ def unscale_latents(self, scaled: mx.array) -> mx.array:
168
+ """Inverse of scale_latents: [0, 1] -> raw.
169
+
170
+ Args:
171
+ scaled: latent array in [0, 1].
172
+
173
+ Returns:
174
+ Raw latent array.
175
+ """
176
+ c = self._config
177
+ return (scaled - c.latent_shift) * (2.0 * c.latent_magnitude)
178
+
179
+
180
+ class TAESD(Taef):
181
+ """TAESD for Stable Diffusion 1.x."""
182
+
183
+ _config = TAESD_CONFIG
184
+
185
+
186
+ class TAESDXL(Taef):
187
+ """TAESD for SDXL."""
188
+
189
+ _config = TAESDXL_CONFIG
190
+
191
+
192
+ class TAEF1(Taef):
193
+ """TAEF1 for FLUX.1."""
194
+
195
+ _config = TAEF1_CONFIG
196
+
197
+
198
+ class TAEF2(Taef):
199
+ """TAEF2 for FLUX.2 Klein."""
200
+
201
+ _config = TAEF2_CONFIG
202
+
203
+
204
+ __all__ = ["TAEF1", "TAEF2", "TAESD", "TAESDXL", "Taef"]
mlx_taef/cli.py ADDED
@@ -0,0 +1,92 @@
1
+ """mlx-taef command-line interface: convert / info / bench."""
2
+
3
+ import argparse
4
+ import logging
5
+ import time
6
+ from pathlib import Path
7
+
8
+ import mlx.core as mx
9
+
10
+ from mlx_taef.variants import ALL_VARIANTS
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def main(argv: list[str] | None = None) -> int:
16
+ """CLI entry point. Returns process exit code."""
17
+ parser = argparse.ArgumentParser(
18
+ prog="mlx-taef",
19
+ description="Tiny AutoEncoder family for diffusion on Apple MLX.",
20
+ )
21
+ sub = parser.add_subparsers(dest="cmd", required=True)
22
+
23
+ variant_names = [v.name for v in ALL_VARIANTS]
24
+
25
+ p_convert = sub.add_parser("convert", help="Download upstream weights and convert to MLX")
26
+ p_convert.add_argument("--variant", required=True, choices=variant_names)
27
+ p_convert.add_argument("--role", default="decoder", choices=["decoder", "encoder"])
28
+ p_convert.add_argument("--dst", required=True, type=Path, help="Output .safetensors path")
29
+
30
+ p_info = sub.add_parser("info", help="Print info about a converted MLX safetensors file")
31
+ p_info.add_argument("path", type=Path)
32
+
33
+ p_bench = sub.add_parser("bench", help="Decode benchmark on the current Mac")
34
+ p_bench.add_argument("--variant", default="taef2", choices=variant_names)
35
+
36
+ args = parser.parse_args(argv)
37
+ if args.cmd == "convert": # pragma: no cover
38
+ return _cmd_convert(args)
39
+ if args.cmd == "info":
40
+ return _cmd_info(args)
41
+ if args.cmd == "bench": # pragma: no cover
42
+ return _cmd_bench(args)
43
+ return 1 # pragma: no cover
44
+
45
+
46
+ def _cmd_convert(args: argparse.Namespace) -> int: # pragma: no cover
47
+ from mlx_taef.convert import convert_hf_decoder_to_mlx, convert_hf_encoder_to_mlx
48
+
49
+ config = next(v for v in ALL_VARIANTS if v.name == args.variant)
50
+ if args.role == "encoder":
51
+ convert_hf_encoder_to_mlx(out_path=args.dst, config=config)
52
+ else:
53
+ convert_hf_decoder_to_mlx(out_path=args.dst, config=config)
54
+ print(f"Wrote {args.dst}")
55
+ return 0
56
+
57
+
58
+ def _cmd_info(args: argparse.Namespace) -> int:
59
+ weights = mx.load(str(args.path))
60
+ print(f"File: {args.path}")
61
+ print(f"Total tensors: {len(weights)}")
62
+ total_params = sum(int(w.size) for w in weights.values()) # type: ignore[misc,union-attr]
63
+ print(f"Total params: {total_params:,}")
64
+ return 0
65
+
66
+
67
+ def _cmd_bench(args: argparse.Namespace) -> int: # pragma: no cover
68
+ from mlx_taef.api import TAEF1, TAEF2, TAESD, TAESDXL
69
+
70
+ cls_by_name = {"taesd": TAESD, "taesdxl": TAESDXL, "taef1": TAEF1, "taef2": TAEF2}
71
+ cls = cls_by_name[args.variant]
72
+ config = next(v for v in ALL_VARIANTS if v.name == args.variant)
73
+
74
+ model = cls.from_pretrained(include_encoder=False)
75
+ # 1024x1024 image with 8x downsample = 128x128 latent
76
+ latent = mx.random.normal((1, 128, 128, config.latent_channels)).astype(mx.float16)
77
+ mx.eval(latent)
78
+
79
+ # Warm-up
80
+ mx.eval(model.decode(latent))
81
+
82
+ times: list[float] = []
83
+ for _ in range(5):
84
+ start = time.perf_counter()
85
+ mx.eval(model.decode(latent))
86
+ times.append(time.perf_counter() - start)
87
+ median_ms = sorted(times)[len(times) // 2] * 1000
88
+ print(f"{args.variant} decode median: {median_ms:.1f} ms over {len(times)} runs")
89
+ return 0
90
+
91
+
92
+ __all__ = ["main"]
mlx_taef/convert.py ADDED
@@ -0,0 +1,201 @@
1
+ """HF safetensors -> MLX safetensors conversion.
2
+
3
+ Zero PyTorch dependency: reads source files with `safetensors.numpy.load_file`
4
+ and writes MLX safetensors directly. Runtime users never need torch.
5
+ """
6
+
7
+ import logging
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ import mlx.core as mx
12
+ import numpy as np
13
+ from huggingface_hub import hf_hub_download
14
+ from safetensors.numpy import load_file as safetensors_load_numpy
15
+
16
+ from mlx_taef.model import make_decoder
17
+ from mlx_taef.variants import TaesdVariantConfig
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def convert_diffusers_to_sequential(
23
+ sd: dict[str, Any],
24
+ *,
25
+ role: str,
26
+ ) -> dict[str, np.ndarray]:
27
+ """Map Diffusers-VAE keys to upstream Sequential-key format.
28
+
29
+ Per the TAEF2 model card, the decoder gets a +1 index shift because the
30
+ Diffusers VAE prepends one layer that the upstream Sequential decoder
31
+ doesn't have. Encoder keys have no offset.
32
+
33
+ Args:
34
+ sd: source state dict with Diffusers keys like 'decoder.layers.0.weight'.
35
+ role: 'decoder' (apply +1 offset) or 'encoder' (no offset).
36
+
37
+ Returns:
38
+ State dict with upstream-Sequential keys like '0.weight', '1.weight'.
39
+ Keys not matching role-prefix are filtered out.
40
+ """
41
+ out: dict[str, np.ndarray] = {}
42
+ prefix = f"{role}."
43
+ for k, v in sd.items():
44
+ if not k.startswith(prefix):
45
+ continue
46
+ suffix = k[len(prefix) :]
47
+ if suffix.startswith("layers."):
48
+ parts = suffix.split(".")
49
+ idx = int(parts[1])
50
+ if role == "decoder":
51
+ idx += 1
52
+ new_key = f"{idx}." + ".".join(parts[2:])
53
+ else: # pragma: no cover
54
+ new_key = suffix
55
+ out[new_key] = v
56
+ return out
57
+
58
+
59
+ def _load_role_state_dict( # pragma: no cover
60
+ config: TaesdVariantConfig,
61
+ role: str,
62
+ ) -> dict[str, np.ndarray]:
63
+ """Download and load weights for (variant, role) into a Sequential-keyed dict."""
64
+ if config.key_format == "diffusers":
65
+ if config.hf_filename is None:
66
+ raise ValueError(f"Diffusers variant {config.name!r} has no hf_filename")
67
+ path = hf_hub_download(repo_id=config.hf_repo, filename=config.hf_filename)
68
+ full_sd = safetensors_load_numpy(path)
69
+ return convert_diffusers_to_sequential(full_sd, role=role)
70
+ if config.key_format == "upstream":
71
+ filename = config.hf_decoder_filename if role == "decoder" else config.hf_encoder_filename
72
+ if filename is None:
73
+ raise ValueError(f"Upstream variant {config.name!r} has no {role} filename")
74
+ path = hf_hub_download(repo_id=config.hf_repo, filename=filename)
75
+ return safetensors_load_numpy(path)
76
+ raise ValueError(f"Unknown key_format: {config.key_format!r}")
77
+
78
+
79
+ def convert_hf_decoder_to_mlx( # pragma: no cover
80
+ *,
81
+ out_path: Path | str,
82
+ config: TaesdVariantConfig,
83
+ ) -> None:
84
+ """Download upstream decoder weights, convert to MLX safetensors at `out_path`.
85
+
86
+ Handles both upstream-Sequential and Diffusers key formats. Transposes
87
+ Conv2d weights from NCHW to NHWC. Writes the result with MLX-flat keys
88
+ like 'layers.0.weight', 'layers.1.weight', ...
89
+
90
+ Args:
91
+ out_path: where to write the MLX safetensors file.
92
+ config: variant configuration.
93
+ """
94
+ sd = _load_role_state_dict(config, role="decoder")
95
+ decoder = make_decoder(config)
96
+ expected = _flatten_module_param_shapes(decoder)
97
+ converted = _build_mlx_state_dict(sd, expected_shapes=expected)
98
+ mx.save_safetensors(str(out_path), converted)
99
+
100
+
101
+ def _sequential_key_to_mlx(src_key: str) -> str:
102
+ """Convert an upstream-Sequential key to an MLX-flat dotted key.
103
+
104
+ MLX's `nn.Sequential` stores its children under `.layers`, so every
105
+ integer path segment (after the first top-level layer index) must be
106
+ wrapped as `layers.<N>` rather than a bare `<N>`.
107
+
108
+ Examples::
109
+
110
+ "1.weight" -> "layers.1.weight"
111
+ "3.conv.0.weight" -> "layers.3.conv.layers.0.weight"
112
+ "3.pool.1.bias" -> "layers.3.pool.layers.1.bias"
113
+
114
+ Args:
115
+ src_key: upstream-Sequential key like '3.conv.0.weight'.
116
+
117
+ Returns:
118
+ MLX-flat dotted key like 'layers.3.conv.layers.0.weight'.
119
+ """
120
+ parts = src_key.split(".")
121
+ out = ["layers", parts[0]]
122
+ for part in parts[1:]:
123
+ if part.isdigit():
124
+ out.extend(["layers", part])
125
+ else:
126
+ out.append(part)
127
+ return ".".join(out)
128
+
129
+
130
+ def _build_mlx_state_dict(
131
+ sd: dict[str, np.ndarray],
132
+ *,
133
+ expected_shapes: dict[str, tuple[int, ...]],
134
+ ) -> dict[str, mx.array]:
135
+ """Apply NCHW->NHWC transpose for Conv weights and prefix keys with 'layers.'."""
136
+ converted: dict[str, mx.array] = {}
137
+ for src_key, arr in sd.items():
138
+ dst_key = _sequential_key_to_mlx(src_key)
139
+ if dst_key not in expected_shapes:
140
+ # Skip keys that don't map to the MLX module structure
141
+ # (e.g., extra Diffusers-specific keys we don't need)
142
+ continue
143
+ # Conv2d weight transpose NCHW (out, in, kH, kW) -> NHWC (out, kH, kW, in)
144
+ # Detected when source is 4D and expected MLX shape matches the transposed shape.
145
+ if arr.ndim == 4 and expected_shapes[dst_key] == (
146
+ arr.shape[0],
147
+ arr.shape[2],
148
+ arr.shape[3],
149
+ arr.shape[1],
150
+ ):
151
+ arr = np.transpose(arr, (0, 2, 3, 1)).copy()
152
+ converted[dst_key] = mx.array(arr)
153
+ return converted
154
+
155
+
156
+ def _flatten_module_param_shapes(module: Any, prefix: str = "") -> dict[str, tuple[int, ...]]:
157
+ """Walk module.parameters() and return a flat dict of dotted-key -> shape."""
158
+ out: dict[str, tuple[int, ...]] = {}
159
+
160
+ def _walk(obj: Any, p: str) -> None:
161
+ if isinstance(obj, dict):
162
+ for k, v in obj.items():
163
+ _walk(v, f"{p}.{k}" if p else k)
164
+ elif isinstance(obj, list):
165
+ for i, item in enumerate(obj):
166
+ _walk(item, f"{p}.{i}")
167
+ elif hasattr(obj, "shape"):
168
+ out[p] = tuple(obj.shape)
169
+
170
+ _walk(module.parameters(), prefix)
171
+ return out
172
+
173
+
174
+ def convert_hf_encoder_to_mlx( # pragma: no cover
175
+ *,
176
+ out_path: Path | str,
177
+ config: TaesdVariantConfig,
178
+ ) -> None:
179
+ """Download upstream encoder weights, convert to MLX safetensors at `out_path`.
180
+
181
+ Mirrors `convert_hf_decoder_to_mlx` but introspects via `make_encoder` so
182
+ Conv weights are transposed against the correct shapes.
183
+
184
+ Args:
185
+ out_path: where to write the MLX safetensors file.
186
+ config: variant configuration.
187
+ """
188
+ from mlx_taef.model import make_encoder
189
+
190
+ sd = _load_role_state_dict(config, role="encoder")
191
+ encoder = make_encoder(config)
192
+ expected = _flatten_module_param_shapes(encoder)
193
+ converted = _build_mlx_state_dict(sd, expected_shapes=expected)
194
+ mx.save_safetensors(str(out_path), converted)
195
+
196
+
197
+ __all__ = [
198
+ "convert_diffusers_to_sequential",
199
+ "convert_hf_decoder_to_mlx",
200
+ "convert_hf_encoder_to_mlx",
201
+ ]
mlx_taef/download.py ADDED
@@ -0,0 +1,45 @@
1
+ """HF Hub auto-download + cache. Zero PyTorch dependency at runtime."""
2
+
3
+ import logging
4
+ from pathlib import Path
5
+
6
+ from mlx_taef.convert import convert_hf_decoder_to_mlx, convert_hf_encoder_to_mlx
7
+ from mlx_taef.variants import TaesdVariantConfig
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ CACHE_ROOT = Path.home() / ".cache" / "mlx-taef"
12
+
13
+
14
+ def get_or_convert(config: TaesdVariantConfig, *, role: str = "decoder") -> Path:
15
+ """Return the local path to converted MLX weights for (variant, role).
16
+
17
+ On cache miss, triggers the full conversion pipeline (HF download + key
18
+ remap + NHWC transpose + safetensors write). Subsequent calls return the
19
+ cached path without any network access.
20
+
21
+ Args:
22
+ config: variant configuration (selects HF repo + filename).
23
+ role: 'decoder' (default) or 'encoder'.
24
+
25
+ Returns:
26
+ Local filesystem path to the MLX safetensors file.
27
+ """
28
+ cache_dir = CACHE_ROOT / config.name
29
+ cache_dir.mkdir(parents=True, exist_ok=True)
30
+ out_path = cache_dir / f"{config.name}_{role}.safetensors"
31
+ if out_path.exists():
32
+ logger.debug("Using cached weights at %s", out_path)
33
+ return out_path
34
+
35
+ logger.info("Downloading + converting %s %s weights from %s", config.name, role, config.hf_repo)
36
+ if role == "decoder":
37
+ convert_hf_decoder_to_mlx(out_path=out_path, config=config)
38
+ elif role == "encoder":
39
+ convert_hf_encoder_to_mlx(out_path=out_path, config=config)
40
+ else:
41
+ raise ValueError(f"role must be 'decoder' or 'encoder', got {role!r}")
42
+ return out_path
43
+
44
+
45
+ __all__ = ["CACHE_ROOT", "get_or_convert"]
@@ -0,0 +1 @@
1
+ """Optional integrations for mlx-taef (e.g. mflux)."""
@@ -0,0 +1,168 @@
1
+ """mflux integration for mlx-taef.
2
+
3
+ Provides:
4
+ - `unpack_flux2_latent`: convert mflux's packed FLUX.2 latents to TAEF2-compatible NHWC.
5
+ - `LivePreviewCallback`: drop-in mflux callback that writes preview PNGs every N steps.
6
+
7
+ Install with: `pip install "mlx-taef[mflux]"`.
8
+ """
9
+
10
+ import logging
11
+ from pathlib import Path
12
+
13
+ import mlx.core as mx
14
+ import numpy as np
15
+
16
+ try:
17
+ from mflux.callbacks.callback import InLoopCallback # type: ignore[import-untyped]
18
+ except ImportError as e: # pragma: no cover
19
+ raise ImportError(
20
+ "mflux is required for mlx_taef.integrations.mflux. "
21
+ "Install with: pip install 'mlx-taef[mflux]'"
22
+ ) from e
23
+
24
+ from mlx_taef.api import TAEF1, TAEF2, Taef
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ def unpack_flux2_latent(
30
+ packed: mx.array,
31
+ *,
32
+ latent_height: int,
33
+ latent_width: int,
34
+ bn_mean: mx.array | None = None,
35
+ bn_var: mx.array | None = None,
36
+ bn_eps: float = 1e-4,
37
+ ) -> mx.array:
38
+ """Unpack mflux's packed FLUX.2 latent into NHWC `(B, lH*2, lW*2, 32)` for TAEF2.
39
+
40
+ Mirrors mflux's `Flux2VAE.decode_packed_latents` preprocessing:
41
+ 1. BN denormalize: latents = packed * sqrt(var + eps) + mean (128-channel stats)
42
+ 2. Unpatchify: (B, 128, lH, lW) -> (B, 32, lH*2, lW*2)
43
+ 3. NCHW -> NHWC transpose.
44
+
45
+ Without `bn_mean`/`bn_var`, the helper assumes identity BN (correct shape,
46
+ slight value offset). Preview structure will still be visible but colors
47
+ may shift. For exact preview, pass the BN stats from the loaded Flux2VAE.
48
+
49
+ Args:
50
+ packed: in-loop latent from mflux, shape (B, lH*lW, 128).
51
+ latent_height: latent spatial height (image_height // 16).
52
+ latent_width: latent spatial width.
53
+ bn_mean: optional BN running_mean for exact value recovery (128 elements).
54
+ bn_var: optional BN running_var for exact value recovery (128 elements).
55
+ bn_eps: BN epsilon. Matches mflux's Flux2BatchNormStats default of 1e-4,
56
+ verified at mflux 0.17.5.
57
+
58
+ Returns:
59
+ NHWC tensor of shape (B, latent_height*2, latent_width*2, 32) — ready
60
+ for `TAEF2.decode()`.
61
+
62
+ See `notes/mflux-latent-layout.md` for the analysis behind this transform.
63
+ """
64
+ b, _, c = packed.shape
65
+ if c != 128:
66
+ raise ValueError(f"Expected 128-channel packed latent, got {c}")
67
+
68
+ # Step 1: reshape and transpose to NCHW: (B, lH*lW, 128) -> (B, lH, lW, 128) -> (B, 128, lH, lW)
69
+ latents = packed.reshape(b, latent_height, latent_width, c).transpose(0, 3, 1, 2)
70
+
71
+ # Step 2: BN denormalize
72
+ if bn_mean is not None and bn_var is not None:
73
+ mean = bn_mean.reshape(1, -1, 1, 1)
74
+ std = mx.sqrt(bn_var.reshape(1, -1, 1, 1) + bn_eps)
75
+ latents = latents * std + mean
76
+ # else: identity BN (mean=0, var=1)
77
+
78
+ # Step 3: Unpatchify (B, 128, lH, lW) -> (B, 32, lH*2, lW*2)
79
+ batch, _, h, w = latents.shape
80
+ latents = latents.reshape(batch, 32, 2, 2, h, w)
81
+ latents = latents.transpose(0, 1, 4, 2, 5, 3)
82
+ latents = latents.reshape(batch, 32, h * 2, w * 2)
83
+
84
+ # Step 4: NCHW -> NHWC for TAEF2
85
+ return latents.transpose(0, 2, 3, 1)
86
+
87
+
88
+ class LivePreviewCallback(InLoopCallback): # type: ignore[misc]
89
+ """mflux callback that writes a low-quality preview PNG every N iterations.
90
+
91
+ Drops into `mflux.Flux2Klein.generate_image(callbacks=[...])`. On the
92
+ iteration indices that match `every`, unpacks the in-flight latent, runs
93
+ TAEF2 decode (via `TAEF1` for FLUX.1), and writes a PIL PNG to disk.
94
+
95
+ Args:
96
+ variant: 'taef1' (for FLUX.1 latents) or 'taef2' (for FLUX.2 Klein).
97
+ every: emit a preview every Nth iteration. Default 5.
98
+ save_to: filesystem path to write previews. Overwritten each emission.
99
+ latent_height: latent spatial height; for 512x512 image with FLUX.2 Klein, this is 32.
100
+ latent_width: latent spatial width.
101
+ bn_mean: optional BN running_mean for TAEF2 (see `unpack_flux2_latent`).
102
+ bn_var: optional BN running_var for TAEF2.
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ *,
108
+ variant: str = "taef2",
109
+ every: int = 5,
110
+ save_to: str | Path = "preview.png",
111
+ latent_height: int = 32,
112
+ latent_width: int = 32,
113
+ bn_mean: mx.array | None = None,
114
+ bn_var: mx.array | None = None,
115
+ ) -> None:
116
+ """Initialise LivePreviewCallback. See class docstring for argument descriptions."""
117
+ if variant == "taef1": # pragma: no cover
118
+ self.model: Taef = TAEF1.from_pretrained(include_encoder=False)
119
+ elif variant == "taef2":
120
+ self.model = TAEF2.from_pretrained(include_encoder=False)
121
+ else: # pragma: no cover
122
+ raise ValueError(f"variant must be 'taef1' or 'taef2', got {variant!r}")
123
+ self.every = every
124
+ self.save_to = Path(save_to)
125
+ self.latent_height = latent_height
126
+ self.latent_width = latent_width
127
+ self.bn_mean = bn_mean
128
+ self.bn_var = bn_var
129
+ self._iter = 0
130
+
131
+ def call_in_loop(
132
+ self,
133
+ t: object,
134
+ seed: object,
135
+ prompt: object,
136
+ latents: mx.array,
137
+ config: object,
138
+ time_steps: object,
139
+ ) -> None:
140
+ """Decode latent + save PNG every Nth iteration (counted by our own iter, not `t`)."""
141
+ idx = self._iter
142
+ self._iter += 1
143
+ if idx % self.every != 0:
144
+ return # pragma: no cover
145
+
146
+ if self.model._config.name == "taef2":
147
+ unpacked = unpack_flux2_latent(
148
+ latents,
149
+ latent_height=self.latent_height,
150
+ latent_width=self.latent_width,
151
+ bn_mean=self.bn_mean,
152
+ bn_var=self.bn_var,
153
+ )
154
+ else: # pragma: no cover
155
+ # TAEF1 / FLUX.1: latents come in NCHW (B, 16, H, W). Transpose to NHWC.
156
+ unpacked = mx.transpose(latents, (0, 2, 3, 1)) if latents.shape[1] == 16 else latents
157
+
158
+ img = self.model.decode_image(unpacked)
159
+ self._save_png(img[0])
160
+
161
+ def _save_png(self, img_nhwc_uint8: mx.array) -> None:
162
+ from PIL import Image
163
+
164
+ arr = np.array(img_nhwc_uint8)
165
+ Image.fromarray(arr).save(self.save_to)
166
+
167
+
168
+ __all__ = ["LivePreviewCallback", "unpack_flux2_latent"]
mlx_taef/model.py ADDED
@@ -0,0 +1,197 @@
1
+ """MLX layer modules for the TAESD-family of tiny autoencoders.
2
+
3
+ Layers are direct ports of the upstream PyTorch implementations at
4
+ https://github.com/madebyollin/taesd (MIT licensed).
5
+
6
+ All public layers accept and return NHWC `mx.array` tensors.
7
+ """
8
+
9
+ import mlx.core as mx
10
+ import mlx.nn as nn
11
+
12
+ from mlx_taef.variants import TaesdVariantConfig
13
+
14
+
15
+ class Clamp(nn.Module): # type: ignore[misc,name-defined]
16
+ """Soft clamp via `tanh(x / 3) * 3` — bounds output to [-3, 3].
17
+
18
+ Direct port of upstream `taesd.Clamp`. Used at the start of the decoder
19
+ to bound latent magnitudes before the first conv.
20
+ """
21
+
22
+ def __call__(self, x: mx.array) -> mx.array:
23
+ """Apply the soft clamp.
24
+
25
+ Args:
26
+ x: input tensor of any shape.
27
+
28
+ Returns:
29
+ Element-wise `tanh(x / 3) * 3` of the same shape.
30
+ """
31
+ return mx.tanh(x / 3.0) * 3.0
32
+
33
+
34
+ def make_conv(n_in: int, n_out: int, *, stride: int = 1, bias: bool = True) -> "nn.Conv2d": # type: ignore[name-defined]
35
+ """3x3 Conv2d with padding=1 — matches upstream `taesd.conv(n_in, n_out, **kwargs)`.
36
+
37
+ Args:
38
+ n_in: input channel count.
39
+ n_out: output channel count.
40
+ stride: spatial stride for the convolution (default 1).
41
+ bias: whether to include a learned bias term (default True).
42
+
43
+ Returns:
44
+ Configured `mlx.nn.Conv2d` with kernel_size=3, padding=1.
45
+ """
46
+ return nn.Conv2d(n_in, n_out, kernel_size=3, stride=stride, padding=1, bias=bias) # type: ignore[attr-defined]
47
+
48
+
49
+ class _Identity(nn.Module): # type: ignore[misc,name-defined]
50
+ """No-op layer used as the skip connection when in/out channels match."""
51
+
52
+ def __call__(self, x: mx.array) -> mx.array:
53
+ return x
54
+
55
+
56
+ class Block(nn.Module): # type: ignore[misc,name-defined]
57
+ """Residual block: `ReLU(Conv -> ReLU -> Conv -> ReLU -> Conv + skip)`.
58
+
59
+ Direct port of upstream `taesd.Block`. Optionally adds a parallel
60
+ `midblock_gn` pool branch (used by the flux_2 TAEF2 variant).
61
+
62
+ Args:
63
+ n_in: input channel count.
64
+ n_out: output channel count.
65
+ use_midblock_gn: when True, prepend `x = x + pool(x)` where `pool` is
66
+ a 1x1 conv expand -> GroupNorm -> ReLU -> 1x1 conv contract chain.
67
+ Used by TAEF2's flux_2 arch_variant. Default False.
68
+ """
69
+
70
+ def __init__(self, n_in: int, n_out: int, *, use_midblock_gn: bool = False) -> None:
71
+ """Initialise Block layers.
72
+
73
+ Args:
74
+ n_in: input channel count.
75
+ n_out: output channel count.
76
+ use_midblock_gn: when True, prepend the midblock GroupNorm pool
77
+ branch before the main conv path. Default False.
78
+ """
79
+ super().__init__()
80
+ self.conv = nn.Sequential( # type: ignore[attr-defined]
81
+ make_conv(n_in, n_out),
82
+ nn.ReLU(), # type: ignore[attr-defined]
83
+ make_conv(n_out, n_out),
84
+ nn.ReLU(), # type: ignore[attr-defined]
85
+ make_conv(n_out, n_out),
86
+ )
87
+ self.skip = (
88
+ nn.Conv2d(n_in, n_out, kernel_size=1, bias=False) if n_in != n_out else _Identity() # type: ignore[attr-defined]
89
+ )
90
+ self.fuse = nn.ReLU() # type: ignore[attr-defined]
91
+ self.pool = None
92
+ if use_midblock_gn:
93
+ n_gn = n_in * 4
94
+ # CRITICAL: pytorch_compatible=True is required for parity. MLX
95
+ # defaults to a different group-channel ordering than PyTorch's
96
+ # nn.GroupNorm.
97
+ self.pool = nn.Sequential( # type: ignore[attr-defined]
98
+ nn.Conv2d(n_in, n_gn, kernel_size=1, bias=False), # type: ignore[attr-defined]
99
+ nn.GroupNorm(num_groups=4, dims=n_gn, pytorch_compatible=True), # type: ignore[attr-defined]
100
+ nn.ReLU(), # type: ignore[attr-defined]
101
+ nn.Conv2d(n_gn, n_in, kernel_size=1, bias=False), # type: ignore[attr-defined]
102
+ )
103
+
104
+ def __call__(self, x: mx.array) -> mx.array:
105
+ """Apply the residual block.
106
+
107
+ Args:
108
+ x: NHWC input tensor with `n_in` channels.
109
+
110
+ Returns:
111
+ NHWC output tensor with `n_out` channels.
112
+ """
113
+ if self.pool is not None:
114
+ x = x + self.pool(x)
115
+ return self.fuse(self.conv(x) + self.skip(x)) # type: ignore[no-any-return]
116
+
117
+
118
+ def make_decoder(config: TaesdVariantConfig) -> nn.Sequential: # type: ignore[name-defined]
119
+ """Build the decoder for a TAESD-family variant.
120
+
121
+ Mirrors upstream `taesd.Decoder(latent_channels, use_midblock_gn)`:
122
+ Clamp -> Conv -> ReLU -> Block x3 -> Up -> Conv -> Block x3 -> Up -> Conv ->
123
+ Block x3 -> Up -> Conv -> Block -> Conv.
124
+
125
+ For arch_variant="flux_2" (TAEF2), the first three Blocks have
126
+ `use_midblock_gn=True`; the rest are vanilla.
127
+
128
+ Args:
129
+ config: variant configuration that selects `latent_channels` and
130
+ whether the first three Blocks use midblock_gn.
131
+
132
+ Returns:
133
+ `nn.Sequential` decoder ready for weight loading.
134
+ """
135
+ mb = config.use_midblock_gn
136
+ layers: list[nn.Module] = [ # type: ignore[name-defined]
137
+ Clamp(),
138
+ make_conv(config.latent_channels, 64),
139
+ nn.ReLU(), # type: ignore[attr-defined]
140
+ Block(64, 64, use_midblock_gn=mb),
141
+ Block(64, 64, use_midblock_gn=mb),
142
+ Block(64, 64, use_midblock_gn=mb),
143
+ nn.Upsample(scale_factor=2, mode="nearest"), # type: ignore[attr-defined]
144
+ make_conv(64, 64, bias=False),
145
+ Block(64, 64),
146
+ Block(64, 64),
147
+ Block(64, 64),
148
+ nn.Upsample(scale_factor=2, mode="nearest"), # type: ignore[attr-defined]
149
+ make_conv(64, 64, bias=False),
150
+ Block(64, 64),
151
+ Block(64, 64),
152
+ Block(64, 64),
153
+ nn.Upsample(scale_factor=2, mode="nearest"), # type: ignore[attr-defined]
154
+ make_conv(64, 64, bias=False),
155
+ Block(64, 64),
156
+ make_conv(64, 3),
157
+ ]
158
+ return nn.Sequential(*layers) # type: ignore[attr-defined]
159
+
160
+
161
+ def make_encoder(config: TaesdVariantConfig) -> nn.Sequential: # type: ignore[name-defined]
162
+ """Build the encoder for a TAESD-family variant.
163
+
164
+ Mirrors upstream `taesd.Encoder(latent_channels, use_midblock_gn)`:
165
+ Conv -> Block -> (StridedConv -> Block x3) x3 -> Conv.
166
+ Three strided convs each halve spatial dims (8x downsample total).
167
+
168
+ For arch_variant="flux_2" (TAEF2), the LAST three Blocks use `use_midblock_gn=True`.
169
+
170
+ Args:
171
+ config: variant configuration that selects `latent_channels` and `arch_variant`.
172
+
173
+ Returns:
174
+ `nn.Sequential` encoder ready for weight loading.
175
+ """
176
+ mb = config.use_midblock_gn
177
+ layers: list[nn.Module] = [ # type: ignore[name-defined]
178
+ make_conv(3, 64),
179
+ Block(64, 64),
180
+ make_conv(64, 64, stride=2, bias=False),
181
+ Block(64, 64),
182
+ Block(64, 64),
183
+ Block(64, 64),
184
+ make_conv(64, 64, stride=2, bias=False),
185
+ Block(64, 64),
186
+ Block(64, 64),
187
+ Block(64, 64),
188
+ make_conv(64, 64, stride=2, bias=False),
189
+ Block(64, 64, use_midblock_gn=mb),
190
+ Block(64, 64, use_midblock_gn=mb),
191
+ Block(64, 64, use_midblock_gn=mb),
192
+ make_conv(64, config.latent_channels),
193
+ ]
194
+ return nn.Sequential(*layers) # type: ignore[attr-defined]
195
+
196
+
197
+ __all__ = ["Block", "Clamp", "make_conv", "make_decoder", "make_encoder"]
mlx_taef/py.typed ADDED
File without changes
mlx_taef/variants.py ADDED
@@ -0,0 +1,115 @@
1
+ """Variant configurations for the TAESD-family of tiny autoencoders.
2
+
3
+ Field reference:
4
+ - `key_format`: "upstream" means the HF safetensors uses upstream Sequential keys
5
+ like "0.weight", "1.weight"; "diffusers" means the HF safetensors uses Diffusers
6
+ VAE keys like "encoder.layers.0.weight" and requires a +1 decoder index offset.
7
+ - `arch_variant`: None | "flux_2" | "f32". "flux_2" enables midblock_gn pool
8
+ branches in three blocks of encoder and decoder. "f32" is reserved for
9
+ TAESANA (future).
10
+ - `latent_magnitude`, `latent_shift`: from upstream TAESD.scale_latents /
11
+ unscale_latents.
12
+ - `hf_filename` (Diffusers single-file format) vs `hf_decoder_filename` /
13
+ `hf_encoder_filename` (upstream two-file format) — depends on the actual
14
+ repo layout for each variant.
15
+ """
16
+
17
+ import logging
18
+ from dataclasses import dataclass
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ @dataclass(frozen=True, slots=True, kw_only=True)
24
+ class TaesdVariantConfig:
25
+ """Configuration for a single TAESD-family variant.
26
+
27
+ Attributes:
28
+ name: short identifier like "taef2".
29
+ latent_channels: latent channel count (4 for SD1/SDXL, 16 for FLUX.1, 32 for FLUX.2 Klein).
30
+ arch_variant: None for the standard arch, "flux_2" for TAEF2's midblock_gn, "f32" for TAESANA.
31
+ key_format: "upstream" (Sequential keys) or "diffusers" (requires remap).
32
+ hf_repo: e.g. "madebyollin/taef2".
33
+ hf_filename: filename within hf_repo when key_format == "diffusers" (single file).
34
+ hf_decoder_filename: decoder filename when key_format == "upstream".
35
+ hf_encoder_filename: encoder filename when key_format == "upstream".
36
+ latent_magnitude: scale factor for scale_latents/unscale_latents.
37
+ latent_shift: shift factor for scale_latents/unscale_latents.
38
+ """
39
+
40
+ name: str
41
+ latent_channels: int
42
+ arch_variant: str | None
43
+ key_format: str
44
+ hf_repo: str
45
+ hf_filename: str | None
46
+ hf_decoder_filename: str | None
47
+ hf_encoder_filename: str | None
48
+ latent_magnitude: float = 3.0
49
+ latent_shift: float = 0.5
50
+
51
+ @property
52
+ def use_midblock_gn(self) -> bool:
53
+ """Whether the variant's Block layers use the midblock GroupNorm pool branch."""
54
+ return self.arch_variant == "flux_2"
55
+
56
+
57
+ TAESD_CONFIG = TaesdVariantConfig(
58
+ name="taesd",
59
+ latent_channels=4,
60
+ arch_variant=None,
61
+ key_format="upstream",
62
+ hf_repo="madebyollin/taesd",
63
+ hf_filename=None,
64
+ hf_decoder_filename="taesd_decoder.safetensors",
65
+ hf_encoder_filename="taesd_encoder.safetensors",
66
+ )
67
+
68
+ TAESDXL_CONFIG = TaesdVariantConfig(
69
+ name="taesdxl",
70
+ latent_channels=4,
71
+ arch_variant=None,
72
+ key_format="upstream",
73
+ hf_repo="madebyollin/taesdxl",
74
+ hf_filename=None,
75
+ hf_decoder_filename="taesdxl_decoder.safetensors",
76
+ hf_encoder_filename="taesdxl_encoder.safetensors",
77
+ )
78
+
79
+ TAEF1_CONFIG = TaesdVariantConfig(
80
+ name="taef1",
81
+ latent_channels=16,
82
+ arch_variant=None,
83
+ key_format="diffusers",
84
+ hf_repo="madebyollin/taef1",
85
+ hf_filename="diffusion_pytorch_model.safetensors",
86
+ hf_decoder_filename=None,
87
+ hf_encoder_filename=None,
88
+ )
89
+
90
+ TAEF2_CONFIG = TaesdVariantConfig(
91
+ name="taef2",
92
+ latent_channels=32,
93
+ arch_variant="flux_2",
94
+ key_format="diffusers",
95
+ hf_repo="madebyollin/taef2",
96
+ hf_filename="taef2.safetensors",
97
+ hf_decoder_filename=None,
98
+ hf_encoder_filename=None,
99
+ )
100
+
101
+ ALL_VARIANTS: tuple[TaesdVariantConfig, ...] = (
102
+ TAESD_CONFIG,
103
+ TAESDXL_CONFIG,
104
+ TAEF1_CONFIG,
105
+ TAEF2_CONFIG,
106
+ )
107
+
108
+ __all__ = [
109
+ "ALL_VARIANTS",
110
+ "TAEF1_CONFIG",
111
+ "TAEF2_CONFIG",
112
+ "TAESDXL_CONFIG",
113
+ "TAESD_CONFIG",
114
+ "TaesdVariantConfig",
115
+ ]
@@ -0,0 +1,125 @@
1
+ Metadata-Version: 2.4
2
+ Name: mlx-taef
3
+ Version: 0.1.0
4
+ Summary: Tiny AutoEncoders for diffusion (TAESD family) on Apple MLX.
5
+ Project-URL: Homepage, https://github.com/ionden/mlx-taef
6
+ Project-URL: Documentation, https://ionden.github.io/mlx-taef
7
+ Project-URL: Issues, https://github.com/ionden/mlx-taef/issues
8
+ Project-URL: Changelog, https://github.com/ionden/mlx-taef/blob/main/CHANGELOG.md
9
+ Author-email: Denis Ineshin <denis.ineshin@gmail.com>
10
+ License-Expression: MIT
11
+ Keywords: apple-silicon,autoencoder,flux,mlx,stable-diffusion,taesd
12
+ Classifier: Development Status :: 3 - Alpha
13
+ Classifier: Intended Audience :: Developers
14
+ Classifier: Intended Audience :: Science/Research
15
+ Classifier: License :: OSI Approved :: MIT License
16
+ Classifier: Operating System :: MacOS :: MacOS X
17
+ Classifier: Programming Language :: Python :: 3 :: Only
18
+ Classifier: Programming Language :: Python :: 3.11
19
+ Classifier: Programming Language :: Python :: 3.12
20
+ Classifier: Programming Language :: Python :: 3.13
21
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
22
+ Classifier: Typing :: Typed
23
+ Requires-Python: >=3.11
24
+ Requires-Dist: huggingface-hub>=0.24
25
+ Requires-Dist: mlx>=0.20
26
+ Requires-Dist: numpy>=1.26
27
+ Requires-Dist: safetensors>=0.4
28
+ Provides-Extra: image
29
+ Requires-Dist: pillow>=10.0; extra == 'image'
30
+ Provides-Extra: mflux
31
+ Requires-Dist: mflux>=0.17; extra == 'mflux'
32
+ Description-Content-Type: text/markdown
33
+
34
+ # mlx-taef
35
+
36
+ Tiny AutoEncoders for diffusion latents on Apple Silicon, in pure MLX.
37
+
38
+ `mlx-taef` is the first MLX port of the TAESD family — TAESD (SD1.x), TAESDXL (SDXL), TAEF1 (FLUX.1), TAEF2 (FLUX.2 Klein) — distilled mini-autoencoders that decode diffusion latents to RGB in milliseconds using a few-MB model instead of multi-GB full VAEs.
39
+
40
+ Use it for:
41
+ - **Live previews** during long generations on Mac — see each step refresh in <100 ms instead of waiting 30 s for the full VAE.
42
+ - **Low-memory fallbacks** when the full VAE OOMs on 16 GB Macs (TAEF2 peaks at ~1 GB for 1024×1024 vs ~9.6 GB for the full Flux VAE).
43
+ - **Quick latent inspection** in notebooks and ML research.
44
+
45
+ ```python
46
+ import mlx.core as mx
47
+ from mlx_taef import TAEF2
48
+
49
+ taef = TAEF2.from_pretrained() # downloads + converts on first call
50
+ img = taef.decode(latents) # NHWC float in [0, 1]
51
+ img_uint8 = taef.decode_image(latents) # uint8 NHWC ready for PIL
52
+ ```
53
+
54
+ ## Install
55
+
56
+ ```bash
57
+ pip install mlx-taef
58
+ # With mflux preview callback:
59
+ pip install "mlx-taef[mflux]"
60
+ ```
61
+
62
+ Requires Python ≥ 3.11 and Apple Silicon. Runtime install has **zero PyTorch dependency**.
63
+
64
+ ## Variants
65
+
66
+ | Variant | latent_channels | For | HF source |
67
+ |---|---|---|---|
68
+ | `TAESD` | 4 | Stable Diffusion 1.x | [madebyollin/taesd](https://huggingface.co/madebyollin/taesd) |
69
+ | `TAESDXL` | 4 | Stable Diffusion XL | [madebyollin/taesdxl](https://huggingface.co/madebyollin/taesdxl) |
70
+ | `TAEF1` | 16 | FLUX.1 | [madebyollin/taef1](https://huggingface.co/madebyollin/taef1) |
71
+ | `TAEF2` | 32 | FLUX.2 Klein | [madebyollin/taef2](https://huggingface.co/madebyollin/taef2) |
72
+
73
+ All four share one API.
74
+
75
+ ## Benchmarks (M1 Max, fp16)
76
+
77
+ | Metric | TAEF2 (this library) | Full Flux VAE (reference) | Win |
78
+ |---|---|---|---|
79
+ | Decode latency 1024×1024 | **~100 ms** | seconds | 50–100× |
80
+ | Peak unified memory 1024×1024 | **~1 GB** | ~9.6 GB | **9.4×** |
81
+ | Output cosine sim vs PyTorch reference | > 0.999 | — | (parity verified) |
82
+
83
+ Numbers from `tests/test_perf.py` on M1 Max 32 GB. See `notes/phase1-benchmarks.md` for details.
84
+
85
+ ## mflux live previews
86
+
87
+ ```python
88
+ from mflux.models.flux2 import Flux2Klein
89
+ from mlx_taef.integrations.mflux import LivePreviewCallback
90
+
91
+ model = Flux2Klein.from_pretrained("4bit")
92
+ preview = LivePreviewCallback(
93
+ variant="taef2",
94
+ every=5,
95
+ save_to="preview.png",
96
+ latent_height=32, # 512 / 16
97
+ latent_width=32,
98
+ )
99
+ model.callbacks.register(preview)
100
+ model.generate_image(
101
+ prompt="a red apple on a wooden table",
102
+ num_inference_steps=25,
103
+ width=512,
104
+ height=512,
105
+ seed=42,
106
+ )
107
+ ```
108
+
109
+ For exact value-space recovery, also pass `bn_mean=flux2_vae.bn.running_mean, bn_var=flux2_vae.bn.running_var` to the callback. Without them, previews show correct structure but colors may shift.
110
+
111
+ See `docs/manual-verification.md` for the full verification recipe.
112
+
113
+ ## Status
114
+
115
+ - v0.1.0 — initial public release. All four variants, encoder + decoder, mflux integration, CI, 100% test coverage.
116
+
117
+ ## License
118
+
119
+ MIT. Mirrors upstream [madebyollin/taesd](https://github.com/madebyollin/taesd) license. Pretrained weights belong to their respective authors (madebyollin).
120
+
121
+ ## Acknowledgements
122
+
123
+ - [madebyollin](https://github.com/madebyollin) for the upstream TAESD-family models and weights.
124
+ - [Apple ML Explore](https://github.com/ml-explore/mlx) for MLX.
125
+ - [filipstrand/mflux](https://github.com/filipstrand/mflux) for the MLX-native FLUX runner this library integrates with.
@@ -0,0 +1,15 @@
1
+ mlx_taef/__init__.py,sha256=hn12r3kNI6DnIKnzNpx6EjU6zfYXpHT034qtk3OQ8ic,261
2
+ mlx_taef/_version.py,sha256=n_5vdJsPNu7wZ57LGuRL585uvll-hiuvZUBWzdG0RQU,520
3
+ mlx_taef/api.py,sha256=_Si8Mi0uVkjbr5n9Ni_kxES-CVhyW3M0nNGedwmtaQo,6502
4
+ mlx_taef/cli.py,sha256=-8nvCxga9tf4YWU6CJADuv0ch7qi2fa9MkwneKJrg0Q,3253
5
+ mlx_taef/convert.py,sha256=5mLO-nfat6HuUpUoQFbIVu2Vi6wbEW7qlGetybL7syo,7002
6
+ mlx_taef/download.py,sha256=QDrOz54aRtByHj_x4idkcAV3k1X-LjgNsFwl3f-OSQw,1608
7
+ mlx_taef/model.py,sha256=FtI8PeDUQi_n01YGvOz_hlJoNkirprQPawPZ2lnSBQM,7340
8
+ mlx_taef/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ mlx_taef/variants.py,sha256=jAsiU0C2jG38fwazE8XfKf-6gVkh1A0AaBz6n8jnoNI,3655
10
+ mlx_taef/integrations/__init__.py,sha256=XS1-Ml6Mj8iPVF9E3CvFbK18mg-o6dBDiYvg4-MIo08,55
11
+ mlx_taef/integrations/mflux.py,sha256=PyDtgZp1Bgsb3ZYzKYQWg3DvJz_MUxnSsSXaWv2hs2k,6214
12
+ mlx_taef-0.1.0.dist-info/METADATA,sha256=cNI3F-seg2WdGCcdvp5VVRkrhH1AyStHSN3RppctkPA,4883
13
+ mlx_taef-0.1.0.dist-info/WHEEL,sha256=QccIxa26bgl1E6uMy58deGWi-0aeIkkangHcxk2kWfw,87
14
+ mlx_taef-0.1.0.dist-info/entry_points.txt,sha256=x8owUzlm7qAoK-CsSc0ScgAnTzu0Ipkn9evR_kTso1I,47
15
+ mlx_taef-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.29.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ mlx-taef = mlx_taef.cli:main