tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +34 -303
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
- tests/lora/test_layers.py +6 -0
- tests/lora/utils.py +8 -0
- tests/test_envs.py +11 -32
- tests/test_utils.py +2 -1
- tpu_inference/__init__.py +3 -22
- tpu_inference/core/disagg_utils.py +8 -6
- tpu_inference/distributed/tpu_connector.py +4 -3
- tpu_inference/distributed/utils.py +2 -3
- tpu_inference/envs.py +8 -61
- tpu_inference/executors/ray_distributed_executor.py +2 -9
- tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +145 -266
- tpu_inference/layers/common/attention_interface.py +1 -7
- tpu_inference/layers/common/sharding.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +208 -170
- tpu_inference/layers/vllm/quantization/common.py +1 -6
- tpu_inference/layers/vllm/quantization/mxfp4.py +73 -138
- tpu_inference/layers/vllm/quantization/unquantized.py +64 -58
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +2 -1
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/common/model_loader.py +10 -43
- tpu_inference/models/jax/llama3.py +1 -2
- tpu_inference/models/jax/llama_eagle3.py +5 -8
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +1 -2
- tpu_inference/models/jax/qwen2_5_vl.py +48 -163
- tpu_inference/models/jax/qwen3.py +1 -2
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
- tpu_inference/models/jax/utils/weight_utils.py +143 -198
- tpu_inference/models/vllm/vllm_model_wrapper.py +8 -14
- tpu_inference/platforms/tpu_platform.py +31 -37
- tpu_inference/runner/compilation_manager.py +58 -141
- tpu_inference/runner/kv_cache.py +1 -1
- tpu_inference/runner/kv_cache_manager.py +18 -17
- tpu_inference/runner/persistent_batch_manager.py +2 -40
- tpu_inference/runner/structured_decoding_manager.py +3 -2
- tpu_inference/runner/tpu_runner.py +147 -271
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +21 -71
- tpu_inference/tpu_info.py +3 -4
- tpu_inference/utils.py +13 -36
- tpu_inference/worker/tpu_worker.py +25 -162
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +3 -4
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +55 -50
- tpu_inference/models/jax/llama_guard_4.py +0 -361
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import jax
|
|
2
2
|
import jax.numpy as jnp
|
|
3
3
|
import numpy as np
|
|
4
|
-
from absl.testing import absltest
|
|
4
|
+
from absl.testing import absltest
|
|
5
5
|
from jax._src import test_util as jtu
|
|
6
6
|
from jax.sharding import Mesh
|
|
7
7
|
|
|
@@ -10,15 +10,6 @@ from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe, ref_moe
|
|
|
10
10
|
jax.config.parse_flags_with_absl()
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
def cdiv(a, b):
|
|
14
|
-
assert b != 0
|
|
15
|
-
return (a + b - 1) // b
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def align_to(x, a):
|
|
19
|
-
return cdiv(x, a) * a
|
|
20
|
-
|
|
21
|
-
|
|
22
13
|
def gen_moe_inputs(
|
|
23
14
|
dtype,
|
|
24
15
|
top_k,
|
|
@@ -28,14 +19,11 @@ def gen_moe_inputs(
|
|
|
28
19
|
num_tokens,
|
|
29
20
|
*,
|
|
30
21
|
seed=1234,
|
|
31
|
-
has_bias=False,
|
|
32
22
|
):
|
|
33
23
|
key = jax.random.key(seed)
|
|
34
|
-
k0, k1, k2,
|
|
35
|
-
|
|
24
|
+
k0, k1, k2, k4, k5 = jax.random.split(key, 5)
|
|
36
25
|
a = jax.random.normal(k0, (num_tokens, hidden_size),
|
|
37
26
|
dtype=jnp.float32).astype(dtype) / 10
|
|
38
|
-
|
|
39
27
|
w1 = (jax.random.normal(
|
|
40
28
|
k1,
|
|
41
29
|
(num_experts, 2, hidden_size, intermediate_size),
|
|
@@ -43,54 +31,21 @@ def gen_moe_inputs(
|
|
|
43
31
|
) / 10).astype(dtype)
|
|
44
32
|
w2 = (jax.random.normal(k2, (num_experts, intermediate_size, hidden_size),
|
|
45
33
|
dtype=jnp.float32) / 10).astype(dtype)
|
|
46
|
-
|
|
47
|
-
if has_bias:
|
|
48
|
-
b1 = (jax.random.normal(k3, (num_experts, 2, intermediate_size),
|
|
49
|
-
dtype=jnp.float32) / 10).astype(dtype)
|
|
50
|
-
b2 = (jax.random.normal(k4, (num_experts, hidden_size),
|
|
51
|
-
dtype=jnp.float32) / 10).astype(dtype)
|
|
52
|
-
else:
|
|
53
|
-
b1 = b2 = None
|
|
54
|
-
|
|
55
34
|
gating_output = (
|
|
56
|
-
jax.random.normal(
|
|
35
|
+
jax.random.normal(k4, (num_tokens, num_experts), dtype=jnp.float32) +
|
|
57
36
|
jnp.arange(num_tokens * num_experts, dtype=jnp.float32).reshape(
|
|
58
37
|
num_tokens, num_experts) / 100)
|
|
59
|
-
|
|
60
38
|
# To generate unique top-k!
|
|
61
|
-
top_k_indices = jax.random.randint(
|
|
39
|
+
top_k_indices = jax.random.randint(k5, (num_tokens, top_k),
|
|
62
40
|
minval=0,
|
|
63
41
|
maxval=num_experts - 1,
|
|
64
42
|
dtype=jnp.int32)
|
|
65
|
-
|
|
66
43
|
one_hot = (jnp.sum(
|
|
67
44
|
jax.nn.one_hot(top_k_indices, num_experts, dtype=jnp.float32),
|
|
68
45
|
axis=1,
|
|
69
|
-
) *
|
|
70
|
-
|
|
46
|
+
) * 10)
|
|
71
47
|
gating_output = (gating_output + one_hot).astype(dtype)
|
|
72
|
-
|
|
73
|
-
return a, w1, w2, b1, b2, gating_output
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
def sub_channel_quantize(x, quant_dtype, wsz=256):
|
|
77
|
-
"""Quantizes x with sub-channel quantization on the 2nd minor."""
|
|
78
|
-
if jnp.issubdtype(quant_dtype, jnp.floating):
|
|
79
|
-
dtype_info = jnp.finfo(quant_dtype)
|
|
80
|
-
else:
|
|
81
|
-
dtype_info = jnp.iinfo(quant_dtype)
|
|
82
|
-
dtype_max = float(dtype_info.max)
|
|
83
|
-
w_lst, scale_lst = [], []
|
|
84
|
-
assert len(x.shape) >= 2
|
|
85
|
-
assert x.shape[-2] % wsz == 0
|
|
86
|
-
for i in range(0, x.shape[-2], wsz):
|
|
87
|
-
y = x[..., i:i + wsz, :]
|
|
88
|
-
abs_max = jnp.abs(y).max(axis=-2, keepdims=True)
|
|
89
|
-
scale = (abs_max / dtype_max).astype(jnp.float32)
|
|
90
|
-
w = (y / scale).astype(quant_dtype)
|
|
91
|
-
w_lst.append(w)
|
|
92
|
-
scale_lst.append(scale)
|
|
93
|
-
return jnp.concat(w_lst, axis=-2), jnp.concat(scale_lst, axis=-2)
|
|
48
|
+
return a, w1, w2, gating_output
|
|
94
49
|
|
|
95
50
|
|
|
96
51
|
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
|
@@ -108,266 +63,42 @@ class MoEKernelTest(jtu.JaxTestCase):
|
|
|
108
63
|
self.mesh = Mesh(np.array(self.mesh_devices).reshape(1, -1),
|
|
109
64
|
axis_names=("data", "model"))
|
|
110
65
|
|
|
111
|
-
def
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
renormalize_topk_logits,
|
|
121
|
-
bt,
|
|
122
|
-
bf,
|
|
123
|
-
bd1,
|
|
124
|
-
bd2,
|
|
125
|
-
btc,
|
|
126
|
-
bfc,
|
|
127
|
-
bd1c,
|
|
128
|
-
bd2c,
|
|
129
|
-
act_fn="silu",
|
|
130
|
-
w_dtype=None,
|
|
131
|
-
subc_quant_wsz=None,
|
|
132
|
-
has_bias=False,
|
|
133
|
-
atol=2e-1,
|
|
134
|
-
rtol=2e-1,
|
|
135
|
-
):
|
|
136
|
-
a, w1, w2, b1, b2, gating_output = gen_moe_inputs(
|
|
66
|
+
def test_basic(self):
|
|
67
|
+
dtype = jnp.bfloat16
|
|
68
|
+
top_k = 2
|
|
69
|
+
num_experts = 16
|
|
70
|
+
hidden_size = 256
|
|
71
|
+
intermediate_size = 256
|
|
72
|
+
num_tokens = 8 * 2
|
|
73
|
+
|
|
74
|
+
a, w1, w2, gating_output = gen_moe_inputs(
|
|
137
75
|
dtype,
|
|
138
76
|
top_k,
|
|
139
77
|
num_experts,
|
|
140
78
|
hidden_size,
|
|
141
79
|
intermediate_size,
|
|
142
80
|
num_tokens,
|
|
143
|
-
seed=seed,
|
|
144
|
-
has_bias=has_bias,
|
|
145
81
|
)
|
|
146
|
-
w1_scale = None
|
|
147
|
-
w2_scale = None
|
|
148
|
-
if w_dtype is not None:
|
|
149
|
-
if subc_quant_wsz is None:
|
|
150
|
-
subc_quant_wsz = 256
|
|
151
|
-
w1, w1_scale = sub_channel_quantize(w1, w_dtype, subc_quant_wsz)
|
|
152
|
-
w2, w2_scale = sub_channel_quantize(w2, w_dtype, subc_quant_wsz)
|
|
153
82
|
|
|
154
|
-
actual =
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
bfc=bfc,
|
|
174
|
-
bd1c=bd1c,
|
|
175
|
-
bd2c=bd2c,
|
|
176
|
-
)
|
|
177
|
-
expected = ref_moe(
|
|
178
|
-
a,
|
|
179
|
-
w1,
|
|
180
|
-
w2,
|
|
181
|
-
gating_output,
|
|
182
|
-
top_k,
|
|
183
|
-
b1=b1,
|
|
184
|
-
b2=b2,
|
|
185
|
-
renormalize_topk_logits=renormalize_topk_logits,
|
|
186
|
-
activation=act_fn,
|
|
187
|
-
subc_quant_wsz=subc_quant_wsz,
|
|
188
|
-
w1_scale=w1_scale,
|
|
189
|
-
w2_scale=w2_scale,
|
|
190
|
-
)
|
|
191
|
-
self.assertAllClose(actual, expected, atol=atol, rtol=rtol)
|
|
192
|
-
|
|
193
|
-
@parameterized.product(renormalize_topk_logits=[True, False], )
|
|
194
|
-
def test_basic(self, renormalize_topk_logits):
|
|
195
|
-
dtype = jnp.bfloat16
|
|
196
|
-
top_k = 8
|
|
197
|
-
num_experts = 128
|
|
198
|
-
hidden_size = 1024
|
|
199
|
-
intermediate_size = 1024
|
|
200
|
-
num_tokens = 8 * 32
|
|
201
|
-
self._test_moe(
|
|
202
|
-
dtype=dtype,
|
|
203
|
-
top_k=top_k,
|
|
204
|
-
num_experts=num_experts,
|
|
205
|
-
hidden_size=hidden_size,
|
|
206
|
-
intermediate_size=intermediate_size,
|
|
207
|
-
num_tokens=num_tokens,
|
|
208
|
-
seed=1234,
|
|
209
|
-
renormalize_topk_logits=renormalize_topk_logits,
|
|
210
|
-
bt=32,
|
|
211
|
-
bf=1024,
|
|
212
|
-
bd1=1024,
|
|
213
|
-
bd2=1024,
|
|
214
|
-
btc=32,
|
|
215
|
-
bfc=256,
|
|
216
|
-
bd1c=256,
|
|
217
|
-
bd2c=256,
|
|
218
|
-
)
|
|
219
|
-
|
|
220
|
-
@parameterized.product(act_fn=["silu", "gelu", "swigluoai"], )
|
|
221
|
-
def test_activation(self, act_fn):
|
|
222
|
-
dtype = jnp.bfloat16
|
|
223
|
-
top_k = 8
|
|
224
|
-
num_experts = 128
|
|
225
|
-
hidden_size = 1024
|
|
226
|
-
intermediate_size = 1024
|
|
227
|
-
num_tokens = 8 * 32
|
|
228
|
-
self._test_moe(
|
|
229
|
-
dtype=dtype,
|
|
230
|
-
top_k=top_k,
|
|
231
|
-
num_experts=num_experts,
|
|
232
|
-
hidden_size=hidden_size,
|
|
233
|
-
intermediate_size=intermediate_size,
|
|
234
|
-
num_tokens=num_tokens,
|
|
235
|
-
seed=1234,
|
|
236
|
-
renormalize_topk_logits=True,
|
|
237
|
-
act_fn=act_fn,
|
|
238
|
-
bt=32,
|
|
239
|
-
bf=512,
|
|
240
|
-
bd1=512,
|
|
241
|
-
bd2=512,
|
|
242
|
-
btc=32,
|
|
243
|
-
bfc=256,
|
|
244
|
-
bd1c=256,
|
|
245
|
-
bd2c=256,
|
|
246
|
-
)
|
|
247
|
-
|
|
248
|
-
def test_benchmark_qwen_235(self):
|
|
249
|
-
num_experts = 128
|
|
250
|
-
top_k = 8
|
|
251
|
-
hidden_size = 4096
|
|
252
|
-
intermediate_size = 1536
|
|
253
|
-
dtype = jnp.bfloat16
|
|
254
|
-
num_tokens = 8 * 64
|
|
255
|
-
seed = 54321
|
|
256
|
-
renormalize_topk_logits = True
|
|
257
|
-
self._test_moe(
|
|
258
|
-
dtype=dtype,
|
|
259
|
-
top_k=top_k,
|
|
260
|
-
num_experts=num_experts,
|
|
261
|
-
hidden_size=hidden_size,
|
|
262
|
-
intermediate_size=intermediate_size,
|
|
263
|
-
num_tokens=num_tokens,
|
|
264
|
-
seed=seed,
|
|
265
|
-
renormalize_topk_logits=renormalize_topk_logits,
|
|
266
|
-
bt=64,
|
|
267
|
-
bf=768,
|
|
268
|
-
bd1=2048,
|
|
269
|
-
bd2=2048,
|
|
270
|
-
btc=64,
|
|
271
|
-
bfc=768,
|
|
272
|
-
bd1c=2048,
|
|
273
|
-
bd2c=2048,
|
|
274
|
-
act_fn="silu",
|
|
275
|
-
atol=5e-2,
|
|
276
|
-
rtol=5e-2,
|
|
277
|
-
)
|
|
278
|
-
|
|
279
|
-
def test_benchmark_qwen_30b_a3b(self):
|
|
280
|
-
num_experts = 128
|
|
281
|
-
top_k = 8
|
|
282
|
-
hidden_size = 2048
|
|
283
|
-
intermediate_size = 768
|
|
284
|
-
dtype = jnp.bfloat16
|
|
285
|
-
num_tokens = 512
|
|
286
|
-
seed = 54321
|
|
287
|
-
renormalize_topk_logits = True
|
|
288
|
-
self._test_moe(
|
|
289
|
-
dtype=dtype,
|
|
290
|
-
top_k=top_k,
|
|
291
|
-
num_experts=num_experts,
|
|
292
|
-
hidden_size=hidden_size,
|
|
293
|
-
intermediate_size=intermediate_size,
|
|
294
|
-
num_tokens=num_tokens,
|
|
295
|
-
seed=seed,
|
|
296
|
-
renormalize_topk_logits=renormalize_topk_logits,
|
|
297
|
-
bt=16,
|
|
298
|
-
bf=384,
|
|
299
|
-
bd1=512,
|
|
300
|
-
bd2=512,
|
|
301
|
-
btc=16,
|
|
302
|
-
bfc=384,
|
|
303
|
-
bd1c=256,
|
|
304
|
-
bd2c=256,
|
|
305
|
-
act_fn="silu",
|
|
306
|
-
atol=5e-2,
|
|
307
|
-
rtol=5e-2,
|
|
308
|
-
)
|
|
309
|
-
|
|
310
|
-
@parameterized.product(
|
|
311
|
-
w_dtype=[jnp.int8, jnp.float8_e5m2, jnp.float4_e2m1fn], )
|
|
312
|
-
def test_sub_channel_quantization(self, w_dtype):
|
|
313
|
-
if w_dtype in (
|
|
314
|
-
jnp.float8_e5m2,
|
|
315
|
-
jnp.float4_e2m1fn,
|
|
316
|
-
) and not jtu.is_device_tpu_at_least(version=7):
|
|
317
|
-
self.skipTest("Expect TPUv7+")
|
|
318
|
-
dtype = jnp.bfloat16
|
|
319
|
-
top_k = 8
|
|
320
|
-
num_experts = 128
|
|
321
|
-
hidden_size = 1024
|
|
322
|
-
intermediate_size = 1024
|
|
323
|
-
num_tokens = 8 * 32
|
|
324
|
-
self._test_moe(
|
|
325
|
-
dtype=dtype,
|
|
326
|
-
top_k=top_k,
|
|
327
|
-
num_experts=num_experts,
|
|
328
|
-
hidden_size=hidden_size,
|
|
329
|
-
intermediate_size=intermediate_size,
|
|
330
|
-
num_tokens=num_tokens,
|
|
331
|
-
seed=1234,
|
|
332
|
-
renormalize_topk_logits=False,
|
|
333
|
-
w_dtype=w_dtype,
|
|
334
|
-
subc_quant_wsz=256,
|
|
335
|
-
bt=32,
|
|
336
|
-
bf=1024,
|
|
337
|
-
bd1=1024,
|
|
338
|
-
bd2=1024,
|
|
339
|
-
btc=32,
|
|
340
|
-
bfc=256,
|
|
341
|
-
bd1c=256,
|
|
342
|
-
bd2c=256,
|
|
343
|
-
)
|
|
344
|
-
|
|
345
|
-
def test_bias(self):
|
|
346
|
-
dtype = jnp.bfloat16
|
|
347
|
-
top_k = 8
|
|
348
|
-
num_experts = 128
|
|
349
|
-
hidden_size = 1024
|
|
350
|
-
intermediate_size = 1024
|
|
351
|
-
num_tokens = 8 * 32
|
|
352
|
-
self._test_moe(
|
|
353
|
-
dtype=dtype,
|
|
354
|
-
top_k=top_k,
|
|
355
|
-
num_experts=num_experts,
|
|
356
|
-
hidden_size=hidden_size,
|
|
357
|
-
intermediate_size=intermediate_size,
|
|
358
|
-
num_tokens=num_tokens,
|
|
359
|
-
seed=1234,
|
|
360
|
-
renormalize_topk_logits=False,
|
|
361
|
-
has_bias=True,
|
|
362
|
-
bt=32,
|
|
363
|
-
bf=512,
|
|
364
|
-
bd1=512,
|
|
365
|
-
bd2=512,
|
|
366
|
-
btc=32,
|
|
367
|
-
bfc=256,
|
|
368
|
-
bd1c=256,
|
|
369
|
-
bd2c=256,
|
|
370
|
-
)
|
|
83
|
+
actual = jax.block_until_ready(
|
|
84
|
+
fused_ep_moe(
|
|
85
|
+
mesh=self.mesh,
|
|
86
|
+
tokens=a,
|
|
87
|
+
w1=w1,
|
|
88
|
+
w2=w2,
|
|
89
|
+
gating_output=gating_output,
|
|
90
|
+
top_k=top_k,
|
|
91
|
+
bt=32,
|
|
92
|
+
bf=512,
|
|
93
|
+
bd1=512,
|
|
94
|
+
bd2=512,
|
|
95
|
+
btc=32,
|
|
96
|
+
bfc=256,
|
|
97
|
+
bd1c=256,
|
|
98
|
+
bd2c=256,
|
|
99
|
+
))
|
|
100
|
+
expected = ref_moe(a, w1, w2, gating_output, top_k)
|
|
101
|
+
self.assertAllClose(expected, actual, atol=2e-2, rtol=2e-2)
|
|
371
102
|
|
|
372
103
|
|
|
373
104
|
if __name__ == "__main__":
|
|
@@ -99,7 +99,7 @@ class RaggedPagedAttentionHeadDim64KernelTest(jtu.JaxTestCase):
|
|
|
99
99
|
(0, 0),
|
|
100
100
|
(0, 0),
|
|
101
101
|
),
|
|
102
|
-
constant_values=
|
|
102
|
+
constant_values=jnp.nan,
|
|
103
103
|
).reshape(
|
|
104
104
|
-1,
|
|
105
105
|
page_size,
|
|
@@ -122,7 +122,7 @@ class RaggedPagedAttentionHeadDim64KernelTest(jtu.JaxTestCase):
|
|
|
122
122
|
kv_cache,
|
|
123
123
|
((0, num_pages - kv_cache.shape[0]), (0, 0), (0, 0), (0, 0),
|
|
124
124
|
(0, 0)),
|
|
125
|
-
constant_values=
|
|
125
|
+
constant_values=jnp.nan,
|
|
126
126
|
)
|
|
127
127
|
page_indices = jnp.stack(page_indices_list, axis=0)
|
|
128
128
|
page_indices = jnp.pad(
|
tests/lora/test_layers.py
CHANGED
|
@@ -91,6 +91,7 @@ def populate_loras(
|
|
|
91
91
|
index_to_id: list[Optional[int]],
|
|
92
92
|
lora_layer: BaseLayerWithLoRA,
|
|
93
93
|
baselayer_weights: torch.Tensor,
|
|
94
|
+
generate_embeddings_tensor: int = 0,
|
|
94
95
|
repeats: int = 1,
|
|
95
96
|
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
|
|
96
97
|
"""This method populates the lora weights (lora_a and lora_b) in the lora layers (BaseLayerWithLoRA).
|
|
@@ -102,6 +103,8 @@ def populate_loras(
|
|
|
102
103
|
lora_layer: the LoRAlayer to populate.
|
|
103
104
|
baselayer_weights: the PyTorch tensor containing the layer's
|
|
104
105
|
weights.
|
|
106
|
+
generate_embeddings_tensor: whether to generate an
|
|
107
|
+
embeddings tensor for each LoRA.
|
|
105
108
|
repeats: must only be set for column parallel packed
|
|
106
109
|
layers. Indicates the number of loras to compose
|
|
107
110
|
together to create a single lora layer.
|
|
@@ -128,6 +131,7 @@ def populate_loras(
|
|
|
128
131
|
baselayer_weights.device).init_random_lora(
|
|
129
132
|
module_name=f"fake_{i}",
|
|
130
133
|
weight=baselayer_weights,
|
|
134
|
+
generate_embeddings_tensor=generate_embeddings_tensor,
|
|
131
135
|
)
|
|
132
136
|
sublora.lora_b = sublora.lora_b[(sublora_len *
|
|
133
137
|
i):(sublora_len * (i + 1)), :]
|
|
@@ -143,6 +147,7 @@ def populate_loras(
|
|
|
143
147
|
slot_idx,
|
|
144
148
|
lora_a=lora.lora_a,
|
|
145
149
|
lora_b=lora.lora_b,
|
|
150
|
+
embeddings_tensor=lora.embeddings_tensor,
|
|
146
151
|
)
|
|
147
152
|
|
|
148
153
|
lora_dict[lora_id] = lora
|
|
@@ -541,6 +546,7 @@ def _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
|
|
|
541
546
|
index_to_id,
|
|
542
547
|
lora_config.max_loras,
|
|
543
548
|
vocab_size=512,
|
|
549
|
+
extra_vocab_size=lora_config.lora_extra_vocab_size,
|
|
544
550
|
)
|
|
545
551
|
assert jax_view(punica_wrapper._lora_indices_per_batch).platform(
|
|
546
552
|
) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
|
tests/lora/utils.py
CHANGED
|
@@ -24,6 +24,7 @@ class DummyLoRAManager:
|
|
|
24
24
|
module_name: str,
|
|
25
25
|
weight: torch.Tensor,
|
|
26
26
|
rank: int = 8,
|
|
27
|
+
generate_embeddings_tensor: int = 0,
|
|
27
28
|
):
|
|
28
29
|
lora = LoRALayerWeights(
|
|
29
30
|
module_name,
|
|
@@ -36,6 +37,13 @@ class DummyLoRAManager:
|
|
|
36
37
|
dtype=weight.dtype,
|
|
37
38
|
device=self._device),
|
|
38
39
|
)
|
|
40
|
+
if generate_embeddings_tensor:
|
|
41
|
+
lora.embeddings_tensor = torch.rand(
|
|
42
|
+
5,
|
|
43
|
+
generate_embeddings_tensor,
|
|
44
|
+
dtype=weight.dtype,
|
|
45
|
+
device=self._device,
|
|
46
|
+
)
|
|
39
47
|
self.set_module_lora(module_name, lora)
|
|
40
48
|
|
|
41
49
|
return lora
|
tests/test_envs.py
CHANGED
|
@@ -56,12 +56,6 @@ def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch):
|
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
59
|
-
# Ensure clean environment for boolean vars by setting to default "0"
|
|
60
|
-
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
|
|
61
|
-
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
|
|
62
|
-
monkeypatch.setenv("NEW_MODEL_DESIGN", "0")
|
|
63
|
-
monkeypatch.setenv("USE_MOE_EP_KERNEL", "0")
|
|
64
|
-
|
|
65
59
|
# Test SKIP_JAX_PRECOMPILE (default False)
|
|
66
60
|
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
67
61
|
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "1")
|
|
@@ -69,13 +63,6 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
69
63
|
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
|
|
70
64
|
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
71
65
|
|
|
72
|
-
# Test VLLM_XLA_CHECK_RECOMPILATION (default False)
|
|
73
|
-
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
|
|
74
|
-
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "1")
|
|
75
|
-
assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
|
|
76
|
-
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
|
|
77
|
-
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
|
|
78
|
-
|
|
79
66
|
# Test NEW_MODEL_DESIGN (default False)
|
|
80
67
|
assert envs.NEW_MODEL_DESIGN is False
|
|
81
68
|
monkeypatch.setenv("NEW_MODEL_DESIGN", "1")
|
|
@@ -88,32 +75,20 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
88
75
|
|
|
89
76
|
|
|
90
77
|
def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
91
|
-
# Ensure clean environment for integer vars by setting to defaults
|
|
92
|
-
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "1")
|
|
93
|
-
monkeypatch.setenv("NUM_SLICES", "1")
|
|
94
|
-
|
|
95
78
|
assert envs.PYTHON_TRACER_LEVEL == 1
|
|
96
79
|
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "3")
|
|
97
80
|
assert envs.PYTHON_TRACER_LEVEL == 3
|
|
98
81
|
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "0")
|
|
99
82
|
assert envs.PYTHON_TRACER_LEVEL == 0
|
|
100
83
|
|
|
101
|
-
# Test NUM_SLICES (default 1)
|
|
102
|
-
assert envs.NUM_SLICES == 1
|
|
103
|
-
monkeypatch.setenv("NUM_SLICES", "2")
|
|
104
|
-
assert envs.NUM_SLICES == 2
|
|
105
|
-
monkeypatch.setenv("NUM_SLICES", "4")
|
|
106
|
-
assert envs.NUM_SLICES == 4
|
|
107
84
|
|
|
85
|
+
def test_lowercase_conversion(monkeypatch: pytest.MonkeyPatch):
|
|
86
|
+
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "GRPC")
|
|
87
|
+
assert envs.TPU_MULTIHOST_BACKEND == "grpc"
|
|
108
88
|
|
|
109
|
-
|
|
110
|
-
# Test case sensitive choices
|
|
111
|
-
monkeypatch.setenv("MODEL_IMPL_TYPE", "flax_nnx")
|
|
89
|
+
monkeypatch.setenv("MODEL_IMPL_TYPE", "FLAX_NNX")
|
|
112
90
|
assert envs.MODEL_IMPL_TYPE == "flax_nnx"
|
|
113
91
|
|
|
114
|
-
monkeypatch.setenv("MODEL_IMPL_TYPE", "vllm")
|
|
115
|
-
assert envs.MODEL_IMPL_TYPE == "vllm"
|
|
116
|
-
|
|
117
92
|
|
|
118
93
|
def test_string_env_vars_defaults(monkeypatch: pytest.MonkeyPatch):
|
|
119
94
|
monkeypatch.delenv("JAX_PLATFORMS", raising=False)
|
|
@@ -142,6 +117,8 @@ def test_ray_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
142
117
|
assert envs.RAY_USAGE_STATS_ENABLED == "1"
|
|
143
118
|
|
|
144
119
|
assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "shm"
|
|
120
|
+
monkeypatch.setenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "nccl")
|
|
121
|
+
assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"
|
|
145
122
|
|
|
146
123
|
|
|
147
124
|
def test_invalid_attribute_raises_error():
|
|
@@ -157,7 +134,6 @@ def test_dir_returns_all_env_vars():
|
|
|
157
134
|
assert "JAX_PLATFORMS" in env_vars
|
|
158
135
|
assert "TPU_NAME" in env_vars
|
|
159
136
|
assert "SKIP_JAX_PRECOMPILE" in env_vars
|
|
160
|
-
assert "VLLM_XLA_CHECK_RECOMPILATION" in env_vars
|
|
161
137
|
assert "MODEL_IMPL_TYPE" in env_vars
|
|
162
138
|
|
|
163
139
|
|
|
@@ -165,8 +141,11 @@ def test_tpu_multihost_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
165
141
|
monkeypatch.setenv("TPU_WORKER_ID", "0")
|
|
166
142
|
assert envs.TPU_WORKER_ID == "0"
|
|
167
143
|
|
|
168
|
-
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "
|
|
169
|
-
assert envs.TPU_MULTIHOST_BACKEND == "
|
|
144
|
+
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "grpc")
|
|
145
|
+
assert envs.TPU_MULTIHOST_BACKEND == "grpc"
|
|
146
|
+
|
|
147
|
+
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "xla")
|
|
148
|
+
assert envs.TPU_MULTIHOST_BACKEND == "xla"
|
|
170
149
|
|
|
171
150
|
|
|
172
151
|
def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
|
tests/test_utils.py
CHANGED
|
@@ -231,5 +231,6 @@ def test_get_jax_dtype_from_str_dtype():
|
|
|
231
231
|
assert get_jax_dtype_from_str_dtype("int8") == jnp.int8
|
|
232
232
|
assert get_jax_dtype_from_str_dtype("bfloat16") == jnp.bfloat16
|
|
233
233
|
assert get_jax_dtype_from_str_dtype("fp8") == jnp.float8_e4m3fn
|
|
234
|
-
assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.
|
|
234
|
+
assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.float8_e4m3
|
|
235
235
|
assert get_jax_dtype_from_str_dtype("fp8_e5m2") == jnp.float8_e5m2
|
|
236
|
+
assert get_jax_dtype_from_str_dtype("auto") is None
|
tpu_inference/__init__.py
CHANGED
|
@@ -1,40 +1,21 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
1
3
|
# The environment variables override should be imported before any other
|
|
2
4
|
# modules to ensure that the environment variables are set before any
|
|
3
5
|
# other modules are imported.
|
|
4
6
|
import tpu_inference.env_override # noqa: F401
|
|
5
|
-
from tpu_inference import envs
|
|
6
7
|
from tpu_inference import tpu_info as ti
|
|
7
8
|
from tpu_inference.logger import init_logger
|
|
8
9
|
|
|
9
10
|
logger = init_logger(__name__)
|
|
10
11
|
|
|
11
|
-
if "proxy" in
|
|
12
|
+
if "proxy" in os.environ.get('JAX_PLATFORMS', '').lower():
|
|
12
13
|
logger.info("Running vLLM on TPU via Pathways proxy.")
|
|
13
14
|
# Must run pathwaysutils.initialize() before any JAX operations
|
|
14
15
|
try:
|
|
15
|
-
import traceback
|
|
16
|
-
|
|
17
16
|
import pathwaysutils
|
|
18
|
-
import vllm
|
|
19
|
-
from vllm.platforms import (resolve_current_platform_cls_qualname,
|
|
20
|
-
resolve_obj_by_qualname)
|
|
21
17
|
pathwaysutils.initialize()
|
|
22
18
|
logger.info("Module pathwaysutils is imported.")
|
|
23
|
-
|
|
24
|
-
# Pathways requires eager resolution of vllm.current_platform instead of
|
|
25
|
-
# lazy resolution in the normal code path. Since this part involves
|
|
26
|
-
# global topology discovery across multiple hosts, the platform
|
|
27
|
-
# resolution must happen before other components are loaded.
|
|
28
|
-
logger.info("Eagerly resolving vLLM current_platform for Pathways.")
|
|
29
|
-
platform_cls_qualname = resolve_current_platform_cls_qualname()
|
|
30
|
-
resolved_platform_instance = resolve_obj_by_qualname(
|
|
31
|
-
platform_cls_qualname)()
|
|
32
|
-
vllm.platforms._current_platform = resolved_platform_instance
|
|
33
|
-
vllm.platforms._init_trace = "".join(traceback.format_stack())
|
|
34
|
-
logger.info(
|
|
35
|
-
f"vLLM platform resolved to: {resolved_platform_instance.__class__.__name__}"
|
|
36
|
-
)
|
|
37
|
-
|
|
38
19
|
except Exception as e:
|
|
39
20
|
logger.error(
|
|
40
21
|
f"Error occurred while importing pathwaysutils or logging TPU info: {e}"
|
|
@@ -1,15 +1,17 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
|
|
3
|
+
import os
|
|
3
4
|
from typing import Tuple
|
|
4
5
|
|
|
5
|
-
|
|
6
|
+
PREFILL_SLICES = 'PREFILL_SLICES'
|
|
7
|
+
DECODE_SLICES = 'DECODE_SLICES'
|
|
6
8
|
|
|
7
9
|
|
|
8
10
|
def is_disagg_enabled() -> bool:
|
|
9
11
|
# We triggrer our code path as long as prefill slices are set. This
|
|
10
12
|
# allows us to test interleave mode effectively with the code path
|
|
11
13
|
# for comparison purposes.
|
|
12
|
-
return
|
|
14
|
+
return PREFILL_SLICES in os.environ
|
|
13
15
|
|
|
14
16
|
|
|
15
17
|
def _parse_slices(slices_str: str) -> Tuple[int, ...]:
|
|
@@ -38,12 +40,12 @@ def _parse_slices(slices_str: str) -> Tuple[int, ...]:
|
|
|
38
40
|
|
|
39
41
|
|
|
40
42
|
def get_prefill_slices() -> Tuple[int, ...]:
|
|
41
|
-
if not
|
|
43
|
+
if PREFILL_SLICES not in os.environ:
|
|
42
44
|
return ()
|
|
43
|
-
return _parse_slices(
|
|
45
|
+
return _parse_slices(os.environ[PREFILL_SLICES])
|
|
44
46
|
|
|
45
47
|
|
|
46
48
|
def get_decode_slices() -> Tuple[int, ...]:
|
|
47
|
-
if not
|
|
49
|
+
if DECODE_SLICES not in os.environ:
|
|
48
50
|
return ()
|
|
49
|
-
return _parse_slices(
|
|
51
|
+
return _parse_slices(os.environ[DECODE_SLICES])
|