tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511130813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +34 -303
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
- tests/lora/test_layers.py +6 -0
- tests/lora/utils.py +8 -0
- tests/test_utils.py +16 -24
- tpu_inference/__init__.py +3 -22
- tpu_inference/core/core_tpu.py +9 -17
- tpu_inference/core/disagg_utils.py +8 -6
- tpu_inference/distributed/tpu_connector.py +4 -3
- tpu_inference/distributed/utils.py +2 -3
- tpu_inference/envs.py +8 -61
- tpu_inference/executors/ray_distributed_executor.py +11 -31
- tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +143 -287
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -7
- tpu_inference/layers/jax/attention/attention.py +1 -1
- tpu_inference/layers/{common → jax}/attention_interface.py +2 -8
- tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
- tpu_inference/layers/jax/sample/sampling.py +2 -2
- tpu_inference/layers/{common → jax}/sharding.py +5 -5
- tpu_inference/layers/vllm/attention.py +1 -1
- tpu_inference/layers/vllm/fused_moe.py +208 -170
- tpu_inference/layers/vllm/quantization/__init__.py +3 -7
- tpu_inference/layers/vllm/quantization/awq.py +3 -4
- tpu_inference/layers/vllm/quantization/common.py +1 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +2 -4
- tpu_inference/layers/vllm/quantization/unquantized.py +67 -62
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +2 -1
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/common/model_loader.py +12 -46
- tpu_inference/models/jax/llama3.py +3 -4
- tpu_inference/models/jax/llama_eagle3.py +5 -8
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +2 -3
- tpu_inference/models/jax/qwen2_5_vl.py +50 -165
- tpu_inference/models/jax/qwen3.py +2 -3
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
- tpu_inference/models/jax/utils/weight_utils.py +143 -198
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -32
- tpu_inference/platforms/tpu_platform.py +34 -47
- tpu_inference/runner/compilation_manager.py +60 -145
- tpu_inference/runner/kv_cache.py +2 -2
- tpu_inference/runner/kv_cache_manager.py +18 -17
- tpu_inference/runner/persistent_batch_manager.py +2 -40
- tpu_inference/runner/structured_decoding_manager.py +3 -2
- tpu_inference/runner/tpu_runner.py +135 -283
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +21 -71
- tpu_inference/tpu_info.py +3 -4
- tpu_inference/utils.py +15 -38
- tpu_inference/worker/tpu_worker.py +26 -163
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/METADATA +3 -4
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/RECORD +63 -61
- tests/test_envs.py +0 -203
- tpu_inference/layers/common/quant_methods.py +0 -8
- tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
- tpu_inference/models/jax/llama_guard_4.py +0 -361
- /tpu_inference/layers/{common → jax}/binary_search.py +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/WHEEL +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import jax
|
|
2
2
|
import jax.numpy as jnp
|
|
3
3
|
import numpy as np
|
|
4
|
-
from absl.testing import absltest
|
|
4
|
+
from absl.testing import absltest
|
|
5
5
|
from jax._src import test_util as jtu
|
|
6
6
|
from jax.sharding import Mesh
|
|
7
7
|
|
|
@@ -10,15 +10,6 @@ from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe, ref_moe
|
|
|
10
10
|
jax.config.parse_flags_with_absl()
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
def cdiv(a, b):
|
|
14
|
-
assert b != 0
|
|
15
|
-
return (a + b - 1) // b
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def align_to(x, a):
|
|
19
|
-
return cdiv(x, a) * a
|
|
20
|
-
|
|
21
|
-
|
|
22
13
|
def gen_moe_inputs(
|
|
23
14
|
dtype,
|
|
24
15
|
top_k,
|
|
@@ -28,14 +19,11 @@ def gen_moe_inputs(
|
|
|
28
19
|
num_tokens,
|
|
29
20
|
*,
|
|
30
21
|
seed=1234,
|
|
31
|
-
has_bias=False,
|
|
32
22
|
):
|
|
33
23
|
key = jax.random.key(seed)
|
|
34
|
-
k0, k1, k2,
|
|
35
|
-
|
|
24
|
+
k0, k1, k2, k4, k5 = jax.random.split(key, 5)
|
|
36
25
|
a = jax.random.normal(k0, (num_tokens, hidden_size),
|
|
37
26
|
dtype=jnp.float32).astype(dtype) / 10
|
|
38
|
-
|
|
39
27
|
w1 = (jax.random.normal(
|
|
40
28
|
k1,
|
|
41
29
|
(num_experts, 2, hidden_size, intermediate_size),
|
|
@@ -43,54 +31,21 @@ def gen_moe_inputs(
|
|
|
43
31
|
) / 10).astype(dtype)
|
|
44
32
|
w2 = (jax.random.normal(k2, (num_experts, intermediate_size, hidden_size),
|
|
45
33
|
dtype=jnp.float32) / 10).astype(dtype)
|
|
46
|
-
|
|
47
|
-
if has_bias:
|
|
48
|
-
b1 = (jax.random.normal(k3, (num_experts, 2, intermediate_size),
|
|
49
|
-
dtype=jnp.float32) / 10).astype(dtype)
|
|
50
|
-
b2 = (jax.random.normal(k4, (num_experts, hidden_size),
|
|
51
|
-
dtype=jnp.float32) / 10).astype(dtype)
|
|
52
|
-
else:
|
|
53
|
-
b1 = b2 = None
|
|
54
|
-
|
|
55
34
|
gating_output = (
|
|
56
|
-
jax.random.normal(
|
|
35
|
+
jax.random.normal(k4, (num_tokens, num_experts), dtype=jnp.float32) +
|
|
57
36
|
jnp.arange(num_tokens * num_experts, dtype=jnp.float32).reshape(
|
|
58
37
|
num_tokens, num_experts) / 100)
|
|
59
|
-
|
|
60
38
|
# To generate unique top-k!
|
|
61
|
-
top_k_indices = jax.random.randint(
|
|
39
|
+
top_k_indices = jax.random.randint(k5, (num_tokens, top_k),
|
|
62
40
|
minval=0,
|
|
63
41
|
maxval=num_experts - 1,
|
|
64
42
|
dtype=jnp.int32)
|
|
65
|
-
|
|
66
43
|
one_hot = (jnp.sum(
|
|
67
44
|
jax.nn.one_hot(top_k_indices, num_experts, dtype=jnp.float32),
|
|
68
45
|
axis=1,
|
|
69
|
-
) *
|
|
70
|
-
|
|
46
|
+
) * 10)
|
|
71
47
|
gating_output = (gating_output + one_hot).astype(dtype)
|
|
72
|
-
|
|
73
|
-
return a, w1, w2, b1, b2, gating_output
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
def sub_channel_quantize(x, quant_dtype, wsz=256):
|
|
77
|
-
"""Quantizes x with sub-channel quantization on the 2nd minor."""
|
|
78
|
-
if jnp.issubdtype(quant_dtype, jnp.floating):
|
|
79
|
-
dtype_info = jnp.finfo(quant_dtype)
|
|
80
|
-
else:
|
|
81
|
-
dtype_info = jnp.iinfo(quant_dtype)
|
|
82
|
-
dtype_max = float(dtype_info.max)
|
|
83
|
-
w_lst, scale_lst = [], []
|
|
84
|
-
assert len(x.shape) >= 2
|
|
85
|
-
assert x.shape[-2] % wsz == 0
|
|
86
|
-
for i in range(0, x.shape[-2], wsz):
|
|
87
|
-
y = x[..., i:i + wsz, :]
|
|
88
|
-
abs_max = jnp.abs(y).max(axis=-2, keepdims=True)
|
|
89
|
-
scale = (abs_max / dtype_max).astype(jnp.float32)
|
|
90
|
-
w = (y / scale).astype(quant_dtype)
|
|
91
|
-
w_lst.append(w)
|
|
92
|
-
scale_lst.append(scale)
|
|
93
|
-
return jnp.concat(w_lst, axis=-2), jnp.concat(scale_lst, axis=-2)
|
|
48
|
+
return a, w1, w2, gating_output
|
|
94
49
|
|
|
95
50
|
|
|
96
51
|
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
|
@@ -108,266 +63,42 @@ class MoEKernelTest(jtu.JaxTestCase):
|
|
|
108
63
|
self.mesh = Mesh(np.array(self.mesh_devices).reshape(1, -1),
|
|
109
64
|
axis_names=("data", "model"))
|
|
110
65
|
|
|
111
|
-
def
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
renormalize_topk_logits,
|
|
121
|
-
bt,
|
|
122
|
-
bf,
|
|
123
|
-
bd1,
|
|
124
|
-
bd2,
|
|
125
|
-
btc,
|
|
126
|
-
bfc,
|
|
127
|
-
bd1c,
|
|
128
|
-
bd2c,
|
|
129
|
-
act_fn="silu",
|
|
130
|
-
w_dtype=None,
|
|
131
|
-
subc_quant_wsz=None,
|
|
132
|
-
has_bias=False,
|
|
133
|
-
atol=2e-1,
|
|
134
|
-
rtol=2e-1,
|
|
135
|
-
):
|
|
136
|
-
a, w1, w2, b1, b2, gating_output = gen_moe_inputs(
|
|
66
|
+
def test_basic(self):
|
|
67
|
+
dtype = jnp.bfloat16
|
|
68
|
+
top_k = 2
|
|
69
|
+
num_experts = 16
|
|
70
|
+
hidden_size = 256
|
|
71
|
+
intermediate_size = 256
|
|
72
|
+
num_tokens = 8 * 2
|
|
73
|
+
|
|
74
|
+
a, w1, w2, gating_output = gen_moe_inputs(
|
|
137
75
|
dtype,
|
|
138
76
|
top_k,
|
|
139
77
|
num_experts,
|
|
140
78
|
hidden_size,
|
|
141
79
|
intermediate_size,
|
|
142
80
|
num_tokens,
|
|
143
|
-
seed=seed,
|
|
144
|
-
has_bias=has_bias,
|
|
145
81
|
)
|
|
146
|
-
w1_scale = None
|
|
147
|
-
w2_scale = None
|
|
148
|
-
if w_dtype is not None:
|
|
149
|
-
if subc_quant_wsz is None:
|
|
150
|
-
subc_quant_wsz = 256
|
|
151
|
-
w1, w1_scale = sub_channel_quantize(w1, w_dtype, subc_quant_wsz)
|
|
152
|
-
w2, w2_scale = sub_channel_quantize(w2, w_dtype, subc_quant_wsz)
|
|
153
82
|
|
|
154
|
-
actual =
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
bfc=bfc,
|
|
174
|
-
bd1c=bd1c,
|
|
175
|
-
bd2c=bd2c,
|
|
176
|
-
)
|
|
177
|
-
expected = ref_moe(
|
|
178
|
-
a,
|
|
179
|
-
w1,
|
|
180
|
-
w2,
|
|
181
|
-
gating_output,
|
|
182
|
-
top_k,
|
|
183
|
-
b1=b1,
|
|
184
|
-
b2=b2,
|
|
185
|
-
renormalize_topk_logits=renormalize_topk_logits,
|
|
186
|
-
activation=act_fn,
|
|
187
|
-
subc_quant_wsz=subc_quant_wsz,
|
|
188
|
-
w1_scale=w1_scale,
|
|
189
|
-
w2_scale=w2_scale,
|
|
190
|
-
)
|
|
191
|
-
self.assertAllClose(actual, expected, atol=atol, rtol=rtol)
|
|
192
|
-
|
|
193
|
-
@parameterized.product(renormalize_topk_logits=[True, False], )
|
|
194
|
-
def test_basic(self, renormalize_topk_logits):
|
|
195
|
-
dtype = jnp.bfloat16
|
|
196
|
-
top_k = 8
|
|
197
|
-
num_experts = 128
|
|
198
|
-
hidden_size = 1024
|
|
199
|
-
intermediate_size = 1024
|
|
200
|
-
num_tokens = 8 * 32
|
|
201
|
-
self._test_moe(
|
|
202
|
-
dtype=dtype,
|
|
203
|
-
top_k=top_k,
|
|
204
|
-
num_experts=num_experts,
|
|
205
|
-
hidden_size=hidden_size,
|
|
206
|
-
intermediate_size=intermediate_size,
|
|
207
|
-
num_tokens=num_tokens,
|
|
208
|
-
seed=1234,
|
|
209
|
-
renormalize_topk_logits=renormalize_topk_logits,
|
|
210
|
-
bt=32,
|
|
211
|
-
bf=1024,
|
|
212
|
-
bd1=1024,
|
|
213
|
-
bd2=1024,
|
|
214
|
-
btc=32,
|
|
215
|
-
bfc=256,
|
|
216
|
-
bd1c=256,
|
|
217
|
-
bd2c=256,
|
|
218
|
-
)
|
|
219
|
-
|
|
220
|
-
@parameterized.product(act_fn=["silu", "gelu", "swigluoai"], )
|
|
221
|
-
def test_activation(self, act_fn):
|
|
222
|
-
dtype = jnp.bfloat16
|
|
223
|
-
top_k = 8
|
|
224
|
-
num_experts = 128
|
|
225
|
-
hidden_size = 1024
|
|
226
|
-
intermediate_size = 1024
|
|
227
|
-
num_tokens = 8 * 32
|
|
228
|
-
self._test_moe(
|
|
229
|
-
dtype=dtype,
|
|
230
|
-
top_k=top_k,
|
|
231
|
-
num_experts=num_experts,
|
|
232
|
-
hidden_size=hidden_size,
|
|
233
|
-
intermediate_size=intermediate_size,
|
|
234
|
-
num_tokens=num_tokens,
|
|
235
|
-
seed=1234,
|
|
236
|
-
renormalize_topk_logits=True,
|
|
237
|
-
act_fn=act_fn,
|
|
238
|
-
bt=32,
|
|
239
|
-
bf=512,
|
|
240
|
-
bd1=512,
|
|
241
|
-
bd2=512,
|
|
242
|
-
btc=32,
|
|
243
|
-
bfc=256,
|
|
244
|
-
bd1c=256,
|
|
245
|
-
bd2c=256,
|
|
246
|
-
)
|
|
247
|
-
|
|
248
|
-
def test_benchmark_qwen_235(self):
|
|
249
|
-
num_experts = 128
|
|
250
|
-
top_k = 8
|
|
251
|
-
hidden_size = 4096
|
|
252
|
-
intermediate_size = 1536
|
|
253
|
-
dtype = jnp.bfloat16
|
|
254
|
-
num_tokens = 8 * 64
|
|
255
|
-
seed = 54321
|
|
256
|
-
renormalize_topk_logits = True
|
|
257
|
-
self._test_moe(
|
|
258
|
-
dtype=dtype,
|
|
259
|
-
top_k=top_k,
|
|
260
|
-
num_experts=num_experts,
|
|
261
|
-
hidden_size=hidden_size,
|
|
262
|
-
intermediate_size=intermediate_size,
|
|
263
|
-
num_tokens=num_tokens,
|
|
264
|
-
seed=seed,
|
|
265
|
-
renormalize_topk_logits=renormalize_topk_logits,
|
|
266
|
-
bt=64,
|
|
267
|
-
bf=768,
|
|
268
|
-
bd1=2048,
|
|
269
|
-
bd2=2048,
|
|
270
|
-
btc=64,
|
|
271
|
-
bfc=768,
|
|
272
|
-
bd1c=2048,
|
|
273
|
-
bd2c=2048,
|
|
274
|
-
act_fn="silu",
|
|
275
|
-
atol=5e-2,
|
|
276
|
-
rtol=5e-2,
|
|
277
|
-
)
|
|
278
|
-
|
|
279
|
-
def test_benchmark_qwen_30b_a3b(self):
|
|
280
|
-
num_experts = 128
|
|
281
|
-
top_k = 8
|
|
282
|
-
hidden_size = 2048
|
|
283
|
-
intermediate_size = 768
|
|
284
|
-
dtype = jnp.bfloat16
|
|
285
|
-
num_tokens = 512
|
|
286
|
-
seed = 54321
|
|
287
|
-
renormalize_topk_logits = True
|
|
288
|
-
self._test_moe(
|
|
289
|
-
dtype=dtype,
|
|
290
|
-
top_k=top_k,
|
|
291
|
-
num_experts=num_experts,
|
|
292
|
-
hidden_size=hidden_size,
|
|
293
|
-
intermediate_size=intermediate_size,
|
|
294
|
-
num_tokens=num_tokens,
|
|
295
|
-
seed=seed,
|
|
296
|
-
renormalize_topk_logits=renormalize_topk_logits,
|
|
297
|
-
bt=16,
|
|
298
|
-
bf=384,
|
|
299
|
-
bd1=512,
|
|
300
|
-
bd2=512,
|
|
301
|
-
btc=16,
|
|
302
|
-
bfc=384,
|
|
303
|
-
bd1c=256,
|
|
304
|
-
bd2c=256,
|
|
305
|
-
act_fn="silu",
|
|
306
|
-
atol=5e-2,
|
|
307
|
-
rtol=5e-2,
|
|
308
|
-
)
|
|
309
|
-
|
|
310
|
-
@parameterized.product(
|
|
311
|
-
w_dtype=[jnp.int8, jnp.float8_e5m2, jnp.float4_e2m1fn], )
|
|
312
|
-
def test_sub_channel_quantization(self, w_dtype):
|
|
313
|
-
if w_dtype in (
|
|
314
|
-
jnp.float8_e5m2,
|
|
315
|
-
jnp.float4_e2m1fn,
|
|
316
|
-
) and not jtu.is_device_tpu_at_least(version=7):
|
|
317
|
-
self.skipTest("Expect TPUv7+")
|
|
318
|
-
dtype = jnp.bfloat16
|
|
319
|
-
top_k = 8
|
|
320
|
-
num_experts = 128
|
|
321
|
-
hidden_size = 1024
|
|
322
|
-
intermediate_size = 1024
|
|
323
|
-
num_tokens = 8 * 32
|
|
324
|
-
self._test_moe(
|
|
325
|
-
dtype=dtype,
|
|
326
|
-
top_k=top_k,
|
|
327
|
-
num_experts=num_experts,
|
|
328
|
-
hidden_size=hidden_size,
|
|
329
|
-
intermediate_size=intermediate_size,
|
|
330
|
-
num_tokens=num_tokens,
|
|
331
|
-
seed=1234,
|
|
332
|
-
renormalize_topk_logits=False,
|
|
333
|
-
w_dtype=w_dtype,
|
|
334
|
-
subc_quant_wsz=256,
|
|
335
|
-
bt=32,
|
|
336
|
-
bf=1024,
|
|
337
|
-
bd1=1024,
|
|
338
|
-
bd2=1024,
|
|
339
|
-
btc=32,
|
|
340
|
-
bfc=256,
|
|
341
|
-
bd1c=256,
|
|
342
|
-
bd2c=256,
|
|
343
|
-
)
|
|
344
|
-
|
|
345
|
-
def test_bias(self):
|
|
346
|
-
dtype = jnp.bfloat16
|
|
347
|
-
top_k = 8
|
|
348
|
-
num_experts = 128
|
|
349
|
-
hidden_size = 1024
|
|
350
|
-
intermediate_size = 1024
|
|
351
|
-
num_tokens = 8 * 32
|
|
352
|
-
self._test_moe(
|
|
353
|
-
dtype=dtype,
|
|
354
|
-
top_k=top_k,
|
|
355
|
-
num_experts=num_experts,
|
|
356
|
-
hidden_size=hidden_size,
|
|
357
|
-
intermediate_size=intermediate_size,
|
|
358
|
-
num_tokens=num_tokens,
|
|
359
|
-
seed=1234,
|
|
360
|
-
renormalize_topk_logits=False,
|
|
361
|
-
has_bias=True,
|
|
362
|
-
bt=32,
|
|
363
|
-
bf=512,
|
|
364
|
-
bd1=512,
|
|
365
|
-
bd2=512,
|
|
366
|
-
btc=32,
|
|
367
|
-
bfc=256,
|
|
368
|
-
bd1c=256,
|
|
369
|
-
bd2c=256,
|
|
370
|
-
)
|
|
83
|
+
actual = jax.block_until_ready(
|
|
84
|
+
fused_ep_moe(
|
|
85
|
+
mesh=self.mesh,
|
|
86
|
+
tokens=a,
|
|
87
|
+
w1=w1,
|
|
88
|
+
w2=w2,
|
|
89
|
+
gating_output=gating_output,
|
|
90
|
+
top_k=top_k,
|
|
91
|
+
bt=32,
|
|
92
|
+
bf=512,
|
|
93
|
+
bd1=512,
|
|
94
|
+
bd2=512,
|
|
95
|
+
btc=32,
|
|
96
|
+
bfc=256,
|
|
97
|
+
bd1c=256,
|
|
98
|
+
bd2c=256,
|
|
99
|
+
))
|
|
100
|
+
expected = ref_moe(a, w1, w2, gating_output, top_k)
|
|
101
|
+
self.assertAllClose(expected, actual, atol=2e-2, rtol=2e-2)
|
|
371
102
|
|
|
372
103
|
|
|
373
104
|
if __name__ == "__main__":
|
|
@@ -99,7 +99,7 @@ class RaggedPagedAttentionHeadDim64KernelTest(jtu.JaxTestCase):
|
|
|
99
99
|
(0, 0),
|
|
100
100
|
(0, 0),
|
|
101
101
|
),
|
|
102
|
-
constant_values=
|
|
102
|
+
constant_values=jnp.nan,
|
|
103
103
|
).reshape(
|
|
104
104
|
-1,
|
|
105
105
|
page_size,
|
|
@@ -122,7 +122,7 @@ class RaggedPagedAttentionHeadDim64KernelTest(jtu.JaxTestCase):
|
|
|
122
122
|
kv_cache,
|
|
123
123
|
((0, num_pages - kv_cache.shape[0]), (0, 0), (0, 0), (0, 0),
|
|
124
124
|
(0, 0)),
|
|
125
|
-
constant_values=
|
|
125
|
+
constant_values=jnp.nan,
|
|
126
126
|
)
|
|
127
127
|
page_indices = jnp.stack(page_indices_list, axis=0)
|
|
128
128
|
page_indices = jnp.pad(
|
tests/lora/test_layers.py
CHANGED
|
@@ -91,6 +91,7 @@ def populate_loras(
|
|
|
91
91
|
index_to_id: list[Optional[int]],
|
|
92
92
|
lora_layer: BaseLayerWithLoRA,
|
|
93
93
|
baselayer_weights: torch.Tensor,
|
|
94
|
+
generate_embeddings_tensor: int = 0,
|
|
94
95
|
repeats: int = 1,
|
|
95
96
|
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
|
|
96
97
|
"""This method populates the lora weights (lora_a and lora_b) in the lora layers (BaseLayerWithLoRA).
|
|
@@ -102,6 +103,8 @@ def populate_loras(
|
|
|
102
103
|
lora_layer: the LoRAlayer to populate.
|
|
103
104
|
baselayer_weights: the PyTorch tensor containing the layer's
|
|
104
105
|
weights.
|
|
106
|
+
generate_embeddings_tensor: whether to generate an
|
|
107
|
+
embeddings tensor for each LoRA.
|
|
105
108
|
repeats: must only be set for column parallel packed
|
|
106
109
|
layers. Indicates the number of loras to compose
|
|
107
110
|
together to create a single lora layer.
|
|
@@ -128,6 +131,7 @@ def populate_loras(
|
|
|
128
131
|
baselayer_weights.device).init_random_lora(
|
|
129
132
|
module_name=f"fake_{i}",
|
|
130
133
|
weight=baselayer_weights,
|
|
134
|
+
generate_embeddings_tensor=generate_embeddings_tensor,
|
|
131
135
|
)
|
|
132
136
|
sublora.lora_b = sublora.lora_b[(sublora_len *
|
|
133
137
|
i):(sublora_len * (i + 1)), :]
|
|
@@ -143,6 +147,7 @@ def populate_loras(
|
|
|
143
147
|
slot_idx,
|
|
144
148
|
lora_a=lora.lora_a,
|
|
145
149
|
lora_b=lora.lora_b,
|
|
150
|
+
embeddings_tensor=lora.embeddings_tensor,
|
|
146
151
|
)
|
|
147
152
|
|
|
148
153
|
lora_dict[lora_id] = lora
|
|
@@ -541,6 +546,7 @@ def _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
|
|
|
541
546
|
index_to_id,
|
|
542
547
|
lora_config.max_loras,
|
|
543
548
|
vocab_size=512,
|
|
549
|
+
extra_vocab_size=lora_config.lora_extra_vocab_size,
|
|
544
550
|
)
|
|
545
551
|
assert jax_view(punica_wrapper._lora_indices_per_batch).platform(
|
|
546
552
|
) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
|
tests/lora/utils.py
CHANGED
|
@@ -24,6 +24,7 @@ class DummyLoRAManager:
|
|
|
24
24
|
module_name: str,
|
|
25
25
|
weight: torch.Tensor,
|
|
26
26
|
rank: int = 8,
|
|
27
|
+
generate_embeddings_tensor: int = 0,
|
|
27
28
|
):
|
|
28
29
|
lora = LoRALayerWeights(
|
|
29
30
|
module_name,
|
|
@@ -36,6 +37,13 @@ class DummyLoRAManager:
|
|
|
36
37
|
dtype=weight.dtype,
|
|
37
38
|
device=self._device),
|
|
38
39
|
)
|
|
40
|
+
if generate_embeddings_tensor:
|
|
41
|
+
lora.embeddings_tensor = torch.rand(
|
|
42
|
+
5,
|
|
43
|
+
generate_embeddings_tensor,
|
|
44
|
+
dtype=weight.dtype,
|
|
45
|
+
device=self._device,
|
|
46
|
+
)
|
|
39
47
|
self.set_module_lora(module_name, lora)
|
|
40
48
|
|
|
41
49
|
return lora
|
tests/test_utils.py
CHANGED
|
@@ -75,34 +75,25 @@ def test_hbm_usage_bytes_pathways_enabled(mock_devices, mock_live_arrays):
|
|
|
75
75
|
mock_device2 = MagicMock()
|
|
76
76
|
devices = [mock_device1, mock_device2]
|
|
77
77
|
|
|
78
|
-
# Create mock
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
78
|
+
# Create mock device buffers
|
|
79
|
+
mock_buffer1_dev1 = MagicMock()
|
|
80
|
+
mock_buffer1_dev1.device = mock_device1
|
|
81
|
+
mock_buffer1_dev1.nbytes = 2000 # 2000 bytes on device1
|
|
82
82
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
83
|
+
mock_buffer1_dev2 = MagicMock()
|
|
84
|
+
mock_buffer1_dev2.device = mock_device2
|
|
85
|
+
mock_buffer1_dev2.nbytes = 2000 # 2000 bytes on device2
|
|
86
86
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
87
|
+
mock_buffer2_dev1 = MagicMock()
|
|
88
|
+
mock_buffer2_dev1.device = mock_device1
|
|
89
|
+
mock_buffer2_dev1.nbytes = 1000 # 1000 bytes on device1
|
|
90
90
|
|
|
91
|
-
|
|
92
|
-
mock_shard1_dev1.data = mock_data1_dev1
|
|
93
|
-
|
|
94
|
-
mock_shard1_dev2 = MagicMock()
|
|
95
|
-
mock_shard1_dev2.data = mock_data1_dev2
|
|
96
|
-
|
|
97
|
-
mock_shard2_dev1 = MagicMock()
|
|
98
|
-
mock_shard2_dev1.data = mock_data2_dev1
|
|
99
|
-
|
|
100
|
-
# Create mock arrays with addressable_shards
|
|
91
|
+
# Create mock arrays with device buffers
|
|
101
92
|
mock_array1 = MagicMock()
|
|
102
|
-
mock_array1.
|
|
93
|
+
mock_array1.device_buffers = [mock_buffer1_dev1, mock_buffer1_dev2]
|
|
103
94
|
|
|
104
95
|
mock_array2 = MagicMock()
|
|
105
|
-
mock_array2.
|
|
96
|
+
mock_array2.device_buffers = [mock_buffer2_dev1]
|
|
106
97
|
|
|
107
98
|
mock_live_arrays.return_value = [mock_array1, mock_array2]
|
|
108
99
|
|
|
@@ -168,7 +159,7 @@ def test_hbm_usage_bytes_pathways_no_arrays(mock_devices, mock_live_arrays):
|
|
|
168
159
|
"head_dim, expected_padded_head_dim",
|
|
169
160
|
[
|
|
170
161
|
(1, 128),
|
|
171
|
-
(64,
|
|
162
|
+
(64, 128),
|
|
172
163
|
(127, 128),
|
|
173
164
|
(128, 128),
|
|
174
165
|
(129, 256),
|
|
@@ -231,5 +222,6 @@ def test_get_jax_dtype_from_str_dtype():
|
|
|
231
222
|
assert get_jax_dtype_from_str_dtype("int8") == jnp.int8
|
|
232
223
|
assert get_jax_dtype_from_str_dtype("bfloat16") == jnp.bfloat16
|
|
233
224
|
assert get_jax_dtype_from_str_dtype("fp8") == jnp.float8_e4m3fn
|
|
234
|
-
assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.
|
|
225
|
+
assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.float8_e4m3
|
|
235
226
|
assert get_jax_dtype_from_str_dtype("fp8_e5m2") == jnp.float8_e5m2
|
|
227
|
+
assert get_jax_dtype_from_str_dtype("auto") is None
|
tpu_inference/__init__.py
CHANGED
|
@@ -1,40 +1,21 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
1
3
|
# The environment variables override should be imported before any other
|
|
2
4
|
# modules to ensure that the environment variables are set before any
|
|
3
5
|
# other modules are imported.
|
|
4
6
|
import tpu_inference.env_override # noqa: F401
|
|
5
|
-
from tpu_inference import envs
|
|
6
7
|
from tpu_inference import tpu_info as ti
|
|
7
8
|
from tpu_inference.logger import init_logger
|
|
8
9
|
|
|
9
10
|
logger = init_logger(__name__)
|
|
10
11
|
|
|
11
|
-
if "proxy" in
|
|
12
|
+
if "proxy" in os.environ.get('JAX_PLATFORMS', '').lower():
|
|
12
13
|
logger.info("Running vLLM on TPU via Pathways proxy.")
|
|
13
14
|
# Must run pathwaysutils.initialize() before any JAX operations
|
|
14
15
|
try:
|
|
15
|
-
import traceback
|
|
16
|
-
|
|
17
16
|
import pathwaysutils
|
|
18
|
-
import vllm
|
|
19
|
-
from vllm.platforms import (resolve_current_platform_cls_qualname,
|
|
20
|
-
resolve_obj_by_qualname)
|
|
21
17
|
pathwaysutils.initialize()
|
|
22
18
|
logger.info("Module pathwaysutils is imported.")
|
|
23
|
-
|
|
24
|
-
# Pathways requires eager resolution of vllm.current_platform instead of
|
|
25
|
-
# lazy resolution in the normal code path. Since this part involves
|
|
26
|
-
# global topology discovery across multiple hosts, the platform
|
|
27
|
-
# resolution must happen before other components are loaded.
|
|
28
|
-
logger.info("Eagerly resolving vLLM current_platform for Pathways.")
|
|
29
|
-
platform_cls_qualname = resolve_current_platform_cls_qualname()
|
|
30
|
-
resolved_platform_instance = resolve_obj_by_qualname(
|
|
31
|
-
platform_cls_qualname)()
|
|
32
|
-
vllm.platforms._current_platform = resolved_platform_instance
|
|
33
|
-
vllm.platforms._init_trace = "".join(traceback.format_stack())
|
|
34
|
-
logger.info(
|
|
35
|
-
f"vLLM platform resolved to: {resolved_platform_instance.__class__.__name__}"
|
|
36
|
-
)
|
|
37
|
-
|
|
38
19
|
except Exception as e:
|
|
39
20
|
logger.error(
|
|
40
21
|
f"Error occurred while importing pathwaysutils or logging TPU info: {e}"
|
tpu_inference/core/core_tpu.py
CHANGED
|
@@ -29,7 +29,6 @@ from vllm.v1.request import Request, RequestStatus
|
|
|
29
29
|
|
|
30
30
|
from tpu_inference import utils as common_utils
|
|
31
31
|
from tpu_inference.core import disagg_executor, disagg_utils
|
|
32
|
-
from tpu_inference.runner.tpu_runner import AsyncTPUModelRunnerOutput
|
|
33
32
|
# ======================================================================================
|
|
34
33
|
# Imports for _DisaggOrchestrator (decoupled from vLLM)
|
|
35
34
|
# ======================================================================================
|
|
@@ -187,8 +186,6 @@ class _DisaggOrchestrator:
|
|
|
187
186
|
if model_output is None:
|
|
188
187
|
model_output = prefill_engine.model_executor.sample_tokens(
|
|
189
188
|
grammar_output)
|
|
190
|
-
if isinstance(model_output, AsyncTPUModelRunnerOutput):
|
|
191
|
-
model_output = model_output.get_output()
|
|
192
189
|
|
|
193
190
|
if scheduler_output.total_num_scheduled_tokens > 0:
|
|
194
191
|
logger.debug(f"Prefill result: {model_output}")
|
|
@@ -221,16 +218,15 @@ class _DisaggOrchestrator:
|
|
|
221
218
|
f"request-{req_id}: tokens={request.all_token_ids} after prefill"
|
|
222
219
|
)
|
|
223
220
|
# Remove request from the prefill engine.
|
|
224
|
-
if req_id in prefill_engine.scheduler.requests:
|
|
225
|
-
request = prefill_engine.scheduler.requests[req_id]
|
|
226
|
-
prefill_engine.scheduler.running.remove(request)
|
|
227
|
-
prefill_engine.scheduler.encoder_cache_manager.free(
|
|
228
|
-
request)
|
|
229
221
|
|
|
230
|
-
|
|
231
|
-
|
|
222
|
+
request = prefill_engine.scheduler.requests[req_id]
|
|
223
|
+
prefill_engine.scheduler.running.remove(request)
|
|
224
|
+
prefill_engine.scheduler.encoder_cache_manager.free(
|
|
225
|
+
request)
|
|
232
226
|
|
|
233
|
-
|
|
227
|
+
prefill_engine.scheduler.kv_cache_manager.free(request)
|
|
228
|
+
|
|
229
|
+
prefill_engine.scheduler.requests.pop(req_id)
|
|
234
230
|
|
|
235
231
|
for output in (engine_core_outputs.items()
|
|
236
232
|
if engine_core_outputs else ()):
|
|
@@ -339,10 +335,8 @@ class _DisaggOrchestrator:
|
|
|
339
335
|
new_block_ids = kv_cache_manager.get_block_ids(req_id)
|
|
340
336
|
logger.debug(
|
|
341
337
|
f"inserting {req_id} new_block_ids {new_block_ids}")
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
logger.warning("Running out of blocks in decode engine! ")
|
|
345
|
-
break
|
|
338
|
+
assert (len(new_block_ids[0]) == math.ceil(
|
|
339
|
+
prompt_tokens / self._config.cache_config.block_size))
|
|
346
340
|
|
|
347
341
|
decode_engine.model_executor.driver_worker.model_runner.insert_request_with_kv_cache(
|
|
348
342
|
vllm_request, kv_cache, new_block_ids)
|
|
@@ -372,8 +366,6 @@ class _DisaggOrchestrator:
|
|
|
372
366
|
if model_output is None:
|
|
373
367
|
model_output = decode_engine.model_executor.sample_tokens(
|
|
374
368
|
grammar_output)
|
|
375
|
-
if isinstance(model_output, AsyncTPUModelRunnerOutput):
|
|
376
|
-
model_output = model_output.get_output()
|
|
377
369
|
|
|
378
370
|
if scheduler_output.total_num_scheduled_tokens > 0:
|
|
379
371
|
logger.debug(f"Decode result: {model_output}")
|
|
@@ -1,15 +1,17 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
|
|
3
|
+
import os
|
|
3
4
|
from typing import Tuple
|
|
4
5
|
|
|
5
|
-
|
|
6
|
+
PREFILL_SLICES = 'PREFILL_SLICES'
|
|
7
|
+
DECODE_SLICES = 'DECODE_SLICES'
|
|
6
8
|
|
|
7
9
|
|
|
8
10
|
def is_disagg_enabled() -> bool:
|
|
9
11
|
# We triggrer our code path as long as prefill slices are set. This
|
|
10
12
|
# allows us to test interleave mode effectively with the code path
|
|
11
13
|
# for comparison purposes.
|
|
12
|
-
return
|
|
14
|
+
return PREFILL_SLICES in os.environ
|
|
13
15
|
|
|
14
16
|
|
|
15
17
|
def _parse_slices(slices_str: str) -> Tuple[int, ...]:
|
|
@@ -38,12 +40,12 @@ def _parse_slices(slices_str: str) -> Tuple[int, ...]:
|
|
|
38
40
|
|
|
39
41
|
|
|
40
42
|
def get_prefill_slices() -> Tuple[int, ...]:
|
|
41
|
-
if not
|
|
43
|
+
if PREFILL_SLICES not in os.environ:
|
|
42
44
|
return ()
|
|
43
|
-
return _parse_slices(
|
|
45
|
+
return _parse_slices(os.environ[PREFILL_SLICES])
|
|
44
46
|
|
|
45
47
|
|
|
46
48
|
def get_decode_slices() -> Tuple[int, ...]:
|
|
47
|
-
if not
|
|
49
|
+
if DECODE_SLICES not in os.environ:
|
|
48
50
|
return ()
|
|
49
|
-
return _parse_slices(
|
|
51
|
+
return _parse_slices(os.environ[DECODE_SLICES])
|