xinference 0.16.0__py3-none-any.whl → 0.16.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +48 -0
- xinference/client/restful/restful_client.py +19 -0
- xinference/constants.py +1 -0
- xinference/core/chat_interface.py +5 -1
- xinference/core/image_interface.py +5 -1
- xinference/core/model.py +106 -16
- xinference/core/scheduler.py +1 -1
- xinference/core/worker.py +3 -1
- xinference/deploy/supervisor.py +0 -4
- xinference/model/audio/chattts.py +25 -14
- xinference/model/audio/core.py +6 -2
- xinference/model/audio/model_spec.json +1 -1
- xinference/model/audio/model_spec_modelscope.json +1 -1
- xinference/model/core.py +3 -1
- xinference/model/embedding/core.py +6 -2
- xinference/model/embedding/model_spec.json +1 -1
- xinference/model/image/core.py +65 -6
- xinference/model/image/model_spec.json +24 -3
- xinference/model/image/model_spec_modelscope.json +25 -3
- xinference/model/image/ocr/__init__.py +13 -0
- xinference/model/image/ocr/got_ocr2.py +79 -0
- xinference/model/image/scheduler/flux.py +1 -1
- xinference/model/image/stable_diffusion/core.py +2 -3
- xinference/model/image/stable_diffusion/mlx.py +221 -0
- xinference/model/llm/__init__.py +33 -0
- xinference/model/llm/core.py +3 -1
- xinference/model/llm/llm_family.json +9 -0
- xinference/model/llm/llm_family.py +68 -2
- xinference/model/llm/llm_family_modelscope.json +11 -0
- xinference/model/llm/llm_family_openmind_hub.json +1359 -0
- xinference/model/rerank/core.py +9 -1
- xinference/model/utils.py +7 -0
- xinference/model/video/core.py +6 -2
- xinference/thirdparty/mlx/__init__.py +13 -0
- xinference/thirdparty/mlx/flux/__init__.py +15 -0
- xinference/thirdparty/mlx/flux/autoencoder.py +357 -0
- xinference/thirdparty/mlx/flux/clip.py +154 -0
- xinference/thirdparty/mlx/flux/datasets.py +75 -0
- xinference/thirdparty/mlx/flux/flux.py +247 -0
- xinference/thirdparty/mlx/flux/layers.py +302 -0
- xinference/thirdparty/mlx/flux/lora.py +76 -0
- xinference/thirdparty/mlx/flux/model.py +134 -0
- xinference/thirdparty/mlx/flux/sampler.py +56 -0
- xinference/thirdparty/mlx/flux/t5.py +244 -0
- xinference/thirdparty/mlx/flux/tokenizers.py +185 -0
- xinference/thirdparty/mlx/flux/trainer.py +98 -0
- xinference/thirdparty/mlx/flux/utils.py +179 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.f7da0140.js → main.2f269bb3.js} +3 -3
- xinference/web/ui/build/static/js/main.2f269bb3.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +1 -0
- {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/METADATA +16 -9
- {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/RECORD +60 -42
- xinference/web/ui/build/static/js/main.f7da0140.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +0 -1
- /xinference/web/ui/build/static/js/{main.f7da0140.js.LICENSE.txt → main.2f269bb3.js.LICENSE.txt} +0 -0
- {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/LICENSE +0 -0
- {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/WHEEL +0 -0
- {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/entry_points.txt +0 -0
- {xinference-0.16.0.dist-info → xinference-0.16.2.dist-info}/top_level.txt +0 -0
xinference/model/rerank/core.py
CHANGED
|
@@ -268,6 +268,12 @@ class RerankModel:
|
|
|
268
268
|
similarity_scores = self._model.compute_score(sentence_combinations)
|
|
269
269
|
if not isinstance(similarity_scores, Sequence):
|
|
270
270
|
similarity_scores = [similarity_scores]
|
|
271
|
+
elif (
|
|
272
|
+
isinstance(similarity_scores, list)
|
|
273
|
+
and len(similarity_scores) > 0
|
|
274
|
+
and isinstance(similarity_scores[0], Sequence)
|
|
275
|
+
):
|
|
276
|
+
similarity_scores = similarity_scores[0]
|
|
271
277
|
|
|
272
278
|
sim_scores_argsort = list(reversed(np.argsort(similarity_scores)))
|
|
273
279
|
if top_n is not None:
|
|
@@ -341,7 +347,9 @@ def create_rerank_model_instance(
|
|
|
341
347
|
devices: List[str],
|
|
342
348
|
model_uid: str,
|
|
343
349
|
model_name: str,
|
|
344
|
-
download_hub: Optional[
|
|
350
|
+
download_hub: Optional[
|
|
351
|
+
Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
|
|
352
|
+
] = None,
|
|
345
353
|
model_path: Optional[str] = None,
|
|
346
354
|
**kwargs,
|
|
347
355
|
) -> Tuple[RerankModel, RerankModelDescription]:
|
xinference/model/utils.py
CHANGED
|
@@ -54,6 +54,13 @@ def download_from_modelscope() -> bool:
|
|
|
54
54
|
return False
|
|
55
55
|
|
|
56
56
|
|
|
57
|
+
def download_from_openmind_hub() -> bool:
|
|
58
|
+
if os.environ.get(XINFERENCE_ENV_MODEL_SRC):
|
|
59
|
+
return os.environ.get(XINFERENCE_ENV_MODEL_SRC) == "openmind_hub"
|
|
60
|
+
else:
|
|
61
|
+
return False
|
|
62
|
+
|
|
63
|
+
|
|
57
64
|
def download_from_csghub() -> bool:
|
|
58
65
|
if os.environ.get(XINFERENCE_ENV_MODEL_SRC) == "csghub":
|
|
59
66
|
return True
|
xinference/model/video/core.py
CHANGED
|
@@ -97,7 +97,9 @@ def generate_video_description(
|
|
|
97
97
|
|
|
98
98
|
def match_diffusion(
|
|
99
99
|
model_name: str,
|
|
100
|
-
download_hub: Optional[
|
|
100
|
+
download_hub: Optional[
|
|
101
|
+
Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
|
|
102
|
+
] = None,
|
|
101
103
|
) -> VideoModelFamilyV1:
|
|
102
104
|
from ..utils import download_from_modelscope
|
|
103
105
|
from . import BUILTIN_VIDEO_MODELS, MODELSCOPE_VIDEO_MODELS
|
|
@@ -157,7 +159,9 @@ def create_video_model_instance(
|
|
|
157
159
|
devices: List[str],
|
|
158
160
|
model_uid: str,
|
|
159
161
|
model_name: str,
|
|
160
|
-
download_hub: Optional[
|
|
162
|
+
download_hub: Optional[
|
|
163
|
+
Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
|
|
164
|
+
] = None,
|
|
161
165
|
model_path: Optional[str] = None,
|
|
162
166
|
**kwargs,
|
|
163
167
|
) -> Tuple[DiffUsersVideoModel, VideoModelDescription]:
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
from .datasets import Dataset, load_dataset
|
|
4
|
+
from .flux import FluxPipeline
|
|
5
|
+
from .lora import LoRALinear
|
|
6
|
+
from .sampler import FluxSampler
|
|
7
|
+
from .trainer import Trainer
|
|
8
|
+
from .utils import (
|
|
9
|
+
load_ae,
|
|
10
|
+
load_clip,
|
|
11
|
+
load_clip_tokenizer,
|
|
12
|
+
load_flow_model,
|
|
13
|
+
load_t5,
|
|
14
|
+
load_t5_tokenizer,
|
|
15
|
+
)
|
|
@@ -0,0 +1,357 @@
|
|
|
1
|
+
# Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
import mlx.core as mx
|
|
7
|
+
import mlx.nn as nn
|
|
8
|
+
from mlx.nn.layers.upsample import upsample_nearest
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class AutoEncoderParams:
|
|
13
|
+
resolution: int
|
|
14
|
+
in_channels: int
|
|
15
|
+
ch: int
|
|
16
|
+
out_ch: int
|
|
17
|
+
ch_mult: List[int]
|
|
18
|
+
num_res_blocks: int
|
|
19
|
+
z_channels: int
|
|
20
|
+
scale_factor: float
|
|
21
|
+
shift_factor: float
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class AttnBlock(nn.Module):
|
|
25
|
+
def __init__(self, in_channels: int):
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.in_channels = in_channels
|
|
28
|
+
|
|
29
|
+
self.norm = nn.GroupNorm(
|
|
30
|
+
num_groups=32,
|
|
31
|
+
dims=in_channels,
|
|
32
|
+
eps=1e-6,
|
|
33
|
+
affine=True,
|
|
34
|
+
pytorch_compatible=True,
|
|
35
|
+
)
|
|
36
|
+
self.q = nn.Linear(in_channels, in_channels)
|
|
37
|
+
self.k = nn.Linear(in_channels, in_channels)
|
|
38
|
+
self.v = nn.Linear(in_channels, in_channels)
|
|
39
|
+
self.proj_out = nn.Linear(in_channels, in_channels)
|
|
40
|
+
|
|
41
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
42
|
+
B, H, W, C = x.shape
|
|
43
|
+
|
|
44
|
+
y = x.reshape(B, 1, -1, C)
|
|
45
|
+
y = self.norm(y)
|
|
46
|
+
q = self.q(y)
|
|
47
|
+
k = self.k(y)
|
|
48
|
+
v = self.v(y)
|
|
49
|
+
y = mx.fast.scaled_dot_product_attention(q, k, v, scale=C ** (-0.5))
|
|
50
|
+
y = self.proj_out(y)
|
|
51
|
+
|
|
52
|
+
return x + y.reshape(B, H, W, C)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class ResnetBlock(nn.Module):
|
|
56
|
+
def __init__(self, in_channels: int, out_channels: int):
|
|
57
|
+
super().__init__()
|
|
58
|
+
self.in_channels = in_channels
|
|
59
|
+
out_channels = in_channels if out_channels is None else out_channels
|
|
60
|
+
self.out_channels = out_channels
|
|
61
|
+
|
|
62
|
+
self.norm1 = nn.GroupNorm(
|
|
63
|
+
num_groups=32,
|
|
64
|
+
dims=in_channels,
|
|
65
|
+
eps=1e-6,
|
|
66
|
+
affine=True,
|
|
67
|
+
pytorch_compatible=True,
|
|
68
|
+
)
|
|
69
|
+
self.conv1 = nn.Conv2d(
|
|
70
|
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
71
|
+
)
|
|
72
|
+
self.norm2 = nn.GroupNorm(
|
|
73
|
+
num_groups=32,
|
|
74
|
+
dims=out_channels,
|
|
75
|
+
eps=1e-6,
|
|
76
|
+
affine=True,
|
|
77
|
+
pytorch_compatible=True,
|
|
78
|
+
)
|
|
79
|
+
self.conv2 = nn.Conv2d(
|
|
80
|
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
81
|
+
)
|
|
82
|
+
if self.in_channels != self.out_channels:
|
|
83
|
+
self.nin_shortcut = nn.Linear(in_channels, out_channels)
|
|
84
|
+
|
|
85
|
+
def __call__(self, x):
|
|
86
|
+
h = x
|
|
87
|
+
h = self.norm1(h)
|
|
88
|
+
h = nn.silu(h)
|
|
89
|
+
h = self.conv1(h)
|
|
90
|
+
|
|
91
|
+
h = self.norm2(h)
|
|
92
|
+
h = nn.silu(h)
|
|
93
|
+
h = self.conv2(h)
|
|
94
|
+
|
|
95
|
+
if self.in_channels != self.out_channels:
|
|
96
|
+
x = self.nin_shortcut(x)
|
|
97
|
+
|
|
98
|
+
return x + h
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class Downsample(nn.Module):
|
|
102
|
+
def __init__(self, in_channels: int):
|
|
103
|
+
super().__init__()
|
|
104
|
+
self.conv = nn.Conv2d(
|
|
105
|
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def __call__(self, x: mx.array):
|
|
109
|
+
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
|
|
110
|
+
x = self.conv(x)
|
|
111
|
+
return x
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class Upsample(nn.Module):
|
|
115
|
+
def __init__(self, in_channels: int):
|
|
116
|
+
super().__init__()
|
|
117
|
+
self.conv = nn.Conv2d(
|
|
118
|
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
def __call__(self, x: mx.array):
|
|
122
|
+
x = upsample_nearest(x, (2, 2))
|
|
123
|
+
x = self.conv(x)
|
|
124
|
+
return x
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class Encoder(nn.Module):
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
resolution: int,
|
|
131
|
+
in_channels: int,
|
|
132
|
+
ch: int,
|
|
133
|
+
ch_mult: list[int],
|
|
134
|
+
num_res_blocks: int,
|
|
135
|
+
z_channels: int,
|
|
136
|
+
):
|
|
137
|
+
super().__init__()
|
|
138
|
+
self.ch = ch
|
|
139
|
+
self.num_resolutions = len(ch_mult)
|
|
140
|
+
self.num_res_blocks = num_res_blocks
|
|
141
|
+
self.resolution = resolution
|
|
142
|
+
self.in_channels = in_channels
|
|
143
|
+
# downsampling
|
|
144
|
+
self.conv_in = nn.Conv2d(
|
|
145
|
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
curr_res = resolution
|
|
149
|
+
in_ch_mult = (1,) + tuple(ch_mult)
|
|
150
|
+
self.in_ch_mult = in_ch_mult
|
|
151
|
+
self.down = []
|
|
152
|
+
block_in = self.ch
|
|
153
|
+
for i_level in range(self.num_resolutions):
|
|
154
|
+
block = []
|
|
155
|
+
attn = [] # TODO: Remove the attn, nobody appends anything to it
|
|
156
|
+
block_in = ch * in_ch_mult[i_level]
|
|
157
|
+
block_out = ch * ch_mult[i_level]
|
|
158
|
+
for _ in range(self.num_res_blocks):
|
|
159
|
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
|
160
|
+
block_in = block_out
|
|
161
|
+
down = {}
|
|
162
|
+
down["block"] = block
|
|
163
|
+
down["attn"] = attn
|
|
164
|
+
if i_level != self.num_resolutions - 1:
|
|
165
|
+
down["downsample"] = Downsample(block_in)
|
|
166
|
+
curr_res = curr_res // 2
|
|
167
|
+
self.down.append(down)
|
|
168
|
+
|
|
169
|
+
# middle
|
|
170
|
+
self.mid = {}
|
|
171
|
+
self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
|
172
|
+
self.mid["attn_1"] = AttnBlock(block_in)
|
|
173
|
+
self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
|
174
|
+
|
|
175
|
+
# end
|
|
176
|
+
self.norm_out = nn.GroupNorm(
|
|
177
|
+
num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
|
|
178
|
+
)
|
|
179
|
+
self.conv_out = nn.Conv2d(
|
|
180
|
+
block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
def __call__(self, x: mx.array):
|
|
184
|
+
hs = [self.conv_in(x)]
|
|
185
|
+
for i_level in range(self.num_resolutions):
|
|
186
|
+
for i_block in range(self.num_res_blocks):
|
|
187
|
+
h = self.down[i_level]["block"][i_block](hs[-1])
|
|
188
|
+
|
|
189
|
+
# TODO: Remove the attn
|
|
190
|
+
if len(self.down[i_level]["attn"]) > 0:
|
|
191
|
+
h = self.down[i_level]["attn"][i_block](h)
|
|
192
|
+
|
|
193
|
+
hs.append(h)
|
|
194
|
+
|
|
195
|
+
if i_level != self.num_resolutions - 1:
|
|
196
|
+
hs.append(self.down[i_level]["downsample"](hs[-1]))
|
|
197
|
+
|
|
198
|
+
# middle
|
|
199
|
+
h = hs[-1]
|
|
200
|
+
h = self.mid["block_1"](h)
|
|
201
|
+
h = self.mid["attn_1"](h)
|
|
202
|
+
h = self.mid["block_2"](h)
|
|
203
|
+
|
|
204
|
+
# end
|
|
205
|
+
h = self.norm_out(h)
|
|
206
|
+
h = nn.silu(h)
|
|
207
|
+
h = self.conv_out(h)
|
|
208
|
+
|
|
209
|
+
return h
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class Decoder(nn.Module):
|
|
213
|
+
def __init__(
|
|
214
|
+
self,
|
|
215
|
+
ch: int,
|
|
216
|
+
out_ch: int,
|
|
217
|
+
ch_mult: list[int],
|
|
218
|
+
num_res_blocks: int,
|
|
219
|
+
in_channels: int,
|
|
220
|
+
resolution: int,
|
|
221
|
+
z_channels: int,
|
|
222
|
+
):
|
|
223
|
+
super().__init__()
|
|
224
|
+
self.ch = ch
|
|
225
|
+
self.num_resolutions = len(ch_mult)
|
|
226
|
+
self.num_res_blocks = num_res_blocks
|
|
227
|
+
self.resolution = resolution
|
|
228
|
+
self.in_channels = in_channels
|
|
229
|
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
|
230
|
+
|
|
231
|
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
|
232
|
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
|
233
|
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
|
234
|
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
|
235
|
+
|
|
236
|
+
# z to block_in
|
|
237
|
+
self.conv_in = nn.Conv2d(
|
|
238
|
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# middle
|
|
242
|
+
self.mid = {}
|
|
243
|
+
self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
|
244
|
+
self.mid["attn_1"] = AttnBlock(block_in)
|
|
245
|
+
self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
|
246
|
+
|
|
247
|
+
# upsampling
|
|
248
|
+
self.up = []
|
|
249
|
+
for i_level in reversed(range(self.num_resolutions)):
|
|
250
|
+
block = []
|
|
251
|
+
attn = [] # TODO: Remove the attn, nobody appends anything to it
|
|
252
|
+
|
|
253
|
+
block_out = ch * ch_mult[i_level]
|
|
254
|
+
for _ in range(self.num_res_blocks + 1):
|
|
255
|
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
|
256
|
+
block_in = block_out
|
|
257
|
+
up = {}
|
|
258
|
+
up["block"] = block
|
|
259
|
+
up["attn"] = attn
|
|
260
|
+
if i_level != 0:
|
|
261
|
+
up["upsample"] = Upsample(block_in)
|
|
262
|
+
curr_res = curr_res * 2
|
|
263
|
+
self.up.insert(0, up) # prepend to get consistent order
|
|
264
|
+
|
|
265
|
+
# end
|
|
266
|
+
self.norm_out = nn.GroupNorm(
|
|
267
|
+
num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
|
|
268
|
+
)
|
|
269
|
+
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
|
270
|
+
|
|
271
|
+
def __call__(self, z: mx.array):
|
|
272
|
+
# z to block_in
|
|
273
|
+
h = self.conv_in(z)
|
|
274
|
+
|
|
275
|
+
# middle
|
|
276
|
+
h = self.mid["block_1"](h)
|
|
277
|
+
h = self.mid["attn_1"](h)
|
|
278
|
+
h = self.mid["block_2"](h)
|
|
279
|
+
|
|
280
|
+
# upsampling
|
|
281
|
+
for i_level in reversed(range(self.num_resolutions)):
|
|
282
|
+
for i_block in range(self.num_res_blocks + 1):
|
|
283
|
+
h = self.up[i_level]["block"][i_block](h)
|
|
284
|
+
|
|
285
|
+
# TODO: Remove the attn
|
|
286
|
+
if len(self.up[i_level]["attn"]) > 0:
|
|
287
|
+
h = self.up[i_level]["attn"][i_block](h)
|
|
288
|
+
|
|
289
|
+
if i_level != 0:
|
|
290
|
+
h = self.up[i_level]["upsample"](h)
|
|
291
|
+
|
|
292
|
+
# end
|
|
293
|
+
h = self.norm_out(h)
|
|
294
|
+
h = nn.silu(h)
|
|
295
|
+
h = self.conv_out(h)
|
|
296
|
+
|
|
297
|
+
return h
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
class DiagonalGaussian(nn.Module):
|
|
301
|
+
def __call__(self, z: mx.array):
|
|
302
|
+
mean, logvar = mx.split(z, 2, axis=-1)
|
|
303
|
+
if self.training:
|
|
304
|
+
std = mx.exp(0.5 * logvar)
|
|
305
|
+
eps = mx.random.normal(shape=z.shape, dtype=z.dtype)
|
|
306
|
+
return mean + std * eps
|
|
307
|
+
else:
|
|
308
|
+
return mean
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
class AutoEncoder(nn.Module):
|
|
312
|
+
def __init__(self, params: AutoEncoderParams):
|
|
313
|
+
super().__init__()
|
|
314
|
+
self.encoder = Encoder(
|
|
315
|
+
resolution=params.resolution,
|
|
316
|
+
in_channels=params.in_channels,
|
|
317
|
+
ch=params.ch,
|
|
318
|
+
ch_mult=params.ch_mult,
|
|
319
|
+
num_res_blocks=params.num_res_blocks,
|
|
320
|
+
z_channels=params.z_channels,
|
|
321
|
+
)
|
|
322
|
+
self.decoder = Decoder(
|
|
323
|
+
resolution=params.resolution,
|
|
324
|
+
in_channels=params.in_channels,
|
|
325
|
+
ch=params.ch,
|
|
326
|
+
out_ch=params.out_ch,
|
|
327
|
+
ch_mult=params.ch_mult,
|
|
328
|
+
num_res_blocks=params.num_res_blocks,
|
|
329
|
+
z_channels=params.z_channels,
|
|
330
|
+
)
|
|
331
|
+
self.reg = DiagonalGaussian()
|
|
332
|
+
|
|
333
|
+
self.scale_factor = params.scale_factor
|
|
334
|
+
self.shift_factor = params.shift_factor
|
|
335
|
+
|
|
336
|
+
def sanitize(self, weights):
|
|
337
|
+
new_weights = {}
|
|
338
|
+
for k, w in weights.items():
|
|
339
|
+
if w.ndim == 4:
|
|
340
|
+
w = w.transpose(0, 2, 3, 1)
|
|
341
|
+
w = w.reshape(-1).reshape(w.shape)
|
|
342
|
+
if w.shape[1:3] == (1, 1):
|
|
343
|
+
w = w.squeeze((1, 2))
|
|
344
|
+
new_weights[k] = w
|
|
345
|
+
return new_weights
|
|
346
|
+
|
|
347
|
+
def encode(self, x: mx.array):
|
|
348
|
+
z = self.reg(self.encoder(x))
|
|
349
|
+
z = self.scale_factor * (z - self.shift_factor)
|
|
350
|
+
return z
|
|
351
|
+
|
|
352
|
+
def decode(self, z: mx.array):
|
|
353
|
+
z = z / self.scale_factor + self.shift_factor
|
|
354
|
+
return self.decoder(z)
|
|
355
|
+
|
|
356
|
+
def __call__(self, x: mx.array):
|
|
357
|
+
return self.decode(self.encode(x))
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
# Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
|
|
6
|
+
import mlx.core as mx
|
|
7
|
+
import mlx.nn as nn
|
|
8
|
+
|
|
9
|
+
_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class CLIPTextModelConfig:
|
|
14
|
+
num_layers: int = 23
|
|
15
|
+
model_dims: int = 1024
|
|
16
|
+
num_heads: int = 16
|
|
17
|
+
max_length: int = 77
|
|
18
|
+
vocab_size: int = 49408
|
|
19
|
+
hidden_act: str = "quick_gelu"
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
def from_dict(cls, config):
|
|
23
|
+
return cls(
|
|
24
|
+
num_layers=config["num_hidden_layers"],
|
|
25
|
+
model_dims=config["hidden_size"],
|
|
26
|
+
num_heads=config["num_attention_heads"],
|
|
27
|
+
max_length=config["max_position_embeddings"],
|
|
28
|
+
vocab_size=config["vocab_size"],
|
|
29
|
+
hidden_act=config["hidden_act"],
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class CLIPOutput:
|
|
35
|
+
# The last_hidden_state indexed at the EOS token and possibly projected if
|
|
36
|
+
# the model has a projection layer
|
|
37
|
+
pooled_output: Optional[mx.array] = None
|
|
38
|
+
|
|
39
|
+
# The full sequence output of the transformer after the final layernorm
|
|
40
|
+
last_hidden_state: Optional[mx.array] = None
|
|
41
|
+
|
|
42
|
+
# A list of hidden states corresponding to the outputs of the transformer layers
|
|
43
|
+
hidden_states: Optional[List[mx.array]] = None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class CLIPEncoderLayer(nn.Module):
|
|
47
|
+
"""The transformer encoder layer from CLIP."""
|
|
48
|
+
|
|
49
|
+
def __init__(self, model_dims: int, num_heads: int, activation: str):
|
|
50
|
+
super().__init__()
|
|
51
|
+
|
|
52
|
+
self.layer_norm1 = nn.LayerNorm(model_dims)
|
|
53
|
+
self.layer_norm2 = nn.LayerNorm(model_dims)
|
|
54
|
+
|
|
55
|
+
self.attention = nn.MultiHeadAttention(model_dims, num_heads, bias=True)
|
|
56
|
+
|
|
57
|
+
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
|
|
58
|
+
self.linear2 = nn.Linear(4 * model_dims, model_dims)
|
|
59
|
+
|
|
60
|
+
self.act = _ACTIVATIONS[activation]
|
|
61
|
+
|
|
62
|
+
def __call__(self, x, attn_mask=None):
|
|
63
|
+
y = self.layer_norm1(x)
|
|
64
|
+
y = self.attention(y, y, y, attn_mask)
|
|
65
|
+
x = y + x
|
|
66
|
+
|
|
67
|
+
y = self.layer_norm2(x)
|
|
68
|
+
y = self.linear1(y)
|
|
69
|
+
y = self.act(y)
|
|
70
|
+
y = self.linear2(y)
|
|
71
|
+
x = y + x
|
|
72
|
+
|
|
73
|
+
return x
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class CLIPTextModel(nn.Module):
|
|
77
|
+
"""Implements the text encoder transformer from CLIP."""
|
|
78
|
+
|
|
79
|
+
def __init__(self, config: CLIPTextModelConfig):
|
|
80
|
+
super().__init__()
|
|
81
|
+
|
|
82
|
+
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
|
|
83
|
+
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
|
|
84
|
+
self.layers = [
|
|
85
|
+
CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act)
|
|
86
|
+
for i in range(config.num_layers)
|
|
87
|
+
]
|
|
88
|
+
self.final_layer_norm = nn.LayerNorm(config.model_dims)
|
|
89
|
+
|
|
90
|
+
def _get_mask(self, N, dtype):
|
|
91
|
+
indices = mx.arange(N)
|
|
92
|
+
mask = indices[:, None] < indices[None]
|
|
93
|
+
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
|
|
94
|
+
return mask
|
|
95
|
+
|
|
96
|
+
def sanitize(self, weights):
|
|
97
|
+
new_weights = {}
|
|
98
|
+
for key, w in weights.items():
|
|
99
|
+
# Remove prefixes
|
|
100
|
+
if key.startswith("text_model."):
|
|
101
|
+
key = key[11:]
|
|
102
|
+
if key.startswith("embeddings."):
|
|
103
|
+
key = key[11:]
|
|
104
|
+
if key.startswith("encoder."):
|
|
105
|
+
key = key[8:]
|
|
106
|
+
|
|
107
|
+
# Map attention layers
|
|
108
|
+
if "self_attn." in key:
|
|
109
|
+
key = key.replace("self_attn.", "attention.")
|
|
110
|
+
if "q_proj." in key:
|
|
111
|
+
key = key.replace("q_proj.", "query_proj.")
|
|
112
|
+
if "k_proj." in key:
|
|
113
|
+
key = key.replace("k_proj.", "key_proj.")
|
|
114
|
+
if "v_proj." in key:
|
|
115
|
+
key = key.replace("v_proj.", "value_proj.")
|
|
116
|
+
|
|
117
|
+
# Map ffn layers
|
|
118
|
+
if "mlp.fc1" in key:
|
|
119
|
+
key = key.replace("mlp.fc1", "linear1")
|
|
120
|
+
if "mlp.fc2" in key:
|
|
121
|
+
key = key.replace("mlp.fc2", "linear2")
|
|
122
|
+
|
|
123
|
+
new_weights[key] = w
|
|
124
|
+
|
|
125
|
+
return new_weights
|
|
126
|
+
|
|
127
|
+
def __call__(self, x):
|
|
128
|
+
# Extract some shapes
|
|
129
|
+
B, N = x.shape
|
|
130
|
+
eos_tokens = x.argmax(-1)
|
|
131
|
+
|
|
132
|
+
# Compute the embeddings
|
|
133
|
+
x = self.token_embedding(x)
|
|
134
|
+
x = x + self.position_embedding.weight[:N]
|
|
135
|
+
|
|
136
|
+
# Compute the features from the transformer
|
|
137
|
+
mask = self._get_mask(N, x.dtype)
|
|
138
|
+
hidden_states = []
|
|
139
|
+
for l in self.layers:
|
|
140
|
+
x = l(x, mask)
|
|
141
|
+
hidden_states.append(x)
|
|
142
|
+
|
|
143
|
+
# Apply the final layernorm and return
|
|
144
|
+
x = self.final_layer_norm(x)
|
|
145
|
+
last_hidden_state = x
|
|
146
|
+
|
|
147
|
+
# Select the EOS token
|
|
148
|
+
pooled_output = x[mx.arange(len(x)), eos_tokens]
|
|
149
|
+
|
|
150
|
+
return CLIPOutput(
|
|
151
|
+
pooled_output=pooled_output,
|
|
152
|
+
last_hidden_state=last_hidden_state,
|
|
153
|
+
hidden_states=hidden_states,
|
|
154
|
+
)
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from PIL import Image
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Dataset:
|
|
8
|
+
def __getitem__(self, index: int):
|
|
9
|
+
raise NotImplementedError()
|
|
10
|
+
|
|
11
|
+
def __len__(self):
|
|
12
|
+
raise NotImplementedError()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LocalDataset(Dataset):
|
|
16
|
+
prompt_key = "prompt"
|
|
17
|
+
|
|
18
|
+
def __init__(self, dataset: str, data_file):
|
|
19
|
+
self.dataset_base = Path(dataset)
|
|
20
|
+
with open(data_file, "r") as fid:
|
|
21
|
+
self._data = [json.loads(l) for l in fid]
|
|
22
|
+
|
|
23
|
+
def __len__(self):
|
|
24
|
+
return len(self._data)
|
|
25
|
+
|
|
26
|
+
def __getitem__(self, index: int):
|
|
27
|
+
item = self._data[index]
|
|
28
|
+
image = Image.open(self.dataset_base / item["image"])
|
|
29
|
+
return image, item[self.prompt_key]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class LegacyDataset(LocalDataset):
|
|
33
|
+
prompt_key = "text"
|
|
34
|
+
|
|
35
|
+
def __init__(self, dataset: str):
|
|
36
|
+
self.dataset_base = Path(dataset)
|
|
37
|
+
with open(self.dataset_base / "index.json") as f:
|
|
38
|
+
self._data = json.load(f)["data"]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class HuggingFaceDataset(Dataset):
|
|
42
|
+
|
|
43
|
+
def __init__(self, dataset: str):
|
|
44
|
+
from datasets import load_dataset as hf_load_dataset
|
|
45
|
+
|
|
46
|
+
self._df = hf_load_dataset(dataset)["train"]
|
|
47
|
+
|
|
48
|
+
def __len__(self):
|
|
49
|
+
return len(self._df)
|
|
50
|
+
|
|
51
|
+
def __getitem__(self, index: int):
|
|
52
|
+
item = self._df[index]
|
|
53
|
+
return item["image"], item["prompt"]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def load_dataset(dataset: str):
|
|
57
|
+
dataset_base = Path(dataset)
|
|
58
|
+
data_file = dataset_base / "train.jsonl"
|
|
59
|
+
legacy_file = dataset_base / "index.json"
|
|
60
|
+
|
|
61
|
+
if data_file.exists():
|
|
62
|
+
print(f"Load the local dataset {data_file} .", flush=True)
|
|
63
|
+
dataset = LocalDataset(dataset, data_file)
|
|
64
|
+
elif legacy_file.exists():
|
|
65
|
+
print(f"Load the local dataset {legacy_file} .")
|
|
66
|
+
print()
|
|
67
|
+
print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.")
|
|
68
|
+
print(" See the README for details.")
|
|
69
|
+
print(flush=True)
|
|
70
|
+
dataset = LegacyDataset(dataset)
|
|
71
|
+
else:
|
|
72
|
+
print(f"Load the Hugging Face dataset {dataset} .", flush=True)
|
|
73
|
+
dataset = HuggingFaceDataset(dataset)
|
|
74
|
+
|
|
75
|
+
return dataset
|