xinference 0.16.0__py3-none-any.whl → 0.16.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (62) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +48 -0
  3. xinference/client/restful/restful_client.py +19 -0
  4. xinference/constants.py +1 -0
  5. xinference/core/chat_interface.py +5 -1
  6. xinference/core/image_interface.py +5 -1
  7. xinference/core/model.py +106 -16
  8. xinference/core/scheduler.py +1 -1
  9. xinference/core/worker.py +3 -1
  10. xinference/deploy/supervisor.py +0 -4
  11. xinference/model/audio/chattts.py +25 -14
  12. xinference/model/audio/core.py +6 -2
  13. xinference/model/audio/model_spec.json +1 -1
  14. xinference/model/audio/model_spec_modelscope.json +1 -1
  15. xinference/model/core.py +3 -1
  16. xinference/model/embedding/core.py +6 -2
  17. xinference/model/embedding/model_spec.json +1 -1
  18. xinference/model/image/core.py +65 -6
  19. xinference/model/image/model_spec.json +24 -3
  20. xinference/model/image/model_spec_modelscope.json +25 -3
  21. xinference/model/image/ocr/__init__.py +13 -0
  22. xinference/model/image/ocr/got_ocr2.py +79 -0
  23. xinference/model/image/scheduler/flux.py +1 -1
  24. xinference/model/image/stable_diffusion/core.py +2 -3
  25. xinference/model/image/stable_diffusion/mlx.py +221 -0
  26. xinference/model/llm/__init__.py +33 -0
  27. xinference/model/llm/core.py +3 -1
  28. xinference/model/llm/llm_family.json +9 -0
  29. xinference/model/llm/llm_family.py +68 -2
  30. xinference/model/llm/llm_family_modelscope.json +11 -0
  31. xinference/model/llm/llm_family_openmind_hub.json +1359 -0
  32. xinference/model/rerank/core.py +9 -1
  33. xinference/model/utils.py +7 -0
  34. xinference/model/video/core.py +6 -2
  35. xinference/thirdparty/mlx/__init__.py +13 -0
  36. xinference/thirdparty/mlx/flux/__init__.py +15 -0
  37. xinference/thirdparty/mlx/flux/autoencoder.py +357 -0
  38. xinference/thirdparty/mlx/flux/clip.py +154 -0
  39. xinference/thirdparty/mlx/flux/datasets.py +75 -0
  40. xinference/thirdparty/mlx/flux/flux.py +247 -0
  41. xinference/thirdparty/mlx/flux/layers.py +302 -0
  42. xinference/thirdparty/mlx/flux/lora.py +76 -0
  43. xinference/thirdparty/mlx/flux/model.py +134 -0
  44. xinference/thirdparty/mlx/flux/sampler.py +56 -0
  45. xinference/thirdparty/mlx/flux/t5.py +244 -0
  46. xinference/thirdparty/mlx/flux/tokenizers.py +185 -0
  47. xinference/thirdparty/mlx/flux/trainer.py +98 -0
  48. xinference/thirdparty/mlx/flux/utils.py +179 -0
  49. xinference/web/ui/build/asset-manifest.json +3 -3
  50. xinference/web/ui/build/index.html +1 -1
  51. xinference/web/ui/build/static/js/{main.f7da0140.js → main.2f269bb3.js} +3 -3
  52. xinference/web/ui/build/static/js/main.2f269bb3.js.map +1 -0
  53. xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +1 -0
  54. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/METADATA +16 -9
  55. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/RECORD +60 -42
  56. xinference/web/ui/build/static/js/main.f7da0140.js.map +0 -1
  57. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +0 -1
  58. /xinference/web/ui/build/static/js/{main.f7da0140.js.LICENSE.txt → main.2f269bb3.js.LICENSE.txt} +0 -0
  59. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/LICENSE +0 -0
  60. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/WHEEL +0 -0
  61. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/entry_points.txt +0 -0
  62. {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/top_level.txt +0 -0
@@ -268,6 +268,12 @@ class RerankModel:
268
268
  similarity_scores = self._model.compute_score(sentence_combinations)
269
269
  if not isinstance(similarity_scores, Sequence):
270
270
  similarity_scores = [similarity_scores]
271
+ elif (
272
+ isinstance(similarity_scores, list)
273
+ and len(similarity_scores) > 0
274
+ and isinstance(similarity_scores[0], Sequence)
275
+ ):
276
+ similarity_scores = similarity_scores[0]
271
277
 
272
278
  sim_scores_argsort = list(reversed(np.argsort(similarity_scores)))
273
279
  if top_n is not None:
@@ -341,7 +347,9 @@ def create_rerank_model_instance(
341
347
  devices: List[str],
342
348
  model_uid: str,
343
349
  model_name: str,
344
- download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
350
+ download_hub: Optional[
351
+ Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
352
+ ] = None,
345
353
  model_path: Optional[str] = None,
346
354
  **kwargs,
347
355
  ) -> Tuple[RerankModel, RerankModelDescription]:
xinference/model/utils.py CHANGED
@@ -54,6 +54,13 @@ def download_from_modelscope() -> bool:
54
54
  return False
55
55
 
56
56
 
57
+ def download_from_openmind_hub() -> bool:
58
+ if os.environ.get(XINFERENCE_ENV_MODEL_SRC):
59
+ return os.environ.get(XINFERENCE_ENV_MODEL_SRC) == "openmind_hub"
60
+ else:
61
+ return False
62
+
63
+
57
64
  def download_from_csghub() -> bool:
58
65
  if os.environ.get(XINFERENCE_ENV_MODEL_SRC) == "csghub":
59
66
  return True
@@ -97,7 +97,9 @@ def generate_video_description(
97
97
 
98
98
  def match_diffusion(
99
99
  model_name: str,
100
- download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
100
+ download_hub: Optional[
101
+ Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
102
+ ] = None,
101
103
  ) -> VideoModelFamilyV1:
102
104
  from ..utils import download_from_modelscope
103
105
  from . import BUILTIN_VIDEO_MODELS, MODELSCOPE_VIDEO_MODELS
@@ -157,7 +159,9 @@ def create_video_model_instance(
157
159
  devices: List[str],
158
160
  model_uid: str,
159
161
  model_name: str,
160
- download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
162
+ download_hub: Optional[
163
+ Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
164
+ ] = None,
161
165
  model_path: Optional[str] = None,
162
166
  **kwargs,
163
167
  ) -> Tuple[DiffUsersVideoModel, VideoModelDescription]:
@@ -0,0 +1,13 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,15 @@
1
+ # Copyright © 2024 Apple Inc.
2
+
3
+ from .datasets import Dataset, load_dataset
4
+ from .flux import FluxPipeline
5
+ from .lora import LoRALinear
6
+ from .sampler import FluxSampler
7
+ from .trainer import Trainer
8
+ from .utils import (
9
+ load_ae,
10
+ load_clip,
11
+ load_clip_tokenizer,
12
+ load_flow_model,
13
+ load_t5,
14
+ load_t5_tokenizer,
15
+ )
@@ -0,0 +1,357 @@
1
+ # Copyright © 2024 Apple Inc.
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List
5
+
6
+ import mlx.core as mx
7
+ import mlx.nn as nn
8
+ from mlx.nn.layers.upsample import upsample_nearest
9
+
10
+
11
+ @dataclass
12
+ class AutoEncoderParams:
13
+ resolution: int
14
+ in_channels: int
15
+ ch: int
16
+ out_ch: int
17
+ ch_mult: List[int]
18
+ num_res_blocks: int
19
+ z_channels: int
20
+ scale_factor: float
21
+ shift_factor: float
22
+
23
+
24
+ class AttnBlock(nn.Module):
25
+ def __init__(self, in_channels: int):
26
+ super().__init__()
27
+ self.in_channels = in_channels
28
+
29
+ self.norm = nn.GroupNorm(
30
+ num_groups=32,
31
+ dims=in_channels,
32
+ eps=1e-6,
33
+ affine=True,
34
+ pytorch_compatible=True,
35
+ )
36
+ self.q = nn.Linear(in_channels, in_channels)
37
+ self.k = nn.Linear(in_channels, in_channels)
38
+ self.v = nn.Linear(in_channels, in_channels)
39
+ self.proj_out = nn.Linear(in_channels, in_channels)
40
+
41
+ def __call__(self, x: mx.array) -> mx.array:
42
+ B, H, W, C = x.shape
43
+
44
+ y = x.reshape(B, 1, -1, C)
45
+ y = self.norm(y)
46
+ q = self.q(y)
47
+ k = self.k(y)
48
+ v = self.v(y)
49
+ y = mx.fast.scaled_dot_product_attention(q, k, v, scale=C ** (-0.5))
50
+ y = self.proj_out(y)
51
+
52
+ return x + y.reshape(B, H, W, C)
53
+
54
+
55
+ class ResnetBlock(nn.Module):
56
+ def __init__(self, in_channels: int, out_channels: int):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+ out_channels = in_channels if out_channels is None else out_channels
60
+ self.out_channels = out_channels
61
+
62
+ self.norm1 = nn.GroupNorm(
63
+ num_groups=32,
64
+ dims=in_channels,
65
+ eps=1e-6,
66
+ affine=True,
67
+ pytorch_compatible=True,
68
+ )
69
+ self.conv1 = nn.Conv2d(
70
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
71
+ )
72
+ self.norm2 = nn.GroupNorm(
73
+ num_groups=32,
74
+ dims=out_channels,
75
+ eps=1e-6,
76
+ affine=True,
77
+ pytorch_compatible=True,
78
+ )
79
+ self.conv2 = nn.Conv2d(
80
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
81
+ )
82
+ if self.in_channels != self.out_channels:
83
+ self.nin_shortcut = nn.Linear(in_channels, out_channels)
84
+
85
+ def __call__(self, x):
86
+ h = x
87
+ h = self.norm1(h)
88
+ h = nn.silu(h)
89
+ h = self.conv1(h)
90
+
91
+ h = self.norm2(h)
92
+ h = nn.silu(h)
93
+ h = self.conv2(h)
94
+
95
+ if self.in_channels != self.out_channels:
96
+ x = self.nin_shortcut(x)
97
+
98
+ return x + h
99
+
100
+
101
+ class Downsample(nn.Module):
102
+ def __init__(self, in_channels: int):
103
+ super().__init__()
104
+ self.conv = nn.Conv2d(
105
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
106
+ )
107
+
108
+ def __call__(self, x: mx.array):
109
+ x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
110
+ x = self.conv(x)
111
+ return x
112
+
113
+
114
+ class Upsample(nn.Module):
115
+ def __init__(self, in_channels: int):
116
+ super().__init__()
117
+ self.conv = nn.Conv2d(
118
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
119
+ )
120
+
121
+ def __call__(self, x: mx.array):
122
+ x = upsample_nearest(x, (2, 2))
123
+ x = self.conv(x)
124
+ return x
125
+
126
+
127
+ class Encoder(nn.Module):
128
+ def __init__(
129
+ self,
130
+ resolution: int,
131
+ in_channels: int,
132
+ ch: int,
133
+ ch_mult: list[int],
134
+ num_res_blocks: int,
135
+ z_channels: int,
136
+ ):
137
+ super().__init__()
138
+ self.ch = ch
139
+ self.num_resolutions = len(ch_mult)
140
+ self.num_res_blocks = num_res_blocks
141
+ self.resolution = resolution
142
+ self.in_channels = in_channels
143
+ # downsampling
144
+ self.conv_in = nn.Conv2d(
145
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
146
+ )
147
+
148
+ curr_res = resolution
149
+ in_ch_mult = (1,) + tuple(ch_mult)
150
+ self.in_ch_mult = in_ch_mult
151
+ self.down = []
152
+ block_in = self.ch
153
+ for i_level in range(self.num_resolutions):
154
+ block = []
155
+ attn = [] # TODO: Remove the attn, nobody appends anything to it
156
+ block_in = ch * in_ch_mult[i_level]
157
+ block_out = ch * ch_mult[i_level]
158
+ for _ in range(self.num_res_blocks):
159
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
160
+ block_in = block_out
161
+ down = {}
162
+ down["block"] = block
163
+ down["attn"] = attn
164
+ if i_level != self.num_resolutions - 1:
165
+ down["downsample"] = Downsample(block_in)
166
+ curr_res = curr_res // 2
167
+ self.down.append(down)
168
+
169
+ # middle
170
+ self.mid = {}
171
+ self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
172
+ self.mid["attn_1"] = AttnBlock(block_in)
173
+ self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
174
+
175
+ # end
176
+ self.norm_out = nn.GroupNorm(
177
+ num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
178
+ )
179
+ self.conv_out = nn.Conv2d(
180
+ block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
181
+ )
182
+
183
+ def __call__(self, x: mx.array):
184
+ hs = [self.conv_in(x)]
185
+ for i_level in range(self.num_resolutions):
186
+ for i_block in range(self.num_res_blocks):
187
+ h = self.down[i_level]["block"][i_block](hs[-1])
188
+
189
+ # TODO: Remove the attn
190
+ if len(self.down[i_level]["attn"]) > 0:
191
+ h = self.down[i_level]["attn"][i_block](h)
192
+
193
+ hs.append(h)
194
+
195
+ if i_level != self.num_resolutions - 1:
196
+ hs.append(self.down[i_level]["downsample"](hs[-1]))
197
+
198
+ # middle
199
+ h = hs[-1]
200
+ h = self.mid["block_1"](h)
201
+ h = self.mid["attn_1"](h)
202
+ h = self.mid["block_2"](h)
203
+
204
+ # end
205
+ h = self.norm_out(h)
206
+ h = nn.silu(h)
207
+ h = self.conv_out(h)
208
+
209
+ return h
210
+
211
+
212
+ class Decoder(nn.Module):
213
+ def __init__(
214
+ self,
215
+ ch: int,
216
+ out_ch: int,
217
+ ch_mult: list[int],
218
+ num_res_blocks: int,
219
+ in_channels: int,
220
+ resolution: int,
221
+ z_channels: int,
222
+ ):
223
+ super().__init__()
224
+ self.ch = ch
225
+ self.num_resolutions = len(ch_mult)
226
+ self.num_res_blocks = num_res_blocks
227
+ self.resolution = resolution
228
+ self.in_channels = in_channels
229
+ self.ffactor = 2 ** (self.num_resolutions - 1)
230
+
231
+ # compute in_ch_mult, block_in and curr_res at lowest res
232
+ block_in = ch * ch_mult[self.num_resolutions - 1]
233
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
234
+ self.z_shape = (1, z_channels, curr_res, curr_res)
235
+
236
+ # z to block_in
237
+ self.conv_in = nn.Conv2d(
238
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
239
+ )
240
+
241
+ # middle
242
+ self.mid = {}
243
+ self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
244
+ self.mid["attn_1"] = AttnBlock(block_in)
245
+ self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
246
+
247
+ # upsampling
248
+ self.up = []
249
+ for i_level in reversed(range(self.num_resolutions)):
250
+ block = []
251
+ attn = [] # TODO: Remove the attn, nobody appends anything to it
252
+
253
+ block_out = ch * ch_mult[i_level]
254
+ for _ in range(self.num_res_blocks + 1):
255
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
256
+ block_in = block_out
257
+ up = {}
258
+ up["block"] = block
259
+ up["attn"] = attn
260
+ if i_level != 0:
261
+ up["upsample"] = Upsample(block_in)
262
+ curr_res = curr_res * 2
263
+ self.up.insert(0, up) # prepend to get consistent order
264
+
265
+ # end
266
+ self.norm_out = nn.GroupNorm(
267
+ num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
268
+ )
269
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
270
+
271
+ def __call__(self, z: mx.array):
272
+ # z to block_in
273
+ h = self.conv_in(z)
274
+
275
+ # middle
276
+ h = self.mid["block_1"](h)
277
+ h = self.mid["attn_1"](h)
278
+ h = self.mid["block_2"](h)
279
+
280
+ # upsampling
281
+ for i_level in reversed(range(self.num_resolutions)):
282
+ for i_block in range(self.num_res_blocks + 1):
283
+ h = self.up[i_level]["block"][i_block](h)
284
+
285
+ # TODO: Remove the attn
286
+ if len(self.up[i_level]["attn"]) > 0:
287
+ h = self.up[i_level]["attn"][i_block](h)
288
+
289
+ if i_level != 0:
290
+ h = self.up[i_level]["upsample"](h)
291
+
292
+ # end
293
+ h = self.norm_out(h)
294
+ h = nn.silu(h)
295
+ h = self.conv_out(h)
296
+
297
+ return h
298
+
299
+
300
+ class DiagonalGaussian(nn.Module):
301
+ def __call__(self, z: mx.array):
302
+ mean, logvar = mx.split(z, 2, axis=-1)
303
+ if self.training:
304
+ std = mx.exp(0.5 * logvar)
305
+ eps = mx.random.normal(shape=z.shape, dtype=z.dtype)
306
+ return mean + std * eps
307
+ else:
308
+ return mean
309
+
310
+
311
+ class AutoEncoder(nn.Module):
312
+ def __init__(self, params: AutoEncoderParams):
313
+ super().__init__()
314
+ self.encoder = Encoder(
315
+ resolution=params.resolution,
316
+ in_channels=params.in_channels,
317
+ ch=params.ch,
318
+ ch_mult=params.ch_mult,
319
+ num_res_blocks=params.num_res_blocks,
320
+ z_channels=params.z_channels,
321
+ )
322
+ self.decoder = Decoder(
323
+ resolution=params.resolution,
324
+ in_channels=params.in_channels,
325
+ ch=params.ch,
326
+ out_ch=params.out_ch,
327
+ ch_mult=params.ch_mult,
328
+ num_res_blocks=params.num_res_blocks,
329
+ z_channels=params.z_channels,
330
+ )
331
+ self.reg = DiagonalGaussian()
332
+
333
+ self.scale_factor = params.scale_factor
334
+ self.shift_factor = params.shift_factor
335
+
336
+ def sanitize(self, weights):
337
+ new_weights = {}
338
+ for k, w in weights.items():
339
+ if w.ndim == 4:
340
+ w = w.transpose(0, 2, 3, 1)
341
+ w = w.reshape(-1).reshape(w.shape)
342
+ if w.shape[1:3] == (1, 1):
343
+ w = w.squeeze((1, 2))
344
+ new_weights[k] = w
345
+ return new_weights
346
+
347
+ def encode(self, x: mx.array):
348
+ z = self.reg(self.encoder(x))
349
+ z = self.scale_factor * (z - self.shift_factor)
350
+ return z
351
+
352
+ def decode(self, z: mx.array):
353
+ z = z / self.scale_factor + self.shift_factor
354
+ return self.decoder(z)
355
+
356
+ def __call__(self, x: mx.array):
357
+ return self.decode(self.encode(x))
@@ -0,0 +1,154 @@
1
+ # Copyright © 2024 Apple Inc.
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional
5
+
6
+ import mlx.core as mx
7
+ import mlx.nn as nn
8
+
9
+ _ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
10
+
11
+
12
+ @dataclass
13
+ class CLIPTextModelConfig:
14
+ num_layers: int = 23
15
+ model_dims: int = 1024
16
+ num_heads: int = 16
17
+ max_length: int = 77
18
+ vocab_size: int = 49408
19
+ hidden_act: str = "quick_gelu"
20
+
21
+ @classmethod
22
+ def from_dict(cls, config):
23
+ return cls(
24
+ num_layers=config["num_hidden_layers"],
25
+ model_dims=config["hidden_size"],
26
+ num_heads=config["num_attention_heads"],
27
+ max_length=config["max_position_embeddings"],
28
+ vocab_size=config["vocab_size"],
29
+ hidden_act=config["hidden_act"],
30
+ )
31
+
32
+
33
+ @dataclass
34
+ class CLIPOutput:
35
+ # The last_hidden_state indexed at the EOS token and possibly projected if
36
+ # the model has a projection layer
37
+ pooled_output: Optional[mx.array] = None
38
+
39
+ # The full sequence output of the transformer after the final layernorm
40
+ last_hidden_state: Optional[mx.array] = None
41
+
42
+ # A list of hidden states corresponding to the outputs of the transformer layers
43
+ hidden_states: Optional[List[mx.array]] = None
44
+
45
+
46
+ class CLIPEncoderLayer(nn.Module):
47
+ """The transformer encoder layer from CLIP."""
48
+
49
+ def __init__(self, model_dims: int, num_heads: int, activation: str):
50
+ super().__init__()
51
+
52
+ self.layer_norm1 = nn.LayerNorm(model_dims)
53
+ self.layer_norm2 = nn.LayerNorm(model_dims)
54
+
55
+ self.attention = nn.MultiHeadAttention(model_dims, num_heads, bias=True)
56
+
57
+ self.linear1 = nn.Linear(model_dims, 4 * model_dims)
58
+ self.linear2 = nn.Linear(4 * model_dims, model_dims)
59
+
60
+ self.act = _ACTIVATIONS[activation]
61
+
62
+ def __call__(self, x, attn_mask=None):
63
+ y = self.layer_norm1(x)
64
+ y = self.attention(y, y, y, attn_mask)
65
+ x = y + x
66
+
67
+ y = self.layer_norm2(x)
68
+ y = self.linear1(y)
69
+ y = self.act(y)
70
+ y = self.linear2(y)
71
+ x = y + x
72
+
73
+ return x
74
+
75
+
76
+ class CLIPTextModel(nn.Module):
77
+ """Implements the text encoder transformer from CLIP."""
78
+
79
+ def __init__(self, config: CLIPTextModelConfig):
80
+ super().__init__()
81
+
82
+ self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
83
+ self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
84
+ self.layers = [
85
+ CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act)
86
+ for i in range(config.num_layers)
87
+ ]
88
+ self.final_layer_norm = nn.LayerNorm(config.model_dims)
89
+
90
+ def _get_mask(self, N, dtype):
91
+ indices = mx.arange(N)
92
+ mask = indices[:, None] < indices[None]
93
+ mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
94
+ return mask
95
+
96
+ def sanitize(self, weights):
97
+ new_weights = {}
98
+ for key, w in weights.items():
99
+ # Remove prefixes
100
+ if key.startswith("text_model."):
101
+ key = key[11:]
102
+ if key.startswith("embeddings."):
103
+ key = key[11:]
104
+ if key.startswith("encoder."):
105
+ key = key[8:]
106
+
107
+ # Map attention layers
108
+ if "self_attn." in key:
109
+ key = key.replace("self_attn.", "attention.")
110
+ if "q_proj." in key:
111
+ key = key.replace("q_proj.", "query_proj.")
112
+ if "k_proj." in key:
113
+ key = key.replace("k_proj.", "key_proj.")
114
+ if "v_proj." in key:
115
+ key = key.replace("v_proj.", "value_proj.")
116
+
117
+ # Map ffn layers
118
+ if "mlp.fc1" in key:
119
+ key = key.replace("mlp.fc1", "linear1")
120
+ if "mlp.fc2" in key:
121
+ key = key.replace("mlp.fc2", "linear2")
122
+
123
+ new_weights[key] = w
124
+
125
+ return new_weights
126
+
127
+ def __call__(self, x):
128
+ # Extract some shapes
129
+ B, N = x.shape
130
+ eos_tokens = x.argmax(-1)
131
+
132
+ # Compute the embeddings
133
+ x = self.token_embedding(x)
134
+ x = x + self.position_embedding.weight[:N]
135
+
136
+ # Compute the features from the transformer
137
+ mask = self._get_mask(N, x.dtype)
138
+ hidden_states = []
139
+ for l in self.layers:
140
+ x = l(x, mask)
141
+ hidden_states.append(x)
142
+
143
+ # Apply the final layernorm and return
144
+ x = self.final_layer_norm(x)
145
+ last_hidden_state = x
146
+
147
+ # Select the EOS token
148
+ pooled_output = x[mx.arange(len(x)), eos_tokens]
149
+
150
+ return CLIPOutput(
151
+ pooled_output=pooled_output,
152
+ last_hidden_state=last_hidden_state,
153
+ hidden_states=hidden_states,
154
+ )
@@ -0,0 +1,75 @@
1
+ import json
2
+ from pathlib import Path
3
+
4
+ from PIL import Image
5
+
6
+
7
+ class Dataset:
8
+ def __getitem__(self, index: int):
9
+ raise NotImplementedError()
10
+
11
+ def __len__(self):
12
+ raise NotImplementedError()
13
+
14
+
15
+ class LocalDataset(Dataset):
16
+ prompt_key = "prompt"
17
+
18
+ def __init__(self, dataset: str, data_file):
19
+ self.dataset_base = Path(dataset)
20
+ with open(data_file, "r") as fid:
21
+ self._data = [json.loads(l) for l in fid]
22
+
23
+ def __len__(self):
24
+ return len(self._data)
25
+
26
+ def __getitem__(self, index: int):
27
+ item = self._data[index]
28
+ image = Image.open(self.dataset_base / item["image"])
29
+ return image, item[self.prompt_key]
30
+
31
+
32
+ class LegacyDataset(LocalDataset):
33
+ prompt_key = "text"
34
+
35
+ def __init__(self, dataset: str):
36
+ self.dataset_base = Path(dataset)
37
+ with open(self.dataset_base / "index.json") as f:
38
+ self._data = json.load(f)["data"]
39
+
40
+
41
+ class HuggingFaceDataset(Dataset):
42
+
43
+ def __init__(self, dataset: str):
44
+ from datasets import load_dataset as hf_load_dataset
45
+
46
+ self._df = hf_load_dataset(dataset)["train"]
47
+
48
+ def __len__(self):
49
+ return len(self._df)
50
+
51
+ def __getitem__(self, index: int):
52
+ item = self._df[index]
53
+ return item["image"], item["prompt"]
54
+
55
+
56
+ def load_dataset(dataset: str):
57
+ dataset_base = Path(dataset)
58
+ data_file = dataset_base / "train.jsonl"
59
+ legacy_file = dataset_base / "index.json"
60
+
61
+ if data_file.exists():
62
+ print(f"Load the local dataset {data_file} .", flush=True)
63
+ dataset = LocalDataset(dataset, data_file)
64
+ elif legacy_file.exists():
65
+ print(f"Load the local dataset {legacy_file} .")
66
+ print()
67
+ print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.")
68
+ print(" See the README for details.")
69
+ print(flush=True)
70
+ dataset = LegacyDataset(dataset)
71
+ else:
72
+ print(f"Load the Hugging Face dataset {dataset} .", flush=True)
73
+ dataset = HuggingFaceDataset(dataset)
74
+
75
+ return dataset