tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (76) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -7
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/lora/utils.py +0 -8
  9. tests/test_envs.py +110 -12
  10. tests/test_quantization.py +3 -0
  11. tests/test_utils.py +1 -2
  12. tpu_inference/__init__.py +22 -3
  13. tpu_inference/core/disagg_utils.py +6 -8
  14. tpu_inference/distributed/tpu_connector.py +3 -4
  15. tpu_inference/distributed/utils.py +3 -2
  16. tpu_inference/envs.py +93 -9
  17. tpu_inference/executors/ray_distributed_executor.py +9 -2
  18. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  19. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  20. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  21. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  22. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  23. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  25. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
  26. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
  27. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  28. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  29. tpu_inference/layers/common/attention_interface.py +7 -1
  30. tpu_inference/layers/common/sharding.py +11 -7
  31. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  32. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  33. tpu_inference/layers/vllm/fused_moe.py +170 -208
  34. tpu_inference/layers/vllm/linear_common.py +43 -21
  35. tpu_inference/layers/vllm/quantization/common.py +11 -6
  36. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  38. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  39. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  40. tpu_inference/layers/vllm/sharding.py +2 -2
  41. tpu_inference/lora/torch_punica_tpu.py +1 -2
  42. tpu_inference/models/common/model_loader.py +84 -28
  43. tpu_inference/models/jax/deepseek_v3.py +185 -64
  44. tpu_inference/models/jax/gpt_oss.py +3 -3
  45. tpu_inference/models/jax/llama3.py +2 -1
  46. tpu_inference/models/jax/llama_eagle3.py +8 -5
  47. tpu_inference/models/jax/llama_guard_4.py +361 -0
  48. tpu_inference/models/jax/qwen2.py +2 -1
  49. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  50. tpu_inference/models/jax/qwen3.py +2 -1
  51. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  52. tpu_inference/models/jax/utils/weight_utils.py +205 -144
  53. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
  54. tpu_inference/platforms/tpu_platform.py +34 -50
  55. tpu_inference/runner/compilation_manager.py +144 -60
  56. tpu_inference/runner/kv_cache.py +40 -20
  57. tpu_inference/runner/kv_cache_manager.py +48 -33
  58. tpu_inference/runner/persistent_batch_manager.py +40 -2
  59. tpu_inference/runner/structured_decoding_manager.py +2 -3
  60. tpu_inference/runner/tpu_runner.py +280 -149
  61. tpu_inference/runner/utils.py +2 -2
  62. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  63. tpu_inference/tpu_info.py +4 -3
  64. tpu_inference/utils.py +46 -18
  65. tpu_inference/worker/tpu_worker.py +197 -63
  66. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
  67. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
  68. tpu_inference/mock/__init__.py +0 -0
  69. tpu_inference/mock/vllm_config_utils.py +0 -28
  70. tpu_inference/mock/vllm_envs.py +0 -1219
  71. tpu_inference/mock/vllm_logger.py +0 -212
  72. tpu_inference/mock/vllm_logging_utils.py +0 -15
  73. tpu_inference/models/jax/phi3.py +0 -376
  74. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  75. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  76. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -13,12 +13,14 @@ from typing import Any, Optional
13
13
  import jax
14
14
  import jax.numpy as jnp
15
15
  import torch
16
+ import torchax
16
17
  from flax import nnx
17
18
  from jax.sharding import Mesh, NamedSharding
18
19
  from jax.sharding import PartitionSpec as P
19
20
  from safetensors import safe_open
21
+ from vllm.config import VllmConfig
20
22
 
21
- from tpu_inference import utils
23
+ from tpu_inference import envs, utils
22
24
  from tpu_inference.logger import init_logger
23
25
  from tpu_inference.models.jax.utils import file_utils
24
26
 
@@ -65,7 +67,13 @@ def transpose_params(param_key: str, param_tensor: jax.Array, transpose_map):
65
67
  def reshape_params(param_key: str, param_tensor: jax.Array, shape_map):
66
68
  for key, new_shape in shape_map.items():
67
69
  if key in param_key:
68
- return jnp.reshape(param_tensor, new_shape)
70
+ try:
71
+ #TODO:(gpolovets) Add validation on whether reshape preserves data layout.
72
+ return jnp.reshape(param_tensor, new_shape)
73
+ except TypeError:
74
+ raise TypeError(
75
+ f"Cannot reshape for key={key}, new_shape={new_shape}, param_shape={param_tensor.shape}"
76
+ )
69
77
  return param_tensor # Base case / no-op
70
78
 
71
79
 
@@ -197,12 +205,11 @@ def shard_put(x: jax.Array, shardings, mesh: jax.sharding.Mesh) -> jax.Array:
197
205
  return jax.device_put(x, shardings)
198
206
 
199
207
 
200
- def get_default_maps(vllm_config, mesh: Mesh,
208
+ def get_default_maps(model_config, mesh: Mesh,
201
209
  name_map: dict[str, str]) -> MetadataMap:
202
210
  """Load weights from one model weights file to the model, run on single thread."""
203
211
  sharding_size = mesh.shape["model"]
204
212
 
205
- model_config = vllm_config.model_config
206
213
  hf_config = model_config.hf_config
207
214
 
208
215
  num_heads = hf_config.num_attention_heads
@@ -266,14 +273,15 @@ def get_default_maps(vllm_config, mesh: Mesh,
266
273
  bias_pad_map=bias_pad_keys)
267
274
 
268
275
 
269
- def _load_hf_weights_on_thread(vllm_config,
270
- params: nnx.State,
271
- metadata_map: MetadataMap,
272
- mesh: Mesh,
273
- weights_file: str,
274
- filter_regex: str | None = None,
275
- keep_original_dtype_keys_regex: list[str]
276
- | None = None):
276
+ def _load_and_shard_weight(vllm_config,
277
+ params: nnx.State,
278
+ shardings: Any,
279
+ metadata_map: MetadataMap,
280
+ mesh: Mesh,
281
+ hf_key: str,
282
+ hf_weight: jax.Array,
283
+ keep_original_dtype_keys_regex: list[str]
284
+ | None = None):
277
285
  name_map = metadata_map.name_map
278
286
  reshape_keys = metadata_map.reshape_map
279
287
  bias_reshape_keys = metadata_map.bias_reshape_map
@@ -290,6 +298,118 @@ def _load_hf_weights_on_thread(vllm_config,
290
298
  head_dim = utils.get_padded_head_dim(head_dim_original)
291
299
  head_dim_pad = head_dim - head_dim_original
292
300
 
301
+ # Check if the key should retain its original dtype
302
+ keep_original_dtype = False
303
+ if keep_original_dtype_keys_regex:
304
+ for pattern in keep_original_dtype_keys_regex:
305
+ if re.match(pattern, hf_key):
306
+ keep_original_dtype = True
307
+ break
308
+
309
+ # Converting to config's dtype
310
+ if not keep_original_dtype and hf_weight.dtype != model_config.dtype:
311
+ logger.warning(
312
+ f"Converting dtype for {hf_key} from {hf_weight.dtype} to {model_config.dtype}"
313
+ )
314
+ hf_weight = hf_weight.astype(model_config.dtype)
315
+
316
+ if hf_key.endswith(".weight"):
317
+ hf_key = hf_key.removesuffix(".weight")
318
+
319
+ # Find the corresponding model key using the HF key
320
+ if "layers" in hf_key:
321
+ layer_num = re.search(r"layers\.(\d+)", hf_key).group(1)
322
+ layer_key = re.sub(r"layers\.\d+", "layers.*", hf_key)
323
+ model_key = name_map[layer_key]
324
+ model_key = re.sub(r"layers\.\*", f"layers.{layer_num}", model_key)
325
+ elif "blocks" in hf_key:
326
+ layer_num = re.search(r"blocks\.(\d+)", hf_key).group(1)
327
+ layer_key = re.sub(r"blocks\.\d+", "blocks.*", hf_key)
328
+ model_key = name_map[layer_key]
329
+ model_key = re.sub(r"blocks\.\*", f"blocks.{layer_num}", model_key)
330
+ else:
331
+ if hf_key not in name_map and hf_key == "lm_head":
332
+ logger.warning(f"Skip loading {hf_key} due to tie_word_embeddings")
333
+ return
334
+ if hf_key not in name_map and "t2d" in hf_key:
335
+ logger.warning(
336
+ f"Skip loading {hf_key} as it's not used in eagle-3 for now")
337
+ return
338
+ model_key = name_map.get(hf_key, hf_key)
339
+
340
+ model_weight, model_sharding = get_param_and_sharding(
341
+ params, shardings, model_key)
342
+
343
+ logger.debug(
344
+ "before transform | "
345
+ f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
346
+ )
347
+
348
+ if hf_key.endswith(".bias"):
349
+ for key in bias_reshape_keys:
350
+ if key in hf_key:
351
+ hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key])
352
+ if head_dim_pad > 0:
353
+ hf_weight = jnp.pad(hf_weight, ((0, 0), (0, head_dim_pad)))
354
+ break
355
+ else:
356
+ for key in reshape_keys:
357
+ if key in hf_key:
358
+ hf_weight = jnp.reshape(hf_weight, reshape_keys[key])
359
+ if head_dim_pad > 0:
360
+ if "o_proj" in key:
361
+ hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0),
362
+ (0, head_dim_pad)))
363
+ else:
364
+ hf_weight = jnp.pad(hf_weight,
365
+ ((0, 0), (0, head_dim_pad),
366
+ (0, 0)))
367
+ break
368
+ for key in transpose_keys:
369
+ if key in hf_key:
370
+ hf_weight = jnp.transpose(hf_weight, transpose_keys[key])
371
+ break
372
+
373
+ # Pad num-kv-heads
374
+ if hf_key.endswith(".bias"):
375
+ for key, value in bias_pad_keys.items():
376
+ dim = value[0]
377
+ dim_size = value[1]
378
+ if key in hf_key and dim_size != 0:
379
+ hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
380
+ break
381
+ else:
382
+ for key, value in pad_keys.items():
383
+ dim = value[0]
384
+ dim_size = value[1]
385
+ if key in hf_key and dim_size != 0:
386
+ hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
387
+ break
388
+
389
+ logger.debug(
390
+ "after transform | "
391
+ f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
392
+ )
393
+
394
+ if head_dim_pad == 0:
395
+ assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}"
396
+
397
+ # Update the model weight
398
+ spec = model_weight.sharding.spec if isinstance(
399
+ model_weight.sharding, NamedSharding) else model_weight.sharding
400
+ model_weight.value = shard(hf_weight, spec)
401
+
402
+
403
+ def _load_hf_weights_on_thread(
404
+ vllm_config: VllmConfig,
405
+ params: nnx.State,
406
+ metadata_map: "MetadataMap",
407
+ mesh: Mesh,
408
+ weights_file: str,
409
+ filter_regex: Optional[str] = None,
410
+ keep_original_dtype_keys_regex: Optional[list[str]] = None,
411
+ ):
412
+ """Loads weights from a single weights file."""
293
413
  try:
294
414
  shardings = nnx.get_named_sharding(params, mesh)
295
415
  except TypeError:
@@ -297,147 +417,88 @@ def _load_hf_weights_on_thread(vllm_config,
297
417
 
298
418
  for hf_key, hf_weight in model_weights_single_file_generator(
299
419
  weights_file, framework="flax", filter_regex=filter_regex):
420
+ _load_and_shard_weight(
421
+ vllm_config,
422
+ params,
423
+ shardings,
424
+ metadata_map,
425
+ mesh,
426
+ hf_key,
427
+ hf_weight,
428
+ keep_original_dtype_keys_regex,
429
+ )
300
430
 
301
- # Check if the key should retain its original dtype
302
- keep_original_dtype = False
303
- if keep_original_dtype_keys_regex:
304
- for pattern in keep_original_dtype_keys_regex:
305
- if re.match(pattern, hf_key):
306
- keep_original_dtype = True
307
- break
308
431
 
309
- # Converting to config's dtype
310
- if not keep_original_dtype and hf_weight.dtype != model_config.dtype:
311
- logger.warning(
312
- f"Converting dtype for {hf_key} from {hf_weight.dtype} to {model_config.dtype}"
313
- )
314
- hf_weight = hf_weight.astype(model_config.dtype)
315
-
316
- if hf_key.endswith(".weight"):
317
- hf_key = hf_key.removesuffix(".weight")
318
-
319
- # Find the corresponding model key using the HF key
320
- if "layers" in hf_key:
321
- layer_num = re.search(r"layers\.(\d+)", hf_key).group(1)
322
- layer_key = re.sub(r"layers\.\d+", "layers.*", hf_key)
323
- model_key = name_map[layer_key]
324
- model_key = re.sub(r"layers\.\*", f"layers.{layer_num}", model_key)
325
- elif "blocks" in hf_key:
326
- layer_num = re.search(r"blocks\.(\d+)", hf_key).group(1)
327
- layer_key = re.sub(r"blocks\.\d+", "blocks.*", hf_key)
328
- model_key = name_map[layer_key]
329
- model_key = re.sub(r"blocks\.\*", f"blocks.{layer_num}", model_key)
330
- else:
331
- if hf_key not in name_map and hf_key == "lm_head":
332
- logger.warning(
333
- f"Skip loading {hf_key} due to tie_word_embeddings")
334
- continue
335
- if hf_key not in name_map and "t2d" in hf_key:
336
- logger.warning(
337
- f"Skip loading {hf_key} as it's not used in eagle-3 for now"
338
- )
432
+ def load_hf_weights(
433
+ vllm_config: VllmConfig,
434
+ model: nnx.Module,
435
+ metadata_map: "MetadataMap",
436
+ mesh: Mesh,
437
+ filter_regex: Optional[str] = None,
438
+ is_draft_model: bool = False,
439
+ keep_original_dtype_keys_regex: Optional[list[str]] = None,
440
+ ):
441
+ """Load weights into a JAX model from either an iterator or files."""
442
+ params = nnx.state(model)
443
+ try:
444
+ shardings = nnx.get_named_sharding(params, mesh)
445
+ except TypeError:
446
+ shardings = params
447
+ weights_iterator = None
448
+ if hasattr(vllm_config.model_config, "model_weights_iterator"):
449
+ weights_iterator = vllm_config.model_config.model_weights_iterator
450
+ env = torchax.default_env()
451
+ # The weights_iterator is used in RunAI model streamer integration.
452
+ if weights_iterator is not None:
453
+ for hf_key, hf_weight in weights_iterator:
454
+ if filter_regex and not re.match(filter_regex, hf_key):
339
455
  continue
340
- model_key = name_map.get(hf_key, hf_key)
341
- model_weight, model_sharding = get_param_and_sharding(
342
- params, shardings, model_key)
343
-
344
- logger.debug(
345
- "before transform | "
346
- f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
347
- )
348
456
 
349
- if hf_key.endswith(".bias"):
350
- for key in bias_reshape_keys:
351
- if key in hf_key:
352
- hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key])
353
- if head_dim_pad > 0:
354
- hf_weight = jnp.pad(hf_weight,
355
- ((0, 0), (0, head_dim_pad)))
356
- break
357
- else:
358
- for key in reshape_keys:
359
- if key in hf_key:
360
- hf_weight = jnp.reshape(hf_weight, reshape_keys[key])
361
- if head_dim_pad > 0:
362
- if "o_proj" in key:
363
- hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0),
364
- (0, head_dim_pad)))
365
- else:
366
- hf_weight = jnp.pad(hf_weight,
367
- ((0, 0), (0, head_dim_pad),
368
- (0, 0)))
369
- break
370
- for key in transpose_keys:
371
- if key in hf_key:
372
- hf_weight = jnp.transpose(hf_weight, transpose_keys[key])
373
- break
374
-
375
- # Pad num-kv-heads
376
- if hf_key.endswith(".bias"):
377
- for key, value in bias_pad_keys.items():
378
- dim = value[0]
379
- dim_size = value[1]
380
- if key in hf_key and dim_size != 0:
381
- hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
382
- break
383
- else:
384
- for key, value in pad_keys.items():
385
- dim = value[0]
386
- dim_size = value[1]
387
- if key in hf_key and dim_size != 0:
388
- hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
389
- break
390
-
391
- logger.debug(
392
- "after transform | "
393
- f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
394
- )
457
+ # Since the weights_iterator yields Pytorch tensors (torch.Tensor),
458
+ # we need to convert them to JAX arrays (jax.Array).
459
+ hf_weight_jax = env.t2j_copy(hf_weight)
395
460
 
396
- if head_dim_pad == 0:
397
- assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}"
398
-
399
- # Update the model weight
400
- spec = model_weight.sharding.spec if isinstance(
401
- model_weight.sharding, NamedSharding) else model_weight.sharding
402
- model_weight.value = shard(hf_weight, spec)
403
-
404
-
405
- def load_hf_weights(vllm_config,
406
- model: nnx.Module,
407
- metadata_map: MetadataMap,
408
- mesh: Mesh,
409
- filter_regex: str | None = None,
410
- is_draft_model: bool = False,
411
- keep_original_dtype_keys_regex: list[str] | None = None):
412
- """Load weights from all model weights files to the model, run in multi threads."""
413
- if is_draft_model:
414
- model_path = vllm_config.speculative_config.draft_model_config.model
415
- else:
416
- model_path = vllm_config.model_config.model
417
- weights_files = get_model_weights_files(
418
- model_path, vllm_config.load_config.download_dir)
419
- params = nnx.state(model)
420
- max_workers = min(64, len(weights_files))
421
- # NOTE(xiang): Disable multi-threading mode if running on multi-host.
422
- # Because multi-threading would cause different JAX processes to load
423
- # different weights at the same time.
424
- if os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
425
- max_workers = 1
426
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
427
- futures = [
428
- executor.submit(
429
- _load_hf_weights_on_thread,
461
+ _load_and_shard_weight(
430
462
  vllm_config,
431
463
  params,
464
+ shardings,
432
465
  metadata_map,
433
466
  mesh,
434
- weights_file,
435
- filter_regex=filter_regex,
436
- keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
437
- for weights_file in weights_files
438
- ]
439
- for future in futures:
440
- future.result()
467
+ hf_key,
468
+ hf_weight_jax,
469
+ keep_original_dtype_keys_regex,
470
+ )
471
+ else:
472
+ # File-based path (multi-threaded)
473
+ if is_draft_model:
474
+ model_path = vllm_config.speculative_config.draft_model_config.model
475
+ else:
476
+ model_path = vllm_config.model_config.model
477
+ weights_files = get_model_weights_files(
478
+ model_path, vllm_config.load_config.download_dir)
479
+ max_workers = min(64, len(weights_files))
480
+ # NOTE(xiang): Disable multi-threading mode if running on multi-host.
481
+ # Because multi-threading would cause different JAX processes to load
482
+ # different weights at the same time.
483
+ if envs.TPU_MULTIHOST_BACKEND == "ray":
484
+ max_workers = 1
485
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
486
+ futures = [
487
+ executor.submit(
488
+ _load_hf_weights_on_thread,
489
+ vllm_config,
490
+ params,
491
+ metadata_map,
492
+ mesh,
493
+ weights_file,
494
+ filter_regex=filter_regex,
495
+ keep_original_dtype_keys_regex=
496
+ keep_original_dtype_keys_regex,
497
+ ) for weights_file in weights_files
498
+ ]
499
+ for future in futures:
500
+ future.result()
501
+
441
502
  check_all_loaded(params)
442
503
  nnx.update(model, params)
443
504
 
@@ -9,6 +9,7 @@ import jax
9
9
  import torch
10
10
  import torch.nn
11
11
  import torchax
12
+ import vllm.envs as vllm_envs
12
13
  from flax.typing import PRNGKey
13
14
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
14
15
  from torchax.interop import jax_view, torch_view
@@ -118,10 +119,16 @@ class VllmModelWrapper:
118
119
  "torch._sync",
119
120
  return_value=None) if use_random_weights else nullcontext()
120
121
 
122
+ # By default load weights to the CPU device first. If we are running
123
+ # under Pathways, this would cause weights to be loaded on a CPU-only
124
+ # node, so we'll need to remove this context.
125
+ jax_context = jax.default_device(
126
+ jax.devices("cpu")
127
+ [0]) if not vllm_envs.VLLM_TPU_USING_PATHWAYS else nullcontext()
128
+
121
129
  # Load the vLLM model and wrap it into a new model whose forward
122
130
  # function can calculate the hidden_state and logits.
123
- available_devices = self.mesh.devices.flatten()
124
- with load_context, jax.default_device(available_devices[0]):
131
+ with load_context, jax_context:
125
132
  vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
126
133
  lora_manager = None
127
134
  if vllm_config_for_load.lora_config is not None:
@@ -162,6 +169,7 @@ class VllmModelWrapper:
162
169
  input_ids: jax.Array,
163
170
  attn_metadata: AttentionMetadata,
164
171
  input_embeds: jax.Array,
172
+ input_positions: jax.Array,
165
173
  layer_name_to_kvcache_index: Sequence[Tuple[str, int]],
166
174
  lora_metadata,
167
175
  intermediate_tensors: JaxIntermediateTensors = None,
@@ -188,8 +196,8 @@ class VllmModelWrapper:
188
196
  torch_view(params_and_buffers),
189
197
  kwargs={
190
198
  "input_ids": torch_view(input_ids),
191
- "positions": torch_view(attn_metadata.input_positions),
192
- "intermediate_tensors": intermediate_tensors,
199
+ "positions": torch_view(input_positions),
200
+ "intermediate_tensors": None,
193
201
  "inputs_embeds": None,
194
202
  },
195
203
  tie_weights=False,
@@ -213,7 +221,7 @@ class VllmModelWrapper:
213
221
  @functools.partial(
214
222
  jax.jit,
215
223
  out_shardings=(NamedSharding(self.mesh,
216
- PartitionSpec(None, "model"))),
224
+ PartitionSpec("data", "model"))),
217
225
  )
218
226
  def compute_logits_func(
219
227
  params_and_buffers: Any,
@@ -255,7 +263,6 @@ def load_lora_model(model: torch.nn.Module, vllm_config: VllmConfig,
255
263
  vllm_config,
256
264
  device,
257
265
  model.embedding_modules,
258
- model.embedding_padding_modules,
259
266
  )
260
267
  return lora_manager, lora_manager.create_lora_manager(model)
261
268
 
@@ -269,10 +276,9 @@ def replace_set_lora(model):
269
276
  index: int,
270
277
  lora_a: torch.Tensor,
271
278
  lora_b: torch.Tensor,
272
- embeddings_tensor: Optional[torch.Tensor],
273
279
  ):
274
280
  with torchax.default_env():
275
- self._original_set_lora(index, lora_a, lora_b, embeddings_tensor)
281
+ self._original_set_lora(index, lora_a, lora_b)
276
282
 
277
283
  def _tpu_reset_lora(self, index: int):
278
284
  with torchax.default_env():
@@ -1,39 +1,34 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
- import os
4
3
  from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
5
4
 
6
5
  import jax.numpy as jnp
6
+ import torch
7
7
  import vllm.envs as vllm_envs
8
- from torchax.ops.mappings import j2t_dtype
9
8
  from tpu_info import device
10
9
  from vllm.inputs import ProcessorInputs, PromptType
11
10
  from vllm.platforms.interface import Platform, PlatformEnum
12
- from vllm.sampling_params import SamplingParams, SamplingType
13
11
 
14
12
  from tpu_inference import envs
15
13
  from tpu_inference.layers.common.sharding import ShardingConfigManager
16
14
  from tpu_inference.logger import init_logger
17
15
 
18
16
  if TYPE_CHECKING:
19
- from vllm.attention.backends.registry import _Backend
17
+ from vllm.attention.backends.registry import AttentionBackendEnum
20
18
  from vllm.config import BlockSize, ModelConfig, VllmConfig
21
19
  from vllm.pooling_params import PoolingParams
20
+ from vllm.sampling_params import SamplingParams, SamplingType
22
21
  else:
23
22
  BlockSize = None
24
23
  ModelConfig = None
25
24
  VllmConfig = None
26
25
  PoolingParams = None
27
- _Backend = None
26
+ AttentionBackendEnum = None
27
+ SamplingParams = None
28
+ SamplingType = None
28
29
 
29
30
  logger = init_logger(__name__)
30
31
 
31
- _DTYPE: dict[str, jnp.dtype] = {
32
- "bfloat16": jnp.bfloat16,
33
- "float": jnp.float32,
34
- "float32": jnp.float32,
35
- }
36
-
37
32
 
38
33
  class TpuPlatform(Platform):
39
34
  _enum = PlatformEnum.TPU
@@ -50,25 +45,22 @@ class TpuPlatform(Platform):
50
45
 
51
46
  additional_env_vars: list[str] = [
52
47
  "PHASED_PROFILING_DIR", "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS",
53
- "TPU_MULTIHOST_BACKEND", "VLLM_MLA_DISABLE", "TPU_BACKEND_TYPE"
48
+ "TPU_MULTIHOST_BACKEND", "VLLM_MLA_DISABLE", "TPU_BACKEND_TYPE",
49
+ "NEW_MODEL_DESIGN"
54
50
  ]
55
51
 
56
52
  @classmethod
57
- def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
58
- dtype: jnp.dtype, kv_cache_dtype: Optional[str],
59
- block_size: int, use_v1: bool, use_mla: bool,
60
- has_sink: bool, use_sparse: bool,
61
- attn_type: Any) -> str:
62
- from vllm.attention.backends.registry import _Backend
63
- if selected_backend != _Backend.PALLAS:
53
+ def get_attn_backend_cls(cls, selected_backend: "AttentionBackendEnum",
54
+ head_size: int, dtype: jnp.dtype,
55
+ kv_cache_dtype: Optional[str], block_size: int,
56
+ use_mla: bool, has_sink: bool, use_sparse: bool,
57
+ use_mm_prefix: bool, attn_type: Any) -> str:
58
+ from vllm.attention.backends.registry import AttentionBackendEnum
59
+ if selected_backend != AttentionBackendEnum.PALLAS:
64
60
  logger.info("Cannot use %s backend on TPU.", selected_backend)
65
61
 
66
- if use_v1:
67
- logger.info("Using Pallas V1 backend.")
68
- return "tpu_inference.layers.vllm.attention.PallasAttentionBackend"
69
- else:
70
- logger.info("Using Pallas backend.")
71
- return "vllm.attention.backends.pallas.PallasAttentionBackend"
62
+ logger.info("Using Pallas V1 backend.")
63
+ return "tpu_inference.layers.vllm.attention.PallasAttentionBackend"
72
64
 
73
65
  @classmethod
74
66
  def get_device_name(cls, device_id: int = 0) -> str:
@@ -83,6 +75,14 @@ class TpuPlatform(Platform):
83
75
  logger.warning(f"Error getting device name: {e}")
84
76
  return 'TPU'
85
77
 
78
+ @classmethod
79
+ def fp8_dtype(cls) -> torch.dtype:
80
+ if cls.get_device_name().lower() == "tpu v6e":
81
+ logger.info(
82
+ "Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.")
83
+ return torch.float8_e5m2
84
+ return torch.float8_e4m3fn
85
+
86
86
  @classmethod
87
87
  def get_device_total_memory(cls, device_id: int = 0) -> int:
88
88
  raise NotImplementedError
@@ -133,6 +133,7 @@ class TpuPlatform(Platform):
133
133
  # For v0, the default block size is 16.
134
134
  if cache_config and cache_config.block_size is None:
135
135
  cache_config.block_size = cast(BlockSize, 16)
136
+
136
137
  compilation_config = vllm_config.compilation_config
137
138
 
138
139
  # TPU only supports DYNAMO_TRACE_ONCE compilation level
@@ -143,27 +144,6 @@ class TpuPlatform(Platform):
143
144
  if compilation_config.backend == "":
144
145
  compilation_config.backend = "openxla"
145
146
 
146
- # If we use vLLM's model implementation in PyTorch, we should set it with torch version of the dtype.
147
- impl = envs.MODEL_IMPL_TYPE
148
-
149
- # NOTE(xiang): convert dtype to jnp.dtype
150
- # NOTE(wenlong): skip this logic for mm model preprocessing
151
- # For mm model preprocessors, it may need the output dtype to be torch.
152
- # In order to avoid a PR to vLLM, we postpone the dtype checking during tpu_worker initialization
153
- if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm":
154
- if not isinstance(vllm_config.model_config.dtype, str):
155
- logger.warning(
156
- "The model dtype is not properly set for JAX backend. "
157
- "Overwriting it to jnp.bfloat16")
158
- vllm_config.model_config.dtype = jnp.bfloat16
159
- else:
160
- vllm_config.model_config.dtype = _DTYPE.get(
161
- vllm_config.model_config.dtype, jnp.bfloat16)
162
-
163
- if impl == "vllm":
164
- vllm_config.model_config.dtype = j2t_dtype(
165
- vllm_config.model_config.dtype.dtype)
166
-
167
147
  # TODO(cuiq): remove this dependency.
168
148
  from vllm.v1.attention.backends.pallas import PallasAttentionBackend
169
149
  cache_config.block_size = PallasAttentionBackend.get_page_size(
@@ -171,8 +151,7 @@ class TpuPlatform(Platform):
171
151
  min_page_size = PallasAttentionBackend.get_min_page_size(vllm_config)
172
152
  if min_page_size > cache_config.block_size:
173
153
  logger.warning(
174
- "Increase the page size from %s to %s to make sure there's"
175
- "no SMEM OOM",
154
+ "Increase the page size from %s to %s to avoid SMEM OOM",
176
155
  cache_config.block_size,
177
156
  min_page_size,
178
157
  )
@@ -183,7 +162,7 @@ class TpuPlatform(Platform):
183
162
  parallel_config.worker_cls = \
184
163
  "tpu_inference.worker.tpu_worker.TPUWorker"
185
164
 
186
- multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
165
+ multihost_backend = envs.TPU_MULTIHOST_BACKEND
187
166
  if not multihost_backend: # Single host
188
167
  if parallel_config.pipeline_parallel_size == 1:
189
168
  logger.info("Force using UniProcExecutor for JAX on \
@@ -247,10 +226,11 @@ class TpuPlatform(Platform):
247
226
  def validate_request(
248
227
  cls,
249
228
  prompt: PromptType,
250
- params: Union[SamplingParams, PoolingParams],
229
+ params: Union["SamplingParams", PoolingParams],
251
230
  processed_inputs: ProcessorInputs,
252
231
  ) -> None:
253
232
  """Raises if this request is unsupported on this platform"""
233
+ from vllm.sampling_params import SamplingParams, SamplingType
254
234
 
255
235
  if isinstance(params, SamplingParams):
256
236
  if params.sampling_type == SamplingType.RANDOM_SEED:
@@ -267,3 +247,7 @@ class TpuPlatform(Platform):
267
247
  Returns if the current platform needs to sync weight loader.
268
248
  """
269
249
  return True
250
+
251
+ @classmethod
252
+ def support_hybrid_kv_cache(cls) -> bool:
253
+ return True