xinference 0.16.0__py3-none-any.whl → 0.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +48 -0
- xinference/client/restful/restful_client.py +19 -0
- xinference/core/chat_interface.py +5 -1
- xinference/core/image_interface.py +5 -1
- xinference/core/model.py +106 -16
- xinference/core/scheduler.py +1 -1
- xinference/deploy/supervisor.py +0 -4
- xinference/model/audio/chattts.py +25 -14
- xinference/model/audio/model_spec.json +1 -1
- xinference/model/audio/model_spec_modelscope.json +1 -1
- xinference/model/embedding/model_spec.json +1 -1
- xinference/model/image/core.py +59 -4
- xinference/model/image/model_spec.json +24 -3
- xinference/model/image/model_spec_modelscope.json +25 -3
- xinference/model/image/ocr/__init__.py +13 -0
- xinference/model/image/ocr/got_ocr2.py +76 -0
- xinference/model/image/scheduler/flux.py +1 -1
- xinference/model/image/stable_diffusion/core.py +2 -3
- xinference/model/image/stable_diffusion/mlx.py +221 -0
- xinference/model/llm/llm_family.json +9 -0
- xinference/model/llm/llm_family_modelscope.json +11 -0
- xinference/thirdparty/mlx/__init__.py +13 -0
- xinference/thirdparty/mlx/flux/__init__.py +15 -0
- xinference/thirdparty/mlx/flux/autoencoder.py +357 -0
- xinference/thirdparty/mlx/flux/clip.py +154 -0
- xinference/thirdparty/mlx/flux/datasets.py +75 -0
- xinference/thirdparty/mlx/flux/flux.py +247 -0
- xinference/thirdparty/mlx/flux/layers.py +302 -0
- xinference/thirdparty/mlx/flux/lora.py +76 -0
- xinference/thirdparty/mlx/flux/model.py +134 -0
- xinference/thirdparty/mlx/flux/sampler.py +56 -0
- xinference/thirdparty/mlx/flux/t5.py +244 -0
- xinference/thirdparty/mlx/flux/tokenizers.py +185 -0
- xinference/thirdparty/mlx/flux/trainer.py +98 -0
- xinference/thirdparty/mlx/flux/utils.py +179 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.f7da0140.js → main.b76aeeb7.js} +3 -3
- xinference/web/ui/build/static/js/main.b76aeeb7.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/32ea2c04cf0bba2761b4883d2c40cc259952c94d2d6bb774e510963ca37aac0a.json +1 -0
- {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/METADATA +15 -8
- {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/RECORD +48 -31
- xinference/web/ui/build/static/js/main.f7da0140.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +0 -1
- /xinference/web/ui/build/static/js/{main.f7da0140.js.LICENSE.txt → main.b76aeeb7.js.LICENSE.txt} +0 -0
- {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/LICENSE +0 -0
- {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/WHEEL +0 -0
- {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.16.0.dist-info → xinference-0.16.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,357 @@
|
|
|
1
|
+
# Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
import mlx.core as mx
|
|
7
|
+
import mlx.nn as nn
|
|
8
|
+
from mlx.nn.layers.upsample import upsample_nearest
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class AutoEncoderParams:
|
|
13
|
+
resolution: int
|
|
14
|
+
in_channels: int
|
|
15
|
+
ch: int
|
|
16
|
+
out_ch: int
|
|
17
|
+
ch_mult: List[int]
|
|
18
|
+
num_res_blocks: int
|
|
19
|
+
z_channels: int
|
|
20
|
+
scale_factor: float
|
|
21
|
+
shift_factor: float
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class AttnBlock(nn.Module):
|
|
25
|
+
def __init__(self, in_channels: int):
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.in_channels = in_channels
|
|
28
|
+
|
|
29
|
+
self.norm = nn.GroupNorm(
|
|
30
|
+
num_groups=32,
|
|
31
|
+
dims=in_channels,
|
|
32
|
+
eps=1e-6,
|
|
33
|
+
affine=True,
|
|
34
|
+
pytorch_compatible=True,
|
|
35
|
+
)
|
|
36
|
+
self.q = nn.Linear(in_channels, in_channels)
|
|
37
|
+
self.k = nn.Linear(in_channels, in_channels)
|
|
38
|
+
self.v = nn.Linear(in_channels, in_channels)
|
|
39
|
+
self.proj_out = nn.Linear(in_channels, in_channels)
|
|
40
|
+
|
|
41
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
42
|
+
B, H, W, C = x.shape
|
|
43
|
+
|
|
44
|
+
y = x.reshape(B, 1, -1, C)
|
|
45
|
+
y = self.norm(y)
|
|
46
|
+
q = self.q(y)
|
|
47
|
+
k = self.k(y)
|
|
48
|
+
v = self.v(y)
|
|
49
|
+
y = mx.fast.scaled_dot_product_attention(q, k, v, scale=C ** (-0.5))
|
|
50
|
+
y = self.proj_out(y)
|
|
51
|
+
|
|
52
|
+
return x + y.reshape(B, H, W, C)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class ResnetBlock(nn.Module):
|
|
56
|
+
def __init__(self, in_channels: int, out_channels: int):
|
|
57
|
+
super().__init__()
|
|
58
|
+
self.in_channels = in_channels
|
|
59
|
+
out_channels = in_channels if out_channels is None else out_channels
|
|
60
|
+
self.out_channels = out_channels
|
|
61
|
+
|
|
62
|
+
self.norm1 = nn.GroupNorm(
|
|
63
|
+
num_groups=32,
|
|
64
|
+
dims=in_channels,
|
|
65
|
+
eps=1e-6,
|
|
66
|
+
affine=True,
|
|
67
|
+
pytorch_compatible=True,
|
|
68
|
+
)
|
|
69
|
+
self.conv1 = nn.Conv2d(
|
|
70
|
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
71
|
+
)
|
|
72
|
+
self.norm2 = nn.GroupNorm(
|
|
73
|
+
num_groups=32,
|
|
74
|
+
dims=out_channels,
|
|
75
|
+
eps=1e-6,
|
|
76
|
+
affine=True,
|
|
77
|
+
pytorch_compatible=True,
|
|
78
|
+
)
|
|
79
|
+
self.conv2 = nn.Conv2d(
|
|
80
|
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
81
|
+
)
|
|
82
|
+
if self.in_channels != self.out_channels:
|
|
83
|
+
self.nin_shortcut = nn.Linear(in_channels, out_channels)
|
|
84
|
+
|
|
85
|
+
def __call__(self, x):
|
|
86
|
+
h = x
|
|
87
|
+
h = self.norm1(h)
|
|
88
|
+
h = nn.silu(h)
|
|
89
|
+
h = self.conv1(h)
|
|
90
|
+
|
|
91
|
+
h = self.norm2(h)
|
|
92
|
+
h = nn.silu(h)
|
|
93
|
+
h = self.conv2(h)
|
|
94
|
+
|
|
95
|
+
if self.in_channels != self.out_channels:
|
|
96
|
+
x = self.nin_shortcut(x)
|
|
97
|
+
|
|
98
|
+
return x + h
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class Downsample(nn.Module):
|
|
102
|
+
def __init__(self, in_channels: int):
|
|
103
|
+
super().__init__()
|
|
104
|
+
self.conv = nn.Conv2d(
|
|
105
|
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def __call__(self, x: mx.array):
|
|
109
|
+
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
|
|
110
|
+
x = self.conv(x)
|
|
111
|
+
return x
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class Upsample(nn.Module):
|
|
115
|
+
def __init__(self, in_channels: int):
|
|
116
|
+
super().__init__()
|
|
117
|
+
self.conv = nn.Conv2d(
|
|
118
|
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
def __call__(self, x: mx.array):
|
|
122
|
+
x = upsample_nearest(x, (2, 2))
|
|
123
|
+
x = self.conv(x)
|
|
124
|
+
return x
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class Encoder(nn.Module):
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
resolution: int,
|
|
131
|
+
in_channels: int,
|
|
132
|
+
ch: int,
|
|
133
|
+
ch_mult: list[int],
|
|
134
|
+
num_res_blocks: int,
|
|
135
|
+
z_channels: int,
|
|
136
|
+
):
|
|
137
|
+
super().__init__()
|
|
138
|
+
self.ch = ch
|
|
139
|
+
self.num_resolutions = len(ch_mult)
|
|
140
|
+
self.num_res_blocks = num_res_blocks
|
|
141
|
+
self.resolution = resolution
|
|
142
|
+
self.in_channels = in_channels
|
|
143
|
+
# downsampling
|
|
144
|
+
self.conv_in = nn.Conv2d(
|
|
145
|
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
curr_res = resolution
|
|
149
|
+
in_ch_mult = (1,) + tuple(ch_mult)
|
|
150
|
+
self.in_ch_mult = in_ch_mult
|
|
151
|
+
self.down = []
|
|
152
|
+
block_in = self.ch
|
|
153
|
+
for i_level in range(self.num_resolutions):
|
|
154
|
+
block = []
|
|
155
|
+
attn = [] # TODO: Remove the attn, nobody appends anything to it
|
|
156
|
+
block_in = ch * in_ch_mult[i_level]
|
|
157
|
+
block_out = ch * ch_mult[i_level]
|
|
158
|
+
for _ in range(self.num_res_blocks):
|
|
159
|
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
|
160
|
+
block_in = block_out
|
|
161
|
+
down = {}
|
|
162
|
+
down["block"] = block
|
|
163
|
+
down["attn"] = attn
|
|
164
|
+
if i_level != self.num_resolutions - 1:
|
|
165
|
+
down["downsample"] = Downsample(block_in)
|
|
166
|
+
curr_res = curr_res // 2
|
|
167
|
+
self.down.append(down)
|
|
168
|
+
|
|
169
|
+
# middle
|
|
170
|
+
self.mid = {}
|
|
171
|
+
self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
|
172
|
+
self.mid["attn_1"] = AttnBlock(block_in)
|
|
173
|
+
self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
|
174
|
+
|
|
175
|
+
# end
|
|
176
|
+
self.norm_out = nn.GroupNorm(
|
|
177
|
+
num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
|
|
178
|
+
)
|
|
179
|
+
self.conv_out = nn.Conv2d(
|
|
180
|
+
block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
def __call__(self, x: mx.array):
|
|
184
|
+
hs = [self.conv_in(x)]
|
|
185
|
+
for i_level in range(self.num_resolutions):
|
|
186
|
+
for i_block in range(self.num_res_blocks):
|
|
187
|
+
h = self.down[i_level]["block"][i_block](hs[-1])
|
|
188
|
+
|
|
189
|
+
# TODO: Remove the attn
|
|
190
|
+
if len(self.down[i_level]["attn"]) > 0:
|
|
191
|
+
h = self.down[i_level]["attn"][i_block](h)
|
|
192
|
+
|
|
193
|
+
hs.append(h)
|
|
194
|
+
|
|
195
|
+
if i_level != self.num_resolutions - 1:
|
|
196
|
+
hs.append(self.down[i_level]["downsample"](hs[-1]))
|
|
197
|
+
|
|
198
|
+
# middle
|
|
199
|
+
h = hs[-1]
|
|
200
|
+
h = self.mid["block_1"](h)
|
|
201
|
+
h = self.mid["attn_1"](h)
|
|
202
|
+
h = self.mid["block_2"](h)
|
|
203
|
+
|
|
204
|
+
# end
|
|
205
|
+
h = self.norm_out(h)
|
|
206
|
+
h = nn.silu(h)
|
|
207
|
+
h = self.conv_out(h)
|
|
208
|
+
|
|
209
|
+
return h
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class Decoder(nn.Module):
|
|
213
|
+
def __init__(
|
|
214
|
+
self,
|
|
215
|
+
ch: int,
|
|
216
|
+
out_ch: int,
|
|
217
|
+
ch_mult: list[int],
|
|
218
|
+
num_res_blocks: int,
|
|
219
|
+
in_channels: int,
|
|
220
|
+
resolution: int,
|
|
221
|
+
z_channels: int,
|
|
222
|
+
):
|
|
223
|
+
super().__init__()
|
|
224
|
+
self.ch = ch
|
|
225
|
+
self.num_resolutions = len(ch_mult)
|
|
226
|
+
self.num_res_blocks = num_res_blocks
|
|
227
|
+
self.resolution = resolution
|
|
228
|
+
self.in_channels = in_channels
|
|
229
|
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
|
230
|
+
|
|
231
|
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
|
232
|
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
|
233
|
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
|
234
|
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
|
235
|
+
|
|
236
|
+
# z to block_in
|
|
237
|
+
self.conv_in = nn.Conv2d(
|
|
238
|
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# middle
|
|
242
|
+
self.mid = {}
|
|
243
|
+
self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
|
244
|
+
self.mid["attn_1"] = AttnBlock(block_in)
|
|
245
|
+
self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
|
246
|
+
|
|
247
|
+
# upsampling
|
|
248
|
+
self.up = []
|
|
249
|
+
for i_level in reversed(range(self.num_resolutions)):
|
|
250
|
+
block = []
|
|
251
|
+
attn = [] # TODO: Remove the attn, nobody appends anything to it
|
|
252
|
+
|
|
253
|
+
block_out = ch * ch_mult[i_level]
|
|
254
|
+
for _ in range(self.num_res_blocks + 1):
|
|
255
|
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
|
256
|
+
block_in = block_out
|
|
257
|
+
up = {}
|
|
258
|
+
up["block"] = block
|
|
259
|
+
up["attn"] = attn
|
|
260
|
+
if i_level != 0:
|
|
261
|
+
up["upsample"] = Upsample(block_in)
|
|
262
|
+
curr_res = curr_res * 2
|
|
263
|
+
self.up.insert(0, up) # prepend to get consistent order
|
|
264
|
+
|
|
265
|
+
# end
|
|
266
|
+
self.norm_out = nn.GroupNorm(
|
|
267
|
+
num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
|
|
268
|
+
)
|
|
269
|
+
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
|
270
|
+
|
|
271
|
+
def __call__(self, z: mx.array):
|
|
272
|
+
# z to block_in
|
|
273
|
+
h = self.conv_in(z)
|
|
274
|
+
|
|
275
|
+
# middle
|
|
276
|
+
h = self.mid["block_1"](h)
|
|
277
|
+
h = self.mid["attn_1"](h)
|
|
278
|
+
h = self.mid["block_2"](h)
|
|
279
|
+
|
|
280
|
+
# upsampling
|
|
281
|
+
for i_level in reversed(range(self.num_resolutions)):
|
|
282
|
+
for i_block in range(self.num_res_blocks + 1):
|
|
283
|
+
h = self.up[i_level]["block"][i_block](h)
|
|
284
|
+
|
|
285
|
+
# TODO: Remove the attn
|
|
286
|
+
if len(self.up[i_level]["attn"]) > 0:
|
|
287
|
+
h = self.up[i_level]["attn"][i_block](h)
|
|
288
|
+
|
|
289
|
+
if i_level != 0:
|
|
290
|
+
h = self.up[i_level]["upsample"](h)
|
|
291
|
+
|
|
292
|
+
# end
|
|
293
|
+
h = self.norm_out(h)
|
|
294
|
+
h = nn.silu(h)
|
|
295
|
+
h = self.conv_out(h)
|
|
296
|
+
|
|
297
|
+
return h
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
class DiagonalGaussian(nn.Module):
|
|
301
|
+
def __call__(self, z: mx.array):
|
|
302
|
+
mean, logvar = mx.split(z, 2, axis=-1)
|
|
303
|
+
if self.training:
|
|
304
|
+
std = mx.exp(0.5 * logvar)
|
|
305
|
+
eps = mx.random.normal(shape=z.shape, dtype=z.dtype)
|
|
306
|
+
return mean + std * eps
|
|
307
|
+
else:
|
|
308
|
+
return mean
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
class AutoEncoder(nn.Module):
|
|
312
|
+
def __init__(self, params: AutoEncoderParams):
|
|
313
|
+
super().__init__()
|
|
314
|
+
self.encoder = Encoder(
|
|
315
|
+
resolution=params.resolution,
|
|
316
|
+
in_channels=params.in_channels,
|
|
317
|
+
ch=params.ch,
|
|
318
|
+
ch_mult=params.ch_mult,
|
|
319
|
+
num_res_blocks=params.num_res_blocks,
|
|
320
|
+
z_channels=params.z_channels,
|
|
321
|
+
)
|
|
322
|
+
self.decoder = Decoder(
|
|
323
|
+
resolution=params.resolution,
|
|
324
|
+
in_channels=params.in_channels,
|
|
325
|
+
ch=params.ch,
|
|
326
|
+
out_ch=params.out_ch,
|
|
327
|
+
ch_mult=params.ch_mult,
|
|
328
|
+
num_res_blocks=params.num_res_blocks,
|
|
329
|
+
z_channels=params.z_channels,
|
|
330
|
+
)
|
|
331
|
+
self.reg = DiagonalGaussian()
|
|
332
|
+
|
|
333
|
+
self.scale_factor = params.scale_factor
|
|
334
|
+
self.shift_factor = params.shift_factor
|
|
335
|
+
|
|
336
|
+
def sanitize(self, weights):
|
|
337
|
+
new_weights = {}
|
|
338
|
+
for k, w in weights.items():
|
|
339
|
+
if w.ndim == 4:
|
|
340
|
+
w = w.transpose(0, 2, 3, 1)
|
|
341
|
+
w = w.reshape(-1).reshape(w.shape)
|
|
342
|
+
if w.shape[1:3] == (1, 1):
|
|
343
|
+
w = w.squeeze((1, 2))
|
|
344
|
+
new_weights[k] = w
|
|
345
|
+
return new_weights
|
|
346
|
+
|
|
347
|
+
def encode(self, x: mx.array):
|
|
348
|
+
z = self.reg(self.encoder(x))
|
|
349
|
+
z = self.scale_factor * (z - self.shift_factor)
|
|
350
|
+
return z
|
|
351
|
+
|
|
352
|
+
def decode(self, z: mx.array):
|
|
353
|
+
z = z / self.scale_factor + self.shift_factor
|
|
354
|
+
return self.decoder(z)
|
|
355
|
+
|
|
356
|
+
def __call__(self, x: mx.array):
|
|
357
|
+
return self.decode(self.encode(x))
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
# Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
|
|
6
|
+
import mlx.core as mx
|
|
7
|
+
import mlx.nn as nn
|
|
8
|
+
|
|
9
|
+
_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class CLIPTextModelConfig:
|
|
14
|
+
num_layers: int = 23
|
|
15
|
+
model_dims: int = 1024
|
|
16
|
+
num_heads: int = 16
|
|
17
|
+
max_length: int = 77
|
|
18
|
+
vocab_size: int = 49408
|
|
19
|
+
hidden_act: str = "quick_gelu"
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
def from_dict(cls, config):
|
|
23
|
+
return cls(
|
|
24
|
+
num_layers=config["num_hidden_layers"],
|
|
25
|
+
model_dims=config["hidden_size"],
|
|
26
|
+
num_heads=config["num_attention_heads"],
|
|
27
|
+
max_length=config["max_position_embeddings"],
|
|
28
|
+
vocab_size=config["vocab_size"],
|
|
29
|
+
hidden_act=config["hidden_act"],
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class CLIPOutput:
|
|
35
|
+
# The last_hidden_state indexed at the EOS token and possibly projected if
|
|
36
|
+
# the model has a projection layer
|
|
37
|
+
pooled_output: Optional[mx.array] = None
|
|
38
|
+
|
|
39
|
+
# The full sequence output of the transformer after the final layernorm
|
|
40
|
+
last_hidden_state: Optional[mx.array] = None
|
|
41
|
+
|
|
42
|
+
# A list of hidden states corresponding to the outputs of the transformer layers
|
|
43
|
+
hidden_states: Optional[List[mx.array]] = None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class CLIPEncoderLayer(nn.Module):
|
|
47
|
+
"""The transformer encoder layer from CLIP."""
|
|
48
|
+
|
|
49
|
+
def __init__(self, model_dims: int, num_heads: int, activation: str):
|
|
50
|
+
super().__init__()
|
|
51
|
+
|
|
52
|
+
self.layer_norm1 = nn.LayerNorm(model_dims)
|
|
53
|
+
self.layer_norm2 = nn.LayerNorm(model_dims)
|
|
54
|
+
|
|
55
|
+
self.attention = nn.MultiHeadAttention(model_dims, num_heads, bias=True)
|
|
56
|
+
|
|
57
|
+
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
|
|
58
|
+
self.linear2 = nn.Linear(4 * model_dims, model_dims)
|
|
59
|
+
|
|
60
|
+
self.act = _ACTIVATIONS[activation]
|
|
61
|
+
|
|
62
|
+
def __call__(self, x, attn_mask=None):
|
|
63
|
+
y = self.layer_norm1(x)
|
|
64
|
+
y = self.attention(y, y, y, attn_mask)
|
|
65
|
+
x = y + x
|
|
66
|
+
|
|
67
|
+
y = self.layer_norm2(x)
|
|
68
|
+
y = self.linear1(y)
|
|
69
|
+
y = self.act(y)
|
|
70
|
+
y = self.linear2(y)
|
|
71
|
+
x = y + x
|
|
72
|
+
|
|
73
|
+
return x
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class CLIPTextModel(nn.Module):
|
|
77
|
+
"""Implements the text encoder transformer from CLIP."""
|
|
78
|
+
|
|
79
|
+
def __init__(self, config: CLIPTextModelConfig):
|
|
80
|
+
super().__init__()
|
|
81
|
+
|
|
82
|
+
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
|
|
83
|
+
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
|
|
84
|
+
self.layers = [
|
|
85
|
+
CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act)
|
|
86
|
+
for i in range(config.num_layers)
|
|
87
|
+
]
|
|
88
|
+
self.final_layer_norm = nn.LayerNorm(config.model_dims)
|
|
89
|
+
|
|
90
|
+
def _get_mask(self, N, dtype):
|
|
91
|
+
indices = mx.arange(N)
|
|
92
|
+
mask = indices[:, None] < indices[None]
|
|
93
|
+
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
|
|
94
|
+
return mask
|
|
95
|
+
|
|
96
|
+
def sanitize(self, weights):
|
|
97
|
+
new_weights = {}
|
|
98
|
+
for key, w in weights.items():
|
|
99
|
+
# Remove prefixes
|
|
100
|
+
if key.startswith("text_model."):
|
|
101
|
+
key = key[11:]
|
|
102
|
+
if key.startswith("embeddings."):
|
|
103
|
+
key = key[11:]
|
|
104
|
+
if key.startswith("encoder."):
|
|
105
|
+
key = key[8:]
|
|
106
|
+
|
|
107
|
+
# Map attention layers
|
|
108
|
+
if "self_attn." in key:
|
|
109
|
+
key = key.replace("self_attn.", "attention.")
|
|
110
|
+
if "q_proj." in key:
|
|
111
|
+
key = key.replace("q_proj.", "query_proj.")
|
|
112
|
+
if "k_proj." in key:
|
|
113
|
+
key = key.replace("k_proj.", "key_proj.")
|
|
114
|
+
if "v_proj." in key:
|
|
115
|
+
key = key.replace("v_proj.", "value_proj.")
|
|
116
|
+
|
|
117
|
+
# Map ffn layers
|
|
118
|
+
if "mlp.fc1" in key:
|
|
119
|
+
key = key.replace("mlp.fc1", "linear1")
|
|
120
|
+
if "mlp.fc2" in key:
|
|
121
|
+
key = key.replace("mlp.fc2", "linear2")
|
|
122
|
+
|
|
123
|
+
new_weights[key] = w
|
|
124
|
+
|
|
125
|
+
return new_weights
|
|
126
|
+
|
|
127
|
+
def __call__(self, x):
|
|
128
|
+
# Extract some shapes
|
|
129
|
+
B, N = x.shape
|
|
130
|
+
eos_tokens = x.argmax(-1)
|
|
131
|
+
|
|
132
|
+
# Compute the embeddings
|
|
133
|
+
x = self.token_embedding(x)
|
|
134
|
+
x = x + self.position_embedding.weight[:N]
|
|
135
|
+
|
|
136
|
+
# Compute the features from the transformer
|
|
137
|
+
mask = self._get_mask(N, x.dtype)
|
|
138
|
+
hidden_states = []
|
|
139
|
+
for l in self.layers:
|
|
140
|
+
x = l(x, mask)
|
|
141
|
+
hidden_states.append(x)
|
|
142
|
+
|
|
143
|
+
# Apply the final layernorm and return
|
|
144
|
+
x = self.final_layer_norm(x)
|
|
145
|
+
last_hidden_state = x
|
|
146
|
+
|
|
147
|
+
# Select the EOS token
|
|
148
|
+
pooled_output = x[mx.arange(len(x)), eos_tokens]
|
|
149
|
+
|
|
150
|
+
return CLIPOutput(
|
|
151
|
+
pooled_output=pooled_output,
|
|
152
|
+
last_hidden_state=last_hidden_state,
|
|
153
|
+
hidden_states=hidden_states,
|
|
154
|
+
)
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from PIL import Image
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Dataset:
|
|
8
|
+
def __getitem__(self, index: int):
|
|
9
|
+
raise NotImplementedError()
|
|
10
|
+
|
|
11
|
+
def __len__(self):
|
|
12
|
+
raise NotImplementedError()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LocalDataset(Dataset):
|
|
16
|
+
prompt_key = "prompt"
|
|
17
|
+
|
|
18
|
+
def __init__(self, dataset: str, data_file):
|
|
19
|
+
self.dataset_base = Path(dataset)
|
|
20
|
+
with open(data_file, "r") as fid:
|
|
21
|
+
self._data = [json.loads(l) for l in fid]
|
|
22
|
+
|
|
23
|
+
def __len__(self):
|
|
24
|
+
return len(self._data)
|
|
25
|
+
|
|
26
|
+
def __getitem__(self, index: int):
|
|
27
|
+
item = self._data[index]
|
|
28
|
+
image = Image.open(self.dataset_base / item["image"])
|
|
29
|
+
return image, item[self.prompt_key]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class LegacyDataset(LocalDataset):
|
|
33
|
+
prompt_key = "text"
|
|
34
|
+
|
|
35
|
+
def __init__(self, dataset: str):
|
|
36
|
+
self.dataset_base = Path(dataset)
|
|
37
|
+
with open(self.dataset_base / "index.json") as f:
|
|
38
|
+
self._data = json.load(f)["data"]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class HuggingFaceDataset(Dataset):
|
|
42
|
+
|
|
43
|
+
def __init__(self, dataset: str):
|
|
44
|
+
from datasets import load_dataset as hf_load_dataset
|
|
45
|
+
|
|
46
|
+
self._df = hf_load_dataset(dataset)["train"]
|
|
47
|
+
|
|
48
|
+
def __len__(self):
|
|
49
|
+
return len(self._df)
|
|
50
|
+
|
|
51
|
+
def __getitem__(self, index: int):
|
|
52
|
+
item = self._df[index]
|
|
53
|
+
return item["image"], item["prompt"]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def load_dataset(dataset: str):
|
|
57
|
+
dataset_base = Path(dataset)
|
|
58
|
+
data_file = dataset_base / "train.jsonl"
|
|
59
|
+
legacy_file = dataset_base / "index.json"
|
|
60
|
+
|
|
61
|
+
if data_file.exists():
|
|
62
|
+
print(f"Load the local dataset {data_file} .", flush=True)
|
|
63
|
+
dataset = LocalDataset(dataset, data_file)
|
|
64
|
+
elif legacy_file.exists():
|
|
65
|
+
print(f"Load the local dataset {legacy_file} .")
|
|
66
|
+
print()
|
|
67
|
+
print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.")
|
|
68
|
+
print(" See the README for details.")
|
|
69
|
+
print(flush=True)
|
|
70
|
+
dataset = LegacyDataset(dataset)
|
|
71
|
+
else:
|
|
72
|
+
print(f"Load the Hugging Face dataset {dataset} .", flush=True)
|
|
73
|
+
dataset = HuggingFaceDataset(dataset)
|
|
74
|
+
|
|
75
|
+
return dataset
|