tpu-inference 0.12.0.dev20251222__py3-none-any.whl → 0.12.0.dev20251224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/core/test_dp_scheduler.py +128 -71
- tests/e2e/test_data_parallel.py +176 -280
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_speculative_decoding.py +26 -6
- tests/layers/jax/test_qwix.py +1 -1
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +36 -21
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +36 -21
- tests/layers/vllm/test_mxfp4.py +25 -10
- tests/layers/vllm/test_unquantized.py +61 -31
- tests/layers/vllm/utils.py +19 -4
- tests/models/common/test_model_loader.py +2 -2
- tests/models/jax/test_qwen2_5_vl.py +10 -11
- tests/runner/test_multimodal_manager.py +3 -3
- tests/runner/test_tpu_runner.py +67 -8
- tests/runner/test_tpu_runner_dp.py +66 -0
- tpu_inference/core/sched/dp_scheduler.py +65 -40
- tpu_inference/kernels/mla/v1/kernel.py +7 -26
- tpu_inference/layers/common/sharding.py +8 -3
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +3 -3
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +3 -3
- tpu_inference/layers/jax/attention/llama4_attention.py +3 -4
- tpu_inference/layers/jax/sample/sampling.py +1 -1
- tpu_inference/layers/vllm/fused_moe.py +51 -47
- tpu_inference/layers/vllm/quantization/common.py +14 -13
- tpu_inference/layers/vllm/quantization/mxfp4.py +21 -7
- tpu_inference/layers/vllm/quantization/unquantized.py +19 -7
- tpu_inference/layers/vllm/sharding.py +7 -4
- tpu_inference/models/common/model_loader.py +11 -14
- tpu_inference/models/jax/llama3.py +13 -10
- tpu_inference/models/jax/llama_guard_4.py +1 -1
- tpu_inference/models/jax/qwen2.py +3 -2
- tpu_inference/models/jax/qwen2_5_vl.py +4 -4
- tpu_inference/models/jax/utils/multi_modal_utils.py +4 -4
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +3 -3
- tpu_inference/models/vllm/vllm_model_wrapper.py +5 -2
- tpu_inference/platforms/tpu_platform.py +7 -7
- tpu_inference/runner/compilation_manager.py +43 -33
- tpu_inference/runner/kv_cache_manager.py +1 -2
- tpu_inference/runner/multimodal_manager.py +1 -1
- tpu_inference/runner/tpu_runner.py +12 -9
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/tpu_worker.py +5 -2
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/METADATA +1 -1
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/RECORD +47 -46
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/top_level.txt +0 -0
|
@@ -18,6 +18,7 @@ from collections import defaultdict, deque
|
|
|
18
18
|
from dataclasses import dataclass
|
|
19
19
|
from enum import Enum
|
|
20
20
|
from multiprocessing import Process, Queue
|
|
21
|
+
from time import time
|
|
21
22
|
from typing import Any, Dict, List, Optional, Tuple
|
|
22
23
|
|
|
23
24
|
import cloudpickle
|
|
@@ -102,7 +103,7 @@ def _disable_cloudpickle():
|
|
|
102
103
|
def _scheduler_worker_process(
|
|
103
104
|
rank: int,
|
|
104
105
|
input_queue: Queue,
|
|
105
|
-
|
|
106
|
+
output_queues: Dict[str, Queue],
|
|
106
107
|
vllm_config: Any,
|
|
107
108
|
kv_cache_config: Any,
|
|
108
109
|
structured_output_manager: Any,
|
|
@@ -135,55 +136,55 @@ def _scheduler_worker_process(
|
|
|
135
136
|
case SchedulerCommand.ADD_REQUEST:
|
|
136
137
|
request = data
|
|
137
138
|
scheduler.add_request(request)
|
|
138
|
-
|
|
139
|
+
output_queues[command.value].put(None) # Signal completion
|
|
139
140
|
|
|
140
141
|
case SchedulerCommand.SCHEDULE:
|
|
141
142
|
output = scheduler.schedule()
|
|
142
|
-
|
|
143
|
+
output_queues[command.value].put(output)
|
|
143
144
|
|
|
144
145
|
case SchedulerCommand.FINISH_REQUESTS:
|
|
145
146
|
request_ids, finished_status = data
|
|
146
147
|
scheduler.finish_requests(request_ids, finished_status)
|
|
147
|
-
|
|
148
|
+
output_queues[command.value].put(None) # Signal completion
|
|
148
149
|
|
|
149
150
|
case SchedulerCommand.UPDATE_DRAFT_TOKEN_IDS:
|
|
150
151
|
draft_token_ids = data
|
|
151
152
|
scheduler.update_draft_token_ids(draft_token_ids)
|
|
152
|
-
|
|
153
|
+
output_queues[command.value].put(None) # Signal completion
|
|
153
154
|
|
|
154
155
|
case SchedulerCommand.UPDATE_FROM_OUTPUT:
|
|
155
156
|
scheduler_output, model_runner_output = data
|
|
156
157
|
result = scheduler.update_from_output(
|
|
157
158
|
scheduler_output, model_runner_output)
|
|
158
|
-
|
|
159
|
+
output_queues[command.value].put(result)
|
|
159
160
|
|
|
160
161
|
case SchedulerCommand.GET_GRAMMAR_BITMASK:
|
|
161
162
|
scheduler_output = data
|
|
162
163
|
result = scheduler.get_grammar_bitmask(scheduler_output)
|
|
163
|
-
|
|
164
|
+
output_queues[command.value].put(result)
|
|
164
165
|
|
|
165
166
|
case SchedulerCommand.MAKE_STATS:
|
|
166
167
|
spec_decoding_stats, kv_connector_stats = data
|
|
167
168
|
result = scheduler.make_stats(spec_decoding_stats,
|
|
168
169
|
kv_connector_stats)
|
|
169
|
-
|
|
170
|
+
output_queues[command.value].put(result)
|
|
170
171
|
|
|
171
172
|
case SchedulerCommand.RESET_PREFIX_CACHE:
|
|
172
173
|
result = scheduler.reset_prefix_cache()
|
|
173
|
-
|
|
174
|
+
output_queues[command.value].put(result)
|
|
174
175
|
|
|
175
176
|
case SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS:
|
|
176
177
|
result = scheduler.get_num_unfinished_requests()
|
|
177
|
-
|
|
178
|
+
output_queues[command.value].put(result)
|
|
178
179
|
|
|
179
180
|
case SchedulerCommand.HAS_FINISHED_REQUESTS:
|
|
180
181
|
result = scheduler.has_finished_requests()
|
|
181
|
-
|
|
182
|
+
output_queues[command.value].put(result)
|
|
182
183
|
|
|
183
184
|
case SchedulerCommand.GET_REQUEST_COUNTS:
|
|
184
185
|
running = len(scheduler.running)
|
|
185
186
|
waiting = len(scheduler.waiting)
|
|
186
|
-
|
|
187
|
+
output_queues[command.value].put((running, waiting))
|
|
187
188
|
|
|
188
189
|
case SchedulerCommand.GET_TOKEN_COUNT:
|
|
189
190
|
# Calculate total tokens across running and waiting requests
|
|
@@ -192,30 +193,29 @@ def _scheduler_worker_process(
|
|
|
192
193
|
total_tokens += len(req.all_token_ids)
|
|
193
194
|
for req in scheduler.waiting:
|
|
194
195
|
total_tokens += len(req.all_token_ids)
|
|
195
|
-
|
|
196
|
+
output_queues[command.value].put(total_tokens)
|
|
196
197
|
|
|
197
198
|
case SchedulerCommand.GET_COMPUTED_BLOCKS:
|
|
198
199
|
request = data
|
|
199
200
|
blocks, cached_tokens = scheduler.kv_cache_manager.get_computed_blocks(
|
|
200
201
|
request)
|
|
201
|
-
|
|
202
|
+
output_queues[command.value].put((blocks, cached_tokens))
|
|
202
203
|
|
|
203
204
|
case SchedulerCommand.SHUTDOWN:
|
|
204
205
|
scheduler.shutdown()
|
|
205
|
-
|
|
206
|
+
output_queues[command.value].put(None) # Signal completion
|
|
206
207
|
break
|
|
207
208
|
case _:
|
|
208
209
|
error = SchedulerWorkerError(
|
|
209
210
|
rank, f"Unknown command: {command}")
|
|
210
|
-
|
|
211
|
+
output_queues[command.value].put(error)
|
|
211
212
|
raise error
|
|
212
213
|
|
|
213
214
|
except Exception as e:
|
|
214
215
|
logger.error(f"Error in scheduler worker {rank}: {e}",
|
|
215
216
|
exc_info=True)
|
|
216
|
-
# Put error on output queue
|
|
217
217
|
error = SchedulerWorkerError(rank, str(e))
|
|
218
|
-
|
|
218
|
+
output_queues[command.value].put(error)
|
|
219
219
|
|
|
220
220
|
|
|
221
221
|
@dataclass
|
|
@@ -276,26 +276,29 @@ class DPScheduler(SchedulerInterface):
|
|
|
276
276
|
# Enable cloudpickle for multiprocessing to handle local functions
|
|
277
277
|
_enable_cloudpickle()
|
|
278
278
|
|
|
279
|
-
# Create worker processes with
|
|
279
|
+
# Create worker processes with separate output queues for each command type
|
|
280
280
|
import multiprocessing
|
|
281
281
|
ctx = multiprocessing.get_context('fork')
|
|
282
282
|
self.input_queues: List[Queue] = []
|
|
283
|
-
self.output_queues:
|
|
283
|
+
self.output_queues: Dict[Tuple[int, str], Queue] = {}
|
|
284
284
|
self.processes: List[Process] = []
|
|
285
285
|
|
|
286
286
|
for rank in range(self.dp_size):
|
|
287
287
|
input_queue = ctx.Queue()
|
|
288
|
-
output_queue = ctx.Queue()
|
|
289
|
-
|
|
290
288
|
self.input_queues.append(input_queue)
|
|
291
|
-
|
|
289
|
+
|
|
290
|
+
output_queues_for_rank: Dict[str, Queue] = {}
|
|
291
|
+
for cmd in SchedulerCommand:
|
|
292
|
+
output_queues_for_rank[cmd.value] = ctx.Queue()
|
|
293
|
+
self.output_queues[(
|
|
294
|
+
rank, cmd.value)] = output_queues_for_rank[cmd.value]
|
|
292
295
|
|
|
293
296
|
process = ctx.Process(
|
|
294
297
|
target=_scheduler_worker_process,
|
|
295
298
|
args=(
|
|
296
299
|
rank,
|
|
297
300
|
input_queue,
|
|
298
|
-
|
|
301
|
+
output_queues_for_rank,
|
|
299
302
|
self.vllm_config,
|
|
300
303
|
self.per_rank_kv_cache_configs[rank],
|
|
301
304
|
structured_output_manager,
|
|
@@ -323,8 +326,24 @@ class DPScheduler(SchedulerInterface):
|
|
|
323
326
|
rank_config.num_blocks = kv_cache_config.num_blocks // self.dp_size
|
|
324
327
|
self.per_rank_kv_cache_configs.append(rank_config)
|
|
325
328
|
|
|
326
|
-
def _get_result_from_queue(self,
|
|
327
|
-
|
|
329
|
+
def _get_result_from_queue(self, rank: int,
|
|
330
|
+
command: SchedulerCommand) -> Any:
|
|
331
|
+
"""Get result from the output queue for a specific rank and command type."""
|
|
332
|
+
queue_obj = self.output_queues[(rank, command.value)]
|
|
333
|
+
try:
|
|
334
|
+
start_time = time()
|
|
335
|
+
result = queue_obj.get()
|
|
336
|
+
end_time = time()
|
|
337
|
+
if end_time - start_time > 1.0:
|
|
338
|
+
logger.warning(
|
|
339
|
+
f"Long wait time ({end_time - start_time:.2f}s) for rank {rank} "
|
|
340
|
+
f"command {command.value} response.")
|
|
341
|
+
except EOFError as e:
|
|
342
|
+
raise RuntimeError(
|
|
343
|
+
f"Queue error for rank {rank} command {command.value}: "
|
|
344
|
+
"Worker process terminated unexpectedly. "
|
|
345
|
+
"This may indicate a crash in the scheduler worker process."
|
|
346
|
+
) from e
|
|
328
347
|
if isinstance(result, SchedulerWorkerError):
|
|
329
348
|
raise result
|
|
330
349
|
return result
|
|
@@ -337,7 +356,8 @@ class DPScheduler(SchedulerInterface):
|
|
|
337
356
|
|
|
338
357
|
rank_tokens = {}
|
|
339
358
|
for rank in range(self.dp_size):
|
|
340
|
-
token_count = self._get_result_from_queue(
|
|
359
|
+
token_count = self._get_result_from_queue(
|
|
360
|
+
rank, SchedulerCommand.GET_TOKEN_COUNT)
|
|
341
361
|
rank_tokens[rank] = token_count
|
|
342
362
|
|
|
343
363
|
return rank_tokens
|
|
@@ -355,7 +375,7 @@ class DPScheduler(SchedulerInterface):
|
|
|
355
375
|
best_cache_tokens = 0
|
|
356
376
|
for rank in range(self.dp_size):
|
|
357
377
|
blocks, cached_tokens = self._get_result_from_queue(
|
|
358
|
-
|
|
378
|
+
rank, SchedulerCommand.GET_COMPUTED_BLOCKS)
|
|
359
379
|
if cached_tokens > best_cache_tokens:
|
|
360
380
|
best_cache_tokens = cached_tokens
|
|
361
381
|
best_cache_rank = rank
|
|
@@ -382,7 +402,7 @@ class DPScheduler(SchedulerInterface):
|
|
|
382
402
|
self.assigned_dp_rank[request.request_id] = rank
|
|
383
403
|
|
|
384
404
|
self.input_queues[rank].put((SchedulerCommand.ADD_REQUEST, request))
|
|
385
|
-
self._get_result_from_queue(
|
|
405
|
+
self._get_result_from_queue(rank, SchedulerCommand.ADD_REQUEST)
|
|
386
406
|
|
|
387
407
|
@time_function
|
|
388
408
|
def schedule(self) -> DPSchedulerOutput:
|
|
@@ -402,7 +422,8 @@ class DPScheduler(SchedulerInterface):
|
|
|
402
422
|
# Collect outputs from all workers (blocking)
|
|
403
423
|
rank_outputs = []
|
|
404
424
|
for rank in range(self.dp_size):
|
|
405
|
-
output = self._get_result_from_queue(
|
|
425
|
+
output = self._get_result_from_queue(rank,
|
|
426
|
+
SchedulerCommand.SCHEDULE)
|
|
406
427
|
rank_outputs.append(output)
|
|
407
428
|
|
|
408
429
|
# Cache scheduler outputs to use in `update_from_output`
|
|
@@ -531,7 +552,7 @@ class DPScheduler(SchedulerInterface):
|
|
|
531
552
|
rank_scheduler_outputs[rank]))
|
|
532
553
|
for rank in range(self.dp_size):
|
|
533
554
|
grammar_output = self._get_result_from_queue(
|
|
534
|
-
|
|
555
|
+
rank, SchedulerCommand.GET_GRAMMAR_BITMASK)
|
|
535
556
|
if grammar_output is not None:
|
|
536
557
|
combined_structured_output_request_ids.extend(
|
|
537
558
|
grammar_output.structured_output_request_ids)
|
|
@@ -572,7 +593,7 @@ class DPScheduler(SchedulerInterface):
|
|
|
572
593
|
combined_engine_outputs = defaultdict(list)
|
|
573
594
|
for rank in range(self.dp_size):
|
|
574
595
|
rank_engine_outputs = self._get_result_from_queue(
|
|
575
|
-
|
|
596
|
+
rank, SchedulerCommand.UPDATE_FROM_OUTPUT)
|
|
576
597
|
for client_idx, engine_output in rank_engine_outputs.items():
|
|
577
598
|
combined_engine_outputs[client_idx].append(engine_output)
|
|
578
599
|
|
|
@@ -640,7 +661,7 @@ class DPScheduler(SchedulerInterface):
|
|
|
640
661
|
for rank, req_ids in rank_request_ids.items():
|
|
641
662
|
self.input_queues[rank].put(
|
|
642
663
|
(SchedulerCommand.FINISH_REQUESTS, (req_ids, finished_status)))
|
|
643
|
-
self._get_result_from_queue(
|
|
664
|
+
self._get_result_from_queue(rank, SchedulerCommand.FINISH_REQUESTS)
|
|
644
665
|
|
|
645
666
|
def get_num_unfinished_requests(self) -> int:
|
|
646
667
|
"""Get total number of unfinished requests across all DP ranks."""
|
|
@@ -650,7 +671,8 @@ class DPScheduler(SchedulerInterface):
|
|
|
650
671
|
|
|
651
672
|
total = 0
|
|
652
673
|
for rank in range(self.dp_size):
|
|
653
|
-
count = self._get_result_from_queue(
|
|
674
|
+
count = self._get_result_from_queue(
|
|
675
|
+
rank, SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS)
|
|
654
676
|
total += count
|
|
655
677
|
return total
|
|
656
678
|
|
|
@@ -663,7 +685,7 @@ class DPScheduler(SchedulerInterface):
|
|
|
663
685
|
has_finished_any = False
|
|
664
686
|
for rank in range(self.dp_size):
|
|
665
687
|
has_finished_any |= self._get_result_from_queue(
|
|
666
|
-
|
|
688
|
+
rank, SchedulerCommand.HAS_FINISHED_REQUESTS)
|
|
667
689
|
return has_finished_any
|
|
668
690
|
|
|
669
691
|
def get_request_counts(self) -> Tuple[int, int]:
|
|
@@ -676,7 +698,7 @@ class DPScheduler(SchedulerInterface):
|
|
|
676
698
|
total_waiting = 0
|
|
677
699
|
for rank in range(self.dp_size):
|
|
678
700
|
running, waiting = self._get_result_from_queue(
|
|
679
|
-
|
|
701
|
+
rank, SchedulerCommand.GET_REQUEST_COUNTS)
|
|
680
702
|
total_running += running
|
|
681
703
|
total_waiting += waiting
|
|
682
704
|
return total_running, total_waiting
|
|
@@ -689,7 +711,8 @@ class DPScheduler(SchedulerInterface):
|
|
|
689
711
|
|
|
690
712
|
all_success = True
|
|
691
713
|
for rank in range(self.dp_size):
|
|
692
|
-
success = self._get_result_from_queue(
|
|
714
|
+
success = self._get_result_from_queue(
|
|
715
|
+
rank, SchedulerCommand.RESET_PREFIX_CACHE)
|
|
693
716
|
all_success &= success
|
|
694
717
|
return all_success
|
|
695
718
|
|
|
@@ -715,7 +738,8 @@ class DPScheduler(SchedulerInterface):
|
|
|
715
738
|
kv_connector_stats)))
|
|
716
739
|
|
|
717
740
|
for rank in range(self.dp_size):
|
|
718
|
-
rank_stats = self._get_result_from_queue(
|
|
741
|
+
rank_stats = self._get_result_from_queue(
|
|
742
|
+
rank, SchedulerCommand.MAKE_STATS)
|
|
719
743
|
if rank_stats is None:
|
|
720
744
|
continue
|
|
721
745
|
|
|
@@ -776,7 +800,8 @@ class DPScheduler(SchedulerInterface):
|
|
|
776
800
|
self.input_queues[rank].put(
|
|
777
801
|
(SchedulerCommand.UPDATE_DRAFT_TOKEN_IDS,
|
|
778
802
|
rank_draft_token_ids))
|
|
779
|
-
self._get_result_from_queue(
|
|
803
|
+
self._get_result_from_queue(
|
|
804
|
+
rank, SchedulerCommand.UPDATE_DRAFT_TOKEN_IDS)
|
|
780
805
|
|
|
781
806
|
def shutdown(self) -> None:
|
|
782
807
|
"""Shutdown all DP rank scheduler worker processes."""
|
|
@@ -786,7 +811,7 @@ class DPScheduler(SchedulerInterface):
|
|
|
786
811
|
|
|
787
812
|
# Wait for acknowledgment (blocking)
|
|
788
813
|
for rank in range(self.dp_size):
|
|
789
|
-
self._get_result_from_queue(
|
|
814
|
+
self._get_result_from_queue(rank, SchedulerCommand.SHUTDOWN)
|
|
790
815
|
|
|
791
816
|
# Terminate and join all processes
|
|
792
817
|
for process in self.processes:
|
|
@@ -822,36 +822,17 @@ def _mla_ragged_paged_attention_kernel(
|
|
|
822
822
|
return q_nope_vec, q_rope_vec
|
|
823
823
|
|
|
824
824
|
def load_bkv(bkv_sem_idx, *, bkvc_mask, bkpe_mask):
|
|
825
|
-
bitwidth = 32 // kv_packing
|
|
826
|
-
repack_ty = jnp.dtype(f"uint{bitwidth}")
|
|
827
825
|
bkvc_ref = (bkvc_x2_ref.bitcast(jnp.uint32).at[bkv_sem_idx].reshape(
|
|
828
826
|
bkv_sz_per_kv_packing, lkv_dim))
|
|
829
|
-
bkvc_vec = bkvc_ref[...]
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
masked_bkvc_vec = bkvc_vec >> (i * bitwidth)
|
|
833
|
-
bkvc_vecs.append(masked_bkvc_vec)
|
|
834
|
-
concated_bkvc_vec = jnp.concatenate(bkvc_vecs, axis=-1)
|
|
835
|
-
concated_bkvc_vec = concated_bkvc_vec.reshape(bkv_sz, lkv_dim)
|
|
836
|
-
concated_bkvc_vec = lax.select(bkvc_mask, concated_bkvc_vec,
|
|
837
|
-
jnp.zeros_like(concated_bkvc_vec))
|
|
838
|
-
concated_bkvc_vec = pltpu.bitcast(concated_bkvc_vec.astype(repack_ty),
|
|
839
|
-
kv_dtype)
|
|
827
|
+
bkvc_vec = pltpu.bitcast(bkvc_ref[...], kv_dtype)
|
|
828
|
+
bkvc_vec = lax.select(bkvc_mask, bkvc_vec, jnp.zeros_like(bkvc_vec))
|
|
829
|
+
|
|
840
830
|
bkpe_ref = (bkpe_x2_ref.bitcast(jnp.uint32).at[bkv_sem_idx].reshape(
|
|
841
831
|
bkv_sz_per_kv_packing, r_dim))
|
|
842
|
-
bkpe_vec = bkpe_ref[...]
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
bkpe_vecs.append(masked_bkpe_vec)
|
|
847
|
-
concated_bkpe_vec = jnp.concatenate(bkpe_vecs, axis=-1)
|
|
848
|
-
concated_bkpe_vec = concated_bkpe_vec.reshape(bkv_sz, r_dim)
|
|
849
|
-
concated_bkpe_vec = lax.select(bkpe_mask, concated_bkpe_vec,
|
|
850
|
-
jnp.zeros_like(concated_bkpe_vec))
|
|
851
|
-
concated_bkpe_vec = pltpu.bitcast(concated_bkpe_vec.astype(repack_ty),
|
|
852
|
-
kv_dtype)
|
|
853
|
-
|
|
854
|
-
return concated_bkvc_vec, concated_bkpe_vec
|
|
832
|
+
bkpe_vec = pltpu.bitcast(bkpe_ref[...], kv_dtype)
|
|
833
|
+
bkpe_vec = lax.select(bkpe_mask, bkpe_vec, jnp.zeros_like(bkpe_vec))
|
|
834
|
+
|
|
835
|
+
return bkvc_vec, bkpe_vec
|
|
855
836
|
|
|
856
837
|
def broadcast_minor(src, shape):
|
|
857
838
|
if src.shape == shape:
|
|
@@ -40,7 +40,7 @@ class ShardingAxisNameBase:
|
|
|
40
40
|
MLP_TENSOR = ('attn_dp', 'model', 'expert')
|
|
41
41
|
MOE_TENSOR = ('attn_dp', 'model')
|
|
42
42
|
EXPERT = ('attn_dp', 'expert', 'model')
|
|
43
|
-
VOCAB = ('expert', 'model')
|
|
43
|
+
VOCAB = ('expert', 'attn_dp', 'model')
|
|
44
44
|
|
|
45
45
|
|
|
46
46
|
class ShardingAxisName2D:
|
|
@@ -141,6 +141,11 @@ class ShardingConfigManager:
|
|
|
141
141
|
kv_dtype = utils.get_jax_dtype_from_str_dtype(
|
|
142
142
|
cache_dtype) or jnp.bfloat16
|
|
143
143
|
packing = 4 // jnp.dtype(kv_dtype).itemsize
|
|
144
|
+
|
|
145
|
+
# The default head dim is 128 but 64 is also supported as a special case.
|
|
146
|
+
if vllm_config.model_config.get_head_size() == 64:
|
|
147
|
+
packing *= 2
|
|
148
|
+
|
|
144
149
|
# When num_kv_heads * 2 / packing < TP, tensor parallelism would
|
|
145
150
|
# duplicate KV heads across devices, wasting kv cache memory.
|
|
146
151
|
# Use attention DP instead to reduce per-device num_kv_heads and
|
|
@@ -186,8 +191,8 @@ class ShardingConfigManager:
|
|
|
186
191
|
if sharding_strategy.attention_data_parallelism > 1:
|
|
187
192
|
if not envs.NEW_MODEL_DESIGN:
|
|
188
193
|
raise ValueError(
|
|
189
|
-
"Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set
|
|
190
|
-
"NEW_MODEL_DESIGN=True
|
|
194
|
+
"Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set "
|
|
195
|
+
"NEW_MODEL_DESIGN=True")
|
|
191
196
|
|
|
192
197
|
@property
|
|
193
198
|
def total_dp_size(self) -> int:
|
|
@@ -30,6 +30,7 @@ from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
|
|
|
30
30
|
from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes import \
|
|
31
31
|
get_tuned_block_sizes
|
|
32
32
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
33
|
+
from tpu_inference.layers.common.quantization import quantize_kv
|
|
33
34
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
34
35
|
from tpu_inference.layers.jax.base import create_param
|
|
35
36
|
from tpu_inference.layers.jax.layers import RMSNorm
|
|
@@ -310,9 +311,8 @@ class MLA(nnx.Module):
|
|
|
310
311
|
# TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
|
|
311
312
|
k_scale = self._k_scale
|
|
312
313
|
v_scale = self._v_scale
|
|
313
|
-
k_SNH, v_SNH =
|
|
314
|
-
|
|
315
|
-
v_scale)
|
|
314
|
+
k_SNH, v_SNH = quantize_kv(self.kv_cache_quantized_dtype,
|
|
315
|
+
k_SNH, v_SNH, k_scale, v_scale)
|
|
316
316
|
|
|
317
317
|
new_kv_cache, outputs_TNH = self.attention(
|
|
318
318
|
is_prefill,
|
|
@@ -26,6 +26,7 @@ from tpu_inference import utils
|
|
|
26
26
|
from tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 import \
|
|
27
27
|
ragged_paged_attention_hd64
|
|
28
28
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
29
|
+
from tpu_inference.layers.common.quantization import quantize_kv
|
|
29
30
|
from tpu_inference.layers.jax.base import create_param
|
|
30
31
|
from tpu_inference.layers.jax.rope import GptOssRotaryEmbedding
|
|
31
32
|
|
|
@@ -248,9 +249,8 @@ class GptOssAttention(nnx.Module):
|
|
|
248
249
|
# q_scale = self._q_scale
|
|
249
250
|
k_scale = self._k_scale
|
|
250
251
|
v_scale = self._v_scale
|
|
251
|
-
k_TKH, v_TKH =
|
|
252
|
-
|
|
253
|
-
k_scale, v_scale)
|
|
252
|
+
k_TKH, v_TKH = quantize_kv(self.kv_cache_quantized_dtype, k_TKH,
|
|
253
|
+
v_TKH, k_scale, v_scale)
|
|
254
254
|
|
|
255
255
|
with jax.named_scope("attn_op"):
|
|
256
256
|
new_kv_cache, attn_out_TNH = self.attention(
|
|
@@ -19,8 +19,8 @@ import jax.numpy as jnp
|
|
|
19
19
|
from flax import nnx
|
|
20
20
|
from jax.sharding import Sharding
|
|
21
21
|
|
|
22
|
-
from tpu_inference import utils
|
|
23
22
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
23
|
+
from tpu_inference.layers.common.quantization import quantize_kv
|
|
24
24
|
from tpu_inference.layers.jax.attention.attention import Attention, KVCache
|
|
25
25
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
26
26
|
from tpu_inference.logger import init_logger
|
|
@@ -128,9 +128,8 @@ class Llama4Attention(Attention):
|
|
|
128
128
|
# q_scale = self._q_scale
|
|
129
129
|
k_scale = self._k_scale
|
|
130
130
|
v_scale = self._v_scale
|
|
131
|
-
k_SKH, v_SKH =
|
|
132
|
-
|
|
133
|
-
k_scale, v_scale)
|
|
131
|
+
k_SKH, v_SKH = quantize_kv(self.kv_cache_quantized_dtype, k_SKH,
|
|
132
|
+
v_SKH, k_scale, v_scale)
|
|
134
133
|
|
|
135
134
|
with jax.named_scope("attn_op"):
|
|
136
135
|
new_kv_cache, outputs_TNH = self.attention(
|
|
@@ -42,7 +42,7 @@ def sample(
|
|
|
42
42
|
if tpu_sampling_metadata.do_sampling:
|
|
43
43
|
# Unshard the logits explicity to avoid latency increase.
|
|
44
44
|
logits = jax.lax.with_sharding_constraint(
|
|
45
|
-
logits, NamedSharding(mesh, P(ShardingAxisName.
|
|
45
|
+
logits, NamedSharding(mesh, P(ShardingAxisName.MLP_DATA, None)))
|
|
46
46
|
greedy_sampled = jnp.argmax(logits, axis=-1)
|
|
47
47
|
if not tpu_sampling_metadata.do_sampling:
|
|
48
48
|
return greedy_sampled
|
|
@@ -16,12 +16,14 @@ import functools
|
|
|
16
16
|
|
|
17
17
|
import jax
|
|
18
18
|
from jax import numpy as jnp
|
|
19
|
-
from jax.sharding import Mesh
|
|
19
|
+
from jax.sharding import Mesh, NamedSharding
|
|
20
20
|
from jax.sharding import PartitionSpec as P
|
|
21
21
|
|
|
22
22
|
from tpu_inference.kernels.megablox.gmm import gmm
|
|
23
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
23
24
|
from tpu_inference.layers.vllm.linear_common import \
|
|
24
25
|
slice_sharded_tensor_for_concatenation
|
|
26
|
+
from tpu_inference.utils import get_mesh_shape_product
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
def activation_fn(activation: str, x1: jax.Array, x2: jax.Array) -> jax.Array:
|
|
@@ -137,25 +139,23 @@ def tensor_sharded_gmm_merged_column_parallel(
|
|
|
137
139
|
group_offset=jnp.array(0),
|
|
138
140
|
)
|
|
139
141
|
|
|
140
|
-
rhs_scale_spec = None if rhs_scale is None else P(
|
|
141
|
-
|
|
142
|
-
rhs_bias_spec = None if rhs_bias is None else P(
|
|
142
|
+
rhs_scale_spec = None if rhs_scale is None else P(
|
|
143
|
+
None, None, None, ShardingAxisName.MLP_TENSOR)
|
|
144
|
+
rhs_bias_spec = None if rhs_bias is None else P(
|
|
145
|
+
None, None, ShardingAxisName.MLP_TENSOR)
|
|
143
146
|
|
|
144
147
|
gmm_result = jax.shard_map(
|
|
145
148
|
_gmm,
|
|
146
149
|
mesh=mesh,
|
|
147
|
-
in_specs=(
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
P("data"),
|
|
153
|
-
),
|
|
154
|
-
out_specs=(P("data", "model")),
|
|
150
|
+
in_specs=(P(ShardingAxisName.MLP_DATA,
|
|
151
|
+
None), P(None, ShardingAxisName.MLP_TENSOR,
|
|
152
|
+
None), rhs_scale_spec, rhs_bias_spec,
|
|
153
|
+
P(ShardingAxisName.MLP_DATA)),
|
|
154
|
+
out_specs=(P(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR)),
|
|
155
155
|
check_vma=False,
|
|
156
156
|
)(lhs, rhs, rhs_scale, rhs_bias, group_sizes)
|
|
157
157
|
|
|
158
|
-
tp_size = mesh.
|
|
158
|
+
tp_size = get_mesh_shape_product(mesh, ShardingAxisName.MLP_TENSOR)
|
|
159
159
|
intermediate_size = gmm_result.shape[-1] // 2
|
|
160
160
|
output_sizes = [intermediate_size, intermediate_size]
|
|
161
161
|
return slice_sharded_tensor_for_concatenation(gmm_result, output_sizes,
|
|
@@ -175,7 +175,7 @@ def tensor_sharded_gmm_row_parallel(
|
|
|
175
175
|
m, g, n, k = lhs.shape[0], *rhs.shape
|
|
176
176
|
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
|
|
177
177
|
if rhs_bias is not None:
|
|
178
|
-
shard_id = jax.lax.axis_index(
|
|
178
|
+
shard_id = jax.lax.axis_index(ShardingAxisName.MLP_TENSOR).sum()
|
|
179
179
|
rhs_bias = jnp.where(shard_id == 0, rhs_bias, 0)
|
|
180
180
|
out = gmm(
|
|
181
181
|
lhs,
|
|
@@ -188,22 +188,19 @@ def tensor_sharded_gmm_row_parallel(
|
|
|
188
188
|
transpose_rhs=True,
|
|
189
189
|
group_offset=jnp.array(0),
|
|
190
190
|
)
|
|
191
|
-
return jax.lax.psum(out, axis_name=
|
|
191
|
+
return jax.lax.psum(out, axis_name=ShardingAxisName.MLP_TENSOR)
|
|
192
192
|
|
|
193
193
|
num_blocks = 1 if rhs_scale is None else rhs_scale.shape[1]
|
|
194
|
-
rhs_scale_spec = None if num_blocks == 1 else P(
|
|
194
|
+
rhs_scale_spec = None if num_blocks == 1 else P(
|
|
195
|
+
None, ShardingAxisName.MLP_TENSOR, None, None)
|
|
195
196
|
rhs_bias_spec = None if rhs_bias is None else P(None, None, None)
|
|
196
197
|
gmm_result = jax.shard_map(
|
|
197
198
|
_gmm_all_reduce,
|
|
198
199
|
mesh=mesh,
|
|
199
|
-
in_specs=(
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
rhs_bias_spec,
|
|
204
|
-
P("data"),
|
|
205
|
-
),
|
|
206
|
-
out_specs=(P("data")),
|
|
200
|
+
in_specs=(P(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR),
|
|
201
|
+
P(None, None, ShardingAxisName.MLP_TENSOR), rhs_scale_spec,
|
|
202
|
+
rhs_bias_spec, P(ShardingAxisName.MLP_DATA)),
|
|
203
|
+
out_specs=(P(ShardingAxisName.MLP_DATA)),
|
|
207
204
|
check_vma=False,
|
|
208
205
|
)(lhs, rhs, rhs_scale, rhs_bias, group_sizes)
|
|
209
206
|
|
|
@@ -219,8 +216,8 @@ def expert_sharded_gmm(
|
|
|
219
216
|
is_last_expert: bool,
|
|
220
217
|
mesh: Mesh,
|
|
221
218
|
) -> jax.Array:
|
|
222
|
-
ep_size = mesh.
|
|
223
|
-
|
|
219
|
+
ep_size = get_mesh_shape_product(mesh, ShardingAxisName.MLP_TENSOR)
|
|
220
|
+
ep_p_spec = P(ShardingAxisName.EXPERT)
|
|
224
221
|
num_experts = rhs.shape[0]
|
|
225
222
|
num_experts_per_shard = num_experts // ep_size
|
|
226
223
|
group_offset = jnp.arange(0, num_experts, num_experts_per_shard)
|
|
@@ -260,21 +257,22 @@ def expert_sharded_gmm(
|
|
|
260
257
|
# 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0 D, D, D, D
|
|
261
258
|
# shard-0 shard-1 shard-2 shard-3
|
|
262
259
|
# Each shards has 3 (row A), 2 (row B), 5 (row C) and 4 (row D).
|
|
263
|
-
lhs_spec =
|
|
264
|
-
|
|
265
|
-
|
|
260
|
+
lhs_spec = ep_p_spec if is_last_expert else P()
|
|
261
|
+
rhs_spec = ep_p_spec
|
|
262
|
+
rhs_scale_spec = None if rhs_scale is None else ep_p_spec
|
|
263
|
+
rhs_bias_spec = None if rhs_bias is None else ep_p_spec
|
|
266
264
|
gmm_res = jax.shard_map(
|
|
267
265
|
_gmm,
|
|
268
266
|
mesh=mesh,
|
|
269
267
|
in_specs=(
|
|
270
268
|
lhs_spec,
|
|
271
|
-
|
|
269
|
+
rhs_spec,
|
|
272
270
|
rhs_scale_spec,
|
|
273
271
|
rhs_bias_spec,
|
|
274
272
|
P(),
|
|
275
|
-
|
|
273
|
+
ep_p_spec,
|
|
276
274
|
),
|
|
277
|
-
out_specs=
|
|
275
|
+
out_specs=ep_p_spec,
|
|
278
276
|
check_vma=False,
|
|
279
277
|
)(lhs, rhs, rhs_scale, rhs_bias, group_sizes, group_offset)
|
|
280
278
|
|
|
@@ -320,15 +318,13 @@ def expert_sharded_gmm(
|
|
|
320
318
|
# send_sizes_of_shard [3, 3, 3, 3] [2, 2, 2, 2] [5, 5, 5, 5] [4, 4, 4, 4 ]
|
|
321
319
|
# output_offsets_of_shard [0, 0, 0, 0] [0, 0, 0, 0] [0, 0, 0, 0] [10,10,10,10]
|
|
322
320
|
# recv_sizes_of_shard [3, 2, 5, 4] [3, 2, 5, 4] [3, 2, 5, 4] [3, 2, 5, 4]
|
|
323
|
-
return jax.lax.ragged_all_to_all(
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
axis_name="model",
|
|
331
|
-
)
|
|
321
|
+
return jax.lax.ragged_all_to_all(operand,
|
|
322
|
+
output,
|
|
323
|
+
input_offsets_of_shard,
|
|
324
|
+
send_sizes_of_shard,
|
|
325
|
+
output_offsets_of_shard,
|
|
326
|
+
recv_sizes_of_shard,
|
|
327
|
+
axis_name=ShardingAxisName.EXPERT)
|
|
332
328
|
|
|
333
329
|
# Use ragged_all_to_all to send the result from gmm for each expert to all
|
|
334
330
|
# the shards. In the working example, the result would be:
|
|
@@ -350,8 +346,8 @@ def expert_sharded_gmm(
|
|
|
350
346
|
return jax.shard_map(
|
|
351
347
|
_ragged_all_to_all,
|
|
352
348
|
mesh=mesh,
|
|
353
|
-
in_specs=(
|
|
354
|
-
out_specs=(P()),
|
|
349
|
+
in_specs=(ep_p_spec, ep_p_spec, ep_p_spec, ep_p_spec, P()),
|
|
350
|
+
out_specs=(P(ShardingAxisName.MLP_DATA)),
|
|
355
351
|
check_vma=False,
|
|
356
352
|
)(gmm_res, input_offsets, send_sizes, output_offsets, recv_sizes)
|
|
357
353
|
|
|
@@ -412,6 +408,9 @@ def fused_moe_func(
|
|
|
412
408
|
assert gating_output.shape == (num_tokens, global_num_experts)
|
|
413
409
|
|
|
414
410
|
topk_weights = jax.nn.softmax(gating_output.astype(jnp.float32), axis=-1)
|
|
411
|
+
# All-gather topk weights for attention dp
|
|
412
|
+
topk_weights = jax.lax.with_sharding_constraint(
|
|
413
|
+
topk_weights, NamedSharding(mesh, P(ShardingAxisName.MLP_DATA, None)))
|
|
415
414
|
topk_weights, topk_indices = jax.lax.top_k(topk_weights, k=topk)
|
|
416
415
|
if renormalize:
|
|
417
416
|
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdims=True)
|
|
@@ -434,8 +433,10 @@ def fused_moe_func(
|
|
|
434
433
|
x, group_sizes, topk_argsort_revert_indices = jax.shard_map(
|
|
435
434
|
_process_tokens_locally,
|
|
436
435
|
mesh=mesh,
|
|
437
|
-
in_specs=(P(
|
|
438
|
-
|
|
436
|
+
in_specs=(P(ShardingAxisName.MLP_DATA,
|
|
437
|
+
None), P(ShardingAxisName.MLP_DATA, None)),
|
|
438
|
+
out_specs=(P(ShardingAxisName.MLP_DATA, None),
|
|
439
|
+
P(ShardingAxisName.MLP_DATA), P(ShardingAxisName.MLP_DATA)),
|
|
439
440
|
)(hidden_states, topk_indices)
|
|
440
441
|
|
|
441
442
|
x = jnp.pad(x, ((0, 0), (0, padded_hidden_size - hidden_size)))
|
|
@@ -495,8 +496,11 @@ def fused_moe_func(
|
|
|
495
496
|
x = jax.shard_map(
|
|
496
497
|
_finalize_output,
|
|
497
498
|
mesh=mesh,
|
|
498
|
-
in_specs=(P(
|
|
499
|
-
|
|
499
|
+
in_specs=(P(ShardingAxisName.MLP_DATA,
|
|
500
|
+
None), P(ShardingAxisName.MLP_DATA),
|
|
501
|
+
P(ShardingAxisName.MLP_DATA, None)),
|
|
502
|
+
out_specs=(P(ShardingAxisName.ATTN_DATA, None)),
|
|
503
|
+
check_vma=False,
|
|
500
504
|
)(x, topk_argsort_revert_indices, topk_weights)
|
|
501
505
|
|
|
502
506
|
return x[:num_tokens, :hidden_size]
|