tpu-inference 0.11.1rc2__py3-none-any.whl → 0.11.1rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (50) hide show
  1. tpu_inference/kernels/collectives/__init__.py +0 -0
  2. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  3. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  4. tpu_inference/kernels/collectives/util.py +47 -0
  5. tpu_inference/layers/__init__.py +0 -0
  6. tpu_inference/layers/common/__init__.py +0 -0
  7. tpu_inference/layers/common/attention_metadata.py +34 -0
  8. tpu_inference/layers/jax/__init__.py +0 -0
  9. tpu_inference/layers/jax/attention/__init__.py +0 -0
  10. tpu_inference/layers/jax/attention/attention.py +254 -0
  11. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  12. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  13. tpu_inference/layers/jax/attention_interface.py +356 -0
  14. tpu_inference/layers/jax/base.py +151 -0
  15. tpu_inference/layers/jax/binary_search.py +295 -0
  16. tpu_inference/layers/jax/constants.py +88 -0
  17. tpu_inference/layers/jax/layers.py +301 -0
  18. tpu_inference/layers/jax/misc.py +16 -0
  19. tpu_inference/layers/jax/moe/__init__.py +0 -0
  20. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  21. tpu_inference/layers/jax/moe/moe.py +209 -0
  22. tpu_inference/layers/jax/rope.py +172 -0
  23. tpu_inference/layers/jax/rope_interface.py +214 -0
  24. tpu_inference/layers/jax/sample/__init__.py +0 -0
  25. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  26. tpu_inference/layers/jax/sample/sampling.py +95 -0
  27. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  28. tpu_inference/layers/jax/sharding.py +406 -0
  29. tpu_inference/layers/jax/transformer_block.py +76 -0
  30. tpu_inference/layers/vllm/__init__.py +0 -0
  31. tpu_inference/layers/vllm/attention.py +184 -0
  32. tpu_inference/layers/vllm/fused_moe.py +399 -0
  33. tpu_inference/layers/vllm/linear_common.py +186 -0
  34. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  35. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  36. tpu_inference/layers/vllm/quantization/common.py +105 -0
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  38. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  39. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  40. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  41. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  42. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  43. tpu_inference/layers/vllm/sharding.py +151 -0
  44. tpu_inference/models/common/__init__.py +0 -0
  45. tpu_inference/models/common/model_loader.py +433 -0
  46. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/METADATA +6 -6
  47. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/RECORD +50 -5
  48. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/WHEEL +1 -1
  49. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/licenses/LICENSE +0 -0
  50. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,515 @@
1
+ """
2
+ JAX-based rejection sampler for speculative decoding on TPU.
3
+
4
+ This implementation follows the same algorithm as the GPU version but is
5
+ designed for JAX/TPU compatibility. It currently only supports greedy sampling.
6
+ """
7
+
8
+ import functools
9
+ from typing import Optional
10
+
11
+ import jax
12
+ import jax.numpy as jnp
13
+ import numpy as np
14
+
15
+ from tpu_inference.layers.jax.binary_search import topk_mask, topp_mask
16
+ from tpu_inference.layers.jax.sample.sampling_metadata import \
17
+ TPUSupportedSamplingMetadata
18
+
19
+ # Placeholder token ID for rejected tokens
20
+ PLACEHOLDER_TOKEN_ID = -1
21
+ GREEDY_TEMPERATURE = -1
22
+
23
+
24
+ class RejectionSampler:
25
+ """
26
+ JAX-based rejection sampler for speculative decoding.
27
+
28
+ The implementation follows the algorithm described in
29
+ https://arxiv.org/abs/2211.17192.
30
+ """
31
+
32
+ def __init__(self):
33
+ pass
34
+
35
+ def __call__(
36
+ self,
37
+ # [num_tokens] - flattened format
38
+ draft_token_ids: jnp.ndarray,
39
+ # [batch_size] - number of draft tokens per request
40
+ num_draft_tokens: jnp.ndarray,
41
+ # [num_tokens, vocab_size] - flattened format
42
+ draft_probs: Optional[jnp.ndarray],
43
+ # [num_tokens, vocab_size] - flattened format
44
+ target_logits: jnp.ndarray,
45
+ # [batch_size]
46
+ bonus_token_ids: jnp.ndarray,
47
+ sampling_metadata: TPUSupportedSamplingMetadata,
48
+ key: Optional[jax.random.PRNGKey] = None,
49
+ ) -> jnp.ndarray:
50
+ """
51
+ Perform rejection sampling on draft tokens with flattened inputs.
52
+
53
+ Args:
54
+ draft_token_ids: Draft token IDs in flattened format [num_tokens].
55
+ num_draft_tokens: Number of draft tokens per request [batch_size].
56
+ draft_probs: Draft probabilities in flattened format [num_tokens, vocab_size].
57
+ target_probs: Target probabilities in flattened format [num_tokens, vocab_size].
58
+ bonus_token_ids: Bonus token IDs [batch_size].
59
+ sampling_metadata: Additional metadata needed for sampling.
60
+ key: JAX random key for non-greedy sampling.
61
+
62
+ Returns:
63
+ output_token_ids: A tensor containing the final output token IDs.
64
+ """
65
+ return self.forward(
66
+ draft_token_ids=draft_token_ids,
67
+ num_draft_tokens=num_draft_tokens,
68
+ draft_probs=draft_probs,
69
+ target_logits=target_logits,
70
+ bonus_token_ids=bonus_token_ids,
71
+ sampling_metadata=sampling_metadata,
72
+ key=key,
73
+ )
74
+
75
+ @functools.partial(jax.jit, static_argnums=(0, ))
76
+ def forward(
77
+ self,
78
+ # [num_tokens] - flattened format
79
+ draft_token_ids: jnp.ndarray,
80
+ # [batch_size] - number of draft tokens per request
81
+ num_draft_tokens: jnp.ndarray,
82
+ # [num_tokens, vocab_size] - flattened format
83
+ draft_probs: Optional[jnp.ndarray],
84
+ # [num_tokens, vocab_size] - flattened format
85
+ target_logits: jnp.ndarray,
86
+ # [batch_size]
87
+ bonus_token_ids: jnp.ndarray,
88
+ sampling_metadata: TPUSupportedSamplingMetadata,
89
+ key: Optional[jax.random.PRNGKey] = None,
90
+ ) -> jnp.ndarray:
91
+ """
92
+ Perform rejection sampling on draft tokens with flattened inputs.
93
+
94
+ Args:
95
+ draft_token_ids: Draft token IDs in flattened format [num_tokens].
96
+ num_draft_tokens: Number of draft tokens per request [batch_size].
97
+ draft_probs: Draft probabilities in flattened format [num_tokens, vocab_size].
98
+ target_logits: Target logits in flattened format [num_tokens, vocab_size].
99
+ bonus_token_ids: Bonus token IDs [batch_size].
100
+ sampling_metadata: Additional metadata needed for sampling.
101
+ key: JAX random key for non-greedy sampling.
102
+
103
+ Returns:
104
+ output_token_ids: A tensor containing the final output token IDs.
105
+ """
106
+
107
+ if sampling_metadata.do_sampling:
108
+ target_probs = _compute_probs(target_logits, num_draft_tokens,
109
+ sampling_metadata)
110
+ else:
111
+ target_probs = target_logits
112
+
113
+ output_token_ids = rejection_sample(
114
+ draft_token_ids,
115
+ num_draft_tokens,
116
+ draft_probs,
117
+ target_probs,
118
+ bonus_token_ids,
119
+ sampling_metadata,
120
+ key=key,
121
+ )
122
+ return output_token_ids
123
+
124
+ @staticmethod
125
+ def parse_output(
126
+ output_token_ids: jnp.ndarray,
127
+ vocab_size: int,
128
+ num_draft_tokens_cpu: np.ndarray,
129
+ batch_size: int,
130
+ padded_tokens_length: int,
131
+ ) -> list[list[int]]:
132
+ """Parse the output of the rejection sampler.
133
+
134
+ Args:
135
+ output_token_ids: The sampled token IDs in shape
136
+ [num_tokens + batch_size]. The first num_tokens elements are
137
+ the main tokens, and the last batch_size elements are bonus tokens.
138
+ Rejected tokens are replaced with `PLACEHOLDER_TOKEN_ID`.
139
+ vocab_size: The size of the vocabulary.
140
+ num_draft_tokens_cpu: Number of draft tokens per request [batch_size]
141
+ as a numpy array on CPU.
142
+ batch_size: The number of requests in the batch.
143
+ padded_tokens_length: The padded length of the main tokens in the output.
144
+
145
+ Returns:
146
+ A list of lists of token IDs.
147
+ """
148
+ # Convert JAX array to numpy for easier manipulation
149
+ output_token_ids_np = np.asarray(output_token_ids)
150
+
151
+ # Split main tokens and bonus tokens
152
+ main_tokens = output_token_ids_np[:
153
+ padded_tokens_length] # [num_tokens]
154
+ bonus_tokens = output_token_ids_np[
155
+ padded_tokens_length:] # [batch_size]
156
+
157
+ # Reconstruct per-sequence outputs
158
+ outputs = []
159
+ start_idx = 0
160
+
161
+ for i in range(batch_size):
162
+ seq_length = int(num_draft_tokens_cpu[i])
163
+ end_idx = start_idx + seq_length
164
+
165
+ # Get main tokens for this sequence
166
+ seq_main_tokens = main_tokens[start_idx:end_idx]
167
+
168
+ # Filter out placeholder tokens
169
+ valid_main_tokens = seq_main_tokens[
170
+ (seq_main_tokens != PLACEHOLDER_TOKEN_ID)
171
+ & (seq_main_tokens < vocab_size)]
172
+
173
+ # Add bonus token if it's valid
174
+ bonus_token = bonus_tokens[i]
175
+ if bonus_token != PLACEHOLDER_TOKEN_ID and bonus_token < vocab_size:
176
+ seq_tokens = np.concatenate([valid_main_tokens, [bonus_token]])
177
+ else:
178
+ seq_tokens = valid_main_tokens
179
+
180
+ outputs.append(seq_tokens.tolist())
181
+ start_idx = end_idx
182
+
183
+ return outputs
184
+
185
+
186
+ def _compute_probs(
187
+ logits: jnp.ndarray,
188
+ num_draft_tokens: jnp.ndarray,
189
+ sampling_metadata: TPUSupportedSamplingMetadata,
190
+ ) -> jnp.ndarray:
191
+ """
192
+ Apply top-k, top-p, and temperature to logits and compute probabilities.
193
+ """
194
+ total_tokens = logits.shape[0]
195
+ segment_ids, _ = _get_segment_info(num_draft_tokens, total_tokens)
196
+
197
+ # Expand sampling params from [batch_size] to [num_tokens]
198
+ top_k = sampling_metadata.top_k[segment_ids]
199
+ top_p = sampling_metadata.top_p[segment_ids]
200
+ temperatures = sampling_metadata.temperature[segment_ids]
201
+
202
+ # Apply top-k and top-p masking
203
+ logits = logits.astype(jnp.float32)
204
+ # Only apply top-k masking if k > 0 for each token
205
+ should_apply_topk = jnp.expand_dims(top_k > 0, axis=-1)
206
+ topk_masked = topk_mask(logits, top_k, replace_val=-jnp.inf)
207
+ logits = jnp.where(should_apply_topk, topk_masked, logits)
208
+
209
+ # Only apply top-p masking if p < 1.0 for each token
210
+ should_apply_topp = jnp.expand_dims(top_p < 1.0, axis=-1)
211
+ topp_masked = topp_mask(logits, top_p, replace_val=-jnp.inf)
212
+ logits = jnp.where(should_apply_topp, topp_masked, logits)
213
+
214
+ # Apply temperature scaling
215
+ temperatures = jnp.expand_dims(temperatures, axis=-1)
216
+ # Add epsilon to avoid division by zero
217
+ logits /= (temperatures + 1e-9)
218
+
219
+ return jax.nn.softmax(logits, axis=-1)
220
+
221
+
222
+ def _get_segment_info(num_draft_tokens: jax.Array, total_tokens: int):
223
+ """Helper to create segment IDs and per-segment indices."""
224
+ batch_size = num_draft_tokens.shape[0]
225
+
226
+ # `segment_ids` assigns a unique ID to each token, corresponding to its
227
+ # sequence in the batch. E.g., [0, 0, 0, 1, 1, 2, 2, 2, 2] for sequences [3, 2, 4].
228
+ segment_ids = jnp.repeat(jnp.arange(batch_size),
229
+ num_draft_tokens,
230
+ total_repeat_length=total_tokens)
231
+
232
+ # `group_indices` creates a within-segment index for each token.
233
+ # E.g., [0, 1, 2, 0, 1, 0, 1, 2, 3] for the example above.
234
+ segment_starts = jnp.concatenate(
235
+ [jnp.array([0]), jnp.cumsum(num_draft_tokens)[:-1]])
236
+ broadcast_starts = jnp.repeat(segment_starts,
237
+ num_draft_tokens,
238
+ total_repeat_length=total_tokens)
239
+ group_indices = jnp.arange(total_tokens) - broadcast_starts
240
+ return segment_ids, group_indices
241
+
242
+
243
+ def _sample_recovered_tokens(
244
+ draft_token_ids: jax.Array,
245
+ draft_probs: Optional[jax.Array],
246
+ target_probs: jax.Array,
247
+ key: jax.random.PRNGKey,
248
+ ) -> jax.Array:
249
+ """
250
+ Sample recovered tokens using the Gumbel-Max trick.
251
+ This is used when a draft token is rejected in random sampling.
252
+ """
253
+ if draft_probs is not None:
254
+ # The new distribution is p' = max(p_target - p_draft, 0)
255
+ new_dist = jnp.maximum(target_probs - draft_probs, 0)
256
+ else:
257
+ # If no draft probs, the new distribution is the target distribution
258
+ # with the draft token's probability zeroed out.
259
+ vocab_size = target_probs.shape[-1]
260
+ mask = jax.nn.one_hot(draft_token_ids, vocab_size, dtype=jnp.bool)
261
+ new_dist = target_probs * ~mask
262
+
263
+ # Gumbel-Max trick to sample from the new distribution:
264
+ # y = argmax(log(p') + g) where g ~ Gumbel(0,1)
265
+ # This is equivalent to argmax(p' / q) where q ~ Exponential(1)
266
+ q = jax.random.exponential(key, shape=new_dist.shape)
267
+
268
+ # Add a small epsilon to avoid division by zero
269
+ recovered_token_ids = jnp.argmax(new_dist / (q + 1e-9), axis=-1)
270
+ return recovered_token_ids
271
+
272
+
273
+ def rejection_sample(
274
+ # [num_tokens] - flattened format
275
+ draft_token_ids: jnp.ndarray,
276
+ # [batch_size] - JAX array
277
+ num_draft_tokens: jnp.ndarray,
278
+ # [num_tokens, vocab_size] - flattened format
279
+ draft_probs: Optional[jnp.ndarray],
280
+ # [num_tokens, vocab_size] - flattened format
281
+ target_probs: jnp.ndarray,
282
+ # [batch_size]
283
+ bonus_token_ids: jnp.ndarray,
284
+ sampling_metadata: TPUSupportedSamplingMetadata,
285
+ key: Optional[jax.random.PRNGKey] = None,
286
+ ) -> jnp.ndarray:
287
+ """
288
+ Perform rejection sampling on draft tokens with flattened inputs.
289
+
290
+ Args:
291
+ draft_token_ids: Draft token IDs in flattened format [num_tokens].
292
+ num_draft_tokens: Number of draft tokens per request [batch_size].
293
+ draft_probs: Draft probabilities in flattened format [num_tokens, vocab_size].
294
+ target_probs: Target probabilities in flattened format [num_tokens, vocab_size].
295
+ bonus_token_ids: Bonus token IDs [batch_size].
296
+ sampling_metadata: Sampling metadata.
297
+ key: JAX random key for non-greedy sampling.
298
+
299
+ Returns:
300
+ output_token_ids: Output token IDs [num_tokens + batch_size].
301
+ """
302
+ if sampling_metadata.do_sampling is False:
303
+ greedy_output = _greedy_rejection_sample_with_segment(
304
+ draft_token_ids, target_probs, num_draft_tokens, bonus_token_ids)
305
+ return greedy_output
306
+
307
+ # Random path
308
+ if key is None:
309
+ raise ValueError(
310
+ "A random key must be provided for non-greedy sampling.")
311
+
312
+ random_output = _random_rejection_sample_with_segment(
313
+ draft_token_ids,
314
+ draft_probs,
315
+ target_probs,
316
+ num_draft_tokens,
317
+ bonus_token_ids,
318
+ key,
319
+ )
320
+
321
+ return random_output
322
+
323
+
324
+ def _random_rejection_sample_with_segment(
325
+ draft_token_ids: jax.Array,
326
+ draft_probs: Optional[jax.Array],
327
+ target_probs: jax.Array,
328
+ num_draft_tokens: jax.Array,
329
+ bonus_token_ids: jax.Array,
330
+ key: jax.random.PRNGKey,
331
+ ) -> jax.Array:
332
+ """
333
+ Performs random speculative decoding validation in a vectorized, jittable manner.
334
+ """
335
+ total_tokens = draft_token_ids.shape[0]
336
+ batch_size = num_draft_tokens.shape[0]
337
+
338
+ # Split random key
339
+ uniform_key, recover_key = jax.random.split(key)
340
+
341
+ # --- Step 1: Get Segment Info ---
342
+ segment_ids, group_indices = _get_segment_info(num_draft_tokens,
343
+ total_tokens)
344
+
345
+ # --- Step 2: Acceptance/Rejection Logic ---
346
+ if draft_probs is not None:
347
+ draft_token_probs = jnp.take_along_axis(draft_probs,
348
+ draft_token_ids[:, None],
349
+ axis=-1).squeeze(-1)
350
+ else:
351
+ draft_token_probs = 1.0
352
+
353
+ target_token_probs = jnp.take_along_axis(target_probs,
354
+ draft_token_ids[:, None],
355
+ axis=-1).squeeze(-1)
356
+
357
+ uniform_probs = jax.random.uniform(uniform_key, shape=(total_tokens, ))
358
+
359
+ # Acceptance condition: p_target(d) / p_draft(d) >= u
360
+ ratio = target_token_probs / (draft_token_probs + 1e-9)
361
+ accepted = ratio >= uniform_probs
362
+
363
+ # --- Step 3: Find First Rejection ---
364
+ rejections = ~accepted
365
+ large_value = total_tokens
366
+ rejection_indices = jnp.where(rejections, group_indices, large_value)
367
+
368
+ first_rejection_idx_per_segment = jax.ops.segment_min(
369
+ data=rejection_indices.astype(jnp.int32),
370
+ segment_ids=segment_ids,
371
+ num_segments=batch_size,
372
+ indices_are_sorted=True,
373
+ )
374
+
375
+ max_int = jnp.iinfo(jnp.int32).max
376
+ first_rejection_idx_per_segment = jnp.where(
377
+ first_rejection_idx_per_segment == max_int, large_value,
378
+ first_rejection_idx_per_segment)
379
+
380
+ # --- Step 4: Sample Recovered Tokens ---
381
+ recovered_token_ids = _sample_recovered_tokens(draft_token_ids,
382
+ draft_probs, target_probs,
383
+ recover_key)
384
+
385
+ # --- Step 5: Generate Main Token Output ---
386
+ first_rejection_idx_broadcast = jnp.repeat(
387
+ first_rejection_idx_per_segment,
388
+ num_draft_tokens,
389
+ total_repeat_length=total_tokens)
390
+
391
+ main_tokens = jnp.where(
392
+ group_indices < first_rejection_idx_broadcast, draft_token_ids,
393
+ jnp.where(group_indices == first_rejection_idx_broadcast,
394
+ recovered_token_ids, PLACEHOLDER_TOKEN_ID))
395
+
396
+ # --- Step 6: Handle Bonus Tokens ---
397
+ all_accepted = first_rejection_idx_per_segment == large_value
398
+ no_draft_tokens = num_draft_tokens == 0
399
+ should_get_bonus = all_accepted | no_draft_tokens
400
+ bonus_tokens = jnp.where(should_get_bonus, bonus_token_ids,
401
+ PLACEHOLDER_TOKEN_ID)
402
+
403
+ # --- Step 7: Concatenate ---
404
+ return jnp.concatenate([main_tokens, bonus_tokens])
405
+
406
+
407
+ # TODO(pooyam): Optimize/Profile this implementation further. Currently, I just want working e2e. There might be overheads with `parse_output` that can be optimized on TPU.
408
+ # I should Benchmark against the following approaches:
409
+ # - Using `jax.lax.segment_xyz`` to work with flattened inputs instead of batched inputs.
410
+ # - Using vectorized implementation using `cumprod` and other masking tricks.
411
+ # - A pallas kernel similar to the Triton implementation.
412
+ # - Scan based approach.
413
+ # Overall, I expect XLA to optimize the scan-based approach pretty well, but
414
+ # it would be good to compare performance against other methods.
415
+ def _greedy_rejection_sample_with_segment(
416
+ draft_token_ids: jax.Array,
417
+ target_probs: jax.Array,
418
+ num_draft_tokens: jax.Array,
419
+ bonus_token_ids: jax.Array,
420
+ ) -> jax.Array:
421
+ """
422
+ Performs greedy speculative decoding validation in a vectorized, jittable manner.
423
+
424
+ This function compares draft tokens with the target model's outputs. For each
425
+ sequence in the batch, it accepts tokens as long as the draft and target match.
426
+ When a mismatch occurs, it takes the target model's token and invalidates the
427
+ rest of the tokens in that sequence by setting them to -1.
428
+
429
+ Args:
430
+ draft_token_ids: A 1D JAX array (num_tokens,) of integers representing the
431
+ concatenated draft tokens for all sequences in the batch.
432
+ target_probs: A 2D JAX array (num_tokens, vocab_size) of floats representing
433
+ the concatenated target model's probabilities.
434
+ num_draft_tokens: A 1D JAX array (batch_size,) of integers specifying the
435
+ number of draft tokens for each sequence in the batch.
436
+ bonus_token_ids: A 1D JAX array (batch_size,) of integers representing the
437
+ bonus token for each sequence.
438
+
439
+ Returns:
440
+ A 1D JAX array (num_tokens + batch_size,) containing the validated token
441
+ sequence followed by bonus tokens (or -1 if not accepted).
442
+ """
443
+ # Get target argmax
444
+ target_logits_argmax = jnp.argmax(target_probs, axis=-1)
445
+
446
+ # --- Step 1: Create Segment IDs and Per-Segment Indices ---
447
+ total_tokens = draft_token_ids.shape[0]
448
+ batch_size = num_draft_tokens.shape[0]
449
+ segment_ids, group_indices = _get_segment_info(num_draft_tokens,
450
+ total_tokens)
451
+
452
+ # --- Step 2: Find the First Mismatch in Each Segment ---
453
+
454
+ # Find all mismatches between draft and target tokens.
455
+ mismatches = draft_token_ids != target_logits_argmax
456
+
457
+ # To find the *first* mismatch, we use a trick with segment_min.
458
+ # We create an array where mismatched positions hold their `group_index`
459
+ # and matched positions hold a large value.
460
+ large_value = total_tokens
461
+ mismatch_indices = jnp.where(mismatches, group_indices, large_value)
462
+
463
+ # `segment_min` finds the minimum `mismatch_index` for each segment. This
464
+ # effectively gives us the `group_index` of the first mismatch.
465
+ # For sequences with no mismatches, the result will be `large_value`.
466
+ first_mismatch_idx_per_segment = jax.ops.segment_min(
467
+ data=mismatch_indices.astype(jnp.int32),
468
+ segment_ids=segment_ids,
469
+ num_segments=batch_size,
470
+ indices_are_sorted=True,
471
+ )
472
+
473
+ # Handle empty segments (where num_draft_tokens is 0). `segment_min` returns
474
+ # the dtype's max value for empty segments; we replace it with our large_value
475
+ # for consistency.
476
+ max_int = jnp.iinfo(jnp.int32).max
477
+ first_mismatch_idx_per_segment = jnp.where(
478
+ first_mismatch_idx_per_segment == max_int, large_value,
479
+ first_mismatch_idx_per_segment)
480
+
481
+ # --- Step 3: Broadcast Mismatch Info and Generate Main Token Output ---
482
+
483
+ # Broadcast the first mismatch index back to the original token dimension.
484
+ first_mismatch_idx_broadcast = jnp.repeat(first_mismatch_idx_per_segment,
485
+ num_draft_tokens,
486
+ total_repeat_length=total_tokens)
487
+
488
+ # The final logic for main tokens:
489
+ # A token is valid if its `group_index` is less than or equal to the
490
+ # index of the first mismatch in its segment.
491
+ # - If `group_index < first_mismatch_idx`, the draft was correct.
492
+ # - If `group_index == first_mismatch_idx`, this is the correction token.
493
+ # - If `group_index > first_mismatch_idx`, the token is invalid (-1).
494
+ main_tokens = jnp.where(group_indices <= first_mismatch_idx_broadcast,
495
+ target_logits_argmax, PLACEHOLDER_TOKEN_ID)
496
+
497
+ # --- Step 4: Handle Bonus Tokens ---
498
+
499
+ # A sequence gets its bonus token if there were no mismatches
500
+ # (first_mismatch_idx_per_segment == large_value)
501
+ all_accepted = first_mismatch_idx_per_segment == large_value
502
+
503
+ # For sequences with no draft tokens, we should still give them the bonus token
504
+ # since there's nothing to reject
505
+ no_draft_tokens = num_draft_tokens == 0
506
+ should_get_bonus = all_accepted | no_draft_tokens
507
+
508
+ bonus_tokens = jnp.where(should_get_bonus, bonus_token_ids,
509
+ PLACEHOLDER_TOKEN_ID)
510
+
511
+ # --- Step 5: Concatenate Main Tokens and Bonus Tokens ---
512
+
513
+ output = jnp.concatenate([main_tokens, bonus_tokens])
514
+
515
+ return output
@@ -0,0 +1,95 @@
1
+ import functools
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from jax.sharding import Mesh, NamedSharding
6
+ from jax.sharding import PartitionSpec as P
7
+ from vllm.v1.outputs import LogprobsTensors
8
+
9
+ from tpu_inference.layers.jax.binary_search import topk_mask, topp_mask
10
+ from tpu_inference.layers.jax.sample.sampling_metadata import \
11
+ TPUSupportedSamplingMetadata
12
+
13
+ _SAMPLING_EPS = 1e-5
14
+
15
+
16
+ @functools.partial(
17
+ jax.jit,
18
+ static_argnames=["mesh"],
19
+ )
20
+ def sample(
21
+ rng: jax.Array,
22
+ mesh: Mesh,
23
+ logits: jax.Array,
24
+ tpu_sampling_metadata: TPUSupportedSamplingMetadata,
25
+ ) -> jax.Array:
26
+ # (B, vocab_size)
27
+ if tpu_sampling_metadata.do_sampling:
28
+ # Unshard the logits explicity to avoid latency increase.
29
+ logits = jax.lax.with_sharding_constraint(
30
+ logits, NamedSharding(mesh, P(None, None)))
31
+ greedy_sampled = jnp.argmax(logits, axis=-1)
32
+ if not tpu_sampling_metadata.do_sampling:
33
+ return greedy_sampled
34
+
35
+ logits = logits.astype(jnp.float32)
36
+ logits = topk_mask(logits, tpu_sampling_metadata.top_k, replace_val=-1e12)
37
+ logits = topp_mask(logits, tpu_sampling_metadata.top_p, replace_val=-1e12)
38
+
39
+ temperatures = tpu_sampling_metadata.temperature.astype(logits.dtype)
40
+ temperatures = jnp.expand_dims(temperatures, axis=-1)
41
+ logits /= temperatures
42
+
43
+ # (batch_size,)
44
+ next_tokens = jax.random.categorical(rng, logits)
45
+ # Note: avoid using the sample result when temperature < _SAMPLING_EPS
46
+ # If temperature < 0, logits /= temperatures will flip the result, causing error.
47
+ return jnp.where(tpu_sampling_metadata.temperature < _SAMPLING_EPS,
48
+ greedy_sampled, next_tokens)
49
+
50
+
51
+ def compute_logprobs(logits: jax.Array) -> jax.Array:
52
+ return jax.nn.log_softmax(logits, axis=-1)
53
+
54
+
55
+ def gather_logprobs(
56
+ logprobs: jax.Array,
57
+ token_ids: jax.Array,
58
+ num_logprobs: int,
59
+ ) -> LogprobsTensors:
60
+ """
61
+ Gather logprobs for topk and sampled/prompt token.
62
+
63
+ Args:
64
+ logprobs: (num tokens) x (vocab) tensor
65
+ token_ids: prompt tokens (if prompt logprobs)
66
+ or sampled tokens (if sampled
67
+ logprobs); 1D token ID tensor
68
+ with (num tokens) elements
69
+ num_logprobs: minimum number of logprobs to
70
+ retain per token
71
+
72
+
73
+ Returns:
74
+ Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
75
+ Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
76
+ Sampled token rank tensor, (num tokens)
77
+ """
78
+ # Find the topK values.
79
+ topk_logprobs, topk_indices = jax.lax.top_k(logprobs, k=num_logprobs)
80
+
81
+ # Get with the logprob of the prompt or sampled token.
82
+ token_ids = jnp.expand_dims(token_ids, axis=-1)
83
+ token_logprobs = jnp.take_along_axis(logprobs, token_ids, axis=-1)
84
+
85
+ # Compute the ranks of the actual token.
86
+ token_ranks = jnp.sum(logprobs >= token_logprobs, axis=-1)
87
+
88
+ # Concatenate together with the topk.
89
+ indices = jnp.concatenate((token_ids, topk_indices), axis=1)
90
+ logprobs = jnp.concatenate((token_logprobs, topk_logprobs), axis=1)
91
+
92
+ # Use int32 to reduce the tensor size.
93
+ indices = jnp.int32(indices)
94
+
95
+ return LogprobsTensors(indices, logprobs, token_ranks)
@@ -0,0 +1,69 @@
1
+ import functools
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import torch
8
+ from jax.sharding import Mesh
9
+
10
+ from tpu_inference.runner.input_batch_jax import InputBatch
11
+ from tpu_inference.utils import device_array
12
+
13
+ DEFAULT_SAMPLING_PARAMS = dict(
14
+ temperature=-1.0,
15
+ top_k=0,
16
+ top_p=1.0,
17
+ )
18
+
19
+
20
+ @functools.partial(
21
+ jax.tree_util.register_dataclass,
22
+ data_fields=[
23
+ "temperature",
24
+ "top_k",
25
+ "top_p",
26
+ ],
27
+ meta_fields=["do_sampling", "logprobs"],
28
+ )
29
+ @dataclass
30
+ class TPUSupportedSamplingMetadata:
31
+ temperature: Optional[jnp.ndarray] = None
32
+ top_k: Optional[jnp.ndarray] = None
33
+ top_p: Optional[jnp.ndarray] = None
34
+ do_sampling: bool = False
35
+ logprobs: bool = False
36
+
37
+ @classmethod
38
+ def from_input_batch(
39
+ cls,
40
+ mesh: Mesh,
41
+ input_batch: InputBatch,
42
+ padded_num_reqs: int,
43
+ ) -> "TPUSupportedSamplingMetadata":
44
+ needs_logprobs = input_batch.max_num_logprobs > 0 if input_batch.max_num_logprobs else False
45
+ if input_batch.all_greedy:
46
+ return cls(do_sampling=False, logprobs=needs_logprobs)
47
+ num_reqs = input_batch.num_reqs
48
+
49
+ def fill_slice(cpu_torch_tensor: torch.Tensor,
50
+ fill_val: float) -> torch.Tensor:
51
+ # Pad value is the default one.
52
+ cpu_torch_tensor[num_reqs:padded_num_reqs] = fill_val
53
+ return cpu_torch_tensor
54
+
55
+ temp_tensor = fill_slice(input_batch.temperature_cpu,
56
+ DEFAULT_SAMPLING_PARAMS["temperature"])
57
+ top_k_tensor = fill_slice(input_batch.top_k_cpu,
58
+ DEFAULT_SAMPLING_PARAMS["top_k"])
59
+ top_p_tensor = fill_slice(input_batch.top_p_cpu,
60
+ DEFAULT_SAMPLING_PARAMS["top_p"])
61
+
62
+ # Slice persistent device tensors to a fixed pre-compiled padded shape.
63
+ return cls(
64
+ temperature=device_array(mesh, temp_tensor[:padded_num_reqs]),
65
+ top_p=device_array(mesh, top_p_tensor[:padded_num_reqs]),
66
+ top_k=device_array(mesh, top_k_tensor[:padded_num_reqs]),
67
+ do_sampling=not input_batch.all_greedy,
68
+ logprobs=needs_logprobs,
69
+ )