tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511130813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (67) hide show
  1. tests/kernels/fused_moe_v1_test.py +34 -303
  2. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
  3. tests/lora/test_layers.py +6 -0
  4. tests/lora/utils.py +8 -0
  5. tests/test_utils.py +16 -24
  6. tpu_inference/__init__.py +3 -22
  7. tpu_inference/core/core_tpu.py +9 -17
  8. tpu_inference/core/disagg_utils.py +8 -6
  9. tpu_inference/distributed/tpu_connector.py +4 -3
  10. tpu_inference/distributed/utils.py +2 -3
  11. tpu_inference/envs.py +8 -61
  12. tpu_inference/executors/ray_distributed_executor.py +11 -31
  13. tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
  15. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +143 -287
  16. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -7
  17. tpu_inference/layers/jax/attention/attention.py +1 -1
  18. tpu_inference/layers/{common → jax}/attention_interface.py +2 -8
  19. tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
  20. tpu_inference/layers/jax/sample/sampling.py +2 -2
  21. tpu_inference/layers/{common → jax}/sharding.py +5 -5
  22. tpu_inference/layers/vllm/attention.py +1 -1
  23. tpu_inference/layers/vllm/fused_moe.py +208 -170
  24. tpu_inference/layers/vllm/quantization/__init__.py +3 -7
  25. tpu_inference/layers/vllm/quantization/awq.py +3 -4
  26. tpu_inference/layers/vllm/quantization/common.py +1 -6
  27. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +2 -4
  28. tpu_inference/layers/vllm/quantization/unquantized.py +67 -62
  29. tpu_inference/layers/vllm/sharding.py +2 -2
  30. tpu_inference/lora/torch_punica_tpu.py +2 -1
  31. tpu_inference/mock/__init__.py +0 -0
  32. tpu_inference/mock/vllm_config_utils.py +28 -0
  33. tpu_inference/mock/vllm_envs.py +1219 -0
  34. tpu_inference/mock/vllm_logger.py +212 -0
  35. tpu_inference/mock/vllm_logging_utils.py +15 -0
  36. tpu_inference/models/common/model_loader.py +12 -46
  37. tpu_inference/models/jax/llama3.py +3 -4
  38. tpu_inference/models/jax/llama_eagle3.py +5 -8
  39. tpu_inference/models/jax/phi3.py +376 -0
  40. tpu_inference/models/jax/qwen2.py +2 -3
  41. tpu_inference/models/jax/qwen2_5_vl.py +50 -165
  42. tpu_inference/models/jax/qwen3.py +2 -3
  43. tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
  44. tpu_inference/models/jax/utils/weight_utils.py +143 -198
  45. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -32
  46. tpu_inference/platforms/tpu_platform.py +34 -47
  47. tpu_inference/runner/compilation_manager.py +60 -145
  48. tpu_inference/runner/kv_cache.py +2 -2
  49. tpu_inference/runner/kv_cache_manager.py +18 -17
  50. tpu_inference/runner/persistent_batch_manager.py +2 -40
  51. tpu_inference/runner/structured_decoding_manager.py +3 -2
  52. tpu_inference/runner/tpu_runner.py +135 -283
  53. tpu_inference/runner/utils.py +2 -2
  54. tpu_inference/spec_decode/jax/eagle3.py +21 -71
  55. tpu_inference/tpu_info.py +3 -4
  56. tpu_inference/utils.py +15 -38
  57. tpu_inference/worker/tpu_worker.py +26 -163
  58. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/METADATA +3 -4
  59. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/RECORD +63 -61
  60. tests/test_envs.py +0 -203
  61. tpu_inference/layers/common/quant_methods.py +0 -8
  62. tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
  63. tpu_inference/models/jax/llama_guard_4.py +0 -361
  64. /tpu_inference/layers/{common → jax}/binary_search.py +0 -0
  65. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/WHEEL +0 -0
  66. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/licenses/LICENSE +0 -0
  67. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/top_level.txt +0 -0
@@ -154,9 +154,12 @@ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
154
154
  logger.info(f"Memory usage before applying quantization of params: "
155
155
  f"hbm={utils.hbm_usage_gb(jax.local_devices())}Gb")
156
156
 
157
- if kv_cache_dtype != "auto":
158
- kv_cache_jnp_dtype = utils.to_jax_dtype(kv_cache_dtype)
159
- else:
157
+ # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
158
+ kv_cache_jnp_dtype = utils.get_jax_dtype_from_str_dtype(kv_cache_dtype)
159
+
160
+ # Handle the case where kv_cache_dtype is "auto"
161
+ if kv_cache_jnp_dtype is None:
162
+ assert kv_cache_dtype == "auto", "kv_cache_dtype must be 'auto' if kv_cache_jnp_dtype is None"
160
163
  kv_cache_jnp_dtype = DEFAULT_KV_CACHE_DTYPE
161
164
 
162
165
  kv_caches = create_kv_caches(
@@ -13,14 +13,12 @@ from typing import Any, Optional
13
13
  import jax
14
14
  import jax.numpy as jnp
15
15
  import torch
16
- import torchax
17
16
  from flax import nnx
18
17
  from jax.sharding import Mesh, NamedSharding
19
18
  from jax.sharding import PartitionSpec as P
20
19
  from safetensors import safe_open
21
- from vllm.config import VllmConfig
22
20
 
23
- from tpu_inference import envs, utils
21
+ from tpu_inference import utils
24
22
  from tpu_inference.logger import init_logger
25
23
  from tpu_inference.models.jax.utils import file_utils
26
24
 
@@ -199,11 +197,12 @@ def shard_put(x: jax.Array, shardings, mesh: jax.sharding.Mesh) -> jax.Array:
199
197
  return jax.device_put(x, shardings)
200
198
 
201
199
 
202
- def get_default_maps(model_config, mesh: Mesh,
200
+ def get_default_maps(vllm_config, mesh: Mesh,
203
201
  name_map: dict[str, str]) -> MetadataMap:
204
202
  """Load weights from one model weights file to the model, run on single thread."""
205
203
  sharding_size = mesh.shape["model"]
206
204
 
205
+ model_config = vllm_config.model_config
207
206
  hf_config = model_config.hf_config
208
207
 
209
208
  num_heads = hf_config.num_attention_heads
@@ -267,15 +266,14 @@ def get_default_maps(model_config, mesh: Mesh,
267
266
  bias_pad_map=bias_pad_keys)
268
267
 
269
268
 
270
- def _load_and_shard_weight(vllm_config,
271
- params: nnx.State,
272
- shardings: Any,
273
- metadata_map: MetadataMap,
274
- mesh: Mesh,
275
- hf_key: str,
276
- hf_weight: jax.Array,
277
- keep_original_dtype_keys_regex: list[str]
278
- | None = None):
269
+ def _load_hf_weights_on_thread(vllm_config,
270
+ params: nnx.State,
271
+ metadata_map: MetadataMap,
272
+ mesh: Mesh,
273
+ weights_file: str,
274
+ filter_regex: str | None = None,
275
+ keep_original_dtype_keys_regex: list[str]
276
+ | None = None):
279
277
  name_map = metadata_map.name_map
280
278
  reshape_keys = metadata_map.reshape_map
281
279
  bias_reshape_keys = metadata_map.bias_reshape_map
@@ -292,118 +290,6 @@ def _load_and_shard_weight(vllm_config,
292
290
  head_dim = utils.get_padded_head_dim(head_dim_original)
293
291
  head_dim_pad = head_dim - head_dim_original
294
292
 
295
- # Check if the key should retain its original dtype
296
- keep_original_dtype = False
297
- if keep_original_dtype_keys_regex:
298
- for pattern in keep_original_dtype_keys_regex:
299
- if re.match(pattern, hf_key):
300
- keep_original_dtype = True
301
- break
302
-
303
- # Converting to config's dtype
304
- if not keep_original_dtype and hf_weight.dtype != model_config.dtype:
305
- logger.warning(
306
- f"Converting dtype for {hf_key} from {hf_weight.dtype} to {model_config.dtype}"
307
- )
308
- hf_weight = hf_weight.astype(model_config.dtype)
309
-
310
- if hf_key.endswith(".weight"):
311
- hf_key = hf_key.removesuffix(".weight")
312
-
313
- # Find the corresponding model key using the HF key
314
- if "layers" in hf_key:
315
- layer_num = re.search(r"layers\.(\d+)", hf_key).group(1)
316
- layer_key = re.sub(r"layers\.\d+", "layers.*", hf_key)
317
- model_key = name_map[layer_key]
318
- model_key = re.sub(r"layers\.\*", f"layers.{layer_num}", model_key)
319
- elif "blocks" in hf_key:
320
- layer_num = re.search(r"blocks\.(\d+)", hf_key).group(1)
321
- layer_key = re.sub(r"blocks\.\d+", "blocks.*", hf_key)
322
- model_key = name_map[layer_key]
323
- model_key = re.sub(r"blocks\.\*", f"blocks.{layer_num}", model_key)
324
- else:
325
- if hf_key not in name_map and hf_key == "lm_head":
326
- logger.warning(f"Skip loading {hf_key} due to tie_word_embeddings")
327
- return
328
- if hf_key not in name_map and "t2d" in hf_key:
329
- logger.warning(
330
- f"Skip loading {hf_key} as it's not used in eagle-3 for now")
331
- return
332
- model_key = name_map.get(hf_key, hf_key)
333
-
334
- model_weight, model_sharding = get_param_and_sharding(
335
- params, shardings, model_key)
336
-
337
- logger.debug(
338
- "before transform | "
339
- f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
340
- )
341
-
342
- if hf_key.endswith(".bias"):
343
- for key in bias_reshape_keys:
344
- if key in hf_key:
345
- hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key])
346
- if head_dim_pad > 0:
347
- hf_weight = jnp.pad(hf_weight, ((0, 0), (0, head_dim_pad)))
348
- break
349
- else:
350
- for key in reshape_keys:
351
- if key in hf_key:
352
- hf_weight = jnp.reshape(hf_weight, reshape_keys[key])
353
- if head_dim_pad > 0:
354
- if "o_proj" in key:
355
- hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0),
356
- (0, head_dim_pad)))
357
- else:
358
- hf_weight = jnp.pad(hf_weight,
359
- ((0, 0), (0, head_dim_pad),
360
- (0, 0)))
361
- break
362
- for key in transpose_keys:
363
- if key in hf_key:
364
- hf_weight = jnp.transpose(hf_weight, transpose_keys[key])
365
- break
366
-
367
- # Pad num-kv-heads
368
- if hf_key.endswith(".bias"):
369
- for key, value in bias_pad_keys.items():
370
- dim = value[0]
371
- dim_size = value[1]
372
- if key in hf_key and dim_size != 0:
373
- hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
374
- break
375
- else:
376
- for key, value in pad_keys.items():
377
- dim = value[0]
378
- dim_size = value[1]
379
- if key in hf_key and dim_size != 0:
380
- hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
381
- break
382
-
383
- logger.debug(
384
- "after transform | "
385
- f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
386
- )
387
-
388
- if head_dim_pad == 0:
389
- assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}"
390
-
391
- # Update the model weight
392
- spec = model_weight.sharding.spec if isinstance(
393
- model_weight.sharding, NamedSharding) else model_weight.sharding
394
- model_weight.value = shard(hf_weight, spec)
395
-
396
-
397
- def _load_hf_weights_on_thread(
398
- vllm_config: VllmConfig,
399
- params: nnx.State,
400
- metadata_map: "MetadataMap",
401
- mesh: Mesh,
402
- weights_file: str,
403
- filter_regex: Optional[str] = None,
404
- keep_original_dtype_keys_regex: Optional[list[str]] = None,
405
- ):
406
- """Loads weights from a single weights file."""
407
293
  try:
408
294
  shardings = nnx.get_named_sharding(params, mesh)
409
295
  except TypeError:
@@ -411,88 +297,147 @@ def _load_hf_weights_on_thread(
411
297
 
412
298
  for hf_key, hf_weight in model_weights_single_file_generator(
413
299
  weights_file, framework="flax", filter_regex=filter_regex):
414
- _load_and_shard_weight(
415
- vllm_config,
416
- params,
417
- shardings,
418
- metadata_map,
419
- mesh,
420
- hf_key,
421
- hf_weight,
422
- keep_original_dtype_keys_regex,
423
- )
424
300
 
301
+ # Check if the key should retain its original dtype
302
+ keep_original_dtype = False
303
+ if keep_original_dtype_keys_regex:
304
+ for pattern in keep_original_dtype_keys_regex:
305
+ if re.match(pattern, hf_key):
306
+ keep_original_dtype = True
307
+ break
425
308
 
426
- def load_hf_weights(
427
- vllm_config: VllmConfig,
428
- model: nnx.Module,
429
- metadata_map: "MetadataMap",
430
- mesh: Mesh,
431
- filter_regex: Optional[str] = None,
432
- is_draft_model: bool = False,
433
- keep_original_dtype_keys_regex: Optional[list[str]] = None,
434
- ):
435
- """Load weights into a JAX model from either an iterator or files."""
436
- params = nnx.state(model)
437
- try:
438
- shardings = nnx.get_named_sharding(params, mesh)
439
- except TypeError:
440
- shardings = params
441
- weights_iterator = None
442
- if hasattr(vllm_config.model_config, "model_weights_iterator"):
443
- weights_iterator = vllm_config.model_config.model_weights_iterator
444
- env = torchax.default_env()
445
- # The weights_iterator is used in RunAI model streamer integration.
446
- if weights_iterator is not None:
447
- for hf_key, hf_weight in weights_iterator:
448
- if filter_regex and not re.match(filter_regex, hf_key):
309
+ # Converting to config's dtype
310
+ if not keep_original_dtype and hf_weight.dtype != model_config.dtype:
311
+ logger.warning(
312
+ f"Converting dtype for {hf_key} from {hf_weight.dtype} to {model_config.dtype}"
313
+ )
314
+ hf_weight = hf_weight.astype(model_config.dtype)
315
+
316
+ if hf_key.endswith(".weight"):
317
+ hf_key = hf_key.removesuffix(".weight")
318
+
319
+ # Find the corresponding model key using the HF key
320
+ if "layers" in hf_key:
321
+ layer_num = re.search(r"layers\.(\d+)", hf_key).group(1)
322
+ layer_key = re.sub(r"layers\.\d+", "layers.*", hf_key)
323
+ model_key = name_map[layer_key]
324
+ model_key = re.sub(r"layers\.\*", f"layers.{layer_num}", model_key)
325
+ elif "blocks" in hf_key:
326
+ layer_num = re.search(r"blocks\.(\d+)", hf_key).group(1)
327
+ layer_key = re.sub(r"blocks\.\d+", "blocks.*", hf_key)
328
+ model_key = name_map[layer_key]
329
+ model_key = re.sub(r"blocks\.\*", f"blocks.{layer_num}", model_key)
330
+ else:
331
+ if hf_key not in name_map and hf_key == "lm_head":
332
+ logger.warning(
333
+ f"Skip loading {hf_key} due to tie_word_embeddings")
334
+ continue
335
+ if hf_key not in name_map and "t2d" in hf_key:
336
+ logger.warning(
337
+ f"Skip loading {hf_key} as it's not used in eagle-3 for now"
338
+ )
449
339
  continue
340
+ model_key = name_map.get(hf_key, hf_key)
341
+ model_weight, model_sharding = get_param_and_sharding(
342
+ params, shardings, model_key)
450
343
 
451
- # Since the weights_iterator yields Pytorch tensors (torch.Tensor),
452
- # we need to convert them to JAX arrays (jax.Array).
453
- hf_weight_jax = env.t2j_copy(hf_weight)
344
+ logger.debug(
345
+ "before transform | "
346
+ f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
347
+ )
348
+
349
+ if hf_key.endswith(".bias"):
350
+ for key in bias_reshape_keys:
351
+ if key in hf_key:
352
+ hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key])
353
+ if head_dim_pad > 0:
354
+ hf_weight = jnp.pad(hf_weight,
355
+ ((0, 0), (0, head_dim_pad)))
356
+ break
357
+ else:
358
+ for key in reshape_keys:
359
+ if key in hf_key:
360
+ hf_weight = jnp.reshape(hf_weight, reshape_keys[key])
361
+ if head_dim_pad > 0:
362
+ if "o_proj" in key:
363
+ hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0),
364
+ (0, head_dim_pad)))
365
+ else:
366
+ hf_weight = jnp.pad(hf_weight,
367
+ ((0, 0), (0, head_dim_pad),
368
+ (0, 0)))
369
+ break
370
+ for key in transpose_keys:
371
+ if key in hf_key:
372
+ hf_weight = jnp.transpose(hf_weight, transpose_keys[key])
373
+ break
374
+
375
+ # Pad num-kv-heads
376
+ if hf_key.endswith(".bias"):
377
+ for key, value in bias_pad_keys.items():
378
+ dim = value[0]
379
+ dim_size = value[1]
380
+ if key in hf_key and dim_size != 0:
381
+ hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
382
+ break
383
+ else:
384
+ for key, value in pad_keys.items():
385
+ dim = value[0]
386
+ dim_size = value[1]
387
+ if key in hf_key and dim_size != 0:
388
+ hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
389
+ break
390
+
391
+ logger.debug(
392
+ "after transform | "
393
+ f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
394
+ )
454
395
 
455
- _load_and_shard_weight(
396
+ if head_dim_pad == 0:
397
+ assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}"
398
+
399
+ # Update the model weight
400
+ spec = model_weight.sharding.spec if isinstance(
401
+ model_weight.sharding, NamedSharding) else model_weight.sharding
402
+ model_weight.value = shard(hf_weight, spec)
403
+
404
+
405
+ def load_hf_weights(vllm_config,
406
+ model: nnx.Module,
407
+ metadata_map: MetadataMap,
408
+ mesh: Mesh,
409
+ filter_regex: str | None = None,
410
+ is_draft_model: bool = False,
411
+ keep_original_dtype_keys_regex: list[str] | None = None):
412
+ """Load weights from all model weights files to the model, run in multi threads."""
413
+ if is_draft_model:
414
+ model_path = vllm_config.speculative_config.draft_model_config.model
415
+ else:
416
+ model_path = vllm_config.model_config.model
417
+ weights_files = get_model_weights_files(
418
+ model_path, vllm_config.load_config.download_dir)
419
+ params = nnx.state(model)
420
+ max_workers = min(64, len(weights_files))
421
+ # NOTE(xiang): Disable multi-threading mode if running on multi-host.
422
+ # Because multi-threading would cause different JAX processes to load
423
+ # different weights at the same time.
424
+ if os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
425
+ max_workers = 1
426
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
427
+ futures = [
428
+ executor.submit(
429
+ _load_hf_weights_on_thread,
456
430
  vllm_config,
457
431
  params,
458
- shardings,
459
432
  metadata_map,
460
433
  mesh,
461
- hf_key,
462
- hf_weight_jax,
463
- keep_original_dtype_keys_regex,
464
- )
465
- else:
466
- # File-based path (multi-threaded)
467
- if is_draft_model:
468
- model_path = vllm_config.speculative_config.draft_model_config.model
469
- else:
470
- model_path = vllm_config.model_config.model
471
- weights_files = get_model_weights_files(
472
- model_path, vllm_config.load_config.download_dir)
473
- max_workers = min(64, len(weights_files))
474
- # NOTE(xiang): Disable multi-threading mode if running on multi-host.
475
- # Because multi-threading would cause different JAX processes to load
476
- # different weights at the same time.
477
- if envs.TPU_MULTIHOST_BACKEND == "ray":
478
- max_workers = 1
479
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
480
- futures = [
481
- executor.submit(
482
- _load_hf_weights_on_thread,
483
- vllm_config,
484
- params,
485
- metadata_map,
486
- mesh,
487
- weights_file,
488
- filter_regex=filter_regex,
489
- keep_original_dtype_keys_regex=
490
- keep_original_dtype_keys_regex,
491
- ) for weights_file in weights_files
492
- ]
493
- for future in futures:
494
- future.result()
495
-
434
+ weights_file,
435
+ filter_regex=filter_regex,
436
+ keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
437
+ for weights_file in weights_files
438
+ ]
439
+ for future in futures:
440
+ future.result()
496
441
  check_all_loaded(params)
497
442
  nnx.update(model, params)
498
443
 
@@ -9,7 +9,6 @@ import jax
9
9
  import torch
10
10
  import torch.nn
11
11
  import torchax
12
- import vllm.envs as vllm_envs
13
12
  from flax.typing import PRNGKey
14
13
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
15
14
  from torchax.interop import jax_view, torch_view
@@ -26,8 +25,6 @@ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
26
25
  from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
27
26
  from tpu_inference.layers.vllm.sharding import shard_model_to_tpu
28
27
  from tpu_inference.logger import init_logger
29
- from tpu_inference.models.jax.jax_intermediate_tensor import \
30
- JaxIntermediateTensors
31
28
  from tpu_inference.models.vllm.vllm_model_wrapper_context import (
32
29
  get_vllm_model_wrapper_context, set_vllm_model_wrapper_context)
33
30
  from tpu_inference.runner.lora_utils import replace_lora_metadata
@@ -92,14 +89,13 @@ class VllmModelWrapper:
92
89
  slice_config = self.vllm_config.device_config.slice
93
90
  modified_slice_config = True
94
91
  self.vllm_config.device_config.slice = None
95
- self.vllm_config.compilation_config.static_forward_context.clear()
96
-
97
92
  vllm_config_for_load = copy.deepcopy(self.vllm_config)
98
93
  if modified_slice_config:
99
94
  self.vllm_config.device_config.slice = slice_config
100
95
  assert self.vllm_config.model_config.dtype in TORCH_DTYPE_TO_JAX, "The model_config.dtype must be a PyTorch dtype."
101
96
  vllm_config_for_load.device_config.device = "cpu"
102
97
  # Clearing the cached compilation config, otherwise vllm model init will fail
98
+ vllm_config_for_load.compilation_config.static_forward_context.clear()
103
99
 
104
100
  # When expert parallelism is enabled, vLLM loads weight in sharding
105
101
  # aware manner. Since tpu-inference has its own sharding logic, this
@@ -119,16 +115,9 @@ class VllmModelWrapper:
119
115
  "torch._sync",
120
116
  return_value=None) if use_random_weights else nullcontext()
121
117
 
122
- # By default load weights to the CPU device first. If we are running
123
- # under Pathways, this would cause weights to be loaded on a CPU-only
124
- # node, so we'll need to remove this context.
125
- jax_context = jax.default_device(
126
- jax.devices("cpu")
127
- [0]) if not vllm_envs.VLLM_TPU_USING_PATHWAYS else nullcontext()
128
-
129
118
  # Load the vLLM model and wrap it into a new model whose forward
130
119
  # function can calculate the hidden_state and logits.
131
- with load_context, jax_context:
120
+ with load_context:
132
121
  vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
133
122
  lora_manager = None
134
123
  if vllm_config_for_load.lora_config is not None:
@@ -160,8 +149,7 @@ class VllmModelWrapper:
160
149
  "xla_tpu_reduce_scatter_collective_matmul_mode":
161
150
  "post_spmd_conservative"
162
151
  },
163
- static_argnames=("layer_name_to_kvcache_index", "is_first_rank",
164
- "is_last_rank"),
152
+ static_argnames=("layer_name_to_kvcache_index", ),
165
153
  )
166
154
  def step_fun(
167
155
  params_and_buffers, # This has been wrapped into torchax TorchValue
@@ -169,12 +157,8 @@ class VllmModelWrapper:
169
157
  input_ids: jax.Array,
170
158
  attn_metadata: AttentionMetadata,
171
159
  input_embeds: jax.Array,
172
- input_positions: jax.Array,
173
160
  layer_name_to_kvcache_index: Sequence[Tuple[str, int]],
174
161
  lora_metadata,
175
- intermediate_tensors: JaxIntermediateTensors = None,
176
- is_first_rank: bool = True,
177
- is_last_rank: bool = True,
178
162
  *args,
179
163
  ) -> Tuple[List[jax.Array], jax.Array]:
180
164
  layer_name_to_kvcache_index = dict(layer_name_to_kvcache_index)
@@ -189,14 +173,12 @@ class VllmModelWrapper:
189
173
  # torch_view in order to call the Torch function.
190
174
  original_lora_metadata = replace_lora_metadata(
191
175
  self.model, lora_metadata, self.vllm_config.lora_config)
192
- if not is_first_rank:
193
- intermediate_tensors = intermediate_tensors.to_torch()
194
- output_from_torch = torch.func.functional_call(
176
+ hidden_states = torch.func.functional_call(
195
177
  self.model,
196
178
  torch_view(params_and_buffers),
197
179
  kwargs={
198
180
  "input_ids": torch_view(input_ids),
199
- "positions": torch_view(input_positions),
181
+ "positions": torch_view(attn_metadata.input_positions),
200
182
  "intermediate_tensors": None,
201
183
  "inputs_embeds": None,
202
184
  },
@@ -206,13 +188,11 @@ class VllmModelWrapper:
206
188
  self.vllm_config.lora_config)
207
189
  vllm_model_wrapper_context = get_vllm_model_wrapper_context()
208
190
  new_kv_caches = vllm_model_wrapper_context.kv_caches
209
- # Wrap the output(hidden states or intermediate tensor)
210
- # from torch land into a JaxValue for the jax code to consume.
211
- if not is_last_rank:
212
- output = JaxIntermediateTensors.from_torch(output_from_torch)
213
- else:
214
- output = jax_view(output_from_torch)
215
- return new_kv_caches, output, []
191
+ # Wrap the hidden_states from torch land into a JaxValue for the jax
192
+ # code to consume.
193
+ hidden_states = jax_view(hidden_states)
194
+
195
+ return new_kv_caches, hidden_states, []
216
196
 
217
197
  return step_fun
218
198
 
@@ -221,7 +201,7 @@ class VllmModelWrapper:
221
201
  @functools.partial(
222
202
  jax.jit,
223
203
  out_shardings=(NamedSharding(self.mesh,
224
- PartitionSpec("data", "model"))),
204
+ PartitionSpec(None, "model"))),
225
205
  )
226
206
  def compute_logits_func(
227
207
  params_and_buffers: Any,
@@ -263,6 +243,7 @@ def load_lora_model(model: torch.nn.Module, vllm_config: VllmConfig,
263
243
  vllm_config,
264
244
  device,
265
245
  model.embedding_modules,
246
+ model.embedding_padding_modules,
266
247
  )
267
248
  return lora_manager, lora_manager.create_lora_manager(model)
268
249
 
@@ -276,9 +257,10 @@ def replace_set_lora(model):
276
257
  index: int,
277
258
  lora_a: torch.Tensor,
278
259
  lora_b: torch.Tensor,
260
+ embeddings_tensor: Optional[torch.Tensor],
279
261
  ):
280
262
  with torchax.default_env():
281
- self._original_set_lora(index, lora_a, lora_b)
263
+ self._original_set_lora(index, lora_a, lora_b, embeddings_tensor)
282
264
 
283
265
  def _tpu_reset_lora(self, index: int):
284
266
  with torchax.default_env():