tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (56) hide show
  1. tests/kernels/fused_moe_v1_test.py +34 -303
  2. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
  3. tests/lora/test_layers.py +6 -0
  4. tests/lora/utils.py +8 -0
  5. tests/test_envs.py +11 -32
  6. tests/test_utils.py +2 -1
  7. tpu_inference/__init__.py +3 -22
  8. tpu_inference/core/disagg_utils.py +8 -6
  9. tpu_inference/distributed/tpu_connector.py +4 -3
  10. tpu_inference/distributed/utils.py +2 -3
  11. tpu_inference/envs.py +8 -61
  12. tpu_inference/executors/ray_distributed_executor.py +2 -9
  13. tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
  15. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +145 -266
  16. tpu_inference/layers/common/attention_interface.py +1 -7
  17. tpu_inference/layers/common/sharding.py +5 -5
  18. tpu_inference/layers/vllm/fused_moe.py +208 -170
  19. tpu_inference/layers/vllm/quantization/common.py +1 -6
  20. tpu_inference/layers/vllm/quantization/mxfp4.py +73 -138
  21. tpu_inference/layers/vllm/quantization/unquantized.py +64 -58
  22. tpu_inference/layers/vllm/sharding.py +2 -2
  23. tpu_inference/lora/torch_punica_tpu.py +2 -1
  24. tpu_inference/mock/__init__.py +0 -0
  25. tpu_inference/mock/vllm_config_utils.py +28 -0
  26. tpu_inference/mock/vllm_envs.py +1219 -0
  27. tpu_inference/mock/vllm_logger.py +212 -0
  28. tpu_inference/mock/vllm_logging_utils.py +15 -0
  29. tpu_inference/models/common/model_loader.py +10 -43
  30. tpu_inference/models/jax/llama3.py +1 -2
  31. tpu_inference/models/jax/llama_eagle3.py +5 -8
  32. tpu_inference/models/jax/phi3.py +376 -0
  33. tpu_inference/models/jax/qwen2.py +1 -2
  34. tpu_inference/models/jax/qwen2_5_vl.py +48 -163
  35. tpu_inference/models/jax/qwen3.py +1 -2
  36. tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
  37. tpu_inference/models/jax/utils/weight_utils.py +143 -198
  38. tpu_inference/models/vllm/vllm_model_wrapper.py +8 -14
  39. tpu_inference/platforms/tpu_platform.py +31 -37
  40. tpu_inference/runner/compilation_manager.py +58 -141
  41. tpu_inference/runner/kv_cache.py +1 -1
  42. tpu_inference/runner/kv_cache_manager.py +18 -17
  43. tpu_inference/runner/persistent_batch_manager.py +2 -40
  44. tpu_inference/runner/structured_decoding_manager.py +3 -2
  45. tpu_inference/runner/tpu_runner.py +147 -271
  46. tpu_inference/runner/utils.py +2 -2
  47. tpu_inference/spec_decode/jax/eagle3.py +21 -71
  48. tpu_inference/tpu_info.py +3 -4
  49. tpu_inference/utils.py +13 -36
  50. tpu_inference/worker/tpu_worker.py +25 -162
  51. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +3 -4
  52. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +55 -50
  53. tpu_inference/models/jax/llama_guard_4.py +0 -361
  54. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
  55. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
  56. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,212 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ """Logging configuration for vLLM."""
4
+ import datetime
5
+ import json
6
+ import logging
7
+ import sys
8
+ from collections.abc import Hashable
9
+ from functools import lru_cache, partial
10
+ from logging import Logger
11
+ from logging.config import dictConfig
12
+ from os import path
13
+ from types import MethodType
14
+ from typing import Any, cast
15
+
16
+ import tpu_inference.mock.vllm_envs as envs
17
+
18
+ VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING
19
+ VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH
20
+ VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL
21
+ VLLM_LOGGING_PREFIX = envs.VLLM_LOGGING_PREFIX
22
+
23
+ _FORMAT = (f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s "
24
+ "[%(filename)s:%(lineno)d] %(message)s")
25
+ _DATE_FORMAT = "%m-%d %H:%M:%S"
26
+
27
+ DEFAULT_LOGGING_CONFIG = {
28
+ "formatters": {
29
+ "vllm": {
30
+ "class": "tpu_inference.vllm_logging_utils.NewLineFormatter",
31
+ "datefmt": _DATE_FORMAT,
32
+ "format": _FORMAT,
33
+ },
34
+ },
35
+ "handlers": {
36
+ "vllm": {
37
+ "class": "logging.StreamHandler",
38
+ "formatter": "vllm",
39
+ "level": VLLM_LOGGING_LEVEL,
40
+ "stream": "ext://sys.stdout",
41
+ },
42
+ },
43
+ "loggers": {
44
+ "vllm": {
45
+ "handlers": ["vllm"],
46
+ "level": "DEBUG",
47
+ "propagate": False,
48
+ },
49
+ },
50
+ "version": 1,
51
+ "disable_existing_loggers": False
52
+ }
53
+
54
+
55
+ @lru_cache
56
+ def _print_debug_once(logger: Logger, msg: str, *args: Hashable) -> None:
57
+ # Set the stacklevel to 2 to print the original caller's line info
58
+ logger.debug(msg, *args, stacklevel=2)
59
+
60
+
61
+ @lru_cache
62
+ def _print_info_once(logger: Logger, msg: str, *args: Hashable) -> None:
63
+ # Set the stacklevel to 2 to print the original caller's line info
64
+ logger.info(msg, *args, stacklevel=2)
65
+
66
+
67
+ @lru_cache
68
+ def _print_warning_once(logger: Logger, msg: str, *args: Hashable) -> None:
69
+ # Set the stacklevel to 2 to print the original caller's line info
70
+ logger.warning(msg, *args, stacklevel=2)
71
+
72
+
73
+ class _VllmLogger(Logger):
74
+ """
75
+ Note:
76
+ This class is just to provide type information.
77
+ We actually patch the methods directly on the [`logging.Logger`][]
78
+ instance to avoid conflicting with other libraries such as
79
+ `intel_extension_for_pytorch.utils._logger`.
80
+ """
81
+
82
+ def debug_once(self, msg: str, *args: Hashable) -> None:
83
+ """
84
+ As [`debug`][logging.Logger.debug], but subsequent calls with
85
+ the same message are silently dropped.
86
+ """
87
+ _print_debug_once(self, msg, *args)
88
+
89
+ def info_once(self, msg: str, *args: Hashable) -> None:
90
+ """
91
+ As [`info`][logging.Logger.info], but subsequent calls with
92
+ the same message are silently dropped.
93
+ """
94
+ _print_info_once(self, msg, *args)
95
+
96
+ def warning_once(self, msg: str, *args: Hashable) -> None:
97
+ """
98
+ As [`warning`][logging.Logger.warning], but subsequent calls with
99
+ the same message are silently dropped.
100
+ """
101
+ _print_warning_once(self, msg, *args)
102
+
103
+
104
+ # Pre-defined methods mapping to avoid repeated dictionary creation
105
+ _METHODS_TO_PATCH = {
106
+ "debug_once": _print_debug_once,
107
+ "info_once": _print_info_once,
108
+ "warning_once": _print_warning_once,
109
+ }
110
+
111
+
112
+ def _configure_vllm_root_logger() -> None:
113
+ logging_config = dict[str, Any]()
114
+
115
+ if not VLLM_CONFIGURE_LOGGING and VLLM_LOGGING_CONFIG_PATH:
116
+ raise RuntimeError(
117
+ "VLLM_CONFIGURE_LOGGING evaluated to false, but "
118
+ "VLLM_LOGGING_CONFIG_PATH was given. VLLM_LOGGING_CONFIG_PATH "
119
+ "implies VLLM_CONFIGURE_LOGGING. Please enable "
120
+ "VLLM_CONFIGURE_LOGGING or unset VLLM_LOGGING_CONFIG_PATH.")
121
+
122
+ if VLLM_CONFIGURE_LOGGING:
123
+ logging_config = DEFAULT_LOGGING_CONFIG
124
+
125
+ if VLLM_LOGGING_CONFIG_PATH:
126
+ if not path.exists(VLLM_LOGGING_CONFIG_PATH):
127
+ raise RuntimeError(
128
+ "Could not load logging config. File does not exist: %s",
129
+ VLLM_LOGGING_CONFIG_PATH)
130
+ with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
131
+ custom_config = json.loads(file.read())
132
+
133
+ if not isinstance(custom_config, dict):
134
+ raise ValueError("Invalid logging config. Expected dict, got %s.",
135
+ type(custom_config).__name__)
136
+ logging_config = custom_config
137
+
138
+ for formatter in logging_config.get("formatters", {}).values():
139
+ # This provides backwards compatibility after #10134.
140
+ if formatter.get(
141
+ "class"
142
+ ) == "tpu_inference.vllm_logging_utils.NewLineFormatter":
143
+ formatter[
144
+ "class"] = "tpu_inference.mock.vllm_logging_utils.NewLineFormatter"
145
+
146
+ if logging_config:
147
+ dictConfig(logging_config)
148
+
149
+
150
+ # The root logger is initialized when the module is imported.
151
+ # This is thread-safe as the module is only imported once,
152
+ # guaranteed by the Python GIL.
153
+ _configure_vllm_root_logger()
154
+
155
+
156
+ def init_logger(name: str) -> _VllmLogger:
157
+ """The main purpose of this function is to ensure that loggers are
158
+ retrieved in such a way that we can be sure the root vllm logger has
159
+ already been configured."""
160
+
161
+ logger = logging.getLogger(name)
162
+
163
+ for method_name, method in _METHODS_TO_PATCH.items():
164
+ setattr(logger, method_name, MethodType(method, logger))
165
+
166
+ return cast(_VllmLogger, logger)
167
+
168
+
169
+ logger = init_logger(__name__)
170
+
171
+
172
+ def _trace_calls(log_path, root_dir, frame, event, arg=None):
173
+ if event in ['call', 'return']:
174
+ # Extract the filename, line number, function name, and the code object
175
+ filename = frame.f_code.co_filename
176
+ lineno = frame.f_lineno
177
+ func_name = frame.f_code.co_name
178
+ if not filename.startswith(root_dir):
179
+ # only log the functions in the vllm root_dir
180
+ return
181
+ # Log every function call or return
182
+ try:
183
+ last_frame = frame.f_back
184
+ if last_frame is not None:
185
+ last_filename = last_frame.f_code.co_filename
186
+ last_lineno = last_frame.f_lineno
187
+ last_func_name = last_frame.f_code.co_name
188
+ else:
189
+ # initial frame
190
+ last_filename = ""
191
+ last_lineno = 0
192
+ last_func_name = ""
193
+ with open(log_path, 'a') as f:
194
+ ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
195
+ if event == 'call':
196
+ f.write(f"{ts} Call to"
197
+ f" {func_name} in {filename}:{lineno}"
198
+ f" from {last_func_name} in {last_filename}:"
199
+ f"{last_lineno}\n")
200
+ else:
201
+ f.write(f"{ts} Return from"
202
+ f" {func_name} in {filename}:{lineno}"
203
+ f" to {last_func_name} in {last_filename}:"
204
+ f"{last_lineno}\n")
205
+ except NameError:
206
+ # modules are deleted during shutdown
207
+ pass
208
+ return partial(_trace_calls, log_path, root_dir)
209
+
210
+
211
+ def enable_trace_function_call(log_file_path: str, root_dir: str):
212
+ sys.settrace(partial(_trace_calls, log_file_path, root_dir))
@@ -0,0 +1,15 @@
1
+ import logging
2
+
3
+
4
+ class NewLineFormatter(logging.Formatter):
5
+ """Adds logging prefix to newlines to align multi-line messages."""
6
+
7
+ def __init__(self, fmt, datefmt=None, style="%"):
8
+ logging.Formatter.__init__(self, fmt, datefmt, style)
9
+
10
+ def format(self, record):
11
+ msg = logging.Formatter.format(self, record)
12
+ if record.message != "":
13
+ parts = msg.split(record.message)
14
+ msg = msg.replace("\n", "\r\n" + parts[0])
15
+ return msg
@@ -8,9 +8,6 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec
8
8
  from torchax.ops.mappings import j2t_dtype
9
9
  from transformers import PretrainedConfig
10
10
  from vllm.config import VllmConfig
11
- from vllm.model_executor.model_loader import get_model_loader
12
- from vllm.model_executor.model_loader.runai_streamer_loader import \
13
- RunaiModelStreamerLoader
14
11
  from vllm.utils.func_utils import supports_kw
15
12
 
16
13
  from tpu_inference import envs
@@ -39,17 +36,19 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
39
36
  from tpu_inference.models.jax.llama3 import LlamaForCausalLM
40
37
  from tpu_inference.models.jax.llama4 import Llama4ForCausalLM
41
38
  from tpu_inference.models.jax.llama_eagle3 import EagleLlama3ForCausalLM
42
- from tpu_inference.models.jax.llama_guard_4 import LlamaGuard4ForCausalLM
39
+ from tpu_inference.models.jax.phi3 import Phi3ForCausalLM
40
+ from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
43
41
  from tpu_inference.models.jax.qwen2_5_vl import \
44
42
  Qwen2_5_VLForConditionalGeneration
45
43
  from tpu_inference.models.jax.qwen3 import Qwen3ForCausalLM
46
44
  _MODEL_REGISTRY["Llama4ForCausalLM"] = Llama4ForCausalLM
47
45
  _MODEL_REGISTRY["DeepseekV3ForCausalLM"] = DeepSeekV3
48
46
  _MODEL_REGISTRY["LlamaForCausalLM"] = LlamaForCausalLM
49
- _MODEL_REGISTRY["Llama4ForConditionalGeneration"] = LlamaGuard4ForCausalLM
47
+ _MODEL_REGISTRY["Qwen2ForCausalLM"] = Qwen2ForCausalLM
50
48
  _MODEL_REGISTRY["Qwen3ForCausalLM"] = Qwen3ForCausalLM
51
49
  _MODEL_REGISTRY[
52
50
  "Qwen2_5_VLForConditionalGeneration"] = Qwen2_5_VLForConditionalGeneration
51
+ _MODEL_REGISTRY["Phi3ForCausalLM"] = Phi3ForCausalLM
53
52
  _MODEL_REGISTRY["Eagle3LlamaForCausalLM"] = EagleLlama3ForCausalLM
54
53
  _MODEL_REGISTRY["GptOssForCausalLM"] = GptOss
55
54
 
@@ -58,10 +57,8 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
58
57
  if arch in _MODEL_REGISTRY:
59
58
  return _MODEL_REGISTRY[arch]
60
59
  raise UnsupportedArchitectureError(
61
- f"Model architectures {architectures} not "
62
- "registered in tpu-inference. Falling back to vLLM-native "
63
- f"Pytorch definition. JAX-native architectures: {list(_MODEL_REGISTRY.keys())}"
64
- )
60
+ f"Model architectures {architectures} are not supported for now. "
61
+ f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
65
62
 
66
63
 
67
64
  def _get_nnx_model(
@@ -180,23 +177,7 @@ def _get_nnx_model(
180
177
  # the model creation again, otherwise the model forward will have
181
178
  # non-trivial overhead in PjitFunction.
182
179
  with mesh:
183
- loader = get_model_loader(vllm_config.load_config)
184
- if isinstance(loader, RunaiModelStreamerLoader):
185
- model_weights = vllm_config.model_config.model
186
- if hasattr(vllm_config.model_config, "model_weights"):
187
- model_weights = vllm_config.model_config.model_weights
188
- weights_iterator = loader._get_weights_iterator(
189
- model_weights, vllm_config.model_config.revision)
190
- # We set the weights iterator at runtime, to prevent having to change
191
- # every model's load_weights signature. This also prevents us from hitting
192
- # a TypeError at runtime if you use the RunaiModelStreamerLoader with any
193
- # flax_nnx model whose load_weights function does not accept the
194
- # weights_iterator keyword argument.
195
- vllm_config.model_config.model_weights_iterator = weights_iterator
196
- model.load_weights(rng)
197
- del vllm_config.model_config.model_weights_iterator
198
- else:
199
- model.load_weights(rng)
180
+ model.load_weights(rng)
200
181
  jit_model = create_jit_model(
201
182
  model,
202
183
  use_qwix_on_abstract_model=should_apply_qwix_on_abstract_model)
@@ -236,9 +217,7 @@ def get_flax_model(
236
217
  hidden_states_sharding, # aux hidden states
237
218
  ),
238
219
  donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
239
- static_argnums=(
240
- 7, 10, 11
241
- ), #7 is layer_name_to_kvcache_index, 10 is is_first_rank, 11 is is_last_rank
220
+ static_argnums=6, #6 is layer_name_to_kvcache_index
242
221
  )
243
222
  def run_model(graphdef, state, *args):
244
223
  model = nnx.merge(graphdef, state)
@@ -347,8 +326,8 @@ def get_model(
347
326
  # Convert the error message to a string to check its contents
348
327
  error_msg = str(e)
349
328
 
350
- logger.warning(error_msg)
351
-
329
+ logger.warning(f"Flax model failed with: '{error_msg}'. "
330
+ "Falling back to vLLM implementation.")
352
331
  # Fall back to the vLLM model and updating the dtype accordingly
353
332
  vllm_config.model_config.dtype = j2t_dtype(
354
333
  vllm_config.model_config.dtype.dtype)
@@ -442,17 +421,6 @@ def register_model(arch: str, model: Any) -> None:
442
421
  "This is a JAX model and does not implement the PyTorch forward method."
443
422
  )
444
423
 
445
- # Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
446
- def unimplemented_get_input_embeddings(
447
- self,
448
- input_ids: "torch.Tensor",
449
- positions: "torch.Tensor",
450
- inputs_embeds: Optional["torch.Tensor"] = None,
451
- ) -> "torch.Tensor":
452
- raise NotImplementedError(
453
- "This is a JAX model and does not implement the PyTorch get_input_embeddings method."
454
- )
455
-
456
424
  # We need a custom __init__ that only calls torch.nn.Module's init,
457
425
  # to avoid triggering JAX logic when vLLM inspects the class.
458
426
  def wrapper_init(self, *args, **kwargs):
@@ -466,7 +434,6 @@ def register_model(arch: str, model: Any) -> None:
466
434
  {
467
435
  "__init__": wrapper_init,
468
436
  "forward": unimplemented_forward,
469
- "get_input_embeddings": unimplemented_get_input_embeddings,
470
437
  # Prevent vLLM from trying to load weights into this dummy class.
471
438
  "load_weights": lambda self, *args, **kwargs: None,
472
439
  })
@@ -368,8 +368,7 @@ class LlamaForCausalLM(nnx.Module):
368
368
  "lm_head": "model.lm_head",
369
369
  })
370
370
 
371
- metadata_map = get_default_maps(self.vllm_config.model_config,
372
- self.mesh, mappings)
371
+ metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
373
372
  load_hf_weights(vllm_config=self.vllm_config,
374
373
  model=self,
375
374
  metadata_map=metadata_map,
@@ -194,12 +194,13 @@ class Eagle3LlamaModel(nnx.Module):
194
194
 
195
195
  def update_reshape_map_for_eagle3(vllm_config: VllmConfig,
196
196
  metadata_map: MetadataMap):
197
- model_config = vllm_config.speculative_config.draft_model_config
197
+ model_config = vllm_config.model_config
198
198
  hf_config = model_config.hf_config
199
199
 
200
200
  num_heads = hf_config.num_attention_heads
201
201
  num_kv_heads = hf_config.num_key_value_heads
202
- hidden_size = hf_config.hidden_size
202
+ hidden_size = model_config.get_hidden_size()
203
+
203
204
  head_dim_original = model_config.get_head_size()
204
205
 
205
206
  metadata_map.reshape_map.update({
@@ -304,8 +305,6 @@ class EagleLlama3ForCausalLM(nnx.Module):
304
305
  "fc": "model.fc.kernel",
305
306
  "lm_head": "lm_head.kernel",
306
307
  "d2t": "draft_id_to_target_id",
307
- "embed_tokens":
308
- "model.embed_tokens.embedding", # Some checkpoints need this
309
308
  }
310
309
 
311
310
  # Define keys to keep in original dtype (e.g., float32 for stability)
@@ -313,9 +312,7 @@ class EagleLlama3ForCausalLM(nnx.Module):
313
312
  r".*d2t.*",
314
313
  ]
315
314
 
316
- metadata_map = get_default_maps(
317
- self.vllm_config.speculative_config.draft_model_config, self.mesh,
318
- mappings)
315
+ metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
319
316
 
320
317
  update_reshape_map_for_eagle3(self.vllm_config, metadata_map)
321
318
 
@@ -327,7 +324,7 @@ class EagleLlama3ForCausalLM(nnx.Module):
327
324
  is_draft_model=True,
328
325
  keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
329
326
 
330
- # If the embedding is not initialized, initialize it with a dummy array here to pass jit compilation. The real weights will be shared from the target model in eagle3 class.
327
+ # If the embedding is not initialized, initialize it with a dummpy array here to pass jit compilation. The real weights will be shared from the target model in eagle3 class.
331
328
  if isinstance(self.model.embed_tokens.embedding.value,
332
329
  jax.ShapeDtypeStruct):
333
330
  self.model.embed_tokens.embedding.value = jnp.zeros(