xinference 0.15.4__py3-none-any.whl → 0.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/__init__.py +0 -4
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +48 -0
- xinference/client/restful/restful_client.py +19 -0
- xinference/constants.py +4 -4
- xinference/core/chat_interface.py +5 -1
- xinference/core/image_interface.py +5 -1
- xinference/core/model.py +195 -34
- xinference/core/scheduler.py +10 -7
- xinference/core/utils.py +9 -0
- xinference/model/__init__.py +4 -0
- xinference/model/audio/chattts.py +25 -14
- xinference/model/audio/model_spec.json +1 -1
- xinference/model/audio/model_spec_modelscope.json +1 -1
- xinference/model/embedding/model_spec.json +1 -1
- xinference/model/image/core.py +59 -4
- xinference/model/image/model_spec.json +24 -3
- xinference/model/image/model_spec_modelscope.json +25 -3
- xinference/model/image/ocr/__init__.py +13 -0
- xinference/model/image/ocr/got_ocr2.py +76 -0
- xinference/model/image/scheduler/__init__.py +13 -0
- xinference/model/image/scheduler/flux.py +533 -0
- xinference/model/image/stable_diffusion/core.py +8 -34
- xinference/model/image/stable_diffusion/mlx.py +221 -0
- xinference/model/image/utils.py +39 -3
- xinference/model/llm/__init__.py +2 -0
- xinference/model/llm/llm_family.json +178 -1
- xinference/model/llm/llm_family_modelscope.json +119 -0
- xinference/model/llm/transformers/chatglm.py +104 -0
- xinference/model/llm/transformers/core.py +37 -111
- xinference/model/llm/transformers/deepseek_v2.py +0 -226
- xinference/model/llm/transformers/internlm2.py +3 -95
- xinference/model/llm/transformers/opt.py +68 -0
- xinference/model/llm/transformers/utils.py +4 -284
- xinference/model/llm/utils.py +2 -2
- xinference/model/llm/vllm/core.py +16 -1
- xinference/thirdparty/mlx/__init__.py +13 -0
- xinference/thirdparty/mlx/flux/__init__.py +15 -0
- xinference/thirdparty/mlx/flux/autoencoder.py +357 -0
- xinference/thirdparty/mlx/flux/clip.py +154 -0
- xinference/thirdparty/mlx/flux/datasets.py +75 -0
- xinference/thirdparty/mlx/flux/flux.py +247 -0
- xinference/thirdparty/mlx/flux/layers.py +302 -0
- xinference/thirdparty/mlx/flux/lora.py +76 -0
- xinference/thirdparty/mlx/flux/model.py +134 -0
- xinference/thirdparty/mlx/flux/sampler.py +56 -0
- xinference/thirdparty/mlx/flux/t5.py +244 -0
- xinference/thirdparty/mlx/flux/tokenizers.py +185 -0
- xinference/thirdparty/mlx/flux/trainer.py +98 -0
- xinference/thirdparty/mlx/flux/utils.py +179 -0
- xinference/utils.py +2 -3
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.e51a356d.js → main.b76aeeb7.js} +3 -3
- xinference/web/ui/build/static/js/main.b76aeeb7.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/32ea2c04cf0bba2761b4883d2c40cc259952c94d2d6bb774e510963ca37aac0a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +1 -0
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/METADATA +49 -10
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/RECORD +64 -44
- xinference/web/ui/build/static/js/main.e51a356d.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4385c1095eefbff0a8ec3b2964ba6e5a66a05ab31be721483ca2f43e2a91f6ff.json +0 -1
- /xinference/web/ui/build/static/js/{main.e51a356d.js.LICENSE.txt → main.b76aeeb7.js.LICENSE.txt} +0 -0
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/LICENSE +0 -0
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/WHEEL +0 -0
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
# Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import mlx.core as mx
|
|
7
|
+
import mlx.nn as nn
|
|
8
|
+
|
|
9
|
+
from .layers import (
|
|
10
|
+
DoubleStreamBlock,
|
|
11
|
+
EmbedND,
|
|
12
|
+
LastLayer,
|
|
13
|
+
MLPEmbedder,
|
|
14
|
+
SingleStreamBlock,
|
|
15
|
+
timestep_embedding,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class FluxParams:
|
|
21
|
+
in_channels: int
|
|
22
|
+
vec_in_dim: int
|
|
23
|
+
context_in_dim: int
|
|
24
|
+
hidden_size: int
|
|
25
|
+
mlp_ratio: float
|
|
26
|
+
num_heads: int
|
|
27
|
+
depth: int
|
|
28
|
+
depth_single_blocks: int
|
|
29
|
+
axes_dim: list[int]
|
|
30
|
+
theta: int
|
|
31
|
+
qkv_bias: bool
|
|
32
|
+
guidance_embed: bool
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Flux(nn.Module):
|
|
36
|
+
def __init__(self, params: FluxParams):
|
|
37
|
+
super().__init__()
|
|
38
|
+
|
|
39
|
+
self.params = params
|
|
40
|
+
self.in_channels = params.in_channels
|
|
41
|
+
self.out_channels = self.in_channels
|
|
42
|
+
if params.hidden_size % params.num_heads != 0:
|
|
43
|
+
raise ValueError(
|
|
44
|
+
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
|
45
|
+
)
|
|
46
|
+
pe_dim = params.hidden_size // params.num_heads
|
|
47
|
+
if sum(params.axes_dim) != pe_dim:
|
|
48
|
+
raise ValueError(
|
|
49
|
+
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
|
|
50
|
+
)
|
|
51
|
+
self.hidden_size = params.hidden_size
|
|
52
|
+
self.num_heads = params.num_heads
|
|
53
|
+
self.pe_embedder = EmbedND(
|
|
54
|
+
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
|
|
55
|
+
)
|
|
56
|
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
|
57
|
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
|
58
|
+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
|
59
|
+
self.guidance_in = (
|
|
60
|
+
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
|
61
|
+
if params.guidance_embed
|
|
62
|
+
else nn.Identity()
|
|
63
|
+
)
|
|
64
|
+
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
|
65
|
+
|
|
66
|
+
self.double_blocks = [
|
|
67
|
+
DoubleStreamBlock(
|
|
68
|
+
self.hidden_size,
|
|
69
|
+
self.num_heads,
|
|
70
|
+
mlp_ratio=params.mlp_ratio,
|
|
71
|
+
qkv_bias=params.qkv_bias,
|
|
72
|
+
)
|
|
73
|
+
for _ in range(params.depth)
|
|
74
|
+
]
|
|
75
|
+
|
|
76
|
+
self.single_blocks = [
|
|
77
|
+
SingleStreamBlock(
|
|
78
|
+
self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
|
|
79
|
+
)
|
|
80
|
+
for _ in range(params.depth_single_blocks)
|
|
81
|
+
]
|
|
82
|
+
|
|
83
|
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
|
84
|
+
|
|
85
|
+
def sanitize(self, weights):
|
|
86
|
+
new_weights = {}
|
|
87
|
+
for k, w in weights.items():
|
|
88
|
+
if k.endswith(".scale"):
|
|
89
|
+
k = k[:-6] + ".weight"
|
|
90
|
+
for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]:
|
|
91
|
+
if f".{seq}." in k:
|
|
92
|
+
k = k.replace(f".{seq}.", f".{seq}.layers.")
|
|
93
|
+
break
|
|
94
|
+
new_weights[k] = w
|
|
95
|
+
return new_weights
|
|
96
|
+
|
|
97
|
+
def __call__(
|
|
98
|
+
self,
|
|
99
|
+
img: mx.array,
|
|
100
|
+
img_ids: mx.array,
|
|
101
|
+
txt: mx.array,
|
|
102
|
+
txt_ids: mx.array,
|
|
103
|
+
timesteps: mx.array,
|
|
104
|
+
y: mx.array,
|
|
105
|
+
guidance: Optional[mx.array] = None,
|
|
106
|
+
) -> mx.array:
|
|
107
|
+
if img.ndim != 3 or txt.ndim != 3:
|
|
108
|
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
|
109
|
+
|
|
110
|
+
img = self.img_in(img)
|
|
111
|
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
|
112
|
+
if self.params.guidance_embed:
|
|
113
|
+
if guidance is None:
|
|
114
|
+
raise ValueError(
|
|
115
|
+
"Didn't get guidance strength for guidance distilled model."
|
|
116
|
+
)
|
|
117
|
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
|
118
|
+
vec = vec + self.vector_in(y)
|
|
119
|
+
txt = self.txt_in(txt)
|
|
120
|
+
|
|
121
|
+
ids = mx.concatenate([txt_ids, img_ids], axis=1)
|
|
122
|
+
pe = self.pe_embedder(ids).astype(img.dtype)
|
|
123
|
+
|
|
124
|
+
for block in self.double_blocks:
|
|
125
|
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
|
126
|
+
|
|
127
|
+
img = mx.concatenate([txt, img], axis=1)
|
|
128
|
+
for block in self.single_blocks:
|
|
129
|
+
img = block(img, vec=vec, pe=pe)
|
|
130
|
+
img = img[:, txt.shape[1] :, ...]
|
|
131
|
+
|
|
132
|
+
img = self.final_layer(img, vec)
|
|
133
|
+
|
|
134
|
+
return img
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from functools import lru_cache
|
|
5
|
+
|
|
6
|
+
import mlx.core as mx
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FluxSampler:
|
|
10
|
+
def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.5):
|
|
11
|
+
self._base_shift = base_shift
|
|
12
|
+
self._max_shift = max_shift
|
|
13
|
+
self._schnell = "schnell" in name
|
|
14
|
+
|
|
15
|
+
def _time_shift(self, x, t):
|
|
16
|
+
x1, x2 = 256, 4096
|
|
17
|
+
t1, t2 = self._base_shift, self._max_shift
|
|
18
|
+
exp_mu = math.exp((x - x1) * (t2 - t1) / (x2 - x1) + t1)
|
|
19
|
+
t = exp_mu / (exp_mu + (1 / t - 1))
|
|
20
|
+
return t
|
|
21
|
+
|
|
22
|
+
@lru_cache
|
|
23
|
+
def timesteps(
|
|
24
|
+
self, num_steps, image_sequence_length, start: float = 1, stop: float = 0
|
|
25
|
+
):
|
|
26
|
+
t = mx.linspace(start, stop, num_steps + 1)
|
|
27
|
+
|
|
28
|
+
if self._schnell:
|
|
29
|
+
t = self._time_shift(image_sequence_length, t)
|
|
30
|
+
|
|
31
|
+
return t.tolist()
|
|
32
|
+
|
|
33
|
+
def random_timesteps(self, B, L, dtype=mx.float32, key=None):
|
|
34
|
+
if self._schnell:
|
|
35
|
+
# TODO: Should we upweigh 1 and 0.75?
|
|
36
|
+
t = mx.random.randint(1, 5, shape=(B,), key=key)
|
|
37
|
+
t = t.astype(dtype) / 4
|
|
38
|
+
else:
|
|
39
|
+
t = mx.random.uniform(shape=(B,), dtype=dtype, key=key)
|
|
40
|
+
t = self._time_shift(L, t)
|
|
41
|
+
|
|
42
|
+
return t
|
|
43
|
+
|
|
44
|
+
def sample_prior(self, shape, dtype=mx.float32, key=None):
|
|
45
|
+
return mx.random.normal(shape, dtype=dtype, key=key)
|
|
46
|
+
|
|
47
|
+
def add_noise(self, x, t, noise=None, key=None):
|
|
48
|
+
noise = (
|
|
49
|
+
noise
|
|
50
|
+
if noise is not None
|
|
51
|
+
else mx.random.normal(x.shape, dtype=x.dtype, key=key)
|
|
52
|
+
)
|
|
53
|
+
return x * (1 - t) + t * noise
|
|
54
|
+
|
|
55
|
+
def step(self, pred, x_t, t, t_prev):
|
|
56
|
+
return x_t + (t_prev - t) * pred
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
# Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import List, Optional, Tuple
|
|
6
|
+
|
|
7
|
+
import mlx.core as mx
|
|
8
|
+
import mlx.nn as nn
|
|
9
|
+
|
|
10
|
+
_SHARED_REPLACEMENT_PATTERNS = [
|
|
11
|
+
(".block.", ".layers."),
|
|
12
|
+
(".k.", ".key_proj."),
|
|
13
|
+
(".o.", ".out_proj."),
|
|
14
|
+
(".q.", ".query_proj."),
|
|
15
|
+
(".v.", ".value_proj."),
|
|
16
|
+
("shared.", "wte."),
|
|
17
|
+
("lm_head.", "lm_head.linear."),
|
|
18
|
+
(".layer.0.layer_norm.", ".ln1."),
|
|
19
|
+
(".layer.1.layer_norm.", ".ln2."),
|
|
20
|
+
(".layer.2.layer_norm.", ".ln3."),
|
|
21
|
+
(".final_layer_norm.", ".ln."),
|
|
22
|
+
(
|
|
23
|
+
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
|
24
|
+
"relative_attention_bias.embeddings.",
|
|
25
|
+
),
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
_ENCODER_REPLACEMENT_PATTERNS = [
|
|
29
|
+
(".layer.0.SelfAttention.", ".attention."),
|
|
30
|
+
(".layer.1.DenseReluDense.", ".dense."),
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class T5Config:
|
|
36
|
+
vocab_size: int
|
|
37
|
+
num_layers: int
|
|
38
|
+
num_heads: int
|
|
39
|
+
relative_attention_num_buckets: int
|
|
40
|
+
d_kv: int
|
|
41
|
+
d_model: int
|
|
42
|
+
feed_forward_proj: str
|
|
43
|
+
tie_word_embeddings: bool
|
|
44
|
+
|
|
45
|
+
d_ff: Optional[int] = None
|
|
46
|
+
num_decoder_layers: Optional[int] = None
|
|
47
|
+
relative_attention_max_distance: int = 128
|
|
48
|
+
layer_norm_epsilon: float = 1e-6
|
|
49
|
+
|
|
50
|
+
@classmethod
|
|
51
|
+
def from_dict(cls, config):
|
|
52
|
+
return cls(
|
|
53
|
+
vocab_size=config["vocab_size"],
|
|
54
|
+
num_layers=config["num_layers"],
|
|
55
|
+
num_heads=config["num_heads"],
|
|
56
|
+
relative_attention_num_buckets=config["relative_attention_num_buckets"],
|
|
57
|
+
d_kv=config["d_kv"],
|
|
58
|
+
d_model=config["d_model"],
|
|
59
|
+
feed_forward_proj=config["feed_forward_proj"],
|
|
60
|
+
tie_word_embeddings=config["tie_word_embeddings"],
|
|
61
|
+
d_ff=config.get("d_ff", 4 * config["d_model"]),
|
|
62
|
+
num_decoder_layers=config.get("num_decoder_layers", config["num_layers"]),
|
|
63
|
+
relative_attention_max_distance=config.get(
|
|
64
|
+
"relative_attention_max_distance", 128
|
|
65
|
+
),
|
|
66
|
+
layer_norm_epsilon=config.get("layer_norm_epsilon", 1e-6),
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class RelativePositionBias(nn.Module):
|
|
71
|
+
def __init__(self, config: T5Config, bidirectional: bool):
|
|
72
|
+
self.bidirectional = bidirectional
|
|
73
|
+
self.num_buckets = config.relative_attention_num_buckets
|
|
74
|
+
self.max_distance = config.relative_attention_max_distance
|
|
75
|
+
self.n_heads = config.num_heads
|
|
76
|
+
self.embeddings = nn.Embedding(self.num_buckets, self.n_heads)
|
|
77
|
+
|
|
78
|
+
@staticmethod
|
|
79
|
+
def _relative_position_bucket(rpos, bidirectional, num_buckets, max_distance):
|
|
80
|
+
num_buckets = num_buckets // 2 if bidirectional else num_buckets
|
|
81
|
+
max_exact = num_buckets // 2
|
|
82
|
+
|
|
83
|
+
abspos = rpos.abs()
|
|
84
|
+
is_small = abspos < max_exact
|
|
85
|
+
|
|
86
|
+
scale = (num_buckets - max_exact) / math.log(max_distance / max_exact)
|
|
87
|
+
buckets_large = (mx.log(abspos / max_exact) * scale).astype(mx.int16)
|
|
88
|
+
buckets_large = mx.minimum(max_exact + buckets_large, num_buckets - 1)
|
|
89
|
+
|
|
90
|
+
buckets = mx.where(is_small, abspos, buckets_large)
|
|
91
|
+
if bidirectional:
|
|
92
|
+
buckets = buckets + (rpos > 0) * num_buckets
|
|
93
|
+
else:
|
|
94
|
+
buckets = buckets * (rpos < 0)
|
|
95
|
+
|
|
96
|
+
return buckets
|
|
97
|
+
|
|
98
|
+
def __call__(self, query_length: int, key_length: int, offset: int = 0):
|
|
99
|
+
"""Compute binned relative position bias"""
|
|
100
|
+
context_position = mx.arange(offset, query_length)[:, None]
|
|
101
|
+
memory_position = mx.arange(key_length)[None, :]
|
|
102
|
+
|
|
103
|
+
# shape (query_length, key_length)
|
|
104
|
+
relative_position = memory_position - context_position
|
|
105
|
+
relative_position_bucket = self._relative_position_bucket(
|
|
106
|
+
relative_position,
|
|
107
|
+
bidirectional=self.bidirectional,
|
|
108
|
+
num_buckets=self.num_buckets,
|
|
109
|
+
max_distance=self.max_distance,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# shape (query_length, key_length, num_heads)
|
|
113
|
+
values = self.embeddings(relative_position_bucket)
|
|
114
|
+
|
|
115
|
+
# shape (num_heads, query_length, key_length)
|
|
116
|
+
return values.transpose(2, 0, 1)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class MultiHeadAttention(nn.Module):
|
|
120
|
+
def __init__(self, config: T5Config):
|
|
121
|
+
super().__init__()
|
|
122
|
+
inner_dim = config.d_kv * config.num_heads
|
|
123
|
+
self.num_heads = config.num_heads
|
|
124
|
+
self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
|
125
|
+
self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
|
126
|
+
self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
|
127
|
+
self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False)
|
|
128
|
+
|
|
129
|
+
def __call__(
|
|
130
|
+
self,
|
|
131
|
+
queries: mx.array,
|
|
132
|
+
keys: mx.array,
|
|
133
|
+
values: mx.array,
|
|
134
|
+
mask: Optional[mx.array],
|
|
135
|
+
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
|
136
|
+
) -> [mx.array, Tuple[mx.array, mx.array]]:
|
|
137
|
+
queries = self.query_proj(queries)
|
|
138
|
+
keys = self.key_proj(keys)
|
|
139
|
+
values = self.value_proj(values)
|
|
140
|
+
|
|
141
|
+
num_heads = self.num_heads
|
|
142
|
+
B, L, _ = queries.shape
|
|
143
|
+
_, S, _ = keys.shape
|
|
144
|
+
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
|
145
|
+
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
|
146
|
+
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
|
147
|
+
|
|
148
|
+
if cache is not None:
|
|
149
|
+
key_cache, value_cache = cache
|
|
150
|
+
keys = mx.concatenate([key_cache, keys], axis=3)
|
|
151
|
+
values = mx.concatenate([value_cache, values], axis=2)
|
|
152
|
+
|
|
153
|
+
values_hat = mx.fast.scaled_dot_product_attention(
|
|
154
|
+
queries, keys, values, scale=1.0, mask=mask.astype(queries.dtype)
|
|
155
|
+
)
|
|
156
|
+
values_hat = values_hat.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
157
|
+
|
|
158
|
+
return self.out_proj(values_hat), (keys, values)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class DenseActivation(nn.Module):
|
|
162
|
+
def __init__(self, config: T5Config):
|
|
163
|
+
super().__init__()
|
|
164
|
+
mlp_dims = config.d_ff or config.d_model * 4
|
|
165
|
+
self.gated = config.feed_forward_proj.startswith("gated")
|
|
166
|
+
if self.gated:
|
|
167
|
+
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
|
168
|
+
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
|
169
|
+
else:
|
|
170
|
+
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
|
|
171
|
+
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
|
|
172
|
+
activation = config.feed_forward_proj.removeprefix("gated-")
|
|
173
|
+
if activation == "relu":
|
|
174
|
+
self.act = nn.relu
|
|
175
|
+
elif activation == "gelu":
|
|
176
|
+
self.act = nn.gelu
|
|
177
|
+
elif activation == "silu":
|
|
178
|
+
self.act = nn.silu
|
|
179
|
+
else:
|
|
180
|
+
raise ValueError(f"Unknown activation: {activation}")
|
|
181
|
+
|
|
182
|
+
def __call__(self, x):
|
|
183
|
+
if self.gated:
|
|
184
|
+
hidden_act = self.act(self.wi_0(x))
|
|
185
|
+
hidden_linear = self.wi_1(x)
|
|
186
|
+
x = hidden_act * hidden_linear
|
|
187
|
+
else:
|
|
188
|
+
x = self.act(self.wi(x))
|
|
189
|
+
return self.wo(x)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class TransformerEncoderLayer(nn.Module):
|
|
193
|
+
def __init__(self, config: T5Config):
|
|
194
|
+
super().__init__()
|
|
195
|
+
self.attention = MultiHeadAttention(config)
|
|
196
|
+
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
|
197
|
+
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
|
198
|
+
self.dense = DenseActivation(config)
|
|
199
|
+
|
|
200
|
+
def __call__(self, x, mask):
|
|
201
|
+
y = self.ln1(x)
|
|
202
|
+
y, _ = self.attention(y, y, y, mask=mask)
|
|
203
|
+
x = x + y
|
|
204
|
+
|
|
205
|
+
y = self.ln2(x)
|
|
206
|
+
y = self.dense(y)
|
|
207
|
+
return x + y
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class TransformerEncoder(nn.Module):
|
|
211
|
+
def __init__(self, config: T5Config):
|
|
212
|
+
super().__init__()
|
|
213
|
+
self.layers = [
|
|
214
|
+
TransformerEncoderLayer(config) for i in range(config.num_layers)
|
|
215
|
+
]
|
|
216
|
+
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
|
217
|
+
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
|
|
218
|
+
|
|
219
|
+
def __call__(self, x: mx.array):
|
|
220
|
+
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
|
|
221
|
+
pos_bias = pos_bias.astype(x.dtype)
|
|
222
|
+
for layer in self.layers:
|
|
223
|
+
x = layer(x, mask=pos_bias)
|
|
224
|
+
return self.ln(x)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class T5Encoder(nn.Module):
|
|
228
|
+
def __init__(self, config: T5Config):
|
|
229
|
+
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
|
230
|
+
self.encoder = TransformerEncoder(config)
|
|
231
|
+
|
|
232
|
+
def sanitize(self, weights):
|
|
233
|
+
new_weights = {}
|
|
234
|
+
for k, w in weights.items():
|
|
235
|
+
for old, new in _SHARED_REPLACEMENT_PATTERNS:
|
|
236
|
+
k = k.replace(old, new)
|
|
237
|
+
if k.startswith("encoder."):
|
|
238
|
+
for old, new in _ENCODER_REPLACEMENT_PATTERNS:
|
|
239
|
+
k = k.replace(old, new)
|
|
240
|
+
new_weights[k] = w
|
|
241
|
+
return new_weights
|
|
242
|
+
|
|
243
|
+
def __call__(self, inputs: mx.array):
|
|
244
|
+
return self.encoder(self.wte(inputs))
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
# Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import regex
|
|
5
|
+
from sentencepiece import SentencePieceProcessor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class CLIPTokenizer:
|
|
9
|
+
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, bpe_ranks, vocab, max_length=77):
|
|
12
|
+
self.max_length = max_length
|
|
13
|
+
self.bpe_ranks = bpe_ranks
|
|
14
|
+
self.vocab = vocab
|
|
15
|
+
self.pat = regex.compile(
|
|
16
|
+
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
|
17
|
+
regex.IGNORECASE,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
self._cache = {self.bos: self.bos, self.eos: self.eos}
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def bos(self):
|
|
24
|
+
return "<|startoftext|>"
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def bos_token(self):
|
|
28
|
+
return self.vocab[self.bos]
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def eos(self):
|
|
32
|
+
return "<|endoftext|>"
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def eos_token(self):
|
|
36
|
+
return self.vocab[self.eos]
|
|
37
|
+
|
|
38
|
+
def bpe(self, text):
|
|
39
|
+
if text in self._cache:
|
|
40
|
+
return self._cache[text]
|
|
41
|
+
|
|
42
|
+
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
|
|
43
|
+
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
|
44
|
+
|
|
45
|
+
if not unique_bigrams:
|
|
46
|
+
return unigrams
|
|
47
|
+
|
|
48
|
+
# In every iteration try to merge the two most likely bigrams. If none
|
|
49
|
+
# was merged we are done.
|
|
50
|
+
#
|
|
51
|
+
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
|
|
52
|
+
while unique_bigrams:
|
|
53
|
+
bigram = min(
|
|
54
|
+
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
|
|
55
|
+
)
|
|
56
|
+
if bigram not in self.bpe_ranks:
|
|
57
|
+
break
|
|
58
|
+
|
|
59
|
+
new_unigrams = []
|
|
60
|
+
skip = False
|
|
61
|
+
for a, b in zip(unigrams, unigrams[1:]):
|
|
62
|
+
if skip:
|
|
63
|
+
skip = False
|
|
64
|
+
continue
|
|
65
|
+
|
|
66
|
+
if (a, b) == bigram:
|
|
67
|
+
new_unigrams.append(a + b)
|
|
68
|
+
skip = True
|
|
69
|
+
|
|
70
|
+
else:
|
|
71
|
+
new_unigrams.append(a)
|
|
72
|
+
|
|
73
|
+
if not skip:
|
|
74
|
+
new_unigrams.append(b)
|
|
75
|
+
|
|
76
|
+
unigrams = new_unigrams
|
|
77
|
+
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
|
78
|
+
|
|
79
|
+
self._cache[text] = unigrams
|
|
80
|
+
|
|
81
|
+
return unigrams
|
|
82
|
+
|
|
83
|
+
def tokenize(self, text, prepend_bos=True, append_eos=True):
|
|
84
|
+
if isinstance(text, list):
|
|
85
|
+
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
|
|
86
|
+
|
|
87
|
+
# Lower case cleanup and split according to self.pat. Hugging Face does
|
|
88
|
+
# a much more thorough job here but this should suffice for 95% of
|
|
89
|
+
# cases.
|
|
90
|
+
clean_text = regex.sub(r"\s+", " ", text.lower())
|
|
91
|
+
tokens = regex.findall(self.pat, clean_text)
|
|
92
|
+
|
|
93
|
+
# Split the tokens according to the byte-pair merge file
|
|
94
|
+
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
|
|
95
|
+
|
|
96
|
+
# Map to token ids and return
|
|
97
|
+
tokens = [self.vocab[t] for t in bpe_tokens]
|
|
98
|
+
if prepend_bos:
|
|
99
|
+
tokens = [self.bos_token] + tokens
|
|
100
|
+
if append_eos:
|
|
101
|
+
tokens.append(self.eos_token)
|
|
102
|
+
|
|
103
|
+
if len(tokens) > self.max_length:
|
|
104
|
+
tokens = tokens[: self.max_length]
|
|
105
|
+
if append_eos:
|
|
106
|
+
tokens[-1] = self.eos_token
|
|
107
|
+
|
|
108
|
+
return tokens
|
|
109
|
+
|
|
110
|
+
def encode(self, text):
|
|
111
|
+
if not isinstance(text, list):
|
|
112
|
+
return self.encode([text])
|
|
113
|
+
|
|
114
|
+
tokens = self.tokenize(text)
|
|
115
|
+
length = max(len(t) for t in tokens)
|
|
116
|
+
for t in tokens:
|
|
117
|
+
t.extend([self.eos_token] * (length - len(t)))
|
|
118
|
+
|
|
119
|
+
return mx.array(tokens)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class T5Tokenizer:
|
|
123
|
+
def __init__(self, model_file, max_length=512):
|
|
124
|
+
self._tokenizer = SentencePieceProcessor(model_file)
|
|
125
|
+
self.max_length = max_length
|
|
126
|
+
|
|
127
|
+
@property
|
|
128
|
+
def pad(self):
|
|
129
|
+
try:
|
|
130
|
+
return self._tokenizer.id_to_piece(self.pad_token)
|
|
131
|
+
except IndexError:
|
|
132
|
+
return None
|
|
133
|
+
|
|
134
|
+
@property
|
|
135
|
+
def pad_token(self):
|
|
136
|
+
return self._tokenizer.pad_id()
|
|
137
|
+
|
|
138
|
+
@property
|
|
139
|
+
def bos(self):
|
|
140
|
+
try:
|
|
141
|
+
return self._tokenizer.id_to_piece(self.bos_token)
|
|
142
|
+
except IndexError:
|
|
143
|
+
return None
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def bos_token(self):
|
|
147
|
+
return self._tokenizer.bos_id()
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def eos(self):
|
|
151
|
+
try:
|
|
152
|
+
return self._tokenizer.id_to_piece(self.eos_token)
|
|
153
|
+
except IndexError:
|
|
154
|
+
return None
|
|
155
|
+
|
|
156
|
+
@property
|
|
157
|
+
def eos_token(self):
|
|
158
|
+
return self._tokenizer.eos_id()
|
|
159
|
+
|
|
160
|
+
def tokenize(self, text, prepend_bos=True, append_eos=True, pad=True):
|
|
161
|
+
if isinstance(text, list):
|
|
162
|
+
return [self.tokenize(t, prepend_bos, append_eos, pad) for t in text]
|
|
163
|
+
|
|
164
|
+
tokens = self._tokenizer.encode(text)
|
|
165
|
+
|
|
166
|
+
if prepend_bos and self.bos_token >= 0:
|
|
167
|
+
tokens = [self.bos_token] + tokens
|
|
168
|
+
if append_eos and self.eos_token >= 0:
|
|
169
|
+
tokens.append(self.eos_token)
|
|
170
|
+
if pad and len(tokens) < self.max_length and self.pad_token >= 0:
|
|
171
|
+
tokens += [self.pad_token] * (self.max_length - len(tokens))
|
|
172
|
+
|
|
173
|
+
return tokens
|
|
174
|
+
|
|
175
|
+
def encode(self, text, pad=True):
|
|
176
|
+
if not isinstance(text, list):
|
|
177
|
+
return self.encode([text], pad=pad)
|
|
178
|
+
|
|
179
|
+
pad_token = self.pad_token if self.pad_token >= 0 else 0
|
|
180
|
+
tokens = self.tokenize(text, pad=pad)
|
|
181
|
+
length = max(len(t) for t in tokens)
|
|
182
|
+
for t in tokens:
|
|
183
|
+
t.extend([pad_token] * (length - len(t)))
|
|
184
|
+
|
|
185
|
+
return mx.array(tokens)
|