tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202511270815__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +2 -3
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +1 -1
- tpu_inference/executors/ray_distributed_executor.py +27 -11
- tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +141 -107
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +2 -1
- tpu_inference/layers/vllm/fused_moe.py +74 -25
- tpu_inference/layers/vllm/quantization/common.py +6 -1
- tpu_inference/layers/vllm/quantization/mxfp4.py +135 -61
- tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +43 -11
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/weight_utils.py +198 -143
- tpu_inference/models/vllm/vllm_model_wrapper.py +13 -5
- tpu_inference/platforms/tpu_platform.py +15 -2
- tpu_inference/runner/compilation_manager.py +58 -33
- tpu_inference/runner/kv_cache_manager.py +9 -3
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +203 -102
- tpu_inference/spec_decode/jax/eagle3.py +19 -2
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +5 -4
- tpu_inference/worker/tpu_worker.py +160 -23
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/METADATA +3 -2
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/RECORD +43 -48
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import jax
|
|
2
2
|
import jax.numpy as jnp
|
|
3
3
|
import numpy as np
|
|
4
|
-
from absl.testing import absltest
|
|
4
|
+
from absl.testing import absltest, parameterized
|
|
5
5
|
from jax._src import test_util as jtu
|
|
6
6
|
from jax.sharding import Mesh
|
|
7
7
|
|
|
@@ -10,6 +10,15 @@ from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe, ref_moe
|
|
|
10
10
|
jax.config.parse_flags_with_absl()
|
|
11
11
|
|
|
12
12
|
|
|
13
|
+
def cdiv(a, b):
|
|
14
|
+
assert b != 0
|
|
15
|
+
return (a + b - 1) // b
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def align_to(x, a):
|
|
19
|
+
return cdiv(x, a) * a
|
|
20
|
+
|
|
21
|
+
|
|
13
22
|
def gen_moe_inputs(
|
|
14
23
|
dtype,
|
|
15
24
|
top_k,
|
|
@@ -19,11 +28,14 @@ def gen_moe_inputs(
|
|
|
19
28
|
num_tokens,
|
|
20
29
|
*,
|
|
21
30
|
seed=1234,
|
|
31
|
+
has_bias=False,
|
|
22
32
|
):
|
|
23
33
|
key = jax.random.key(seed)
|
|
24
|
-
k0, k1, k2, k4, k5 = jax.random.split(key,
|
|
34
|
+
k0, k1, k2, k3, k4, k5, k6 = jax.random.split(key, 7)
|
|
35
|
+
|
|
25
36
|
a = jax.random.normal(k0, (num_tokens, hidden_size),
|
|
26
37
|
dtype=jnp.float32).astype(dtype) / 10
|
|
38
|
+
|
|
27
39
|
w1 = (jax.random.normal(
|
|
28
40
|
k1,
|
|
29
41
|
(num_experts, 2, hidden_size, intermediate_size),
|
|
@@ -31,21 +43,54 @@ def gen_moe_inputs(
|
|
|
31
43
|
) / 10).astype(dtype)
|
|
32
44
|
w2 = (jax.random.normal(k2, (num_experts, intermediate_size, hidden_size),
|
|
33
45
|
dtype=jnp.float32) / 10).astype(dtype)
|
|
46
|
+
|
|
47
|
+
if has_bias:
|
|
48
|
+
b1 = (jax.random.normal(k3, (num_experts, 2, intermediate_size),
|
|
49
|
+
dtype=jnp.float32) / 10).astype(dtype)
|
|
50
|
+
b2 = (jax.random.normal(k4, (num_experts, hidden_size),
|
|
51
|
+
dtype=jnp.float32) / 10).astype(dtype)
|
|
52
|
+
else:
|
|
53
|
+
b1 = b2 = None
|
|
54
|
+
|
|
34
55
|
gating_output = (
|
|
35
|
-
jax.random.normal(
|
|
56
|
+
jax.random.normal(k5, (num_tokens, num_experts), dtype=jnp.float32) +
|
|
36
57
|
jnp.arange(num_tokens * num_experts, dtype=jnp.float32).reshape(
|
|
37
58
|
num_tokens, num_experts) / 100)
|
|
59
|
+
|
|
38
60
|
# To generate unique top-k!
|
|
39
|
-
top_k_indices = jax.random.randint(
|
|
61
|
+
top_k_indices = jax.random.randint(k6, (num_tokens, top_k),
|
|
40
62
|
minval=0,
|
|
41
63
|
maxval=num_experts - 1,
|
|
42
64
|
dtype=jnp.int32)
|
|
65
|
+
|
|
43
66
|
one_hot = (jnp.sum(
|
|
44
67
|
jax.nn.one_hot(top_k_indices, num_experts, dtype=jnp.float32),
|
|
45
68
|
axis=1,
|
|
46
|
-
) *
|
|
69
|
+
) * 30)
|
|
70
|
+
|
|
47
71
|
gating_output = (gating_output + one_hot).astype(dtype)
|
|
48
|
-
|
|
72
|
+
|
|
73
|
+
return a, w1, w2, b1, b2, gating_output
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def sub_channel_quantize(x, quant_dtype, wsz=256):
|
|
77
|
+
"""Quantizes x with sub-channel quantization on the 2nd minor."""
|
|
78
|
+
if jnp.issubdtype(quant_dtype, jnp.floating):
|
|
79
|
+
dtype_info = jnp.finfo(quant_dtype)
|
|
80
|
+
else:
|
|
81
|
+
dtype_info = jnp.iinfo(quant_dtype)
|
|
82
|
+
dtype_max = float(dtype_info.max)
|
|
83
|
+
w_lst, scale_lst = [], []
|
|
84
|
+
assert len(x.shape) >= 2
|
|
85
|
+
assert x.shape[-2] % wsz == 0
|
|
86
|
+
for i in range(0, x.shape[-2], wsz):
|
|
87
|
+
y = x[..., i:i + wsz, :]
|
|
88
|
+
abs_max = jnp.abs(y).max(axis=-2, keepdims=True)
|
|
89
|
+
scale = (abs_max / dtype_max).astype(jnp.float32)
|
|
90
|
+
w = (y / scale).astype(quant_dtype)
|
|
91
|
+
w_lst.append(w)
|
|
92
|
+
scale_lst.append(scale)
|
|
93
|
+
return jnp.concat(w_lst, axis=-2), jnp.concat(scale_lst, axis=-2)
|
|
49
94
|
|
|
50
95
|
|
|
51
96
|
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
|
@@ -63,42 +108,266 @@ class MoEKernelTest(jtu.JaxTestCase):
|
|
|
63
108
|
self.mesh = Mesh(np.array(self.mesh_devices).reshape(1, -1),
|
|
64
109
|
axis_names=("data", "model"))
|
|
65
110
|
|
|
66
|
-
def
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
111
|
+
def _test_moe(
|
|
112
|
+
self,
|
|
113
|
+
dtype,
|
|
114
|
+
top_k,
|
|
115
|
+
num_experts,
|
|
116
|
+
hidden_size,
|
|
117
|
+
intermediate_size,
|
|
118
|
+
num_tokens,
|
|
119
|
+
seed,
|
|
120
|
+
renormalize_topk_logits,
|
|
121
|
+
bt,
|
|
122
|
+
bf,
|
|
123
|
+
bd1,
|
|
124
|
+
bd2,
|
|
125
|
+
btc,
|
|
126
|
+
bfc,
|
|
127
|
+
bd1c,
|
|
128
|
+
bd2c,
|
|
129
|
+
act_fn="silu",
|
|
130
|
+
w_dtype=None,
|
|
131
|
+
subc_quant_wsz=None,
|
|
132
|
+
has_bias=False,
|
|
133
|
+
atol=2e-1,
|
|
134
|
+
rtol=2e-1,
|
|
135
|
+
):
|
|
136
|
+
a, w1, w2, b1, b2, gating_output = gen_moe_inputs(
|
|
75
137
|
dtype,
|
|
76
138
|
top_k,
|
|
77
139
|
num_experts,
|
|
78
140
|
hidden_size,
|
|
79
141
|
intermediate_size,
|
|
80
142
|
num_tokens,
|
|
143
|
+
seed=seed,
|
|
144
|
+
has_bias=has_bias,
|
|
81
145
|
)
|
|
146
|
+
w1_scale = None
|
|
147
|
+
w2_scale = None
|
|
148
|
+
if w_dtype is not None:
|
|
149
|
+
if subc_quant_wsz is None:
|
|
150
|
+
subc_quant_wsz = 256
|
|
151
|
+
w1, w1_scale = sub_channel_quantize(w1, w_dtype, subc_quant_wsz)
|
|
152
|
+
w2, w2_scale = sub_channel_quantize(w2, w_dtype, subc_quant_wsz)
|
|
82
153
|
|
|
83
|
-
actual =
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
154
|
+
actual = fused_ep_moe(
|
|
155
|
+
mesh=self.mesh,
|
|
156
|
+
tokens=a,
|
|
157
|
+
w1=w1,
|
|
158
|
+
w2=w2,
|
|
159
|
+
gating_output=gating_output,
|
|
160
|
+
top_k=top_k,
|
|
161
|
+
renormalize_topk_logits=renormalize_topk_logits,
|
|
162
|
+
act_fn=act_fn,
|
|
163
|
+
subc_quant_wsz=subc_quant_wsz,
|
|
164
|
+
w1_scale=w1_scale,
|
|
165
|
+
w2_scale=w2_scale,
|
|
166
|
+
b1=b1,
|
|
167
|
+
b2=b2,
|
|
168
|
+
bt=bt,
|
|
169
|
+
bf=bf,
|
|
170
|
+
bd1=bd1,
|
|
171
|
+
bd2=bd2,
|
|
172
|
+
btc=btc,
|
|
173
|
+
bfc=bfc,
|
|
174
|
+
bd1c=bd1c,
|
|
175
|
+
bd2c=bd2c,
|
|
176
|
+
)
|
|
177
|
+
expected = ref_moe(
|
|
178
|
+
a,
|
|
179
|
+
w1,
|
|
180
|
+
w2,
|
|
181
|
+
gating_output,
|
|
182
|
+
top_k,
|
|
183
|
+
b1=b1,
|
|
184
|
+
b2=b2,
|
|
185
|
+
renormalize_topk_logits=renormalize_topk_logits,
|
|
186
|
+
activation=act_fn,
|
|
187
|
+
subc_quant_wsz=subc_quant_wsz,
|
|
188
|
+
w1_scale=w1_scale,
|
|
189
|
+
w2_scale=w2_scale,
|
|
190
|
+
)
|
|
191
|
+
self.assertAllClose(actual, expected, atol=atol, rtol=rtol)
|
|
192
|
+
|
|
193
|
+
@parameterized.product(renormalize_topk_logits=[True, False], )
|
|
194
|
+
def test_basic(self, renormalize_topk_logits):
|
|
195
|
+
dtype = jnp.bfloat16
|
|
196
|
+
top_k = 8
|
|
197
|
+
num_experts = 128
|
|
198
|
+
hidden_size = 1024
|
|
199
|
+
intermediate_size = 1024
|
|
200
|
+
num_tokens = 8 * 32
|
|
201
|
+
self._test_moe(
|
|
202
|
+
dtype=dtype,
|
|
203
|
+
top_k=top_k,
|
|
204
|
+
num_experts=num_experts,
|
|
205
|
+
hidden_size=hidden_size,
|
|
206
|
+
intermediate_size=intermediate_size,
|
|
207
|
+
num_tokens=num_tokens,
|
|
208
|
+
seed=1234,
|
|
209
|
+
renormalize_topk_logits=renormalize_topk_logits,
|
|
210
|
+
bt=32,
|
|
211
|
+
bf=1024,
|
|
212
|
+
bd1=1024,
|
|
213
|
+
bd2=1024,
|
|
214
|
+
btc=32,
|
|
215
|
+
bfc=256,
|
|
216
|
+
bd1c=256,
|
|
217
|
+
bd2c=256,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
@parameterized.product(act_fn=["silu", "gelu", "swigluoai"], )
|
|
221
|
+
def test_activation(self, act_fn):
|
|
222
|
+
dtype = jnp.bfloat16
|
|
223
|
+
top_k = 8
|
|
224
|
+
num_experts = 128
|
|
225
|
+
hidden_size = 1024
|
|
226
|
+
intermediate_size = 1024
|
|
227
|
+
num_tokens = 8 * 32
|
|
228
|
+
self._test_moe(
|
|
229
|
+
dtype=dtype,
|
|
230
|
+
top_k=top_k,
|
|
231
|
+
num_experts=num_experts,
|
|
232
|
+
hidden_size=hidden_size,
|
|
233
|
+
intermediate_size=intermediate_size,
|
|
234
|
+
num_tokens=num_tokens,
|
|
235
|
+
seed=1234,
|
|
236
|
+
renormalize_topk_logits=True,
|
|
237
|
+
act_fn=act_fn,
|
|
238
|
+
bt=32,
|
|
239
|
+
bf=512,
|
|
240
|
+
bd1=512,
|
|
241
|
+
bd2=512,
|
|
242
|
+
btc=32,
|
|
243
|
+
bfc=256,
|
|
244
|
+
bd1c=256,
|
|
245
|
+
bd2c=256,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
def test_benchmark_qwen_235(self):
|
|
249
|
+
num_experts = 128
|
|
250
|
+
top_k = 8
|
|
251
|
+
hidden_size = 4096
|
|
252
|
+
intermediate_size = 1536
|
|
253
|
+
dtype = jnp.bfloat16
|
|
254
|
+
num_tokens = 8 * 64
|
|
255
|
+
seed = 54321
|
|
256
|
+
renormalize_topk_logits = True
|
|
257
|
+
self._test_moe(
|
|
258
|
+
dtype=dtype,
|
|
259
|
+
top_k=top_k,
|
|
260
|
+
num_experts=num_experts,
|
|
261
|
+
hidden_size=hidden_size,
|
|
262
|
+
intermediate_size=intermediate_size,
|
|
263
|
+
num_tokens=num_tokens,
|
|
264
|
+
seed=seed,
|
|
265
|
+
renormalize_topk_logits=renormalize_topk_logits,
|
|
266
|
+
bt=64,
|
|
267
|
+
bf=768,
|
|
268
|
+
bd1=2048,
|
|
269
|
+
bd2=2048,
|
|
270
|
+
btc=64,
|
|
271
|
+
bfc=768,
|
|
272
|
+
bd1c=2048,
|
|
273
|
+
bd2c=2048,
|
|
274
|
+
act_fn="silu",
|
|
275
|
+
atol=5e-2,
|
|
276
|
+
rtol=5e-2,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
def test_benchmark_qwen_30b_a3b(self):
|
|
280
|
+
num_experts = 128
|
|
281
|
+
top_k = 8
|
|
282
|
+
hidden_size = 2048
|
|
283
|
+
intermediate_size = 768
|
|
284
|
+
dtype = jnp.bfloat16
|
|
285
|
+
num_tokens = 512
|
|
286
|
+
seed = 54321
|
|
287
|
+
renormalize_topk_logits = True
|
|
288
|
+
self._test_moe(
|
|
289
|
+
dtype=dtype,
|
|
290
|
+
top_k=top_k,
|
|
291
|
+
num_experts=num_experts,
|
|
292
|
+
hidden_size=hidden_size,
|
|
293
|
+
intermediate_size=intermediate_size,
|
|
294
|
+
num_tokens=num_tokens,
|
|
295
|
+
seed=seed,
|
|
296
|
+
renormalize_topk_logits=renormalize_topk_logits,
|
|
297
|
+
bt=16,
|
|
298
|
+
bf=384,
|
|
299
|
+
bd1=512,
|
|
300
|
+
bd2=512,
|
|
301
|
+
btc=16,
|
|
302
|
+
bfc=384,
|
|
303
|
+
bd1c=256,
|
|
304
|
+
bd2c=256,
|
|
305
|
+
act_fn="silu",
|
|
306
|
+
atol=5e-2,
|
|
307
|
+
rtol=5e-2,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
@parameterized.product(
|
|
311
|
+
w_dtype=[jnp.int8, jnp.float8_e5m2, jnp.float4_e2m1fn], )
|
|
312
|
+
def test_sub_channel_quantization(self, w_dtype):
|
|
313
|
+
if w_dtype in (
|
|
314
|
+
jnp.float8_e5m2,
|
|
315
|
+
jnp.float4_e2m1fn,
|
|
316
|
+
) and not jtu.is_device_tpu_at_least(version=7):
|
|
317
|
+
self.skipTest("Expect TPUv7+")
|
|
318
|
+
dtype = jnp.bfloat16
|
|
319
|
+
top_k = 8
|
|
320
|
+
num_experts = 128
|
|
321
|
+
hidden_size = 1024
|
|
322
|
+
intermediate_size = 1024
|
|
323
|
+
num_tokens = 8 * 32
|
|
324
|
+
self._test_moe(
|
|
325
|
+
dtype=dtype,
|
|
326
|
+
top_k=top_k,
|
|
327
|
+
num_experts=num_experts,
|
|
328
|
+
hidden_size=hidden_size,
|
|
329
|
+
intermediate_size=intermediate_size,
|
|
330
|
+
num_tokens=num_tokens,
|
|
331
|
+
seed=1234,
|
|
332
|
+
renormalize_topk_logits=False,
|
|
333
|
+
w_dtype=w_dtype,
|
|
334
|
+
subc_quant_wsz=256,
|
|
335
|
+
bt=32,
|
|
336
|
+
bf=1024,
|
|
337
|
+
bd1=1024,
|
|
338
|
+
bd2=1024,
|
|
339
|
+
btc=32,
|
|
340
|
+
bfc=256,
|
|
341
|
+
bd1c=256,
|
|
342
|
+
bd2c=256,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
def test_bias(self):
|
|
346
|
+
dtype = jnp.bfloat16
|
|
347
|
+
top_k = 8
|
|
348
|
+
num_experts = 128
|
|
349
|
+
hidden_size = 1024
|
|
350
|
+
intermediate_size = 1024
|
|
351
|
+
num_tokens = 8 * 32
|
|
352
|
+
self._test_moe(
|
|
353
|
+
dtype=dtype,
|
|
354
|
+
top_k=top_k,
|
|
355
|
+
num_experts=num_experts,
|
|
356
|
+
hidden_size=hidden_size,
|
|
357
|
+
intermediate_size=intermediate_size,
|
|
358
|
+
num_tokens=num_tokens,
|
|
359
|
+
seed=1234,
|
|
360
|
+
renormalize_topk_logits=False,
|
|
361
|
+
has_bias=True,
|
|
362
|
+
bt=32,
|
|
363
|
+
bf=512,
|
|
364
|
+
bd1=512,
|
|
365
|
+
bd2=512,
|
|
366
|
+
btc=32,
|
|
367
|
+
bfc=256,
|
|
368
|
+
bd1c=256,
|
|
369
|
+
bd2c=256,
|
|
370
|
+
)
|
|
102
371
|
|
|
103
372
|
|
|
104
373
|
if __name__ == "__main__":
|
tests/lora/test_layers.py
CHANGED
|
@@ -91,7 +91,6 @@ def populate_loras(
|
|
|
91
91
|
index_to_id: list[Optional[int]],
|
|
92
92
|
lora_layer: BaseLayerWithLoRA,
|
|
93
93
|
baselayer_weights: torch.Tensor,
|
|
94
|
-
generate_embeddings_tensor: int = 0,
|
|
95
94
|
repeats: int = 1,
|
|
96
95
|
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
|
|
97
96
|
"""This method populates the lora weights (lora_a and lora_b) in the lora layers (BaseLayerWithLoRA).
|
|
@@ -103,8 +102,6 @@ def populate_loras(
|
|
|
103
102
|
lora_layer: the LoRAlayer to populate.
|
|
104
103
|
baselayer_weights: the PyTorch tensor containing the layer's
|
|
105
104
|
weights.
|
|
106
|
-
generate_embeddings_tensor: whether to generate an
|
|
107
|
-
embeddings tensor for each LoRA.
|
|
108
105
|
repeats: must only be set for column parallel packed
|
|
109
106
|
layers. Indicates the number of loras to compose
|
|
110
107
|
together to create a single lora layer.
|
|
@@ -131,7 +128,6 @@ def populate_loras(
|
|
|
131
128
|
baselayer_weights.device).init_random_lora(
|
|
132
129
|
module_name=f"fake_{i}",
|
|
133
130
|
weight=baselayer_weights,
|
|
134
|
-
generate_embeddings_tensor=generate_embeddings_tensor,
|
|
135
131
|
)
|
|
136
132
|
sublora.lora_b = sublora.lora_b[(sublora_len *
|
|
137
133
|
i):(sublora_len * (i + 1)), :]
|
|
@@ -147,7 +143,6 @@ def populate_loras(
|
|
|
147
143
|
slot_idx,
|
|
148
144
|
lora_a=lora.lora_a,
|
|
149
145
|
lora_b=lora.lora_b,
|
|
150
|
-
embeddings_tensor=lora.embeddings_tensor,
|
|
151
146
|
)
|
|
152
147
|
|
|
153
148
|
lora_dict[lora_id] = lora
|
|
@@ -546,7 +541,6 @@ def _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
|
|
|
546
541
|
index_to_id,
|
|
547
542
|
lora_config.max_loras,
|
|
548
543
|
vocab_size=512,
|
|
549
|
-
extra_vocab_size=lora_config.lora_extra_vocab_size,
|
|
550
544
|
)
|
|
551
545
|
assert jax_view(punica_wrapper._lora_indices_per_batch).platform(
|
|
552
546
|
) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
|
tests/lora/utils.py
CHANGED
|
@@ -24,7 +24,6 @@ class DummyLoRAManager:
|
|
|
24
24
|
module_name: str,
|
|
25
25
|
weight: torch.Tensor,
|
|
26
26
|
rank: int = 8,
|
|
27
|
-
generate_embeddings_tensor: int = 0,
|
|
28
27
|
):
|
|
29
28
|
lora = LoRALayerWeights(
|
|
30
29
|
module_name,
|
|
@@ -37,13 +36,6 @@ class DummyLoRAManager:
|
|
|
37
36
|
dtype=weight.dtype,
|
|
38
37
|
device=self._device),
|
|
39
38
|
)
|
|
40
|
-
if generate_embeddings_tensor:
|
|
41
|
-
lora.embeddings_tensor = torch.rand(
|
|
42
|
-
5,
|
|
43
|
-
generate_embeddings_tensor,
|
|
44
|
-
dtype=weight.dtype,
|
|
45
|
-
device=self._device,
|
|
46
|
-
)
|
|
47
39
|
self.set_module_lora(module_name, lora)
|
|
48
40
|
|
|
49
41
|
return lora
|
tpu_inference/__init__.py
CHANGED
|
@@ -1,21 +1,40 @@
|
|
|
1
|
-
import os
|
|
2
|
-
|
|
3
1
|
# The environment variables override should be imported before any other
|
|
4
2
|
# modules to ensure that the environment variables are set before any
|
|
5
3
|
# other modules are imported.
|
|
6
4
|
import tpu_inference.env_override # noqa: F401
|
|
5
|
+
from tpu_inference import envs
|
|
7
6
|
from tpu_inference import tpu_info as ti
|
|
8
7
|
from tpu_inference.logger import init_logger
|
|
9
8
|
|
|
10
9
|
logger = init_logger(__name__)
|
|
11
10
|
|
|
12
|
-
if "proxy" in
|
|
11
|
+
if "proxy" in envs.JAX_PLATFORMS:
|
|
13
12
|
logger.info("Running vLLM on TPU via Pathways proxy.")
|
|
14
13
|
# Must run pathwaysutils.initialize() before any JAX operations
|
|
15
14
|
try:
|
|
15
|
+
import traceback
|
|
16
|
+
|
|
16
17
|
import pathwaysutils
|
|
18
|
+
import vllm
|
|
19
|
+
from vllm.platforms import (resolve_current_platform_cls_qualname,
|
|
20
|
+
resolve_obj_by_qualname)
|
|
17
21
|
pathwaysutils.initialize()
|
|
18
22
|
logger.info("Module pathwaysutils is imported.")
|
|
23
|
+
|
|
24
|
+
# Pathways requires eager resolution of vllm.current_platform instead of
|
|
25
|
+
# lazy resolution in the normal code path. Since this part involves
|
|
26
|
+
# global topology discovery across multiple hosts, the platform
|
|
27
|
+
# resolution must happen before other components are loaded.
|
|
28
|
+
logger.info("Eagerly resolving vLLM current_platform for Pathways.")
|
|
29
|
+
platform_cls_qualname = resolve_current_platform_cls_qualname()
|
|
30
|
+
resolved_platform_instance = resolve_obj_by_qualname(
|
|
31
|
+
platform_cls_qualname)()
|
|
32
|
+
vllm.platforms._current_platform = resolved_platform_instance
|
|
33
|
+
vllm.platforms._init_trace = "".join(traceback.format_stack())
|
|
34
|
+
logger.info(
|
|
35
|
+
f"vLLM platform resolved to: {resolved_platform_instance.__class__.__name__}"
|
|
36
|
+
)
|
|
37
|
+
|
|
19
38
|
except Exception as e:
|
|
20
39
|
logger.error(
|
|
21
40
|
f"Error occurred while importing pathwaysutils or logging TPU info: {e}"
|
|
@@ -1,17 +1,15 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
|
|
3
|
-
import os
|
|
4
3
|
from typing import Tuple
|
|
5
4
|
|
|
6
|
-
|
|
7
|
-
DECODE_SLICES = 'DECODE_SLICES'
|
|
5
|
+
from tpu_inference import envs
|
|
8
6
|
|
|
9
7
|
|
|
10
8
|
def is_disagg_enabled() -> bool:
|
|
11
9
|
# We triggrer our code path as long as prefill slices are set. This
|
|
12
10
|
# allows us to test interleave mode effectively with the code path
|
|
13
11
|
# for comparison purposes.
|
|
14
|
-
return PREFILL_SLICES
|
|
12
|
+
return bool(envs.PREFILL_SLICES)
|
|
15
13
|
|
|
16
14
|
|
|
17
15
|
def _parse_slices(slices_str: str) -> Tuple[int, ...]:
|
|
@@ -40,12 +38,12 @@ def _parse_slices(slices_str: str) -> Tuple[int, ...]:
|
|
|
40
38
|
|
|
41
39
|
|
|
42
40
|
def get_prefill_slices() -> Tuple[int, ...]:
|
|
43
|
-
if
|
|
41
|
+
if not envs.PREFILL_SLICES:
|
|
44
42
|
return ()
|
|
45
|
-
return _parse_slices(
|
|
43
|
+
return _parse_slices(envs.PREFILL_SLICES)
|
|
46
44
|
|
|
47
45
|
|
|
48
46
|
def get_decode_slices() -> Tuple[int, ...]:
|
|
49
|
-
if
|
|
47
|
+
if not envs.DECODE_SLICES:
|
|
50
48
|
return ()
|
|
51
|
-
return _parse_slices(
|
|
49
|
+
return _parse_slices(envs.DECODE_SLICES)
|
|
@@ -60,7 +60,6 @@ D workflow:
|
|
|
60
60
|
|
|
61
61
|
import copy
|
|
62
62
|
import functools
|
|
63
|
-
import os
|
|
64
63
|
import threading
|
|
65
64
|
import time
|
|
66
65
|
from concurrent.futures import Future, ThreadPoolExecutor
|
|
@@ -86,6 +85,7 @@ if TYPE_CHECKING:
|
|
|
86
85
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
|
87
86
|
from vllm.v1.request import Request
|
|
88
87
|
|
|
88
|
+
from tpu_inference import envs
|
|
89
89
|
from tpu_inference.distributed.utils import (get_host_ip, get_kv_ips,
|
|
90
90
|
get_kv_ports,
|
|
91
91
|
get_kv_transfer_port, get_node_id,
|
|
@@ -441,8 +441,7 @@ class TPUConnectorWorker:
|
|
|
441
441
|
|
|
442
442
|
self.runner: TPUModelRunner = None
|
|
443
443
|
self.mesh: Mesh = None
|
|
444
|
-
self.multi_host =
|
|
445
|
-
"").lower() == "ray"
|
|
444
|
+
self.multi_host = envs.TPU_MULTIHOST_BACKEND == "ray"
|
|
446
445
|
# NOTE(xiang): This can not be the worker rank set in RayDistributedExecutor.
|
|
447
446
|
# The worker rank is assigned with vLLM's sorting logic, which does not work
|
|
448
447
|
# for TPU host topology.
|
|
@@ -2,6 +2,7 @@ import os
|
|
|
2
2
|
|
|
3
3
|
from vllm.utils.network_utils import get_ip
|
|
4
4
|
|
|
5
|
+
from tpu_inference import envs
|
|
5
6
|
from tpu_inference.logger import init_logger
|
|
6
7
|
|
|
7
8
|
logger = init_logger(__name__)
|
|
@@ -17,7 +18,7 @@ def set_node_kv_ip_port(ip_port: tuple[int, str, int]):
|
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
def get_kv_ips() -> str:
|
|
20
|
-
if
|
|
21
|
+
if envs.TPU_MULTIHOST_BACKEND == "ray":
|
|
21
22
|
num_nodes = len(_NODES_KV_IP_PORT)
|
|
22
23
|
ips = []
|
|
23
24
|
for node_id in range(num_nodes):
|
|
@@ -28,7 +29,7 @@ def get_kv_ips() -> str:
|
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
def get_kv_ports() -> str:
|
|
31
|
-
if
|
|
32
|
+
if envs.TPU_MULTIHOST_BACKEND == "ray":
|
|
32
33
|
num_nodes = len(_NODES_KV_IP_PORT)
|
|
33
34
|
ports = []
|
|
34
35
|
for node_id in range(num_nodes):
|
tpu_inference/envs.py
CHANGED
|
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
|
|
26
26
|
environment_variables: dict[str, Callable[[], Any]] = {
|
|
27
27
|
# JAX platform selection (e.g., "tpu", "cpu", "proxy")
|
|
28
28
|
"JAX_PLATFORMS":
|
|
29
|
-
lambda: os.getenv("JAX_PLATFORMS", ""),
|
|
29
|
+
lambda: os.getenv("JAX_PLATFORMS", "").lower(),
|
|
30
30
|
# TPU accelerator type (e.g., "v5litepod-16", "v4-8")
|
|
31
31
|
"TPU_ACCELERATOR_TYPE":
|
|
32
32
|
lambda: os.getenv("TPU_ACCELERATOR_TYPE", None),
|
|
@@ -108,6 +108,9 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
108
108
|
ip_port = self.collective_rpc("get_node_kv_ip_port")
|
|
109
109
|
for item in ip_port:
|
|
110
110
|
set_node_kv_ip_port(item)
|
|
111
|
+
self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
|
|
112
|
+
self.vllm_config.ec_transfer_config is None
|
|
113
|
+
or not self.vllm_config.ec_transfer_config.is_ec_producer)
|
|
111
114
|
|
|
112
115
|
def _initialize_ray_cluster(self) -> None:
|
|
113
116
|
"""Initialize the distributed cluster with Ray.
|
|
@@ -131,10 +134,17 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
131
134
|
f"current platform {current_platform.device_name} does not "
|
|
132
135
|
"support ray.")
|
|
133
136
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
137
|
+
pp_size = self.parallel_config.pipeline_parallel_size
|
|
138
|
+
placement_group_specs: List[Dict[str, float]] = []
|
|
139
|
+
if pp_size == 1:
|
|
140
|
+
placement_group_specs = [{
|
|
141
|
+
device_str: node['Resources'][device_str]
|
|
142
|
+
} for node in ray.nodes()]
|
|
143
|
+
else:
|
|
144
|
+
num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
|
|
145
|
+
placement_group_specs = [{
|
|
146
|
+
device_str: num_devices_per_pp_rank
|
|
147
|
+
} for _ in range(pp_size)]
|
|
138
148
|
|
|
139
149
|
# vLLM engine is also a worker to execute model with an accelerator,
|
|
140
150
|
# so it requires to have the device in a current node. Check if
|
|
@@ -329,6 +339,8 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
329
339
|
all_kwargs = []
|
|
330
340
|
for rank, (node_id, _) in enumerate(worker_node_and_tpu_ids):
|
|
331
341
|
local_rank = node_workers[node_id].index(rank)
|
|
342
|
+
ip = sorted_worker_metadata[rank].ip
|
|
343
|
+
prev_ip = sorted_worker_metadata[rank - 1].ip if rank > 0 else ""
|
|
332
344
|
kwargs = dict(
|
|
333
345
|
vllm_config=self.vllm_config,
|
|
334
346
|
local_rank=local_rank,
|
|
@@ -336,22 +348,26 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
336
348
|
distributed_init_method=distributed_init_method,
|
|
337
349
|
is_driver_worker=(not self.parallel_config)
|
|
338
350
|
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
|
351
|
+
ip=ip,
|
|
352
|
+
prev_worker_ip=prev_ip,
|
|
339
353
|
)
|
|
340
354
|
all_kwargs.append(kwargs)
|
|
341
355
|
self.collective_rpc("init_worker", args=(all_kwargs, ))
|
|
342
356
|
self.collective_rpc("init_device")
|
|
357
|
+
if self.parallel_config.pipeline_parallel_size > 1:
|
|
358
|
+
self.collective_rpc("initialize_pp_transfer_connect")
|
|
343
359
|
self.collective_rpc("load_model")
|
|
344
360
|
|
|
345
361
|
if self.use_ray_spmd_worker:
|
|
346
362
|
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
|
|
347
363
|
self.pp_tp_workers.append([])
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
#
|
|
353
|
-
|
|
354
|
-
|
|
364
|
+
num_tp_workers = int(
|
|
365
|
+
self.parallel_config.tensor_parallel_size //
|
|
366
|
+
num_tpu_per_worker)
|
|
367
|
+
for tp_rank in range(num_tp_workers):
|
|
368
|
+
# PP=2, TP=4, num_tpu_per_worker=2
|
|
369
|
+
# pp_tp_workers = [[0, 1], [2, 3]]
|
|
370
|
+
rank = (pp_rank * num_tp_workers) + tp_rank
|
|
355
371
|
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
|
|
356
372
|
assert pp_rank < len(self.pp_tp_workers)
|
|
357
373
|
self.pp_tp_workers[pp_rank].append(self.workers[rank])
|