xinference 0.15.4__py3-none-any.whl → 0.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (67) hide show
  1. xinference/__init__.py +0 -4
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +48 -0
  4. xinference/client/restful/restful_client.py +19 -0
  5. xinference/constants.py +4 -4
  6. xinference/core/chat_interface.py +5 -1
  7. xinference/core/image_interface.py +5 -1
  8. xinference/core/model.py +195 -34
  9. xinference/core/scheduler.py +10 -7
  10. xinference/core/utils.py +9 -0
  11. xinference/model/__init__.py +4 -0
  12. xinference/model/audio/chattts.py +25 -14
  13. xinference/model/audio/model_spec.json +1 -1
  14. xinference/model/audio/model_spec_modelscope.json +1 -1
  15. xinference/model/embedding/model_spec.json +1 -1
  16. xinference/model/image/core.py +59 -4
  17. xinference/model/image/model_spec.json +24 -3
  18. xinference/model/image/model_spec_modelscope.json +25 -3
  19. xinference/model/image/ocr/__init__.py +13 -0
  20. xinference/model/image/ocr/got_ocr2.py +76 -0
  21. xinference/model/image/scheduler/__init__.py +13 -0
  22. xinference/model/image/scheduler/flux.py +533 -0
  23. xinference/model/image/stable_diffusion/core.py +8 -34
  24. xinference/model/image/stable_diffusion/mlx.py +221 -0
  25. xinference/model/image/utils.py +39 -3
  26. xinference/model/llm/__init__.py +2 -0
  27. xinference/model/llm/llm_family.json +178 -1
  28. xinference/model/llm/llm_family_modelscope.json +119 -0
  29. xinference/model/llm/transformers/chatglm.py +104 -0
  30. xinference/model/llm/transformers/core.py +37 -111
  31. xinference/model/llm/transformers/deepseek_v2.py +0 -226
  32. xinference/model/llm/transformers/internlm2.py +3 -95
  33. xinference/model/llm/transformers/opt.py +68 -0
  34. xinference/model/llm/transformers/utils.py +4 -284
  35. xinference/model/llm/utils.py +2 -2
  36. xinference/model/llm/vllm/core.py +16 -1
  37. xinference/thirdparty/mlx/__init__.py +13 -0
  38. xinference/thirdparty/mlx/flux/__init__.py +15 -0
  39. xinference/thirdparty/mlx/flux/autoencoder.py +357 -0
  40. xinference/thirdparty/mlx/flux/clip.py +154 -0
  41. xinference/thirdparty/mlx/flux/datasets.py +75 -0
  42. xinference/thirdparty/mlx/flux/flux.py +247 -0
  43. xinference/thirdparty/mlx/flux/layers.py +302 -0
  44. xinference/thirdparty/mlx/flux/lora.py +76 -0
  45. xinference/thirdparty/mlx/flux/model.py +134 -0
  46. xinference/thirdparty/mlx/flux/sampler.py +56 -0
  47. xinference/thirdparty/mlx/flux/t5.py +244 -0
  48. xinference/thirdparty/mlx/flux/tokenizers.py +185 -0
  49. xinference/thirdparty/mlx/flux/trainer.py +98 -0
  50. xinference/thirdparty/mlx/flux/utils.py +179 -0
  51. xinference/utils.py +2 -3
  52. xinference/web/ui/build/asset-manifest.json +3 -3
  53. xinference/web/ui/build/index.html +1 -1
  54. xinference/web/ui/build/static/js/{main.e51a356d.js → main.b76aeeb7.js} +3 -3
  55. xinference/web/ui/build/static/js/main.b76aeeb7.js.map +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/32ea2c04cf0bba2761b4883d2c40cc259952c94d2d6bb774e510963ca37aac0a.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +1 -0
  58. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/METADATA +49 -10
  59. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/RECORD +64 -44
  60. xinference/web/ui/build/static/js/main.e51a356d.js.map +0 -1
  61. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +0 -1
  62. xinference/web/ui/node_modules/.cache/babel-loader/4385c1095eefbff0a8ec3b2964ba6e5a66a05ab31be721483ca2f43e2a91f6ff.json +0 -1
  63. /xinference/web/ui/build/static/js/{main.e51a356d.js.LICENSE.txt → main.b76aeeb7.js.LICENSE.txt} +0 -0
  64. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/LICENSE +0 -0
  65. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/WHEEL +0 -0
  66. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/entry_points.txt +0 -0
  67. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,134 @@
1
+ # Copyright © 2024 Apple Inc.
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import mlx.core as mx
7
+ import mlx.nn as nn
8
+
9
+ from .layers import (
10
+ DoubleStreamBlock,
11
+ EmbedND,
12
+ LastLayer,
13
+ MLPEmbedder,
14
+ SingleStreamBlock,
15
+ timestep_embedding,
16
+ )
17
+
18
+
19
+ @dataclass
20
+ class FluxParams:
21
+ in_channels: int
22
+ vec_in_dim: int
23
+ context_in_dim: int
24
+ hidden_size: int
25
+ mlp_ratio: float
26
+ num_heads: int
27
+ depth: int
28
+ depth_single_blocks: int
29
+ axes_dim: list[int]
30
+ theta: int
31
+ qkv_bias: bool
32
+ guidance_embed: bool
33
+
34
+
35
+ class Flux(nn.Module):
36
+ def __init__(self, params: FluxParams):
37
+ super().__init__()
38
+
39
+ self.params = params
40
+ self.in_channels = params.in_channels
41
+ self.out_channels = self.in_channels
42
+ if params.hidden_size % params.num_heads != 0:
43
+ raise ValueError(
44
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
45
+ )
46
+ pe_dim = params.hidden_size // params.num_heads
47
+ if sum(params.axes_dim) != pe_dim:
48
+ raise ValueError(
49
+ f"Got {params.axes_dim} but expected positional dim {pe_dim}"
50
+ )
51
+ self.hidden_size = params.hidden_size
52
+ self.num_heads = params.num_heads
53
+ self.pe_embedder = EmbedND(
54
+ dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
55
+ )
56
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
57
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
58
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
59
+ self.guidance_in = (
60
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
61
+ if params.guidance_embed
62
+ else nn.Identity()
63
+ )
64
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
65
+
66
+ self.double_blocks = [
67
+ DoubleStreamBlock(
68
+ self.hidden_size,
69
+ self.num_heads,
70
+ mlp_ratio=params.mlp_ratio,
71
+ qkv_bias=params.qkv_bias,
72
+ )
73
+ for _ in range(params.depth)
74
+ ]
75
+
76
+ self.single_blocks = [
77
+ SingleStreamBlock(
78
+ self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
79
+ )
80
+ for _ in range(params.depth_single_blocks)
81
+ ]
82
+
83
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
84
+
85
+ def sanitize(self, weights):
86
+ new_weights = {}
87
+ for k, w in weights.items():
88
+ if k.endswith(".scale"):
89
+ k = k[:-6] + ".weight"
90
+ for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]:
91
+ if f".{seq}." in k:
92
+ k = k.replace(f".{seq}.", f".{seq}.layers.")
93
+ break
94
+ new_weights[k] = w
95
+ return new_weights
96
+
97
+ def __call__(
98
+ self,
99
+ img: mx.array,
100
+ img_ids: mx.array,
101
+ txt: mx.array,
102
+ txt_ids: mx.array,
103
+ timesteps: mx.array,
104
+ y: mx.array,
105
+ guidance: Optional[mx.array] = None,
106
+ ) -> mx.array:
107
+ if img.ndim != 3 or txt.ndim != 3:
108
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
109
+
110
+ img = self.img_in(img)
111
+ vec = self.time_in(timestep_embedding(timesteps, 256))
112
+ if self.params.guidance_embed:
113
+ if guidance is None:
114
+ raise ValueError(
115
+ "Didn't get guidance strength for guidance distilled model."
116
+ )
117
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
118
+ vec = vec + self.vector_in(y)
119
+ txt = self.txt_in(txt)
120
+
121
+ ids = mx.concatenate([txt_ids, img_ids], axis=1)
122
+ pe = self.pe_embedder(ids).astype(img.dtype)
123
+
124
+ for block in self.double_blocks:
125
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
126
+
127
+ img = mx.concatenate([txt, img], axis=1)
128
+ for block in self.single_blocks:
129
+ img = block(img, vec=vec, pe=pe)
130
+ img = img[:, txt.shape[1] :, ...]
131
+
132
+ img = self.final_layer(img, vec)
133
+
134
+ return img
@@ -0,0 +1,56 @@
1
+ # Copyright © 2024 Apple Inc.
2
+
3
+ import math
4
+ from functools import lru_cache
5
+
6
+ import mlx.core as mx
7
+
8
+
9
+ class FluxSampler:
10
+ def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.5):
11
+ self._base_shift = base_shift
12
+ self._max_shift = max_shift
13
+ self._schnell = "schnell" in name
14
+
15
+ def _time_shift(self, x, t):
16
+ x1, x2 = 256, 4096
17
+ t1, t2 = self._base_shift, self._max_shift
18
+ exp_mu = math.exp((x - x1) * (t2 - t1) / (x2 - x1) + t1)
19
+ t = exp_mu / (exp_mu + (1 / t - 1))
20
+ return t
21
+
22
+ @lru_cache
23
+ def timesteps(
24
+ self, num_steps, image_sequence_length, start: float = 1, stop: float = 0
25
+ ):
26
+ t = mx.linspace(start, stop, num_steps + 1)
27
+
28
+ if self._schnell:
29
+ t = self._time_shift(image_sequence_length, t)
30
+
31
+ return t.tolist()
32
+
33
+ def random_timesteps(self, B, L, dtype=mx.float32, key=None):
34
+ if self._schnell:
35
+ # TODO: Should we upweigh 1 and 0.75?
36
+ t = mx.random.randint(1, 5, shape=(B,), key=key)
37
+ t = t.astype(dtype) / 4
38
+ else:
39
+ t = mx.random.uniform(shape=(B,), dtype=dtype, key=key)
40
+ t = self._time_shift(L, t)
41
+
42
+ return t
43
+
44
+ def sample_prior(self, shape, dtype=mx.float32, key=None):
45
+ return mx.random.normal(shape, dtype=dtype, key=key)
46
+
47
+ def add_noise(self, x, t, noise=None, key=None):
48
+ noise = (
49
+ noise
50
+ if noise is not None
51
+ else mx.random.normal(x.shape, dtype=x.dtype, key=key)
52
+ )
53
+ return x * (1 - t) + t * noise
54
+
55
+ def step(self, pred, x_t, t, t_prev):
56
+ return x_t + (t_prev - t) * pred
@@ -0,0 +1,244 @@
1
+ # Copyright © 2024 Apple Inc.
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import List, Optional, Tuple
6
+
7
+ import mlx.core as mx
8
+ import mlx.nn as nn
9
+
10
+ _SHARED_REPLACEMENT_PATTERNS = [
11
+ (".block.", ".layers."),
12
+ (".k.", ".key_proj."),
13
+ (".o.", ".out_proj."),
14
+ (".q.", ".query_proj."),
15
+ (".v.", ".value_proj."),
16
+ ("shared.", "wte."),
17
+ ("lm_head.", "lm_head.linear."),
18
+ (".layer.0.layer_norm.", ".ln1."),
19
+ (".layer.1.layer_norm.", ".ln2."),
20
+ (".layer.2.layer_norm.", ".ln3."),
21
+ (".final_layer_norm.", ".ln."),
22
+ (
23
+ "layers.0.layer.0.SelfAttention.relative_attention_bias.",
24
+ "relative_attention_bias.embeddings.",
25
+ ),
26
+ ]
27
+
28
+ _ENCODER_REPLACEMENT_PATTERNS = [
29
+ (".layer.0.SelfAttention.", ".attention."),
30
+ (".layer.1.DenseReluDense.", ".dense."),
31
+ ]
32
+
33
+
34
+ @dataclass
35
+ class T5Config:
36
+ vocab_size: int
37
+ num_layers: int
38
+ num_heads: int
39
+ relative_attention_num_buckets: int
40
+ d_kv: int
41
+ d_model: int
42
+ feed_forward_proj: str
43
+ tie_word_embeddings: bool
44
+
45
+ d_ff: Optional[int] = None
46
+ num_decoder_layers: Optional[int] = None
47
+ relative_attention_max_distance: int = 128
48
+ layer_norm_epsilon: float = 1e-6
49
+
50
+ @classmethod
51
+ def from_dict(cls, config):
52
+ return cls(
53
+ vocab_size=config["vocab_size"],
54
+ num_layers=config["num_layers"],
55
+ num_heads=config["num_heads"],
56
+ relative_attention_num_buckets=config["relative_attention_num_buckets"],
57
+ d_kv=config["d_kv"],
58
+ d_model=config["d_model"],
59
+ feed_forward_proj=config["feed_forward_proj"],
60
+ tie_word_embeddings=config["tie_word_embeddings"],
61
+ d_ff=config.get("d_ff", 4 * config["d_model"]),
62
+ num_decoder_layers=config.get("num_decoder_layers", config["num_layers"]),
63
+ relative_attention_max_distance=config.get(
64
+ "relative_attention_max_distance", 128
65
+ ),
66
+ layer_norm_epsilon=config.get("layer_norm_epsilon", 1e-6),
67
+ )
68
+
69
+
70
+ class RelativePositionBias(nn.Module):
71
+ def __init__(self, config: T5Config, bidirectional: bool):
72
+ self.bidirectional = bidirectional
73
+ self.num_buckets = config.relative_attention_num_buckets
74
+ self.max_distance = config.relative_attention_max_distance
75
+ self.n_heads = config.num_heads
76
+ self.embeddings = nn.Embedding(self.num_buckets, self.n_heads)
77
+
78
+ @staticmethod
79
+ def _relative_position_bucket(rpos, bidirectional, num_buckets, max_distance):
80
+ num_buckets = num_buckets // 2 if bidirectional else num_buckets
81
+ max_exact = num_buckets // 2
82
+
83
+ abspos = rpos.abs()
84
+ is_small = abspos < max_exact
85
+
86
+ scale = (num_buckets - max_exact) / math.log(max_distance / max_exact)
87
+ buckets_large = (mx.log(abspos / max_exact) * scale).astype(mx.int16)
88
+ buckets_large = mx.minimum(max_exact + buckets_large, num_buckets - 1)
89
+
90
+ buckets = mx.where(is_small, abspos, buckets_large)
91
+ if bidirectional:
92
+ buckets = buckets + (rpos > 0) * num_buckets
93
+ else:
94
+ buckets = buckets * (rpos < 0)
95
+
96
+ return buckets
97
+
98
+ def __call__(self, query_length: int, key_length: int, offset: int = 0):
99
+ """Compute binned relative position bias"""
100
+ context_position = mx.arange(offset, query_length)[:, None]
101
+ memory_position = mx.arange(key_length)[None, :]
102
+
103
+ # shape (query_length, key_length)
104
+ relative_position = memory_position - context_position
105
+ relative_position_bucket = self._relative_position_bucket(
106
+ relative_position,
107
+ bidirectional=self.bidirectional,
108
+ num_buckets=self.num_buckets,
109
+ max_distance=self.max_distance,
110
+ )
111
+
112
+ # shape (query_length, key_length, num_heads)
113
+ values = self.embeddings(relative_position_bucket)
114
+
115
+ # shape (num_heads, query_length, key_length)
116
+ return values.transpose(2, 0, 1)
117
+
118
+
119
+ class MultiHeadAttention(nn.Module):
120
+ def __init__(self, config: T5Config):
121
+ super().__init__()
122
+ inner_dim = config.d_kv * config.num_heads
123
+ self.num_heads = config.num_heads
124
+ self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False)
125
+ self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False)
126
+ self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False)
127
+ self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False)
128
+
129
+ def __call__(
130
+ self,
131
+ queries: mx.array,
132
+ keys: mx.array,
133
+ values: mx.array,
134
+ mask: Optional[mx.array],
135
+ cache: Optional[Tuple[mx.array, mx.array]] = None,
136
+ ) -> [mx.array, Tuple[mx.array, mx.array]]:
137
+ queries = self.query_proj(queries)
138
+ keys = self.key_proj(keys)
139
+ values = self.value_proj(values)
140
+
141
+ num_heads = self.num_heads
142
+ B, L, _ = queries.shape
143
+ _, S, _ = keys.shape
144
+ queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
145
+ keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
146
+ values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
147
+
148
+ if cache is not None:
149
+ key_cache, value_cache = cache
150
+ keys = mx.concatenate([key_cache, keys], axis=3)
151
+ values = mx.concatenate([value_cache, values], axis=2)
152
+
153
+ values_hat = mx.fast.scaled_dot_product_attention(
154
+ queries, keys, values, scale=1.0, mask=mask.astype(queries.dtype)
155
+ )
156
+ values_hat = values_hat.transpose(0, 2, 1, 3).reshape(B, L, -1)
157
+
158
+ return self.out_proj(values_hat), (keys, values)
159
+
160
+
161
+ class DenseActivation(nn.Module):
162
+ def __init__(self, config: T5Config):
163
+ super().__init__()
164
+ mlp_dims = config.d_ff or config.d_model * 4
165
+ self.gated = config.feed_forward_proj.startswith("gated")
166
+ if self.gated:
167
+ self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
168
+ self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
169
+ else:
170
+ self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
171
+ self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
172
+ activation = config.feed_forward_proj.removeprefix("gated-")
173
+ if activation == "relu":
174
+ self.act = nn.relu
175
+ elif activation == "gelu":
176
+ self.act = nn.gelu
177
+ elif activation == "silu":
178
+ self.act = nn.silu
179
+ else:
180
+ raise ValueError(f"Unknown activation: {activation}")
181
+
182
+ def __call__(self, x):
183
+ if self.gated:
184
+ hidden_act = self.act(self.wi_0(x))
185
+ hidden_linear = self.wi_1(x)
186
+ x = hidden_act * hidden_linear
187
+ else:
188
+ x = self.act(self.wi(x))
189
+ return self.wo(x)
190
+
191
+
192
+ class TransformerEncoderLayer(nn.Module):
193
+ def __init__(self, config: T5Config):
194
+ super().__init__()
195
+ self.attention = MultiHeadAttention(config)
196
+ self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
197
+ self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
198
+ self.dense = DenseActivation(config)
199
+
200
+ def __call__(self, x, mask):
201
+ y = self.ln1(x)
202
+ y, _ = self.attention(y, y, y, mask=mask)
203
+ x = x + y
204
+
205
+ y = self.ln2(x)
206
+ y = self.dense(y)
207
+ return x + y
208
+
209
+
210
+ class TransformerEncoder(nn.Module):
211
+ def __init__(self, config: T5Config):
212
+ super().__init__()
213
+ self.layers = [
214
+ TransformerEncoderLayer(config) for i in range(config.num_layers)
215
+ ]
216
+ self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
217
+ self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
218
+
219
+ def __call__(self, x: mx.array):
220
+ pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
221
+ pos_bias = pos_bias.astype(x.dtype)
222
+ for layer in self.layers:
223
+ x = layer(x, mask=pos_bias)
224
+ return self.ln(x)
225
+
226
+
227
+ class T5Encoder(nn.Module):
228
+ def __init__(self, config: T5Config):
229
+ self.wte = nn.Embedding(config.vocab_size, config.d_model)
230
+ self.encoder = TransformerEncoder(config)
231
+
232
+ def sanitize(self, weights):
233
+ new_weights = {}
234
+ for k, w in weights.items():
235
+ for old, new in _SHARED_REPLACEMENT_PATTERNS:
236
+ k = k.replace(old, new)
237
+ if k.startswith("encoder."):
238
+ for old, new in _ENCODER_REPLACEMENT_PATTERNS:
239
+ k = k.replace(old, new)
240
+ new_weights[k] = w
241
+ return new_weights
242
+
243
+ def __call__(self, inputs: mx.array):
244
+ return self.encoder(self.wte(inputs))
@@ -0,0 +1,185 @@
1
+ # Copyright © 2024 Apple Inc.
2
+
3
+ import mlx.core as mx
4
+ import regex
5
+ from sentencepiece import SentencePieceProcessor
6
+
7
+
8
+ class CLIPTokenizer:
9
+ """A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
10
+
11
+ def __init__(self, bpe_ranks, vocab, max_length=77):
12
+ self.max_length = max_length
13
+ self.bpe_ranks = bpe_ranks
14
+ self.vocab = vocab
15
+ self.pat = regex.compile(
16
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
17
+ regex.IGNORECASE,
18
+ )
19
+
20
+ self._cache = {self.bos: self.bos, self.eos: self.eos}
21
+
22
+ @property
23
+ def bos(self):
24
+ return "<|startoftext|>"
25
+
26
+ @property
27
+ def bos_token(self):
28
+ return self.vocab[self.bos]
29
+
30
+ @property
31
+ def eos(self):
32
+ return "<|endoftext|>"
33
+
34
+ @property
35
+ def eos_token(self):
36
+ return self.vocab[self.eos]
37
+
38
+ def bpe(self, text):
39
+ if text in self._cache:
40
+ return self._cache[text]
41
+
42
+ unigrams = list(text[:-1]) + [text[-1] + "</w>"]
43
+ unique_bigrams = set(zip(unigrams, unigrams[1:]))
44
+
45
+ if not unique_bigrams:
46
+ return unigrams
47
+
48
+ # In every iteration try to merge the two most likely bigrams. If none
49
+ # was merged we are done.
50
+ #
51
+ # Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
52
+ while unique_bigrams:
53
+ bigram = min(
54
+ unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
55
+ )
56
+ if bigram not in self.bpe_ranks:
57
+ break
58
+
59
+ new_unigrams = []
60
+ skip = False
61
+ for a, b in zip(unigrams, unigrams[1:]):
62
+ if skip:
63
+ skip = False
64
+ continue
65
+
66
+ if (a, b) == bigram:
67
+ new_unigrams.append(a + b)
68
+ skip = True
69
+
70
+ else:
71
+ new_unigrams.append(a)
72
+
73
+ if not skip:
74
+ new_unigrams.append(b)
75
+
76
+ unigrams = new_unigrams
77
+ unique_bigrams = set(zip(unigrams, unigrams[1:]))
78
+
79
+ self._cache[text] = unigrams
80
+
81
+ return unigrams
82
+
83
+ def tokenize(self, text, prepend_bos=True, append_eos=True):
84
+ if isinstance(text, list):
85
+ return [self.tokenize(t, prepend_bos, append_eos) for t in text]
86
+
87
+ # Lower case cleanup and split according to self.pat. Hugging Face does
88
+ # a much more thorough job here but this should suffice for 95% of
89
+ # cases.
90
+ clean_text = regex.sub(r"\s+", " ", text.lower())
91
+ tokens = regex.findall(self.pat, clean_text)
92
+
93
+ # Split the tokens according to the byte-pair merge file
94
+ bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
95
+
96
+ # Map to token ids and return
97
+ tokens = [self.vocab[t] for t in bpe_tokens]
98
+ if prepend_bos:
99
+ tokens = [self.bos_token] + tokens
100
+ if append_eos:
101
+ tokens.append(self.eos_token)
102
+
103
+ if len(tokens) > self.max_length:
104
+ tokens = tokens[: self.max_length]
105
+ if append_eos:
106
+ tokens[-1] = self.eos_token
107
+
108
+ return tokens
109
+
110
+ def encode(self, text):
111
+ if not isinstance(text, list):
112
+ return self.encode([text])
113
+
114
+ tokens = self.tokenize(text)
115
+ length = max(len(t) for t in tokens)
116
+ for t in tokens:
117
+ t.extend([self.eos_token] * (length - len(t)))
118
+
119
+ return mx.array(tokens)
120
+
121
+
122
+ class T5Tokenizer:
123
+ def __init__(self, model_file, max_length=512):
124
+ self._tokenizer = SentencePieceProcessor(model_file)
125
+ self.max_length = max_length
126
+
127
+ @property
128
+ def pad(self):
129
+ try:
130
+ return self._tokenizer.id_to_piece(self.pad_token)
131
+ except IndexError:
132
+ return None
133
+
134
+ @property
135
+ def pad_token(self):
136
+ return self._tokenizer.pad_id()
137
+
138
+ @property
139
+ def bos(self):
140
+ try:
141
+ return self._tokenizer.id_to_piece(self.bos_token)
142
+ except IndexError:
143
+ return None
144
+
145
+ @property
146
+ def bos_token(self):
147
+ return self._tokenizer.bos_id()
148
+
149
+ @property
150
+ def eos(self):
151
+ try:
152
+ return self._tokenizer.id_to_piece(self.eos_token)
153
+ except IndexError:
154
+ return None
155
+
156
+ @property
157
+ def eos_token(self):
158
+ return self._tokenizer.eos_id()
159
+
160
+ def tokenize(self, text, prepend_bos=True, append_eos=True, pad=True):
161
+ if isinstance(text, list):
162
+ return [self.tokenize(t, prepend_bos, append_eos, pad) for t in text]
163
+
164
+ tokens = self._tokenizer.encode(text)
165
+
166
+ if prepend_bos and self.bos_token >= 0:
167
+ tokens = [self.bos_token] + tokens
168
+ if append_eos and self.eos_token >= 0:
169
+ tokens.append(self.eos_token)
170
+ if pad and len(tokens) < self.max_length and self.pad_token >= 0:
171
+ tokens += [self.pad_token] * (self.max_length - len(tokens))
172
+
173
+ return tokens
174
+
175
+ def encode(self, text, pad=True):
176
+ if not isinstance(text, list):
177
+ return self.encode([text], pad=pad)
178
+
179
+ pad_token = self.pad_token if self.pad_token >= 0 else 0
180
+ tokens = self.tokenize(text, pad=pad)
181
+ length = max(len(t) for t in tokens)
182
+ for t in tokens:
183
+ t.extend([pad_token] * (length - len(t)))
184
+
185
+ return mx.array(tokens)