xinference 0.16.0__py3-none-any.whl → 0.16.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (62) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +48 -0
  3. xinference/client/restful/restful_client.py +19 -0
  4. xinference/constants.py +1 -0
  5. xinference/core/chat_interface.py +5 -1
  6. xinference/core/image_interface.py +5 -1
  7. xinference/core/model.py +106 -16
  8. xinference/core/scheduler.py +1 -1
  9. xinference/core/worker.py +3 -1
  10. xinference/deploy/supervisor.py +0 -4
  11. xinference/model/audio/chattts.py +25 -14
  12. xinference/model/audio/core.py +6 -2
  13. xinference/model/audio/model_spec.json +1 -1
  14. xinference/model/audio/model_spec_modelscope.json +1 -1
  15. xinference/model/core.py +3 -1
  16. xinference/model/embedding/core.py +6 -2
  17. xinference/model/embedding/model_spec.json +1 -1
  18. xinference/model/image/core.py +65 -6
  19. xinference/model/image/model_spec.json +24 -3
  20. xinference/model/image/model_spec_modelscope.json +25 -3
  21. xinference/model/image/ocr/__init__.py +13 -0
  22. xinference/model/image/ocr/got_ocr2.py +79 -0
  23. xinference/model/image/scheduler/flux.py +1 -1
  24. xinference/model/image/stable_diffusion/core.py +2 -3
  25. xinference/model/image/stable_diffusion/mlx.py +221 -0
  26. xinference/model/llm/__init__.py +33 -0
  27. xinference/model/llm/core.py +3 -1
  28. xinference/model/llm/llm_family.json +9 -0
  29. xinference/model/llm/llm_family.py +68 -2
  30. xinference/model/llm/llm_family_modelscope.json +11 -0
  31. xinference/model/llm/llm_family_openmind_hub.json +1359 -0
  32. xinference/model/rerank/core.py +9 -1
  33. xinference/model/utils.py +7 -0
  34. xinference/model/video/core.py +6 -2
  35. xinference/thirdparty/mlx/__init__.py +13 -0
  36. xinference/thirdparty/mlx/flux/__init__.py +15 -0
  37. xinference/thirdparty/mlx/flux/autoencoder.py +357 -0
  38. xinference/thirdparty/mlx/flux/clip.py +154 -0
  39. xinference/thirdparty/mlx/flux/datasets.py +75 -0
  40. xinference/thirdparty/mlx/flux/flux.py +247 -0
  41. xinference/thirdparty/mlx/flux/layers.py +302 -0
  42. xinference/thirdparty/mlx/flux/lora.py +76 -0
  43. xinference/thirdparty/mlx/flux/model.py +134 -0
  44. xinference/thirdparty/mlx/flux/sampler.py +56 -0
  45. xinference/thirdparty/mlx/flux/t5.py +244 -0
  46. xinference/thirdparty/mlx/flux/tokenizers.py +185 -0
  47. xinference/thirdparty/mlx/flux/trainer.py +98 -0
  48. xinference/thirdparty/mlx/flux/utils.py +179 -0
  49. xinference/web/ui/build/asset-manifest.json +3 -3
  50. xinference/web/ui/build/index.html +1 -1
  51. xinference/web/ui/build/static/js/{main.f7da0140.js → main.2f269bb3.js} +3 -3
  52. xinference/web/ui/build/static/js/main.2f269bb3.js.map +1 -0
  53. xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +1 -0
  54. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/METADATA +16 -9
  55. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/RECORD +60 -42
  56. xinference/web/ui/build/static/js/main.f7da0140.js.map +0 -1
  57. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +0 -1
  58. /xinference/web/ui/build/static/js/{main.f7da0140.js.LICENSE.txt → main.2f269bb3.js.LICENSE.txt} +0 -0
  59. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/LICENSE +0 -0
  60. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/WHEEL +0 -0
  61. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/entry_points.txt +0 -0
  62. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,247 @@
1
+ # Copyright © 2024 Apple Inc.
2
+
3
+ from typing import Tuple
4
+
5
+ import mlx.core as mx
6
+ import mlx.nn as nn
7
+ from mlx.utils import tree_unflatten
8
+ from tqdm import tqdm
9
+
10
+ from .lora import LoRALinear
11
+ from .sampler import FluxSampler
12
+ from .utils import (
13
+ load_ae,
14
+ load_clip,
15
+ load_clip_tokenizer,
16
+ load_flow_model,
17
+ load_t5,
18
+ load_t5_tokenizer,
19
+ )
20
+
21
+
22
+ class FluxPipeline:
23
+ def __init__(self, name: str, model_path: str, t5_padding: bool = True):
24
+ self.dtype = mx.bfloat16
25
+ self.name = name
26
+ self.t5_padding = t5_padding
27
+
28
+ self.model_path = model_path
29
+ self.ae = load_ae(name, model_path)
30
+ self.flow = load_flow_model(name, model_path)
31
+ self.clip = load_clip(name, model_path)
32
+ self.clip_tokenizer = load_clip_tokenizer(name, model_path)
33
+ self.t5 = load_t5(name, model_path)
34
+ self.t5_tokenizer = load_t5_tokenizer(name, model_path)
35
+ self.sampler = FluxSampler(name)
36
+
37
+ def ensure_models_are_loaded(self):
38
+ mx.eval(
39
+ self.ae.parameters(),
40
+ self.flow.parameters(),
41
+ self.clip.parameters(),
42
+ self.t5.parameters(),
43
+ )
44
+
45
+ def reload_text_encoders(self):
46
+ self.t5 = load_t5(self.name, self.model_path)
47
+ self.clip = load_clip(self.name, self.model_path)
48
+
49
+ def tokenize(self, text):
50
+ t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding)
51
+ clip_tokens = self.clip_tokenizer.encode(text)
52
+ return t5_tokens, clip_tokens
53
+
54
+ def _prepare_latent_images(self, x):
55
+ b, h, w, c = x.shape
56
+
57
+ # Pack the latent image to 2x2 patches
58
+ x = x.reshape(b, h // 2, 2, w // 2, 2, c)
59
+ x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
60
+
61
+ # Create positions ids used to positionally encode each patch. Due to
62
+ # the way RoPE works, this results in an interesting positional
63
+ # encoding where parts of the feature are holding different positional
64
+ # information. Namely, the first part holds information independent of
65
+ # the spatial position (hence 0s), the 2nd part holds vertical spatial
66
+ # information and the last one horizontal.
67
+ i = mx.zeros((h // 2, w // 2), dtype=mx.int32)
68
+ j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij")
69
+ x_ids = mx.stack([i, j, k], axis=-1)
70
+ x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0)
71
+
72
+ return x, x_ids
73
+
74
+ def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
75
+ # Prepare the text features
76
+ txt = self.t5(t5_tokens)
77
+ if len(txt) == 1 and n_images > 1:
78
+ txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
79
+ txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
80
+
81
+ # Prepare the clip text features
82
+ vec = self.clip(clip_tokens).pooled_output
83
+ if len(vec) == 1 and n_images > 1:
84
+ vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
85
+
86
+ return txt, txt_ids, vec
87
+
88
+ def _denoising_loop(
89
+ self,
90
+ x_t,
91
+ x_ids,
92
+ txt,
93
+ txt_ids,
94
+ vec,
95
+ num_steps: int = 35,
96
+ guidance: float = 4.0,
97
+ start: float = 1,
98
+ stop: float = 0,
99
+ ):
100
+ B = len(x_t)
101
+
102
+ def scalar(x):
103
+ return mx.full((B,), x, dtype=self.dtype)
104
+
105
+ guidance = scalar(guidance)
106
+ timesteps = self.sampler.timesteps(
107
+ num_steps,
108
+ x_t.shape[1],
109
+ start=start,
110
+ stop=stop,
111
+ )
112
+ for i in range(num_steps):
113
+ t = timesteps[i]
114
+ t_prev = timesteps[i + 1]
115
+
116
+ pred = self.flow(
117
+ img=x_t,
118
+ img_ids=x_ids,
119
+ txt=txt,
120
+ txt_ids=txt_ids,
121
+ y=vec,
122
+ timesteps=scalar(t),
123
+ guidance=guidance,
124
+ )
125
+ x_t = self.sampler.step(pred, x_t, t, t_prev)
126
+
127
+ yield x_t
128
+
129
+ def generate_latents(
130
+ self,
131
+ text: str,
132
+ n_images: int = 1,
133
+ num_steps: int = 35,
134
+ guidance: float = 4.0,
135
+ latent_size: Tuple[int, int] = (64, 64),
136
+ seed=None,
137
+ ):
138
+ # Set the PRNG state
139
+ if seed is not None:
140
+ mx.random.seed(seed)
141
+
142
+ # Create the latent variables
143
+ x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
144
+ x_T, x_ids = self._prepare_latent_images(x_T)
145
+
146
+ # Get the conditioning
147
+ t5_tokens, clip_tokens = self.tokenize(text)
148
+ txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens)
149
+
150
+ # Yield the conditioning for controlled evaluation by the caller
151
+ yield (x_T, x_ids, txt, txt_ids, vec)
152
+
153
+ # Yield the latent sequences from the denoising loop
154
+ yield from self._denoising_loop(
155
+ x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
156
+ )
157
+
158
+ def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
159
+ h, w = latent_size
160
+ x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
161
+ x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
162
+ x = self.ae.decode(x)
163
+ return mx.clip(x + 1, 0, 2) * 0.5
164
+
165
+ def generate_images(
166
+ self,
167
+ text: str,
168
+ n_images: int = 1,
169
+ num_steps: int = 35,
170
+ guidance: float = 4.0,
171
+ latent_size: Tuple[int, int] = (64, 64),
172
+ seed=None,
173
+ reload_text_encoders: bool = True,
174
+ progress: bool = True,
175
+ ):
176
+ latents = self.generate_latents(
177
+ text, n_images, num_steps, guidance, latent_size, seed
178
+ )
179
+ mx.eval(next(latents))
180
+
181
+ if reload_text_encoders:
182
+ self.reload_text_encoders()
183
+
184
+ for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True):
185
+ mx.eval(x_t)
186
+
187
+ images = []
188
+ for i in tqdm(range(len(x_t)), disable=not progress, desc="generate images"):
189
+ images.append(self.decode(x_t[i : i + 1]))
190
+ mx.eval(images[-1])
191
+ images = mx.concatenate(images, axis=0)
192
+ mx.eval(images)
193
+
194
+ return images
195
+
196
+ def training_loss(
197
+ self,
198
+ x_0: mx.array,
199
+ t5_features: mx.array,
200
+ clip_features: mx.array,
201
+ guidance: mx.array,
202
+ ):
203
+ # Get the text conditioning
204
+ txt = t5_features
205
+ txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32)
206
+ vec = clip_features
207
+
208
+ # Prepare the latent input
209
+ x_0, x_ids = self._prepare_latent_images(x_0)
210
+
211
+ # Forward process
212
+ t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
213
+ eps = mx.random.normal(x_0.shape, dtype=self.dtype)
214
+ x_t = self.sampler.add_noise(x_0, t, noise=eps)
215
+ x_t = mx.stop_gradient(x_t)
216
+
217
+ # Do the denoising
218
+ pred = self.flow(
219
+ img=x_t,
220
+ img_ids=x_ids,
221
+ txt=txt,
222
+ txt_ids=txt_ids,
223
+ y=vec,
224
+ timesteps=t,
225
+ guidance=guidance,
226
+ )
227
+
228
+ return (pred + x_0 - eps).square().mean()
229
+
230
+ def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
231
+ """Swap the linear layers in the transformer blocks with LoRA layers."""
232
+ all_blocks = self.flow.double_blocks + self.flow.single_blocks
233
+ all_blocks.reverse()
234
+ num_blocks = num_blocks if num_blocks > 0 else len(all_blocks)
235
+ for i, block in zip(range(num_blocks), all_blocks):
236
+ loras = []
237
+ for name, module in block.named_modules():
238
+ if isinstance(module, nn.Linear):
239
+ loras.append((name, LoRALinear.from_base(module, r=rank)))
240
+ block.update_modules(tree_unflatten(loras))
241
+
242
+ def fuse_lora_layers(self):
243
+ fused_layers = []
244
+ for name, module in self.flow.named_modules():
245
+ if isinstance(module, LoRALinear):
246
+ fused_layers.append((name, module.fuse()))
247
+ self.flow.update_modules(tree_unflatten(fused_layers))
@@ -0,0 +1,302 @@
1
+ # Copyright © 2024 Apple Inc.
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from functools import partial
6
+ from typing import List, Optional, Tuple
7
+
8
+ import mlx.core as mx
9
+ import mlx.nn as nn
10
+
11
+
12
+ def _rope(pos: mx.array, dim: int, theta: float):
13
+ scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim
14
+ omega = 1.0 / (theta**scale)
15
+ x = pos[..., None] * omega
16
+ cosx = mx.cos(x)
17
+ sinx = mx.sin(x)
18
+ pe = mx.stack([cosx, -sinx, sinx, cosx], axis=-1)
19
+ pe = pe.reshape(*pe.shape[:-1], 2, 2)
20
+
21
+ return pe
22
+
23
+
24
+ @partial(mx.compile, shapeless=True)
25
+ def _ab_plus_cd(a, b, c, d):
26
+ return a * b + c * d
27
+
28
+
29
+ def _apply_rope(x, pe):
30
+ s = x.shape
31
+ x = x.reshape(*s[:-1], -1, 1, 2)
32
+ x = _ab_plus_cd(x[..., 0], pe[..., 0], x[..., 1], pe[..., 1])
33
+ return x.reshape(s)
34
+
35
+
36
+ def _attention(q: mx.array, k: mx.array, v: mx.array, pe: mx.array):
37
+ B, H, L, D = q.shape
38
+
39
+ q = _apply_rope(q, pe)
40
+ k = _apply_rope(k, pe)
41
+ x = mx.fast.scaled_dot_product_attention(q, k, v, scale=D ** (-0.5))
42
+
43
+ return x.transpose(0, 2, 1, 3).reshape(B, L, -1)
44
+
45
+
46
+ def timestep_embedding(
47
+ t: mx.array, dim: int, max_period: int = 10000, time_factor: float = 1000.0
48
+ ):
49
+ half = dim // 2
50
+ freqs = mx.arange(0, half, dtype=mx.float32) / half
51
+ freqs = freqs * (-math.log(max_period))
52
+ freqs = mx.exp(freqs)
53
+
54
+ x = (time_factor * t)[:, None] * freqs[None]
55
+ x = mx.concatenate([mx.cos(x), mx.sin(x)], axis=-1)
56
+
57
+ return x.astype(t.dtype)
58
+
59
+
60
+ class EmbedND(nn.Module):
61
+ def __init__(self, dim: int, theta: int, axes_dim: List[int]):
62
+ super().__init__()
63
+
64
+ self.dim = dim
65
+ self.theta = theta
66
+ self.axes_dim = axes_dim
67
+
68
+ def __call__(self, ids: mx.array):
69
+ n_axes = ids.shape[-1]
70
+ pe = mx.concatenate(
71
+ [_rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
72
+ axis=-3,
73
+ )
74
+
75
+ return pe[:, None]
76
+
77
+
78
+ class MLPEmbedder(nn.Module):
79
+ def __init__(self, in_dim: int, hidden_dim: int):
80
+ super().__init__()
81
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
82
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
83
+
84
+ def __call__(self, x: mx.array) -> mx.array:
85
+ return self.out_layer(nn.silu(self.in_layer(x)))
86
+
87
+
88
+ class QKNorm(nn.Module):
89
+ def __init__(self, dim: int):
90
+ super().__init__()
91
+ self.query_norm = nn.RMSNorm(dim)
92
+ self.key_norm = nn.RMSNorm(dim)
93
+
94
+ def __call__(self, q: mx.array, k: mx.array) -> tuple[mx.array, mx.array]:
95
+ return self.query_norm(q), self.key_norm(k)
96
+
97
+
98
+ class SelfAttention(nn.Module):
99
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
100
+ super().__init__()
101
+ self.num_heads = num_heads
102
+ head_dim = dim // num_heads
103
+
104
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
105
+ self.norm = QKNorm(head_dim)
106
+ self.proj = nn.Linear(dim, dim)
107
+
108
+ def __call__(self, x: mx.array, pe: mx.array) -> mx.array:
109
+ H = self.num_heads
110
+ B, L, _ = x.shape
111
+ qkv = self.qkv(x)
112
+ q, k, v = mx.split(qkv, 3, axis=-1)
113
+ q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
114
+ k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
115
+ v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
116
+ q, k = self.norm(q, k)
117
+ x = _attention(q, k, v, pe)
118
+ x = self.proj(x)
119
+ return x
120
+
121
+
122
+ @dataclass
123
+ class ModulationOut:
124
+ shift: mx.array
125
+ scale: mx.array
126
+ gate: mx.array
127
+
128
+
129
+ class Modulation(nn.Module):
130
+ def __init__(self, dim: int, double: bool):
131
+ super().__init__()
132
+ self.is_double = double
133
+ self.multiplier = 6 if double else 3
134
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
135
+
136
+ def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]:
137
+ x = self.lin(nn.silu(x))
138
+ xs = mx.split(x[:, None, :], self.multiplier, axis=-1)
139
+
140
+ mod1 = ModulationOut(*xs[:3])
141
+ mod2 = ModulationOut(*xs[3:]) if self.is_double else None
142
+
143
+ return mod1, mod2
144
+
145
+
146
+ class DoubleStreamBlock(nn.Module):
147
+ def __init__(
148
+ self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
149
+ ):
150
+ super().__init__()
151
+
152
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
153
+ self.num_heads = num_heads
154
+ self.hidden_size = hidden_size
155
+ self.img_mod = Modulation(hidden_size, double=True)
156
+ self.img_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
157
+ self.img_attn = SelfAttention(
158
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
159
+ )
160
+
161
+ self.img_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
162
+ self.img_mlp = nn.Sequential(
163
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
164
+ nn.GELU(approx="tanh"),
165
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
166
+ )
167
+
168
+ self.txt_mod = Modulation(hidden_size, double=True)
169
+ self.txt_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
170
+ self.txt_attn = SelfAttention(
171
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
172
+ )
173
+
174
+ self.txt_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
175
+ self.txt_mlp = nn.Sequential(
176
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
177
+ nn.GELU(approx="tanh"),
178
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
179
+ )
180
+
181
+ def __call__(
182
+ self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array
183
+ ) -> Tuple[mx.array, mx.array]:
184
+ B, L, _ = img.shape
185
+ _, S, _ = txt.shape
186
+ H = self.num_heads
187
+
188
+ img_mod1, img_mod2 = self.img_mod(vec)
189
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
190
+
191
+ # prepare image for attention
192
+ img_modulated = self.img_norm1(img)
193
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
194
+ img_qkv = self.img_attn.qkv(img_modulated)
195
+ img_q, img_k, img_v = mx.split(img_qkv, 3, axis=-1)
196
+ img_q = img_q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
197
+ img_k = img_k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
198
+ img_v = img_v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
199
+ img_q, img_k = self.img_attn.norm(img_q, img_k)
200
+
201
+ # prepare txt for attention
202
+ txt_modulated = self.txt_norm1(txt)
203
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
204
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
205
+ txt_q, txt_k, txt_v = mx.split(txt_qkv, 3, axis=-1)
206
+ txt_q = txt_q.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
207
+ txt_k = txt_k.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
208
+ txt_v = txt_v.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
209
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k)
210
+
211
+ # run actual attention
212
+ q = mx.concatenate([txt_q, img_q], axis=2)
213
+ k = mx.concatenate([txt_k, img_k], axis=2)
214
+ v = mx.concatenate([txt_v, img_v], axis=2)
215
+
216
+ attn = _attention(q, k, v, pe)
217
+ txt_attn, img_attn = mx.split(attn, [S], axis=1)
218
+
219
+ # calculate the img bloks
220
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
221
+ img = img + img_mod2.gate * self.img_mlp(
222
+ (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
223
+ )
224
+
225
+ # calculate the txt bloks
226
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
227
+ txt = txt + txt_mod2.gate * self.txt_mlp(
228
+ (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
229
+ )
230
+
231
+ return img, txt
232
+
233
+
234
+ class SingleStreamBlock(nn.Module):
235
+ def __init__(
236
+ self,
237
+ hidden_size: int,
238
+ num_heads: int,
239
+ mlp_ratio: float = 4.0,
240
+ qk_scale: Optional[float] = None,
241
+ ):
242
+ super().__init__()
243
+ self.hidden_dim = hidden_size
244
+ self.num_heads = num_heads
245
+ head_dim = hidden_size // num_heads
246
+ self.scale = qk_scale or head_dim**-0.5
247
+
248
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
249
+ # qkv and mlp_in
250
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
251
+ # proj and mlp_out
252
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
253
+
254
+ self.norm = QKNorm(head_dim)
255
+
256
+ self.hidden_size = hidden_size
257
+ self.pre_norm = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
258
+
259
+ self.mlp_act = nn.GELU(approx="tanh")
260
+ self.modulation = Modulation(hidden_size, double=False)
261
+
262
+ def __call__(self, x: mx.array, vec: mx.array, pe: mx.array):
263
+ B, L, _ = x.shape
264
+ H = self.num_heads
265
+
266
+ mod, _ = self.modulation(vec)
267
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
268
+
269
+ q, k, v, mlp = mx.split(
270
+ self.linear1(x_mod),
271
+ [self.hidden_size, 2 * self.hidden_size, 3 * self.hidden_size],
272
+ axis=-1,
273
+ )
274
+ q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
275
+ k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
276
+ v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
277
+ q, k = self.norm(q, k)
278
+
279
+ # compute attention
280
+ y = _attention(q, k, v, pe)
281
+
282
+ # compute activation in mlp stream, cat again and run second linear layer
283
+ y = self.linear2(mx.concatenate([y, self.mlp_act(mlp)], axis=2))
284
+ return x + mod.gate * y
285
+
286
+
287
+ class LastLayer(nn.Module):
288
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
289
+ super().__init__()
290
+ self.norm_final = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
291
+ self.linear = nn.Linear(
292
+ hidden_size, patch_size * patch_size * out_channels, bias=True
293
+ )
294
+ self.adaLN_modulation = nn.Sequential(
295
+ nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
296
+ )
297
+
298
+ def __call__(self, x: mx.array, vec: mx.array):
299
+ shift, scale = mx.split(self.adaLN_modulation(vec), 2, axis=1)
300
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
301
+ x = self.linear(x)
302
+ return x
@@ -0,0 +1,76 @@
1
+ # Copyright © 2024 Apple Inc.
2
+
3
+ import math
4
+
5
+ import mlx.core as mx
6
+ import mlx.nn as nn
7
+
8
+
9
+ class LoRALinear(nn.Module):
10
+ @staticmethod
11
+ def from_base(
12
+ linear: nn.Linear,
13
+ r: int = 8,
14
+ dropout: float = 0.0,
15
+ scale: float = 1.0,
16
+ ):
17
+ output_dims, input_dims = linear.weight.shape
18
+ lora_lin = LoRALinear(
19
+ input_dims=input_dims,
20
+ output_dims=output_dims,
21
+ r=r,
22
+ dropout=dropout,
23
+ scale=scale,
24
+ )
25
+ lora_lin.linear = linear
26
+ return lora_lin
27
+
28
+ def fuse(self):
29
+ linear = self.linear
30
+ bias = "bias" in linear
31
+ weight = linear.weight
32
+ dtype = weight.dtype
33
+
34
+ output_dims, input_dims = weight.shape
35
+ fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
36
+
37
+ lora_b = self.scale * self.lora_b.T
38
+ lora_a = self.lora_a.T
39
+ fused_linear.weight = weight + (lora_b @ lora_a).astype(dtype)
40
+ if bias:
41
+ fused_linear.bias = linear.bias
42
+
43
+ return fused_linear
44
+
45
+ def __init__(
46
+ self,
47
+ input_dims: int,
48
+ output_dims: int,
49
+ r: int = 8,
50
+ dropout: float = 0.0,
51
+ scale: float = 1.0,
52
+ bias: bool = False,
53
+ ):
54
+ super().__init__()
55
+
56
+ # Regular linear layer weights
57
+ self.linear = nn.Linear(input_dims, output_dims, bias=bias)
58
+
59
+ self.dropout = nn.Dropout(p=dropout)
60
+
61
+ # Scale for low-rank update
62
+ self.scale = scale
63
+
64
+ # Low rank lora weights
65
+ scale = 1 / math.sqrt(input_dims)
66
+ self.lora_a = mx.random.uniform(
67
+ low=-scale,
68
+ high=scale,
69
+ shape=(input_dims, r),
70
+ )
71
+ self.lora_b = mx.zeros(shape=(r, output_dims))
72
+
73
+ def __call__(self, x):
74
+ y = self.linear(x)
75
+ z = (self.dropout(x) @ self.lora_a) @ self.lora_b
76
+ return y + (self.scale * z).astype(x.dtype)