xinference 0.16.0__py3-none-any.whl → 0.16.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +48 -0
- xinference/client/restful/restful_client.py +19 -0
- xinference/constants.py +1 -0
- xinference/core/chat_interface.py +5 -1
- xinference/core/image_interface.py +5 -1
- xinference/core/model.py +106 -16
- xinference/core/scheduler.py +1 -1
- xinference/core/worker.py +3 -1
- xinference/deploy/supervisor.py +0 -4
- xinference/model/audio/chattts.py +25 -14
- xinference/model/audio/core.py +6 -2
- xinference/model/audio/model_spec.json +1 -1
- xinference/model/audio/model_spec_modelscope.json +1 -1
- xinference/model/core.py +3 -1
- xinference/model/embedding/core.py +6 -2
- xinference/model/embedding/model_spec.json +1 -1
- xinference/model/image/core.py +65 -6
- xinference/model/image/model_spec.json +24 -3
- xinference/model/image/model_spec_modelscope.json +25 -3
- xinference/model/image/ocr/__init__.py +13 -0
- xinference/model/image/ocr/got_ocr2.py +79 -0
- xinference/model/image/scheduler/flux.py +1 -1
- xinference/model/image/stable_diffusion/core.py +2 -3
- xinference/model/image/stable_diffusion/mlx.py +221 -0
- xinference/model/llm/__init__.py +33 -0
- xinference/model/llm/core.py +3 -1
- xinference/model/llm/llm_family.json +9 -0
- xinference/model/llm/llm_family.py +68 -2
- xinference/model/llm/llm_family_modelscope.json +11 -0
- xinference/model/llm/llm_family_openmind_hub.json +1359 -0
- xinference/model/rerank/core.py +9 -1
- xinference/model/utils.py +7 -0
- xinference/model/video/core.py +6 -2
- xinference/thirdparty/mlx/__init__.py +13 -0
- xinference/thirdparty/mlx/flux/__init__.py +15 -0
- xinference/thirdparty/mlx/flux/autoencoder.py +357 -0
- xinference/thirdparty/mlx/flux/clip.py +154 -0
- xinference/thirdparty/mlx/flux/datasets.py +75 -0
- xinference/thirdparty/mlx/flux/flux.py +247 -0
- xinference/thirdparty/mlx/flux/layers.py +302 -0
- xinference/thirdparty/mlx/flux/lora.py +76 -0
- xinference/thirdparty/mlx/flux/model.py +134 -0
- xinference/thirdparty/mlx/flux/sampler.py +56 -0
- xinference/thirdparty/mlx/flux/t5.py +244 -0
- xinference/thirdparty/mlx/flux/tokenizers.py +185 -0
- xinference/thirdparty/mlx/flux/trainer.py +98 -0
- xinference/thirdparty/mlx/flux/utils.py +179 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.f7da0140.js → main.2f269bb3.js} +3 -3
- xinference/web/ui/build/static/js/main.2f269bb3.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +1 -0
- {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/METADATA +16 -9
- {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/RECORD +60 -42
- xinference/web/ui/build/static/js/main.f7da0140.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +0 -1
- /xinference/web/ui/build/static/js/{main.f7da0140.js.LICENSE.txt → main.2f269bb3.js.LICENSE.txt} +0 -0
- {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/LICENSE +0 -0
- {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/WHEEL +0 -0
- {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/entry_points.txt +0 -0
- {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
# Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
5
|
+
import mlx.core as mx
|
|
6
|
+
import mlx.nn as nn
|
|
7
|
+
from mlx.utils import tree_unflatten
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
|
|
10
|
+
from .lora import LoRALinear
|
|
11
|
+
from .sampler import FluxSampler
|
|
12
|
+
from .utils import (
|
|
13
|
+
load_ae,
|
|
14
|
+
load_clip,
|
|
15
|
+
load_clip_tokenizer,
|
|
16
|
+
load_flow_model,
|
|
17
|
+
load_t5,
|
|
18
|
+
load_t5_tokenizer,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class FluxPipeline:
|
|
23
|
+
def __init__(self, name: str, model_path: str, t5_padding: bool = True):
|
|
24
|
+
self.dtype = mx.bfloat16
|
|
25
|
+
self.name = name
|
|
26
|
+
self.t5_padding = t5_padding
|
|
27
|
+
|
|
28
|
+
self.model_path = model_path
|
|
29
|
+
self.ae = load_ae(name, model_path)
|
|
30
|
+
self.flow = load_flow_model(name, model_path)
|
|
31
|
+
self.clip = load_clip(name, model_path)
|
|
32
|
+
self.clip_tokenizer = load_clip_tokenizer(name, model_path)
|
|
33
|
+
self.t5 = load_t5(name, model_path)
|
|
34
|
+
self.t5_tokenizer = load_t5_tokenizer(name, model_path)
|
|
35
|
+
self.sampler = FluxSampler(name)
|
|
36
|
+
|
|
37
|
+
def ensure_models_are_loaded(self):
|
|
38
|
+
mx.eval(
|
|
39
|
+
self.ae.parameters(),
|
|
40
|
+
self.flow.parameters(),
|
|
41
|
+
self.clip.parameters(),
|
|
42
|
+
self.t5.parameters(),
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
def reload_text_encoders(self):
|
|
46
|
+
self.t5 = load_t5(self.name, self.model_path)
|
|
47
|
+
self.clip = load_clip(self.name, self.model_path)
|
|
48
|
+
|
|
49
|
+
def tokenize(self, text):
|
|
50
|
+
t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding)
|
|
51
|
+
clip_tokens = self.clip_tokenizer.encode(text)
|
|
52
|
+
return t5_tokens, clip_tokens
|
|
53
|
+
|
|
54
|
+
def _prepare_latent_images(self, x):
|
|
55
|
+
b, h, w, c = x.shape
|
|
56
|
+
|
|
57
|
+
# Pack the latent image to 2x2 patches
|
|
58
|
+
x = x.reshape(b, h // 2, 2, w // 2, 2, c)
|
|
59
|
+
x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
|
|
60
|
+
|
|
61
|
+
# Create positions ids used to positionally encode each patch. Due to
|
|
62
|
+
# the way RoPE works, this results in an interesting positional
|
|
63
|
+
# encoding where parts of the feature are holding different positional
|
|
64
|
+
# information. Namely, the first part holds information independent of
|
|
65
|
+
# the spatial position (hence 0s), the 2nd part holds vertical spatial
|
|
66
|
+
# information and the last one horizontal.
|
|
67
|
+
i = mx.zeros((h // 2, w // 2), dtype=mx.int32)
|
|
68
|
+
j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij")
|
|
69
|
+
x_ids = mx.stack([i, j, k], axis=-1)
|
|
70
|
+
x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0)
|
|
71
|
+
|
|
72
|
+
return x, x_ids
|
|
73
|
+
|
|
74
|
+
def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
|
|
75
|
+
# Prepare the text features
|
|
76
|
+
txt = self.t5(t5_tokens)
|
|
77
|
+
if len(txt) == 1 and n_images > 1:
|
|
78
|
+
txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
|
|
79
|
+
txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
|
|
80
|
+
|
|
81
|
+
# Prepare the clip text features
|
|
82
|
+
vec = self.clip(clip_tokens).pooled_output
|
|
83
|
+
if len(vec) == 1 and n_images > 1:
|
|
84
|
+
vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
|
|
85
|
+
|
|
86
|
+
return txt, txt_ids, vec
|
|
87
|
+
|
|
88
|
+
def _denoising_loop(
|
|
89
|
+
self,
|
|
90
|
+
x_t,
|
|
91
|
+
x_ids,
|
|
92
|
+
txt,
|
|
93
|
+
txt_ids,
|
|
94
|
+
vec,
|
|
95
|
+
num_steps: int = 35,
|
|
96
|
+
guidance: float = 4.0,
|
|
97
|
+
start: float = 1,
|
|
98
|
+
stop: float = 0,
|
|
99
|
+
):
|
|
100
|
+
B = len(x_t)
|
|
101
|
+
|
|
102
|
+
def scalar(x):
|
|
103
|
+
return mx.full((B,), x, dtype=self.dtype)
|
|
104
|
+
|
|
105
|
+
guidance = scalar(guidance)
|
|
106
|
+
timesteps = self.sampler.timesteps(
|
|
107
|
+
num_steps,
|
|
108
|
+
x_t.shape[1],
|
|
109
|
+
start=start,
|
|
110
|
+
stop=stop,
|
|
111
|
+
)
|
|
112
|
+
for i in range(num_steps):
|
|
113
|
+
t = timesteps[i]
|
|
114
|
+
t_prev = timesteps[i + 1]
|
|
115
|
+
|
|
116
|
+
pred = self.flow(
|
|
117
|
+
img=x_t,
|
|
118
|
+
img_ids=x_ids,
|
|
119
|
+
txt=txt,
|
|
120
|
+
txt_ids=txt_ids,
|
|
121
|
+
y=vec,
|
|
122
|
+
timesteps=scalar(t),
|
|
123
|
+
guidance=guidance,
|
|
124
|
+
)
|
|
125
|
+
x_t = self.sampler.step(pred, x_t, t, t_prev)
|
|
126
|
+
|
|
127
|
+
yield x_t
|
|
128
|
+
|
|
129
|
+
def generate_latents(
|
|
130
|
+
self,
|
|
131
|
+
text: str,
|
|
132
|
+
n_images: int = 1,
|
|
133
|
+
num_steps: int = 35,
|
|
134
|
+
guidance: float = 4.0,
|
|
135
|
+
latent_size: Tuple[int, int] = (64, 64),
|
|
136
|
+
seed=None,
|
|
137
|
+
):
|
|
138
|
+
# Set the PRNG state
|
|
139
|
+
if seed is not None:
|
|
140
|
+
mx.random.seed(seed)
|
|
141
|
+
|
|
142
|
+
# Create the latent variables
|
|
143
|
+
x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
|
|
144
|
+
x_T, x_ids = self._prepare_latent_images(x_T)
|
|
145
|
+
|
|
146
|
+
# Get the conditioning
|
|
147
|
+
t5_tokens, clip_tokens = self.tokenize(text)
|
|
148
|
+
txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens)
|
|
149
|
+
|
|
150
|
+
# Yield the conditioning for controlled evaluation by the caller
|
|
151
|
+
yield (x_T, x_ids, txt, txt_ids, vec)
|
|
152
|
+
|
|
153
|
+
# Yield the latent sequences from the denoising loop
|
|
154
|
+
yield from self._denoising_loop(
|
|
155
|
+
x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
|
|
159
|
+
h, w = latent_size
|
|
160
|
+
x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
|
|
161
|
+
x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
|
|
162
|
+
x = self.ae.decode(x)
|
|
163
|
+
return mx.clip(x + 1, 0, 2) * 0.5
|
|
164
|
+
|
|
165
|
+
def generate_images(
|
|
166
|
+
self,
|
|
167
|
+
text: str,
|
|
168
|
+
n_images: int = 1,
|
|
169
|
+
num_steps: int = 35,
|
|
170
|
+
guidance: float = 4.0,
|
|
171
|
+
latent_size: Tuple[int, int] = (64, 64),
|
|
172
|
+
seed=None,
|
|
173
|
+
reload_text_encoders: bool = True,
|
|
174
|
+
progress: bool = True,
|
|
175
|
+
):
|
|
176
|
+
latents = self.generate_latents(
|
|
177
|
+
text, n_images, num_steps, guidance, latent_size, seed
|
|
178
|
+
)
|
|
179
|
+
mx.eval(next(latents))
|
|
180
|
+
|
|
181
|
+
if reload_text_encoders:
|
|
182
|
+
self.reload_text_encoders()
|
|
183
|
+
|
|
184
|
+
for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True):
|
|
185
|
+
mx.eval(x_t)
|
|
186
|
+
|
|
187
|
+
images = []
|
|
188
|
+
for i in tqdm(range(len(x_t)), disable=not progress, desc="generate images"):
|
|
189
|
+
images.append(self.decode(x_t[i : i + 1]))
|
|
190
|
+
mx.eval(images[-1])
|
|
191
|
+
images = mx.concatenate(images, axis=0)
|
|
192
|
+
mx.eval(images)
|
|
193
|
+
|
|
194
|
+
return images
|
|
195
|
+
|
|
196
|
+
def training_loss(
|
|
197
|
+
self,
|
|
198
|
+
x_0: mx.array,
|
|
199
|
+
t5_features: mx.array,
|
|
200
|
+
clip_features: mx.array,
|
|
201
|
+
guidance: mx.array,
|
|
202
|
+
):
|
|
203
|
+
# Get the text conditioning
|
|
204
|
+
txt = t5_features
|
|
205
|
+
txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32)
|
|
206
|
+
vec = clip_features
|
|
207
|
+
|
|
208
|
+
# Prepare the latent input
|
|
209
|
+
x_0, x_ids = self._prepare_latent_images(x_0)
|
|
210
|
+
|
|
211
|
+
# Forward process
|
|
212
|
+
t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
|
|
213
|
+
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
|
|
214
|
+
x_t = self.sampler.add_noise(x_0, t, noise=eps)
|
|
215
|
+
x_t = mx.stop_gradient(x_t)
|
|
216
|
+
|
|
217
|
+
# Do the denoising
|
|
218
|
+
pred = self.flow(
|
|
219
|
+
img=x_t,
|
|
220
|
+
img_ids=x_ids,
|
|
221
|
+
txt=txt,
|
|
222
|
+
txt_ids=txt_ids,
|
|
223
|
+
y=vec,
|
|
224
|
+
timesteps=t,
|
|
225
|
+
guidance=guidance,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
return (pred + x_0 - eps).square().mean()
|
|
229
|
+
|
|
230
|
+
def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
|
|
231
|
+
"""Swap the linear layers in the transformer blocks with LoRA layers."""
|
|
232
|
+
all_blocks = self.flow.double_blocks + self.flow.single_blocks
|
|
233
|
+
all_blocks.reverse()
|
|
234
|
+
num_blocks = num_blocks if num_blocks > 0 else len(all_blocks)
|
|
235
|
+
for i, block in zip(range(num_blocks), all_blocks):
|
|
236
|
+
loras = []
|
|
237
|
+
for name, module in block.named_modules():
|
|
238
|
+
if isinstance(module, nn.Linear):
|
|
239
|
+
loras.append((name, LoRALinear.from_base(module, r=rank)))
|
|
240
|
+
block.update_modules(tree_unflatten(loras))
|
|
241
|
+
|
|
242
|
+
def fuse_lora_layers(self):
|
|
243
|
+
fused_layers = []
|
|
244
|
+
for name, module in self.flow.named_modules():
|
|
245
|
+
if isinstance(module, LoRALinear):
|
|
246
|
+
fused_layers.append((name, module.fuse()))
|
|
247
|
+
self.flow.update_modules(tree_unflatten(fused_layers))
|
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
# Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from functools import partial
|
|
6
|
+
from typing import List, Optional, Tuple
|
|
7
|
+
|
|
8
|
+
import mlx.core as mx
|
|
9
|
+
import mlx.nn as nn
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _rope(pos: mx.array, dim: int, theta: float):
|
|
13
|
+
scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim
|
|
14
|
+
omega = 1.0 / (theta**scale)
|
|
15
|
+
x = pos[..., None] * omega
|
|
16
|
+
cosx = mx.cos(x)
|
|
17
|
+
sinx = mx.sin(x)
|
|
18
|
+
pe = mx.stack([cosx, -sinx, sinx, cosx], axis=-1)
|
|
19
|
+
pe = pe.reshape(*pe.shape[:-1], 2, 2)
|
|
20
|
+
|
|
21
|
+
return pe
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@partial(mx.compile, shapeless=True)
|
|
25
|
+
def _ab_plus_cd(a, b, c, d):
|
|
26
|
+
return a * b + c * d
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _apply_rope(x, pe):
|
|
30
|
+
s = x.shape
|
|
31
|
+
x = x.reshape(*s[:-1], -1, 1, 2)
|
|
32
|
+
x = _ab_plus_cd(x[..., 0], pe[..., 0], x[..., 1], pe[..., 1])
|
|
33
|
+
return x.reshape(s)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _attention(q: mx.array, k: mx.array, v: mx.array, pe: mx.array):
|
|
37
|
+
B, H, L, D = q.shape
|
|
38
|
+
|
|
39
|
+
q = _apply_rope(q, pe)
|
|
40
|
+
k = _apply_rope(k, pe)
|
|
41
|
+
x = mx.fast.scaled_dot_product_attention(q, k, v, scale=D ** (-0.5))
|
|
42
|
+
|
|
43
|
+
return x.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def timestep_embedding(
|
|
47
|
+
t: mx.array, dim: int, max_period: int = 10000, time_factor: float = 1000.0
|
|
48
|
+
):
|
|
49
|
+
half = dim // 2
|
|
50
|
+
freqs = mx.arange(0, half, dtype=mx.float32) / half
|
|
51
|
+
freqs = freqs * (-math.log(max_period))
|
|
52
|
+
freqs = mx.exp(freqs)
|
|
53
|
+
|
|
54
|
+
x = (time_factor * t)[:, None] * freqs[None]
|
|
55
|
+
x = mx.concatenate([mx.cos(x), mx.sin(x)], axis=-1)
|
|
56
|
+
|
|
57
|
+
return x.astype(t.dtype)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class EmbedND(nn.Module):
|
|
61
|
+
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
|
|
62
|
+
super().__init__()
|
|
63
|
+
|
|
64
|
+
self.dim = dim
|
|
65
|
+
self.theta = theta
|
|
66
|
+
self.axes_dim = axes_dim
|
|
67
|
+
|
|
68
|
+
def __call__(self, ids: mx.array):
|
|
69
|
+
n_axes = ids.shape[-1]
|
|
70
|
+
pe = mx.concatenate(
|
|
71
|
+
[_rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
|
72
|
+
axis=-3,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
return pe[:, None]
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class MLPEmbedder(nn.Module):
|
|
79
|
+
def __init__(self, in_dim: int, hidden_dim: int):
|
|
80
|
+
super().__init__()
|
|
81
|
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
|
82
|
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
|
83
|
+
|
|
84
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
85
|
+
return self.out_layer(nn.silu(self.in_layer(x)))
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class QKNorm(nn.Module):
|
|
89
|
+
def __init__(self, dim: int):
|
|
90
|
+
super().__init__()
|
|
91
|
+
self.query_norm = nn.RMSNorm(dim)
|
|
92
|
+
self.key_norm = nn.RMSNorm(dim)
|
|
93
|
+
|
|
94
|
+
def __call__(self, q: mx.array, k: mx.array) -> tuple[mx.array, mx.array]:
|
|
95
|
+
return self.query_norm(q), self.key_norm(k)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class SelfAttention(nn.Module):
|
|
99
|
+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
|
100
|
+
super().__init__()
|
|
101
|
+
self.num_heads = num_heads
|
|
102
|
+
head_dim = dim // num_heads
|
|
103
|
+
|
|
104
|
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
105
|
+
self.norm = QKNorm(head_dim)
|
|
106
|
+
self.proj = nn.Linear(dim, dim)
|
|
107
|
+
|
|
108
|
+
def __call__(self, x: mx.array, pe: mx.array) -> mx.array:
|
|
109
|
+
H = self.num_heads
|
|
110
|
+
B, L, _ = x.shape
|
|
111
|
+
qkv = self.qkv(x)
|
|
112
|
+
q, k, v = mx.split(qkv, 3, axis=-1)
|
|
113
|
+
q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
114
|
+
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
115
|
+
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
116
|
+
q, k = self.norm(q, k)
|
|
117
|
+
x = _attention(q, k, v, pe)
|
|
118
|
+
x = self.proj(x)
|
|
119
|
+
return x
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@dataclass
|
|
123
|
+
class ModulationOut:
|
|
124
|
+
shift: mx.array
|
|
125
|
+
scale: mx.array
|
|
126
|
+
gate: mx.array
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class Modulation(nn.Module):
|
|
130
|
+
def __init__(self, dim: int, double: bool):
|
|
131
|
+
super().__init__()
|
|
132
|
+
self.is_double = double
|
|
133
|
+
self.multiplier = 6 if double else 3
|
|
134
|
+
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
|
135
|
+
|
|
136
|
+
def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]:
|
|
137
|
+
x = self.lin(nn.silu(x))
|
|
138
|
+
xs = mx.split(x[:, None, :], self.multiplier, axis=-1)
|
|
139
|
+
|
|
140
|
+
mod1 = ModulationOut(*xs[:3])
|
|
141
|
+
mod2 = ModulationOut(*xs[3:]) if self.is_double else None
|
|
142
|
+
|
|
143
|
+
return mod1, mod2
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class DoubleStreamBlock(nn.Module):
|
|
147
|
+
def __init__(
|
|
148
|
+
self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
|
|
149
|
+
):
|
|
150
|
+
super().__init__()
|
|
151
|
+
|
|
152
|
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
|
153
|
+
self.num_heads = num_heads
|
|
154
|
+
self.hidden_size = hidden_size
|
|
155
|
+
self.img_mod = Modulation(hidden_size, double=True)
|
|
156
|
+
self.img_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
|
157
|
+
self.img_attn = SelfAttention(
|
|
158
|
+
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
self.img_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
|
162
|
+
self.img_mlp = nn.Sequential(
|
|
163
|
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
|
164
|
+
nn.GELU(approx="tanh"),
|
|
165
|
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
self.txt_mod = Modulation(hidden_size, double=True)
|
|
169
|
+
self.txt_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
|
170
|
+
self.txt_attn = SelfAttention(
|
|
171
|
+
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
self.txt_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
|
175
|
+
self.txt_mlp = nn.Sequential(
|
|
176
|
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
|
177
|
+
nn.GELU(approx="tanh"),
|
|
178
|
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
def __call__(
|
|
182
|
+
self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array
|
|
183
|
+
) -> Tuple[mx.array, mx.array]:
|
|
184
|
+
B, L, _ = img.shape
|
|
185
|
+
_, S, _ = txt.shape
|
|
186
|
+
H = self.num_heads
|
|
187
|
+
|
|
188
|
+
img_mod1, img_mod2 = self.img_mod(vec)
|
|
189
|
+
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
|
190
|
+
|
|
191
|
+
# prepare image for attention
|
|
192
|
+
img_modulated = self.img_norm1(img)
|
|
193
|
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
|
194
|
+
img_qkv = self.img_attn.qkv(img_modulated)
|
|
195
|
+
img_q, img_k, img_v = mx.split(img_qkv, 3, axis=-1)
|
|
196
|
+
img_q = img_q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
197
|
+
img_k = img_k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
198
|
+
img_v = img_v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
199
|
+
img_q, img_k = self.img_attn.norm(img_q, img_k)
|
|
200
|
+
|
|
201
|
+
# prepare txt for attention
|
|
202
|
+
txt_modulated = self.txt_norm1(txt)
|
|
203
|
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
|
204
|
+
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
|
205
|
+
txt_q, txt_k, txt_v = mx.split(txt_qkv, 3, axis=-1)
|
|
206
|
+
txt_q = txt_q.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
|
|
207
|
+
txt_k = txt_k.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
|
|
208
|
+
txt_v = txt_v.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
|
|
209
|
+
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k)
|
|
210
|
+
|
|
211
|
+
# run actual attention
|
|
212
|
+
q = mx.concatenate([txt_q, img_q], axis=2)
|
|
213
|
+
k = mx.concatenate([txt_k, img_k], axis=2)
|
|
214
|
+
v = mx.concatenate([txt_v, img_v], axis=2)
|
|
215
|
+
|
|
216
|
+
attn = _attention(q, k, v, pe)
|
|
217
|
+
txt_attn, img_attn = mx.split(attn, [S], axis=1)
|
|
218
|
+
|
|
219
|
+
# calculate the img bloks
|
|
220
|
+
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
|
221
|
+
img = img + img_mod2.gate * self.img_mlp(
|
|
222
|
+
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# calculate the txt bloks
|
|
226
|
+
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
|
227
|
+
txt = txt + txt_mod2.gate * self.txt_mlp(
|
|
228
|
+
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
return img, txt
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class SingleStreamBlock(nn.Module):
|
|
235
|
+
def __init__(
|
|
236
|
+
self,
|
|
237
|
+
hidden_size: int,
|
|
238
|
+
num_heads: int,
|
|
239
|
+
mlp_ratio: float = 4.0,
|
|
240
|
+
qk_scale: Optional[float] = None,
|
|
241
|
+
):
|
|
242
|
+
super().__init__()
|
|
243
|
+
self.hidden_dim = hidden_size
|
|
244
|
+
self.num_heads = num_heads
|
|
245
|
+
head_dim = hidden_size // num_heads
|
|
246
|
+
self.scale = qk_scale or head_dim**-0.5
|
|
247
|
+
|
|
248
|
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
|
249
|
+
# qkv and mlp_in
|
|
250
|
+
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
|
251
|
+
# proj and mlp_out
|
|
252
|
+
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
|
253
|
+
|
|
254
|
+
self.norm = QKNorm(head_dim)
|
|
255
|
+
|
|
256
|
+
self.hidden_size = hidden_size
|
|
257
|
+
self.pre_norm = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
|
258
|
+
|
|
259
|
+
self.mlp_act = nn.GELU(approx="tanh")
|
|
260
|
+
self.modulation = Modulation(hidden_size, double=False)
|
|
261
|
+
|
|
262
|
+
def __call__(self, x: mx.array, vec: mx.array, pe: mx.array):
|
|
263
|
+
B, L, _ = x.shape
|
|
264
|
+
H = self.num_heads
|
|
265
|
+
|
|
266
|
+
mod, _ = self.modulation(vec)
|
|
267
|
+
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
|
268
|
+
|
|
269
|
+
q, k, v, mlp = mx.split(
|
|
270
|
+
self.linear1(x_mod),
|
|
271
|
+
[self.hidden_size, 2 * self.hidden_size, 3 * self.hidden_size],
|
|
272
|
+
axis=-1,
|
|
273
|
+
)
|
|
274
|
+
q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
275
|
+
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
276
|
+
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
|
277
|
+
q, k = self.norm(q, k)
|
|
278
|
+
|
|
279
|
+
# compute attention
|
|
280
|
+
y = _attention(q, k, v, pe)
|
|
281
|
+
|
|
282
|
+
# compute activation in mlp stream, cat again and run second linear layer
|
|
283
|
+
y = self.linear2(mx.concatenate([y, self.mlp_act(mlp)], axis=2))
|
|
284
|
+
return x + mod.gate * y
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
class LastLayer(nn.Module):
|
|
288
|
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
|
289
|
+
super().__init__()
|
|
290
|
+
self.norm_final = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
|
291
|
+
self.linear = nn.Linear(
|
|
292
|
+
hidden_size, patch_size * patch_size * out_channels, bias=True
|
|
293
|
+
)
|
|
294
|
+
self.adaLN_modulation = nn.Sequential(
|
|
295
|
+
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
def __call__(self, x: mx.array, vec: mx.array):
|
|
299
|
+
shift, scale = mx.split(self.adaLN_modulation(vec), 2, axis=1)
|
|
300
|
+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
|
301
|
+
x = self.linear(x)
|
|
302
|
+
return x
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
# Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
import mlx.core as mx
|
|
6
|
+
import mlx.nn as nn
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LoRALinear(nn.Module):
|
|
10
|
+
@staticmethod
|
|
11
|
+
def from_base(
|
|
12
|
+
linear: nn.Linear,
|
|
13
|
+
r: int = 8,
|
|
14
|
+
dropout: float = 0.0,
|
|
15
|
+
scale: float = 1.0,
|
|
16
|
+
):
|
|
17
|
+
output_dims, input_dims = linear.weight.shape
|
|
18
|
+
lora_lin = LoRALinear(
|
|
19
|
+
input_dims=input_dims,
|
|
20
|
+
output_dims=output_dims,
|
|
21
|
+
r=r,
|
|
22
|
+
dropout=dropout,
|
|
23
|
+
scale=scale,
|
|
24
|
+
)
|
|
25
|
+
lora_lin.linear = linear
|
|
26
|
+
return lora_lin
|
|
27
|
+
|
|
28
|
+
def fuse(self):
|
|
29
|
+
linear = self.linear
|
|
30
|
+
bias = "bias" in linear
|
|
31
|
+
weight = linear.weight
|
|
32
|
+
dtype = weight.dtype
|
|
33
|
+
|
|
34
|
+
output_dims, input_dims = weight.shape
|
|
35
|
+
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
|
|
36
|
+
|
|
37
|
+
lora_b = self.scale * self.lora_b.T
|
|
38
|
+
lora_a = self.lora_a.T
|
|
39
|
+
fused_linear.weight = weight + (lora_b @ lora_a).astype(dtype)
|
|
40
|
+
if bias:
|
|
41
|
+
fused_linear.bias = linear.bias
|
|
42
|
+
|
|
43
|
+
return fused_linear
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
input_dims: int,
|
|
48
|
+
output_dims: int,
|
|
49
|
+
r: int = 8,
|
|
50
|
+
dropout: float = 0.0,
|
|
51
|
+
scale: float = 1.0,
|
|
52
|
+
bias: bool = False,
|
|
53
|
+
):
|
|
54
|
+
super().__init__()
|
|
55
|
+
|
|
56
|
+
# Regular linear layer weights
|
|
57
|
+
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
|
|
58
|
+
|
|
59
|
+
self.dropout = nn.Dropout(p=dropout)
|
|
60
|
+
|
|
61
|
+
# Scale for low-rank update
|
|
62
|
+
self.scale = scale
|
|
63
|
+
|
|
64
|
+
# Low rank lora weights
|
|
65
|
+
scale = 1 / math.sqrt(input_dims)
|
|
66
|
+
self.lora_a = mx.random.uniform(
|
|
67
|
+
low=-scale,
|
|
68
|
+
high=scale,
|
|
69
|
+
shape=(input_dims, r),
|
|
70
|
+
)
|
|
71
|
+
self.lora_b = mx.zeros(shape=(r, output_dims))
|
|
72
|
+
|
|
73
|
+
def __call__(self, x):
|
|
74
|
+
y = self.linear(x)
|
|
75
|
+
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
|
|
76
|
+
return y + (self.scale * z).astype(x.dtype)
|