tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (56) hide show
  1. tests/kernels/fused_moe_v1_test.py +34 -303
  2. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
  3. tests/lora/test_layers.py +6 -0
  4. tests/lora/utils.py +8 -0
  5. tests/test_envs.py +11 -32
  6. tests/test_utils.py +2 -1
  7. tpu_inference/__init__.py +3 -22
  8. tpu_inference/core/disagg_utils.py +8 -6
  9. tpu_inference/distributed/tpu_connector.py +4 -3
  10. tpu_inference/distributed/utils.py +2 -3
  11. tpu_inference/envs.py +8 -61
  12. tpu_inference/executors/ray_distributed_executor.py +2 -9
  13. tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
  15. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +145 -266
  16. tpu_inference/layers/common/attention_interface.py +1 -7
  17. tpu_inference/layers/common/sharding.py +5 -5
  18. tpu_inference/layers/vllm/fused_moe.py +208 -170
  19. tpu_inference/layers/vllm/quantization/common.py +1 -6
  20. tpu_inference/layers/vllm/quantization/mxfp4.py +73 -138
  21. tpu_inference/layers/vllm/quantization/unquantized.py +64 -58
  22. tpu_inference/layers/vllm/sharding.py +2 -2
  23. tpu_inference/lora/torch_punica_tpu.py +2 -1
  24. tpu_inference/mock/__init__.py +0 -0
  25. tpu_inference/mock/vllm_config_utils.py +28 -0
  26. tpu_inference/mock/vllm_envs.py +1219 -0
  27. tpu_inference/mock/vllm_logger.py +212 -0
  28. tpu_inference/mock/vllm_logging_utils.py +15 -0
  29. tpu_inference/models/common/model_loader.py +10 -43
  30. tpu_inference/models/jax/llama3.py +1 -2
  31. tpu_inference/models/jax/llama_eagle3.py +5 -8
  32. tpu_inference/models/jax/phi3.py +376 -0
  33. tpu_inference/models/jax/qwen2.py +1 -2
  34. tpu_inference/models/jax/qwen2_5_vl.py +48 -163
  35. tpu_inference/models/jax/qwen3.py +1 -2
  36. tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
  37. tpu_inference/models/jax/utils/weight_utils.py +143 -198
  38. tpu_inference/models/vllm/vllm_model_wrapper.py +8 -14
  39. tpu_inference/platforms/tpu_platform.py +31 -37
  40. tpu_inference/runner/compilation_manager.py +58 -141
  41. tpu_inference/runner/kv_cache.py +1 -1
  42. tpu_inference/runner/kv_cache_manager.py +18 -17
  43. tpu_inference/runner/persistent_batch_manager.py +2 -40
  44. tpu_inference/runner/structured_decoding_manager.py +3 -2
  45. tpu_inference/runner/tpu_runner.py +147 -271
  46. tpu_inference/runner/utils.py +2 -2
  47. tpu_inference/spec_decode/jax/eagle3.py +21 -71
  48. tpu_inference/tpu_info.py +3 -4
  49. tpu_inference/utils.py +13 -36
  50. tpu_inference/worker/tpu_worker.py +25 -162
  51. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +3 -4
  52. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +55 -50
  53. tpu_inference/models/jax/llama_guard_4.py +0 -361
  54. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
  55. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
  56. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
@@ -13,14 +13,12 @@ from typing import Any, Optional
13
13
  import jax
14
14
  import jax.numpy as jnp
15
15
  import torch
16
- import torchax
17
16
  from flax import nnx
18
17
  from jax.sharding import Mesh, NamedSharding
19
18
  from jax.sharding import PartitionSpec as P
20
19
  from safetensors import safe_open
21
- from vllm.config import VllmConfig
22
20
 
23
- from tpu_inference import envs, utils
21
+ from tpu_inference import utils
24
22
  from tpu_inference.logger import init_logger
25
23
  from tpu_inference.models.jax.utils import file_utils
26
24
 
@@ -199,11 +197,12 @@ def shard_put(x: jax.Array, shardings, mesh: jax.sharding.Mesh) -> jax.Array:
199
197
  return jax.device_put(x, shardings)
200
198
 
201
199
 
202
- def get_default_maps(model_config, mesh: Mesh,
200
+ def get_default_maps(vllm_config, mesh: Mesh,
203
201
  name_map: dict[str, str]) -> MetadataMap:
204
202
  """Load weights from one model weights file to the model, run on single thread."""
205
203
  sharding_size = mesh.shape["model"]
206
204
 
205
+ model_config = vllm_config.model_config
207
206
  hf_config = model_config.hf_config
208
207
 
209
208
  num_heads = hf_config.num_attention_heads
@@ -267,15 +266,14 @@ def get_default_maps(model_config, mesh: Mesh,
267
266
  bias_pad_map=bias_pad_keys)
268
267
 
269
268
 
270
- def _load_and_shard_weight(vllm_config,
271
- params: nnx.State,
272
- shardings: Any,
273
- metadata_map: MetadataMap,
274
- mesh: Mesh,
275
- hf_key: str,
276
- hf_weight: jax.Array,
277
- keep_original_dtype_keys_regex: list[str]
278
- | None = None):
269
+ def _load_hf_weights_on_thread(vllm_config,
270
+ params: nnx.State,
271
+ metadata_map: MetadataMap,
272
+ mesh: Mesh,
273
+ weights_file: str,
274
+ filter_regex: str | None = None,
275
+ keep_original_dtype_keys_regex: list[str]
276
+ | None = None):
279
277
  name_map = metadata_map.name_map
280
278
  reshape_keys = metadata_map.reshape_map
281
279
  bias_reshape_keys = metadata_map.bias_reshape_map
@@ -292,118 +290,6 @@ def _load_and_shard_weight(vllm_config,
292
290
  head_dim = utils.get_padded_head_dim(head_dim_original)
293
291
  head_dim_pad = head_dim - head_dim_original
294
292
 
295
- # Check if the key should retain its original dtype
296
- keep_original_dtype = False
297
- if keep_original_dtype_keys_regex:
298
- for pattern in keep_original_dtype_keys_regex:
299
- if re.match(pattern, hf_key):
300
- keep_original_dtype = True
301
- break
302
-
303
- # Converting to config's dtype
304
- if not keep_original_dtype and hf_weight.dtype != model_config.dtype:
305
- logger.warning(
306
- f"Converting dtype for {hf_key} from {hf_weight.dtype} to {model_config.dtype}"
307
- )
308
- hf_weight = hf_weight.astype(model_config.dtype)
309
-
310
- if hf_key.endswith(".weight"):
311
- hf_key = hf_key.removesuffix(".weight")
312
-
313
- # Find the corresponding model key using the HF key
314
- if "layers" in hf_key:
315
- layer_num = re.search(r"layers\.(\d+)", hf_key).group(1)
316
- layer_key = re.sub(r"layers\.\d+", "layers.*", hf_key)
317
- model_key = name_map[layer_key]
318
- model_key = re.sub(r"layers\.\*", f"layers.{layer_num}", model_key)
319
- elif "blocks" in hf_key:
320
- layer_num = re.search(r"blocks\.(\d+)", hf_key).group(1)
321
- layer_key = re.sub(r"blocks\.\d+", "blocks.*", hf_key)
322
- model_key = name_map[layer_key]
323
- model_key = re.sub(r"blocks\.\*", f"blocks.{layer_num}", model_key)
324
- else:
325
- if hf_key not in name_map and hf_key == "lm_head":
326
- logger.warning(f"Skip loading {hf_key} due to tie_word_embeddings")
327
- return
328
- if hf_key not in name_map and "t2d" in hf_key:
329
- logger.warning(
330
- f"Skip loading {hf_key} as it's not used in eagle-3 for now")
331
- return
332
- model_key = name_map.get(hf_key, hf_key)
333
-
334
- model_weight, model_sharding = get_param_and_sharding(
335
- params, shardings, model_key)
336
-
337
- logger.debug(
338
- "before transform | "
339
- f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
340
- )
341
-
342
- if hf_key.endswith(".bias"):
343
- for key in bias_reshape_keys:
344
- if key in hf_key:
345
- hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key])
346
- if head_dim_pad > 0:
347
- hf_weight = jnp.pad(hf_weight, ((0, 0), (0, head_dim_pad)))
348
- break
349
- else:
350
- for key in reshape_keys:
351
- if key in hf_key:
352
- hf_weight = jnp.reshape(hf_weight, reshape_keys[key])
353
- if head_dim_pad > 0:
354
- if "o_proj" in key:
355
- hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0),
356
- (0, head_dim_pad)))
357
- else:
358
- hf_weight = jnp.pad(hf_weight,
359
- ((0, 0), (0, head_dim_pad),
360
- (0, 0)))
361
- break
362
- for key in transpose_keys:
363
- if key in hf_key:
364
- hf_weight = jnp.transpose(hf_weight, transpose_keys[key])
365
- break
366
-
367
- # Pad num-kv-heads
368
- if hf_key.endswith(".bias"):
369
- for key, value in bias_pad_keys.items():
370
- dim = value[0]
371
- dim_size = value[1]
372
- if key in hf_key and dim_size != 0:
373
- hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
374
- break
375
- else:
376
- for key, value in pad_keys.items():
377
- dim = value[0]
378
- dim_size = value[1]
379
- if key in hf_key and dim_size != 0:
380
- hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
381
- break
382
-
383
- logger.debug(
384
- "after transform | "
385
- f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
386
- )
387
-
388
- if head_dim_pad == 0:
389
- assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}"
390
-
391
- # Update the model weight
392
- spec = model_weight.sharding.spec if isinstance(
393
- model_weight.sharding, NamedSharding) else model_weight.sharding
394
- model_weight.value = shard(hf_weight, spec)
395
-
396
-
397
- def _load_hf_weights_on_thread(
398
- vllm_config: VllmConfig,
399
- params: nnx.State,
400
- metadata_map: "MetadataMap",
401
- mesh: Mesh,
402
- weights_file: str,
403
- filter_regex: Optional[str] = None,
404
- keep_original_dtype_keys_regex: Optional[list[str]] = None,
405
- ):
406
- """Loads weights from a single weights file."""
407
293
  try:
408
294
  shardings = nnx.get_named_sharding(params, mesh)
409
295
  except TypeError:
@@ -411,88 +297,147 @@ def _load_hf_weights_on_thread(
411
297
 
412
298
  for hf_key, hf_weight in model_weights_single_file_generator(
413
299
  weights_file, framework="flax", filter_regex=filter_regex):
414
- _load_and_shard_weight(
415
- vllm_config,
416
- params,
417
- shardings,
418
- metadata_map,
419
- mesh,
420
- hf_key,
421
- hf_weight,
422
- keep_original_dtype_keys_regex,
423
- )
424
300
 
301
+ # Check if the key should retain its original dtype
302
+ keep_original_dtype = False
303
+ if keep_original_dtype_keys_regex:
304
+ for pattern in keep_original_dtype_keys_regex:
305
+ if re.match(pattern, hf_key):
306
+ keep_original_dtype = True
307
+ break
425
308
 
426
- def load_hf_weights(
427
- vllm_config: VllmConfig,
428
- model: nnx.Module,
429
- metadata_map: "MetadataMap",
430
- mesh: Mesh,
431
- filter_regex: Optional[str] = None,
432
- is_draft_model: bool = False,
433
- keep_original_dtype_keys_regex: Optional[list[str]] = None,
434
- ):
435
- """Load weights into a JAX model from either an iterator or files."""
436
- params = nnx.state(model)
437
- try:
438
- shardings = nnx.get_named_sharding(params, mesh)
439
- except TypeError:
440
- shardings = params
441
- weights_iterator = None
442
- if hasattr(vllm_config.model_config, "model_weights_iterator"):
443
- weights_iterator = vllm_config.model_config.model_weights_iterator
444
- env = torchax.default_env()
445
- # The weights_iterator is used in RunAI model streamer integration.
446
- if weights_iterator is not None:
447
- for hf_key, hf_weight in weights_iterator:
448
- if filter_regex and not re.match(filter_regex, hf_key):
309
+ # Converting to config's dtype
310
+ if not keep_original_dtype and hf_weight.dtype != model_config.dtype:
311
+ logger.warning(
312
+ f"Converting dtype for {hf_key} from {hf_weight.dtype} to {model_config.dtype}"
313
+ )
314
+ hf_weight = hf_weight.astype(model_config.dtype)
315
+
316
+ if hf_key.endswith(".weight"):
317
+ hf_key = hf_key.removesuffix(".weight")
318
+
319
+ # Find the corresponding model key using the HF key
320
+ if "layers" in hf_key:
321
+ layer_num = re.search(r"layers\.(\d+)", hf_key).group(1)
322
+ layer_key = re.sub(r"layers\.\d+", "layers.*", hf_key)
323
+ model_key = name_map[layer_key]
324
+ model_key = re.sub(r"layers\.\*", f"layers.{layer_num}", model_key)
325
+ elif "blocks" in hf_key:
326
+ layer_num = re.search(r"blocks\.(\d+)", hf_key).group(1)
327
+ layer_key = re.sub(r"blocks\.\d+", "blocks.*", hf_key)
328
+ model_key = name_map[layer_key]
329
+ model_key = re.sub(r"blocks\.\*", f"blocks.{layer_num}", model_key)
330
+ else:
331
+ if hf_key not in name_map and hf_key == "lm_head":
332
+ logger.warning(
333
+ f"Skip loading {hf_key} due to tie_word_embeddings")
334
+ continue
335
+ if hf_key not in name_map and "t2d" in hf_key:
336
+ logger.warning(
337
+ f"Skip loading {hf_key} as it's not used in eagle-3 for now"
338
+ )
449
339
  continue
340
+ model_key = name_map.get(hf_key, hf_key)
341
+ model_weight, model_sharding = get_param_and_sharding(
342
+ params, shardings, model_key)
450
343
 
451
- # Since the weights_iterator yields Pytorch tensors (torch.Tensor),
452
- # we need to convert them to JAX arrays (jax.Array).
453
- hf_weight_jax = env.t2j_copy(hf_weight)
344
+ logger.debug(
345
+ "before transform | "
346
+ f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
347
+ )
348
+
349
+ if hf_key.endswith(".bias"):
350
+ for key in bias_reshape_keys:
351
+ if key in hf_key:
352
+ hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key])
353
+ if head_dim_pad > 0:
354
+ hf_weight = jnp.pad(hf_weight,
355
+ ((0, 0), (0, head_dim_pad)))
356
+ break
357
+ else:
358
+ for key in reshape_keys:
359
+ if key in hf_key:
360
+ hf_weight = jnp.reshape(hf_weight, reshape_keys[key])
361
+ if head_dim_pad > 0:
362
+ if "o_proj" in key:
363
+ hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0),
364
+ (0, head_dim_pad)))
365
+ else:
366
+ hf_weight = jnp.pad(hf_weight,
367
+ ((0, 0), (0, head_dim_pad),
368
+ (0, 0)))
369
+ break
370
+ for key in transpose_keys:
371
+ if key in hf_key:
372
+ hf_weight = jnp.transpose(hf_weight, transpose_keys[key])
373
+ break
374
+
375
+ # Pad num-kv-heads
376
+ if hf_key.endswith(".bias"):
377
+ for key, value in bias_pad_keys.items():
378
+ dim = value[0]
379
+ dim_size = value[1]
380
+ if key in hf_key and dim_size != 0:
381
+ hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
382
+ break
383
+ else:
384
+ for key, value in pad_keys.items():
385
+ dim = value[0]
386
+ dim_size = value[1]
387
+ if key in hf_key and dim_size != 0:
388
+ hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
389
+ break
390
+
391
+ logger.debug(
392
+ "after transform | "
393
+ f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
394
+ )
454
395
 
455
- _load_and_shard_weight(
396
+ if head_dim_pad == 0:
397
+ assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}"
398
+
399
+ # Update the model weight
400
+ spec = model_weight.sharding.spec if isinstance(
401
+ model_weight.sharding, NamedSharding) else model_weight.sharding
402
+ model_weight.value = shard(hf_weight, spec)
403
+
404
+
405
+ def load_hf_weights(vllm_config,
406
+ model: nnx.Module,
407
+ metadata_map: MetadataMap,
408
+ mesh: Mesh,
409
+ filter_regex: str | None = None,
410
+ is_draft_model: bool = False,
411
+ keep_original_dtype_keys_regex: list[str] | None = None):
412
+ """Load weights from all model weights files to the model, run in multi threads."""
413
+ if is_draft_model:
414
+ model_path = vllm_config.speculative_config.draft_model_config.model
415
+ else:
416
+ model_path = vllm_config.model_config.model
417
+ weights_files = get_model_weights_files(
418
+ model_path, vllm_config.load_config.download_dir)
419
+ params = nnx.state(model)
420
+ max_workers = min(64, len(weights_files))
421
+ # NOTE(xiang): Disable multi-threading mode if running on multi-host.
422
+ # Because multi-threading would cause different JAX processes to load
423
+ # different weights at the same time.
424
+ if os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
425
+ max_workers = 1
426
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
427
+ futures = [
428
+ executor.submit(
429
+ _load_hf_weights_on_thread,
456
430
  vllm_config,
457
431
  params,
458
- shardings,
459
432
  metadata_map,
460
433
  mesh,
461
- hf_key,
462
- hf_weight_jax,
463
- keep_original_dtype_keys_regex,
464
- )
465
- else:
466
- # File-based path (multi-threaded)
467
- if is_draft_model:
468
- model_path = vllm_config.speculative_config.draft_model_config.model
469
- else:
470
- model_path = vllm_config.model_config.model
471
- weights_files = get_model_weights_files(
472
- model_path, vllm_config.load_config.download_dir)
473
- max_workers = min(64, len(weights_files))
474
- # NOTE(xiang): Disable multi-threading mode if running on multi-host.
475
- # Because multi-threading would cause different JAX processes to load
476
- # different weights at the same time.
477
- if envs.TPU_MULTIHOST_BACKEND == "ray":
478
- max_workers = 1
479
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
480
- futures = [
481
- executor.submit(
482
- _load_hf_weights_on_thread,
483
- vllm_config,
484
- params,
485
- metadata_map,
486
- mesh,
487
- weights_file,
488
- filter_regex=filter_regex,
489
- keep_original_dtype_keys_regex=
490
- keep_original_dtype_keys_regex,
491
- ) for weights_file in weights_files
492
- ]
493
- for future in futures:
494
- future.result()
495
-
434
+ weights_file,
435
+ filter_regex=filter_regex,
436
+ keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
437
+ for weights_file in weights_files
438
+ ]
439
+ for future in futures:
440
+ future.result()
496
441
  check_all_loaded(params)
497
442
  nnx.update(model, params)
498
443
 
@@ -9,7 +9,6 @@ import jax
9
9
  import torch
10
10
  import torch.nn
11
11
  import torchax
12
- import vllm.envs as vllm_envs
13
12
  from flax.typing import PRNGKey
14
13
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
15
14
  from torchax.interop import jax_view, torch_view
@@ -119,16 +118,10 @@ class VllmModelWrapper:
119
118
  "torch._sync",
120
119
  return_value=None) if use_random_weights else nullcontext()
121
120
 
122
- # By default load weights to the CPU device first. If we are running
123
- # under Pathways, this would cause weights to be loaded on a CPU-only
124
- # node, so we'll need to remove this context.
125
- jax_context = jax.default_device(
126
- jax.devices("cpu")
127
- [0]) if not vllm_envs.VLLM_TPU_USING_PATHWAYS else nullcontext()
128
-
129
121
  # Load the vLLM model and wrap it into a new model whose forward
130
122
  # function can calculate the hidden_state and logits.
131
- with load_context, jax_context:
123
+ available_devices = self.mesh.devices.flatten()
124
+ with load_context, jax.default_device(available_devices[0]):
132
125
  vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
133
126
  lora_manager = None
134
127
  if vllm_config_for_load.lora_config is not None:
@@ -169,7 +162,6 @@ class VllmModelWrapper:
169
162
  input_ids: jax.Array,
170
163
  attn_metadata: AttentionMetadata,
171
164
  input_embeds: jax.Array,
172
- input_positions: jax.Array,
173
165
  layer_name_to_kvcache_index: Sequence[Tuple[str, int]],
174
166
  lora_metadata,
175
167
  intermediate_tensors: JaxIntermediateTensors = None,
@@ -196,8 +188,8 @@ class VllmModelWrapper:
196
188
  torch_view(params_and_buffers),
197
189
  kwargs={
198
190
  "input_ids": torch_view(input_ids),
199
- "positions": torch_view(input_positions),
200
- "intermediate_tensors": None,
191
+ "positions": torch_view(attn_metadata.input_positions),
192
+ "intermediate_tensors": intermediate_tensors,
201
193
  "inputs_embeds": None,
202
194
  },
203
195
  tie_weights=False,
@@ -221,7 +213,7 @@ class VllmModelWrapper:
221
213
  @functools.partial(
222
214
  jax.jit,
223
215
  out_shardings=(NamedSharding(self.mesh,
224
- PartitionSpec("data", "model"))),
216
+ PartitionSpec(None, "model"))),
225
217
  )
226
218
  def compute_logits_func(
227
219
  params_and_buffers: Any,
@@ -263,6 +255,7 @@ def load_lora_model(model: torch.nn.Module, vllm_config: VllmConfig,
263
255
  vllm_config,
264
256
  device,
265
257
  model.embedding_modules,
258
+ model.embedding_padding_modules,
266
259
  )
267
260
  return lora_manager, lora_manager.create_lora_manager(model)
268
261
 
@@ -276,9 +269,10 @@ def replace_set_lora(model):
276
269
  index: int,
277
270
  lora_a: torch.Tensor,
278
271
  lora_b: torch.Tensor,
272
+ embeddings_tensor: Optional[torch.Tensor],
279
273
  ):
280
274
  with torchax.default_env():
281
- self._original_set_lora(index, lora_a, lora_b)
275
+ self._original_set_lora(index, lora_a, lora_b, embeddings_tensor)
282
276
 
283
277
  def _tpu_reset_lora(self, index: int):
284
278
  with torchax.default_env():
@@ -1,10 +1,11 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
+ import os
3
4
  from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
4
5
 
5
6
  import jax.numpy as jnp
6
- import torch
7
7
  import vllm.envs as vllm_envs
8
+ from torchax.ops.mappings import j2t_dtype
8
9
  from tpu_info import device
9
10
  from vllm.inputs import ProcessorInputs, PromptType
10
11
  from vllm.platforms.interface import Platform, PlatformEnum
@@ -13,10 +14,9 @@ from vllm.sampling_params import SamplingParams, SamplingType
13
14
  from tpu_inference import envs
14
15
  from tpu_inference.layers.common.sharding import ShardingConfigManager
15
16
  from tpu_inference.logger import init_logger
16
- from tpu_inference.utils import to_jax_dtype, to_torch_dtype
17
17
 
18
18
  if TYPE_CHECKING:
19
- from vllm.attention.backends.registry import AttentionBackendEnum
19
+ from vllm.attention.backends.registry import _Backend
20
20
  from vllm.config import BlockSize, ModelConfig, VllmConfig
21
21
  from vllm.pooling_params import PoolingParams
22
22
  else:
@@ -24,10 +24,16 @@ else:
24
24
  ModelConfig = None
25
25
  VllmConfig = None
26
26
  PoolingParams = None
27
- AttentionBackendEnum = None
27
+ _Backend = None
28
28
 
29
29
  logger = init_logger(__name__)
30
30
 
31
+ _DTYPE: dict[str, jnp.dtype] = {
32
+ "bfloat16": jnp.bfloat16,
33
+ "float": jnp.float32,
34
+ "float32": jnp.float32,
35
+ }
36
+
31
37
 
32
38
  class TpuPlatform(Platform):
33
39
  _enum = PlatformEnum.TPU
@@ -48,13 +54,13 @@ class TpuPlatform(Platform):
48
54
  ]
49
55
 
50
56
  @classmethod
51
- def get_attn_backend_cls(cls, selected_backend: "AttentionBackendEnum",
52
- head_size: int, dtype: jnp.dtype,
53
- kv_cache_dtype: Optional[str], block_size: int,
54
- use_v1: bool, use_mla: bool, has_sink: bool,
55
- use_sparse: bool, attn_type: Any) -> str:
56
- from vllm.attention.backends.registry import AttentionBackendEnum
57
- if selected_backend != AttentionBackendEnum.PALLAS:
57
+ def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
58
+ dtype: jnp.dtype, kv_cache_dtype: Optional[str],
59
+ block_size: int, use_v1: bool, use_mla: bool,
60
+ has_sink: bool, use_sparse: bool,
61
+ attn_type: Any) -> str:
62
+ from vllm.attention.backends.registry import _Backend
63
+ if selected_backend != _Backend.PALLAS:
58
64
  logger.info("Cannot use %s backend on TPU.", selected_backend)
59
65
 
60
66
  if use_v1:
@@ -77,14 +83,6 @@ class TpuPlatform(Platform):
77
83
  logger.warning(f"Error getting device name: {e}")
78
84
  return 'TPU'
79
85
 
80
- @classmethod
81
- def fp8_dtype(cls) -> torch.dtype:
82
- if cls.get_device_name().lower() == "tpu v6e":
83
- logger.info(
84
- "Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.")
85
- return torch.float8_e5m2
86
- return torch.float8_e4m3fn
87
-
88
86
  @classmethod
89
87
  def get_device_total_memory(cls, device_id: int = 0) -> int:
90
88
  raise NotImplementedError
@@ -135,7 +133,6 @@ class TpuPlatform(Platform):
135
133
  # For v0, the default block size is 16.
136
134
  if cache_config and cache_config.block_size is None:
137
135
  cache_config.block_size = cast(BlockSize, 16)
138
-
139
136
  compilation_config = vllm_config.compilation_config
140
137
 
141
138
  # TPU only supports DYNAMO_TRACE_ONCE compilation level
@@ -152,19 +149,20 @@ class TpuPlatform(Platform):
152
149
  # NOTE(xiang): convert dtype to jnp.dtype
153
150
  # NOTE(wenlong): skip this logic for mm model preprocessing
154
151
  # For mm model preprocessors, it may need the output dtype to be torch.
155
- # In order to avoid a PR to vLLM, we postpone the dtype checking during
156
- # tpu_worker initialization
152
+ # In order to avoid a PR to vLLM, we postpone the dtype checking during tpu_worker initialization
157
153
  if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm":
158
- model_dtype = vllm_config.model_config.dtype
159
- try:
160
- dtype = to_jax_dtype(model_dtype)
161
- except ValueError:
162
- logger.warning(f"{model_dtype=} is not supported. "
163
- "Falling back to jnp.bfloat16")
164
- dtype = jnp.bfloat16
165
- if impl == "vllm":
166
- dtype = to_torch_dtype(dtype)
167
- vllm_config.model_config.dtype = dtype
154
+ if not isinstance(vllm_config.model_config.dtype, str):
155
+ logger.warning(
156
+ "The model dtype is not properly set for JAX backend. "
157
+ "Overwriting it to jnp.bfloat16")
158
+ vllm_config.model_config.dtype = jnp.bfloat16
159
+ else:
160
+ vllm_config.model_config.dtype = _DTYPE.get(
161
+ vllm_config.model_config.dtype, jnp.bfloat16)
162
+
163
+ if impl == "vllm":
164
+ vllm_config.model_config.dtype = j2t_dtype(
165
+ vllm_config.model_config.dtype.dtype)
168
166
 
169
167
  # TODO(cuiq): remove this dependency.
170
168
  from vllm.v1.attention.backends.pallas import PallasAttentionBackend
@@ -185,7 +183,7 @@ class TpuPlatform(Platform):
185
183
  parallel_config.worker_cls = \
186
184
  "tpu_inference.worker.tpu_worker.TPUWorker"
187
185
 
188
- multihost_backend = envs.TPU_MULTIHOST_BACKEND
186
+ multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
189
187
  if not multihost_backend: # Single host
190
188
  if parallel_config.pipeline_parallel_size == 1:
191
189
  logger.info("Force using UniProcExecutor for JAX on \
@@ -269,7 +267,3 @@ class TpuPlatform(Platform):
269
267
  Returns if the current platform needs to sync weight loader.
270
268
  """
271
269
  return True
272
-
273
- @classmethod
274
- def support_hybrid_kv_cache(cls) -> bool:
275
- return True