tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202512030818__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tests/test_envs.py +32 -11
- tests/test_utils.py +1 -2
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +3 -4
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +61 -8
- tpu_inference/executors/ray_distributed_executor.py +31 -11
- tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +213 -126
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +74 -25
- tpu_inference/layers/vllm/quantization/common.py +6 -1
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -62
- tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +45 -11
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +3 -6
- tpu_inference/models/jax/utils/weight_utils.py +198 -143
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -7
- tpu_inference/platforms/tpu_platform.py +28 -22
- tpu_inference/runner/compilation_manager.py +144 -59
- tpu_inference/runner/kv_cache_manager.py +17 -18
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +271 -147
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -21
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +36 -13
- tpu_inference/worker/tpu_worker.py +162 -25
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/METADATA +3 -2
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/RECORD +48 -53
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import jax
|
|
2
2
|
import jax.numpy as jnp
|
|
3
3
|
import numpy as np
|
|
4
|
-
from absl.testing import absltest
|
|
4
|
+
from absl.testing import absltest, parameterized
|
|
5
5
|
from jax._src import test_util as jtu
|
|
6
6
|
from jax.sharding import Mesh
|
|
7
7
|
|
|
@@ -10,6 +10,15 @@ from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe, ref_moe
|
|
|
10
10
|
jax.config.parse_flags_with_absl()
|
|
11
11
|
|
|
12
12
|
|
|
13
|
+
def cdiv(a, b):
|
|
14
|
+
assert b != 0
|
|
15
|
+
return (a + b - 1) // b
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def align_to(x, a):
|
|
19
|
+
return cdiv(x, a) * a
|
|
20
|
+
|
|
21
|
+
|
|
13
22
|
def gen_moe_inputs(
|
|
14
23
|
dtype,
|
|
15
24
|
top_k,
|
|
@@ -19,11 +28,14 @@ def gen_moe_inputs(
|
|
|
19
28
|
num_tokens,
|
|
20
29
|
*,
|
|
21
30
|
seed=1234,
|
|
31
|
+
has_bias=False,
|
|
22
32
|
):
|
|
23
33
|
key = jax.random.key(seed)
|
|
24
|
-
k0, k1, k2, k4, k5 = jax.random.split(key,
|
|
34
|
+
k0, k1, k2, k3, k4, k5, k6 = jax.random.split(key, 7)
|
|
35
|
+
|
|
25
36
|
a = jax.random.normal(k0, (num_tokens, hidden_size),
|
|
26
37
|
dtype=jnp.float32).astype(dtype) / 10
|
|
38
|
+
|
|
27
39
|
w1 = (jax.random.normal(
|
|
28
40
|
k1,
|
|
29
41
|
(num_experts, 2, hidden_size, intermediate_size),
|
|
@@ -31,21 +43,54 @@ def gen_moe_inputs(
|
|
|
31
43
|
) / 10).astype(dtype)
|
|
32
44
|
w2 = (jax.random.normal(k2, (num_experts, intermediate_size, hidden_size),
|
|
33
45
|
dtype=jnp.float32) / 10).astype(dtype)
|
|
46
|
+
|
|
47
|
+
if has_bias:
|
|
48
|
+
b1 = (jax.random.normal(k3, (num_experts, 2, intermediate_size),
|
|
49
|
+
dtype=jnp.float32) / 10).astype(dtype)
|
|
50
|
+
b2 = (jax.random.normal(k4, (num_experts, hidden_size),
|
|
51
|
+
dtype=jnp.float32) / 10).astype(dtype)
|
|
52
|
+
else:
|
|
53
|
+
b1 = b2 = None
|
|
54
|
+
|
|
34
55
|
gating_output = (
|
|
35
|
-
jax.random.normal(
|
|
56
|
+
jax.random.normal(k5, (num_tokens, num_experts), dtype=jnp.float32) +
|
|
36
57
|
jnp.arange(num_tokens * num_experts, dtype=jnp.float32).reshape(
|
|
37
58
|
num_tokens, num_experts) / 100)
|
|
59
|
+
|
|
38
60
|
# To generate unique top-k!
|
|
39
|
-
top_k_indices = jax.random.randint(
|
|
61
|
+
top_k_indices = jax.random.randint(k6, (num_tokens, top_k),
|
|
40
62
|
minval=0,
|
|
41
63
|
maxval=num_experts - 1,
|
|
42
64
|
dtype=jnp.int32)
|
|
65
|
+
|
|
43
66
|
one_hot = (jnp.sum(
|
|
44
67
|
jax.nn.one_hot(top_k_indices, num_experts, dtype=jnp.float32),
|
|
45
68
|
axis=1,
|
|
46
|
-
) *
|
|
69
|
+
) * 30)
|
|
70
|
+
|
|
47
71
|
gating_output = (gating_output + one_hot).astype(dtype)
|
|
48
|
-
|
|
72
|
+
|
|
73
|
+
return a, w1, w2, b1, b2, gating_output
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def sub_channel_quantize(x, quant_dtype, wsz=256):
|
|
77
|
+
"""Quantizes x with sub-channel quantization on the 2nd minor."""
|
|
78
|
+
if jnp.issubdtype(quant_dtype, jnp.floating):
|
|
79
|
+
dtype_info = jnp.finfo(quant_dtype)
|
|
80
|
+
else:
|
|
81
|
+
dtype_info = jnp.iinfo(quant_dtype)
|
|
82
|
+
dtype_max = float(dtype_info.max)
|
|
83
|
+
w_lst, scale_lst = [], []
|
|
84
|
+
assert len(x.shape) >= 2
|
|
85
|
+
assert x.shape[-2] % wsz == 0
|
|
86
|
+
for i in range(0, x.shape[-2], wsz):
|
|
87
|
+
y = x[..., i:i + wsz, :]
|
|
88
|
+
abs_max = jnp.abs(y).max(axis=-2, keepdims=True)
|
|
89
|
+
scale = (abs_max / dtype_max).astype(jnp.float32)
|
|
90
|
+
w = (y / scale).astype(quant_dtype)
|
|
91
|
+
w_lst.append(w)
|
|
92
|
+
scale_lst.append(scale)
|
|
93
|
+
return jnp.concat(w_lst, axis=-2), jnp.concat(scale_lst, axis=-2)
|
|
49
94
|
|
|
50
95
|
|
|
51
96
|
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
|
@@ -63,42 +108,266 @@ class MoEKernelTest(jtu.JaxTestCase):
|
|
|
63
108
|
self.mesh = Mesh(np.array(self.mesh_devices).reshape(1, -1),
|
|
64
109
|
axis_names=("data", "model"))
|
|
65
110
|
|
|
66
|
-
def
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
111
|
+
def _test_moe(
|
|
112
|
+
self,
|
|
113
|
+
dtype,
|
|
114
|
+
top_k,
|
|
115
|
+
num_experts,
|
|
116
|
+
hidden_size,
|
|
117
|
+
intermediate_size,
|
|
118
|
+
num_tokens,
|
|
119
|
+
seed,
|
|
120
|
+
renormalize_topk_logits,
|
|
121
|
+
bt,
|
|
122
|
+
bf,
|
|
123
|
+
bd1,
|
|
124
|
+
bd2,
|
|
125
|
+
btc,
|
|
126
|
+
bfc,
|
|
127
|
+
bd1c,
|
|
128
|
+
bd2c,
|
|
129
|
+
act_fn="silu",
|
|
130
|
+
w_dtype=None,
|
|
131
|
+
subc_quant_wsz=None,
|
|
132
|
+
has_bias=False,
|
|
133
|
+
atol=2e-1,
|
|
134
|
+
rtol=2e-1,
|
|
135
|
+
):
|
|
136
|
+
a, w1, w2, b1, b2, gating_output = gen_moe_inputs(
|
|
75
137
|
dtype,
|
|
76
138
|
top_k,
|
|
77
139
|
num_experts,
|
|
78
140
|
hidden_size,
|
|
79
141
|
intermediate_size,
|
|
80
142
|
num_tokens,
|
|
143
|
+
seed=seed,
|
|
144
|
+
has_bias=has_bias,
|
|
81
145
|
)
|
|
146
|
+
w1_scale = None
|
|
147
|
+
w2_scale = None
|
|
148
|
+
if w_dtype is not None:
|
|
149
|
+
if subc_quant_wsz is None:
|
|
150
|
+
subc_quant_wsz = 256
|
|
151
|
+
w1, w1_scale = sub_channel_quantize(w1, w_dtype, subc_quant_wsz)
|
|
152
|
+
w2, w2_scale = sub_channel_quantize(w2, w_dtype, subc_quant_wsz)
|
|
82
153
|
|
|
83
|
-
actual =
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
154
|
+
actual = fused_ep_moe(
|
|
155
|
+
mesh=self.mesh,
|
|
156
|
+
tokens=a,
|
|
157
|
+
w1=w1,
|
|
158
|
+
w2=w2,
|
|
159
|
+
gating_output=gating_output,
|
|
160
|
+
top_k=top_k,
|
|
161
|
+
renormalize_topk_logits=renormalize_topk_logits,
|
|
162
|
+
act_fn=act_fn,
|
|
163
|
+
subc_quant_wsz=subc_quant_wsz,
|
|
164
|
+
w1_scale=w1_scale,
|
|
165
|
+
w2_scale=w2_scale,
|
|
166
|
+
b1=b1,
|
|
167
|
+
b2=b2,
|
|
168
|
+
bt=bt,
|
|
169
|
+
bf=bf,
|
|
170
|
+
bd1=bd1,
|
|
171
|
+
bd2=bd2,
|
|
172
|
+
btc=btc,
|
|
173
|
+
bfc=bfc,
|
|
174
|
+
bd1c=bd1c,
|
|
175
|
+
bd2c=bd2c,
|
|
176
|
+
)
|
|
177
|
+
expected = ref_moe(
|
|
178
|
+
a,
|
|
179
|
+
w1,
|
|
180
|
+
w2,
|
|
181
|
+
gating_output,
|
|
182
|
+
top_k,
|
|
183
|
+
b1=b1,
|
|
184
|
+
b2=b2,
|
|
185
|
+
renormalize_topk_logits=renormalize_topk_logits,
|
|
186
|
+
activation=act_fn,
|
|
187
|
+
subc_quant_wsz=subc_quant_wsz,
|
|
188
|
+
w1_scale=w1_scale,
|
|
189
|
+
w2_scale=w2_scale,
|
|
190
|
+
)
|
|
191
|
+
self.assertAllClose(actual, expected, atol=atol, rtol=rtol)
|
|
192
|
+
|
|
193
|
+
@parameterized.product(renormalize_topk_logits=[True, False], )
|
|
194
|
+
def test_basic(self, renormalize_topk_logits):
|
|
195
|
+
dtype = jnp.bfloat16
|
|
196
|
+
top_k = 8
|
|
197
|
+
num_experts = 128
|
|
198
|
+
hidden_size = 1024
|
|
199
|
+
intermediate_size = 1024
|
|
200
|
+
num_tokens = 8 * 32
|
|
201
|
+
self._test_moe(
|
|
202
|
+
dtype=dtype,
|
|
203
|
+
top_k=top_k,
|
|
204
|
+
num_experts=num_experts,
|
|
205
|
+
hidden_size=hidden_size,
|
|
206
|
+
intermediate_size=intermediate_size,
|
|
207
|
+
num_tokens=num_tokens,
|
|
208
|
+
seed=1234,
|
|
209
|
+
renormalize_topk_logits=renormalize_topk_logits,
|
|
210
|
+
bt=32,
|
|
211
|
+
bf=1024,
|
|
212
|
+
bd1=1024,
|
|
213
|
+
bd2=1024,
|
|
214
|
+
btc=32,
|
|
215
|
+
bfc=256,
|
|
216
|
+
bd1c=256,
|
|
217
|
+
bd2c=256,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
@parameterized.product(act_fn=["silu", "gelu", "swigluoai"], )
|
|
221
|
+
def test_activation(self, act_fn):
|
|
222
|
+
dtype = jnp.bfloat16
|
|
223
|
+
top_k = 8
|
|
224
|
+
num_experts = 128
|
|
225
|
+
hidden_size = 1024
|
|
226
|
+
intermediate_size = 1024
|
|
227
|
+
num_tokens = 8 * 32
|
|
228
|
+
self._test_moe(
|
|
229
|
+
dtype=dtype,
|
|
230
|
+
top_k=top_k,
|
|
231
|
+
num_experts=num_experts,
|
|
232
|
+
hidden_size=hidden_size,
|
|
233
|
+
intermediate_size=intermediate_size,
|
|
234
|
+
num_tokens=num_tokens,
|
|
235
|
+
seed=1234,
|
|
236
|
+
renormalize_topk_logits=True,
|
|
237
|
+
act_fn=act_fn,
|
|
238
|
+
bt=32,
|
|
239
|
+
bf=512,
|
|
240
|
+
bd1=512,
|
|
241
|
+
bd2=512,
|
|
242
|
+
btc=32,
|
|
243
|
+
bfc=256,
|
|
244
|
+
bd1c=256,
|
|
245
|
+
bd2c=256,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
def test_benchmark_qwen_235(self):
|
|
249
|
+
num_experts = 128
|
|
250
|
+
top_k = 8
|
|
251
|
+
hidden_size = 4096
|
|
252
|
+
intermediate_size = 1536
|
|
253
|
+
dtype = jnp.bfloat16
|
|
254
|
+
num_tokens = 8 * 64
|
|
255
|
+
seed = 54321
|
|
256
|
+
renormalize_topk_logits = True
|
|
257
|
+
self._test_moe(
|
|
258
|
+
dtype=dtype,
|
|
259
|
+
top_k=top_k,
|
|
260
|
+
num_experts=num_experts,
|
|
261
|
+
hidden_size=hidden_size,
|
|
262
|
+
intermediate_size=intermediate_size,
|
|
263
|
+
num_tokens=num_tokens,
|
|
264
|
+
seed=seed,
|
|
265
|
+
renormalize_topk_logits=renormalize_topk_logits,
|
|
266
|
+
bt=64,
|
|
267
|
+
bf=768,
|
|
268
|
+
bd1=2048,
|
|
269
|
+
bd2=2048,
|
|
270
|
+
btc=64,
|
|
271
|
+
bfc=768,
|
|
272
|
+
bd1c=2048,
|
|
273
|
+
bd2c=2048,
|
|
274
|
+
act_fn="silu",
|
|
275
|
+
atol=5e-2,
|
|
276
|
+
rtol=5e-2,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
def test_benchmark_qwen_30b_a3b(self):
|
|
280
|
+
num_experts = 128
|
|
281
|
+
top_k = 8
|
|
282
|
+
hidden_size = 2048
|
|
283
|
+
intermediate_size = 768
|
|
284
|
+
dtype = jnp.bfloat16
|
|
285
|
+
num_tokens = 512
|
|
286
|
+
seed = 54321
|
|
287
|
+
renormalize_topk_logits = True
|
|
288
|
+
self._test_moe(
|
|
289
|
+
dtype=dtype,
|
|
290
|
+
top_k=top_k,
|
|
291
|
+
num_experts=num_experts,
|
|
292
|
+
hidden_size=hidden_size,
|
|
293
|
+
intermediate_size=intermediate_size,
|
|
294
|
+
num_tokens=num_tokens,
|
|
295
|
+
seed=seed,
|
|
296
|
+
renormalize_topk_logits=renormalize_topk_logits,
|
|
297
|
+
bt=16,
|
|
298
|
+
bf=384,
|
|
299
|
+
bd1=512,
|
|
300
|
+
bd2=512,
|
|
301
|
+
btc=16,
|
|
302
|
+
bfc=384,
|
|
303
|
+
bd1c=256,
|
|
304
|
+
bd2c=256,
|
|
305
|
+
act_fn="silu",
|
|
306
|
+
atol=5e-2,
|
|
307
|
+
rtol=5e-2,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
@parameterized.product(
|
|
311
|
+
w_dtype=[jnp.int8, jnp.float8_e5m2, jnp.float4_e2m1fn], )
|
|
312
|
+
def test_sub_channel_quantization(self, w_dtype):
|
|
313
|
+
if w_dtype in (
|
|
314
|
+
jnp.float8_e5m2,
|
|
315
|
+
jnp.float4_e2m1fn,
|
|
316
|
+
) and not jtu.is_device_tpu_at_least(version=7):
|
|
317
|
+
self.skipTest("Expect TPUv7+")
|
|
318
|
+
dtype = jnp.bfloat16
|
|
319
|
+
top_k = 8
|
|
320
|
+
num_experts = 128
|
|
321
|
+
hidden_size = 1024
|
|
322
|
+
intermediate_size = 1024
|
|
323
|
+
num_tokens = 8 * 32
|
|
324
|
+
self._test_moe(
|
|
325
|
+
dtype=dtype,
|
|
326
|
+
top_k=top_k,
|
|
327
|
+
num_experts=num_experts,
|
|
328
|
+
hidden_size=hidden_size,
|
|
329
|
+
intermediate_size=intermediate_size,
|
|
330
|
+
num_tokens=num_tokens,
|
|
331
|
+
seed=1234,
|
|
332
|
+
renormalize_topk_logits=False,
|
|
333
|
+
w_dtype=w_dtype,
|
|
334
|
+
subc_quant_wsz=256,
|
|
335
|
+
bt=32,
|
|
336
|
+
bf=1024,
|
|
337
|
+
bd1=1024,
|
|
338
|
+
bd2=1024,
|
|
339
|
+
btc=32,
|
|
340
|
+
bfc=256,
|
|
341
|
+
bd1c=256,
|
|
342
|
+
bd2c=256,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
def test_bias(self):
|
|
346
|
+
dtype = jnp.bfloat16
|
|
347
|
+
top_k = 8
|
|
348
|
+
num_experts = 128
|
|
349
|
+
hidden_size = 1024
|
|
350
|
+
intermediate_size = 1024
|
|
351
|
+
num_tokens = 8 * 32
|
|
352
|
+
self._test_moe(
|
|
353
|
+
dtype=dtype,
|
|
354
|
+
top_k=top_k,
|
|
355
|
+
num_experts=num_experts,
|
|
356
|
+
hidden_size=hidden_size,
|
|
357
|
+
intermediate_size=intermediate_size,
|
|
358
|
+
num_tokens=num_tokens,
|
|
359
|
+
seed=1234,
|
|
360
|
+
renormalize_topk_logits=False,
|
|
361
|
+
has_bias=True,
|
|
362
|
+
bt=32,
|
|
363
|
+
bf=512,
|
|
364
|
+
bd1=512,
|
|
365
|
+
bd2=512,
|
|
366
|
+
btc=32,
|
|
367
|
+
bfc=256,
|
|
368
|
+
bd1c=256,
|
|
369
|
+
bd2c=256,
|
|
370
|
+
)
|
|
102
371
|
|
|
103
372
|
|
|
104
373
|
if __name__ == "__main__":
|
tests/lora/test_layers.py
CHANGED
|
@@ -91,7 +91,6 @@ def populate_loras(
|
|
|
91
91
|
index_to_id: list[Optional[int]],
|
|
92
92
|
lora_layer: BaseLayerWithLoRA,
|
|
93
93
|
baselayer_weights: torch.Tensor,
|
|
94
|
-
generate_embeddings_tensor: int = 0,
|
|
95
94
|
repeats: int = 1,
|
|
96
95
|
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
|
|
97
96
|
"""This method populates the lora weights (lora_a and lora_b) in the lora layers (BaseLayerWithLoRA).
|
|
@@ -103,8 +102,6 @@ def populate_loras(
|
|
|
103
102
|
lora_layer: the LoRAlayer to populate.
|
|
104
103
|
baselayer_weights: the PyTorch tensor containing the layer's
|
|
105
104
|
weights.
|
|
106
|
-
generate_embeddings_tensor: whether to generate an
|
|
107
|
-
embeddings tensor for each LoRA.
|
|
108
105
|
repeats: must only be set for column parallel packed
|
|
109
106
|
layers. Indicates the number of loras to compose
|
|
110
107
|
together to create a single lora layer.
|
|
@@ -131,7 +128,6 @@ def populate_loras(
|
|
|
131
128
|
baselayer_weights.device).init_random_lora(
|
|
132
129
|
module_name=f"fake_{i}",
|
|
133
130
|
weight=baselayer_weights,
|
|
134
|
-
generate_embeddings_tensor=generate_embeddings_tensor,
|
|
135
131
|
)
|
|
136
132
|
sublora.lora_b = sublora.lora_b[(sublora_len *
|
|
137
133
|
i):(sublora_len * (i + 1)), :]
|
|
@@ -147,7 +143,6 @@ def populate_loras(
|
|
|
147
143
|
slot_idx,
|
|
148
144
|
lora_a=lora.lora_a,
|
|
149
145
|
lora_b=lora.lora_b,
|
|
150
|
-
embeddings_tensor=lora.embeddings_tensor,
|
|
151
146
|
)
|
|
152
147
|
|
|
153
148
|
lora_dict[lora_id] = lora
|
|
@@ -546,7 +541,6 @@ def _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
|
|
|
546
541
|
index_to_id,
|
|
547
542
|
lora_config.max_loras,
|
|
548
543
|
vocab_size=512,
|
|
549
|
-
extra_vocab_size=lora_config.lora_extra_vocab_size,
|
|
550
544
|
)
|
|
551
545
|
assert jax_view(punica_wrapper._lora_indices_per_batch).platform(
|
|
552
546
|
) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
|
tests/lora/utils.py
CHANGED
|
@@ -24,7 +24,6 @@ class DummyLoRAManager:
|
|
|
24
24
|
module_name: str,
|
|
25
25
|
weight: torch.Tensor,
|
|
26
26
|
rank: int = 8,
|
|
27
|
-
generate_embeddings_tensor: int = 0,
|
|
28
27
|
):
|
|
29
28
|
lora = LoRALayerWeights(
|
|
30
29
|
module_name,
|
|
@@ -37,13 +36,6 @@ class DummyLoRAManager:
|
|
|
37
36
|
dtype=weight.dtype,
|
|
38
37
|
device=self._device),
|
|
39
38
|
)
|
|
40
|
-
if generate_embeddings_tensor:
|
|
41
|
-
lora.embeddings_tensor = torch.rand(
|
|
42
|
-
5,
|
|
43
|
-
generate_embeddings_tensor,
|
|
44
|
-
dtype=weight.dtype,
|
|
45
|
-
device=self._device,
|
|
46
|
-
)
|
|
47
39
|
self.set_module_lora(module_name, lora)
|
|
48
40
|
|
|
49
41
|
return lora
|
tests/test_envs.py
CHANGED
|
@@ -56,6 +56,12 @@ def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch):
|
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
59
|
+
# Ensure clean environment for boolean vars by setting to default "0"
|
|
60
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
|
|
61
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
|
|
62
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "0")
|
|
63
|
+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "0")
|
|
64
|
+
|
|
59
65
|
# Test SKIP_JAX_PRECOMPILE (default False)
|
|
60
66
|
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
61
67
|
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "1")
|
|
@@ -63,6 +69,13 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
63
69
|
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
|
|
64
70
|
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
65
71
|
|
|
72
|
+
# Test VLLM_XLA_CHECK_RECOMPILATION (default False)
|
|
73
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
|
|
74
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "1")
|
|
75
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
|
|
76
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
|
|
77
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
|
|
78
|
+
|
|
66
79
|
# Test NEW_MODEL_DESIGN (default False)
|
|
67
80
|
assert envs.NEW_MODEL_DESIGN is False
|
|
68
81
|
monkeypatch.setenv("NEW_MODEL_DESIGN", "1")
|
|
@@ -75,20 +88,32 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
75
88
|
|
|
76
89
|
|
|
77
90
|
def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
91
|
+
# Ensure clean environment for integer vars by setting to defaults
|
|
92
|
+
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "1")
|
|
93
|
+
monkeypatch.setenv("NUM_SLICES", "1")
|
|
94
|
+
|
|
78
95
|
assert envs.PYTHON_TRACER_LEVEL == 1
|
|
79
96
|
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "3")
|
|
80
97
|
assert envs.PYTHON_TRACER_LEVEL == 3
|
|
81
98
|
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "0")
|
|
82
99
|
assert envs.PYTHON_TRACER_LEVEL == 0
|
|
83
100
|
|
|
101
|
+
# Test NUM_SLICES (default 1)
|
|
102
|
+
assert envs.NUM_SLICES == 1
|
|
103
|
+
monkeypatch.setenv("NUM_SLICES", "2")
|
|
104
|
+
assert envs.NUM_SLICES == 2
|
|
105
|
+
monkeypatch.setenv("NUM_SLICES", "4")
|
|
106
|
+
assert envs.NUM_SLICES == 4
|
|
84
107
|
|
|
85
|
-
def test_lowercase_conversion(monkeypatch: pytest.MonkeyPatch):
|
|
86
|
-
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "GRPC")
|
|
87
|
-
assert envs.TPU_MULTIHOST_BACKEND == "grpc"
|
|
88
108
|
|
|
89
|
-
|
|
109
|
+
def test_model_impl_type_choices(monkeypatch: pytest.MonkeyPatch):
|
|
110
|
+
# Test case sensitive choices
|
|
111
|
+
monkeypatch.setenv("MODEL_IMPL_TYPE", "flax_nnx")
|
|
90
112
|
assert envs.MODEL_IMPL_TYPE == "flax_nnx"
|
|
91
113
|
|
|
114
|
+
monkeypatch.setenv("MODEL_IMPL_TYPE", "vllm")
|
|
115
|
+
assert envs.MODEL_IMPL_TYPE == "vllm"
|
|
116
|
+
|
|
92
117
|
|
|
93
118
|
def test_string_env_vars_defaults(monkeypatch: pytest.MonkeyPatch):
|
|
94
119
|
monkeypatch.delenv("JAX_PLATFORMS", raising=False)
|
|
@@ -117,8 +142,6 @@ def test_ray_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
117
142
|
assert envs.RAY_USAGE_STATS_ENABLED == "1"
|
|
118
143
|
|
|
119
144
|
assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "shm"
|
|
120
|
-
monkeypatch.setenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "nccl")
|
|
121
|
-
assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"
|
|
122
145
|
|
|
123
146
|
|
|
124
147
|
def test_invalid_attribute_raises_error():
|
|
@@ -134,6 +157,7 @@ def test_dir_returns_all_env_vars():
|
|
|
134
157
|
assert "JAX_PLATFORMS" in env_vars
|
|
135
158
|
assert "TPU_NAME" in env_vars
|
|
136
159
|
assert "SKIP_JAX_PRECOMPILE" in env_vars
|
|
160
|
+
assert "VLLM_XLA_CHECK_RECOMPILATION" in env_vars
|
|
137
161
|
assert "MODEL_IMPL_TYPE" in env_vars
|
|
138
162
|
|
|
139
163
|
|
|
@@ -141,11 +165,8 @@ def test_tpu_multihost_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
141
165
|
monkeypatch.setenv("TPU_WORKER_ID", "0")
|
|
142
166
|
assert envs.TPU_WORKER_ID == "0"
|
|
143
167
|
|
|
144
|
-
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "
|
|
145
|
-
assert envs.TPU_MULTIHOST_BACKEND == "
|
|
146
|
-
|
|
147
|
-
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "xla")
|
|
148
|
-
assert envs.TPU_MULTIHOST_BACKEND == "xla"
|
|
168
|
+
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "ray")
|
|
169
|
+
assert envs.TPU_MULTIHOST_BACKEND == "ray"
|
|
149
170
|
|
|
150
171
|
|
|
151
172
|
def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
|
tests/test_utils.py
CHANGED
|
@@ -231,6 +231,5 @@ def test_get_jax_dtype_from_str_dtype():
|
|
|
231
231
|
assert get_jax_dtype_from_str_dtype("int8") == jnp.int8
|
|
232
232
|
assert get_jax_dtype_from_str_dtype("bfloat16") == jnp.bfloat16
|
|
233
233
|
assert get_jax_dtype_from_str_dtype("fp8") == jnp.float8_e4m3fn
|
|
234
|
-
assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.
|
|
234
|
+
assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.float8_e4m3fn
|
|
235
235
|
assert get_jax_dtype_from_str_dtype("fp8_e5m2") == jnp.float8_e5m2
|
|
236
|
-
assert get_jax_dtype_from_str_dtype("auto") is None
|
tpu_inference/__init__.py
CHANGED
|
@@ -1,21 +1,40 @@
|
|
|
1
|
-
import os
|
|
2
|
-
|
|
3
1
|
# The environment variables override should be imported before any other
|
|
4
2
|
# modules to ensure that the environment variables are set before any
|
|
5
3
|
# other modules are imported.
|
|
6
4
|
import tpu_inference.env_override # noqa: F401
|
|
5
|
+
from tpu_inference import envs
|
|
7
6
|
from tpu_inference import tpu_info as ti
|
|
8
7
|
from tpu_inference.logger import init_logger
|
|
9
8
|
|
|
10
9
|
logger = init_logger(__name__)
|
|
11
10
|
|
|
12
|
-
if "proxy" in
|
|
11
|
+
if "proxy" in envs.JAX_PLATFORMS:
|
|
13
12
|
logger.info("Running vLLM on TPU via Pathways proxy.")
|
|
14
13
|
# Must run pathwaysutils.initialize() before any JAX operations
|
|
15
14
|
try:
|
|
15
|
+
import traceback
|
|
16
|
+
|
|
16
17
|
import pathwaysutils
|
|
18
|
+
import vllm
|
|
19
|
+
from vllm.platforms import (resolve_current_platform_cls_qualname,
|
|
20
|
+
resolve_obj_by_qualname)
|
|
17
21
|
pathwaysutils.initialize()
|
|
18
22
|
logger.info("Module pathwaysutils is imported.")
|
|
23
|
+
|
|
24
|
+
# Pathways requires eager resolution of vllm.current_platform instead of
|
|
25
|
+
# lazy resolution in the normal code path. Since this part involves
|
|
26
|
+
# global topology discovery across multiple hosts, the platform
|
|
27
|
+
# resolution must happen before other components are loaded.
|
|
28
|
+
logger.info("Eagerly resolving vLLM current_platform for Pathways.")
|
|
29
|
+
platform_cls_qualname = resolve_current_platform_cls_qualname()
|
|
30
|
+
resolved_platform_instance = resolve_obj_by_qualname(
|
|
31
|
+
platform_cls_qualname)()
|
|
32
|
+
vllm.platforms._current_platform = resolved_platform_instance
|
|
33
|
+
vllm.platforms._init_trace = "".join(traceback.format_stack())
|
|
34
|
+
logger.info(
|
|
35
|
+
f"vLLM platform resolved to: {resolved_platform_instance.__class__.__name__}"
|
|
36
|
+
)
|
|
37
|
+
|
|
19
38
|
except Exception as e:
|
|
20
39
|
logger.error(
|
|
21
40
|
f"Error occurred while importing pathwaysutils or logging TPU info: {e}"
|
|
@@ -1,17 +1,15 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
|
|
3
|
-
import os
|
|
4
3
|
from typing import Tuple
|
|
5
4
|
|
|
6
|
-
|
|
7
|
-
DECODE_SLICES = 'DECODE_SLICES'
|
|
5
|
+
from tpu_inference import envs
|
|
8
6
|
|
|
9
7
|
|
|
10
8
|
def is_disagg_enabled() -> bool:
|
|
11
9
|
# We triggrer our code path as long as prefill slices are set. This
|
|
12
10
|
# allows us to test interleave mode effectively with the code path
|
|
13
11
|
# for comparison purposes.
|
|
14
|
-
return PREFILL_SLICES
|
|
12
|
+
return bool(envs.PREFILL_SLICES)
|
|
15
13
|
|
|
16
14
|
|
|
17
15
|
def _parse_slices(slices_str: str) -> Tuple[int, ...]:
|
|
@@ -40,12 +38,12 @@ def _parse_slices(slices_str: str) -> Tuple[int, ...]:
|
|
|
40
38
|
|
|
41
39
|
|
|
42
40
|
def get_prefill_slices() -> Tuple[int, ...]:
|
|
43
|
-
if
|
|
41
|
+
if not envs.PREFILL_SLICES:
|
|
44
42
|
return ()
|
|
45
|
-
return _parse_slices(
|
|
43
|
+
return _parse_slices(envs.PREFILL_SLICES)
|
|
46
44
|
|
|
47
45
|
|
|
48
46
|
def get_decode_slices() -> Tuple[int, ...]:
|
|
49
|
-
if
|
|
47
|
+
if not envs.DECODE_SLICES:
|
|
50
48
|
return ()
|
|
51
|
-
return _parse_slices(
|
|
49
|
+
return _parse_slices(envs.DECODE_SLICES)
|
|
@@ -60,7 +60,6 @@ D workflow:
|
|
|
60
60
|
|
|
61
61
|
import copy
|
|
62
62
|
import functools
|
|
63
|
-
import os
|
|
64
63
|
import threading
|
|
65
64
|
import time
|
|
66
65
|
from concurrent.futures import Future, ThreadPoolExecutor
|
|
@@ -86,6 +85,7 @@ if TYPE_CHECKING:
|
|
|
86
85
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
|
87
86
|
from vllm.v1.request import Request
|
|
88
87
|
|
|
88
|
+
from tpu_inference import envs
|
|
89
89
|
from tpu_inference.distributed.utils import (get_host_ip, get_kv_ips,
|
|
90
90
|
get_kv_ports,
|
|
91
91
|
get_kv_transfer_port, get_node_id,
|
|
@@ -441,8 +441,7 @@ class TPUConnectorWorker:
|
|
|
441
441
|
|
|
442
442
|
self.runner: TPUModelRunner = None
|
|
443
443
|
self.mesh: Mesh = None
|
|
444
|
-
self.multi_host =
|
|
445
|
-
"").lower() == "ray"
|
|
444
|
+
self.multi_host = envs.TPU_MULTIHOST_BACKEND == "ray"
|
|
446
445
|
# NOTE(xiang): This can not be the worker rank set in RayDistributedExecutor.
|
|
447
446
|
# The worker rank is assigned with vLLM's sorting logic, which does not work
|
|
448
447
|
# for TPU host topology.
|
|
@@ -458,7 +457,6 @@ class TPUConnectorWorker:
|
|
|
458
457
|
self.side_channel_port = get_side_channel_port()
|
|
459
458
|
|
|
460
459
|
self.kv_transfer_server = None
|
|
461
|
-
self._maybe_start_p2p_server()
|
|
462
460
|
self.zmq_cxt = zmq.Context()
|
|
463
461
|
if self.is_producer:
|
|
464
462
|
ready_event = threading.Event()
|
|
@@ -500,6 +498,7 @@ class TPUConnectorWorker:
|
|
|
500
498
|
self.shape = list(kv_layer.shape)
|
|
501
499
|
self.dtype = kv_layer.dtype
|
|
502
500
|
self.sharding = kv_layer.sharding
|
|
501
|
+
self._maybe_start_p2p_server()
|
|
503
502
|
|
|
504
503
|
def _maybe_start_p2p_server(self):
|
|
505
504
|
if self.kv_transfer_server is not None:
|