tpu-inference 0.11.1.dev202511130813__py3-none-any.whl → 0.11.1.dev202511220812__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (58) hide show
  1. tests/lora/test_layers.py +0 -6
  2. tests/lora/utils.py +0 -8
  3. tests/test_envs.py +182 -0
  4. tests/test_utils.py +23 -14
  5. tpu_inference/__init__.py +22 -3
  6. tpu_inference/core/core_tpu.py +17 -9
  7. tpu_inference/core/disagg_utils.py +6 -8
  8. tpu_inference/distributed/tpu_connector.py +2 -3
  9. tpu_inference/distributed/utils.py +3 -2
  10. tpu_inference/envs.py +1 -1
  11. tpu_inference/executors/ray_distributed_executor.py +27 -11
  12. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  13. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +110 -64
  14. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +7 -0
  15. tpu_inference/layers/{jax → common}/attention_interface.py +1 -1
  16. tpu_inference/layers/common/quant_methods.py +8 -0
  17. tpu_inference/layers/jax/attention/attention.py +1 -1
  18. tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
  19. tpu_inference/layers/jax/sample/sampling.py +2 -2
  20. tpu_inference/layers/vllm/attention.py +1 -1
  21. tpu_inference/layers/vllm/quantization/__init__.py +7 -3
  22. tpu_inference/layers/vllm/quantization/awq.py +4 -3
  23. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -2
  24. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  25. tpu_inference/layers/vllm/quantization/unquantized.py +4 -3
  26. tpu_inference/layers/vllm/sharding.py +2 -2
  27. tpu_inference/lora/torch_punica_tpu.py +1 -2
  28. tpu_inference/models/common/model_loader.py +12 -11
  29. tpu_inference/models/jax/llama3.py +4 -3
  30. tpu_inference/models/jax/llama_eagle3.py +9 -5
  31. tpu_inference/models/jax/llama_guard_4.py +361 -0
  32. tpu_inference/models/jax/qwen2.py +3 -2
  33. tpu_inference/models/jax/qwen2_5_vl.py +4 -3
  34. tpu_inference/models/jax/qwen3.py +3 -2
  35. tpu_inference/models/jax/utils/weight_utils.py +21 -8
  36. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -10
  37. tpu_inference/platforms/tpu_platform.py +17 -7
  38. tpu_inference/runner/compilation_manager.py +37 -17
  39. tpu_inference/runner/kv_cache.py +1 -1
  40. tpu_inference/runner/kv_cache_manager.py +8 -2
  41. tpu_inference/runner/tpu_runner.py +199 -87
  42. tpu_inference/spec_decode/jax/eagle3.py +2 -1
  43. tpu_inference/tpu_info.py +4 -3
  44. tpu_inference/utils.py +7 -6
  45. tpu_inference/worker/tpu_worker.py +159 -23
  46. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/METADATA +2 -2
  47. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/RECORD +52 -54
  48. tpu_inference/mock/__init__.py +0 -0
  49. tpu_inference/mock/vllm_config_utils.py +0 -28
  50. tpu_inference/mock/vllm_envs.py +0 -1219
  51. tpu_inference/mock/vllm_logger.py +0 -212
  52. tpu_inference/mock/vllm_logging_utils.py +0 -15
  53. tpu_inference/models/jax/phi3.py +0 -376
  54. /tpu_inference/layers/{jax → common}/binary_search.py +0 -0
  55. /tpu_inference/layers/{jax → common}/sharding.py +0 -0
  56. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/WHEEL +0 -0
  57. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/licenses/LICENSE +0 -0
  58. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/top_level.txt +0 -0
tests/lora/test_layers.py CHANGED
@@ -91,7 +91,6 @@ def populate_loras(
91
91
  index_to_id: list[Optional[int]],
92
92
  lora_layer: BaseLayerWithLoRA,
93
93
  baselayer_weights: torch.Tensor,
94
- generate_embeddings_tensor: int = 0,
95
94
  repeats: int = 1,
96
95
  ) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
97
96
  """This method populates the lora weights (lora_a and lora_b) in the lora layers (BaseLayerWithLoRA).
@@ -103,8 +102,6 @@ def populate_loras(
103
102
  lora_layer: the LoRAlayer to populate.
104
103
  baselayer_weights: the PyTorch tensor containing the layer's
105
104
  weights.
106
- generate_embeddings_tensor: whether to generate an
107
- embeddings tensor for each LoRA.
108
105
  repeats: must only be set for column parallel packed
109
106
  layers. Indicates the number of loras to compose
110
107
  together to create a single lora layer.
@@ -131,7 +128,6 @@ def populate_loras(
131
128
  baselayer_weights.device).init_random_lora(
132
129
  module_name=f"fake_{i}",
133
130
  weight=baselayer_weights,
134
- generate_embeddings_tensor=generate_embeddings_tensor,
135
131
  )
136
132
  sublora.lora_b = sublora.lora_b[(sublora_len *
137
133
  i):(sublora_len * (i + 1)), :]
@@ -147,7 +143,6 @@ def populate_loras(
147
143
  slot_idx,
148
144
  lora_a=lora.lora_a,
149
145
  lora_b=lora.lora_b,
150
- embeddings_tensor=lora.embeddings_tensor,
151
146
  )
152
147
 
153
148
  lora_dict[lora_id] = lora
@@ -546,7 +541,6 @@ def _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
546
541
  index_to_id,
547
542
  lora_config.max_loras,
548
543
  vocab_size=512,
549
- extra_vocab_size=lora_config.lora_extra_vocab_size,
550
544
  )
551
545
  assert jax_view(punica_wrapper._lora_indices_per_batch).platform(
552
546
  ) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
tests/lora/utils.py CHANGED
@@ -24,7 +24,6 @@ class DummyLoRAManager:
24
24
  module_name: str,
25
25
  weight: torch.Tensor,
26
26
  rank: int = 8,
27
- generate_embeddings_tensor: int = 0,
28
27
  ):
29
28
  lora = LoRALayerWeights(
30
29
  module_name,
@@ -37,13 +36,6 @@ class DummyLoRAManager:
37
36
  dtype=weight.dtype,
38
37
  device=self._device),
39
38
  )
40
- if generate_embeddings_tensor:
41
- lora.embeddings_tensor = torch.rand(
42
- 5,
43
- generate_embeddings_tensor,
44
- dtype=weight.dtype,
45
- device=self._device,
46
- )
47
39
  self.set_module_lora(module_name, lora)
48
40
 
49
41
  return lora
tests/test_envs.py ADDED
@@ -0,0 +1,182 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the tpu-inference project
3
+
4
+ import pytest
5
+
6
+ import tpu_inference.envs as envs
7
+ from tpu_inference.envs import enable_envs_cache, environment_variables
8
+
9
+
10
+ def test_getattr_without_cache(monkeypatch: pytest.MonkeyPatch):
11
+ assert envs.JAX_PLATFORMS == ""
12
+ assert envs.PHASED_PROFILING_DIR == ""
13
+ monkeypatch.setenv("JAX_PLATFORMS", "tpu")
14
+ monkeypatch.setenv("PHASED_PROFILING_DIR", "/tmp/profiling")
15
+ assert envs.JAX_PLATFORMS == "tpu"
16
+ assert envs.PHASED_PROFILING_DIR == "/tmp/profiling"
17
+
18
+ assert envs.TPU_NAME is None
19
+ assert envs.TPU_ACCELERATOR_TYPE is None
20
+ monkeypatch.setenv("TPU_NAME", "my-tpu")
21
+ monkeypatch.setenv("TPU_ACCELERATOR_TYPE", "v5litepod-16")
22
+ assert envs.TPU_NAME == "my-tpu"
23
+ assert envs.TPU_ACCELERATOR_TYPE == "v5litepod-16"
24
+
25
+ # __getattr__ is not decorated with functools.cache
26
+ assert not hasattr(envs.__getattr__, "cache_info")
27
+
28
+
29
+ def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch):
30
+ monkeypatch.setenv("JAX_PLATFORMS", "tpu")
31
+ monkeypatch.setenv("TPU_NAME", "my-tpu")
32
+
33
+ # __getattr__ is not decorated with functools.cache
34
+ assert not hasattr(envs.__getattr__, "cache_info")
35
+
36
+ enable_envs_cache()
37
+
38
+ # __getattr__ is decorated with functools.cache
39
+ assert hasattr(envs.__getattr__, "cache_info")
40
+ start_hits = envs.__getattr__.cache_info().hits
41
+
42
+ # 2 more hits due to JAX_PLATFORMS and TPU_NAME accesses
43
+ assert envs.JAX_PLATFORMS == "tpu"
44
+ assert envs.TPU_NAME == "my-tpu"
45
+ assert envs.__getattr__.cache_info().hits == start_hits + 2
46
+
47
+ # All environment variables are cached
48
+ for environment_variable in environment_variables:
49
+ envs.__getattr__(environment_variable)
50
+ assert envs.__getattr__.cache_info(
51
+ ).hits == start_hits + 2 + len(environment_variables)
52
+
53
+ # Reset envs.__getattr__ back to non-cached version to
54
+ # avoid affecting other tests
55
+ envs.__getattr__ = envs.__getattr__.__wrapped__
56
+
57
+
58
+ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
59
+ # Test SKIP_JAX_PRECOMPILE (default False)
60
+ assert envs.SKIP_JAX_PRECOMPILE is False
61
+ monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "1")
62
+ assert envs.SKIP_JAX_PRECOMPILE is True
63
+ monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
64
+ assert envs.SKIP_JAX_PRECOMPILE is False
65
+
66
+ # Test NEW_MODEL_DESIGN (default False)
67
+ assert envs.NEW_MODEL_DESIGN is False
68
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "1")
69
+ assert envs.NEW_MODEL_DESIGN is True
70
+
71
+ # Test USE_MOE_EP_KERNEL (default False)
72
+ assert envs.USE_MOE_EP_KERNEL is False
73
+ monkeypatch.setenv("USE_MOE_EP_KERNEL", "1")
74
+ assert envs.USE_MOE_EP_KERNEL is True
75
+
76
+
77
+ def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
78
+ assert envs.PYTHON_TRACER_LEVEL == 1
79
+ monkeypatch.setenv("PYTHON_TRACER_LEVEL", "3")
80
+ assert envs.PYTHON_TRACER_LEVEL == 3
81
+ monkeypatch.setenv("PYTHON_TRACER_LEVEL", "0")
82
+ assert envs.PYTHON_TRACER_LEVEL == 0
83
+
84
+
85
+ def test_lowercase_conversion(monkeypatch: pytest.MonkeyPatch):
86
+ monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "GRPC")
87
+ assert envs.TPU_MULTIHOST_BACKEND == "grpc"
88
+
89
+ monkeypatch.setenv("MODEL_IMPL_TYPE", "FLAX_NNX")
90
+ assert envs.MODEL_IMPL_TYPE == "flax_nnx"
91
+
92
+
93
+ def test_string_env_vars_defaults(monkeypatch: pytest.MonkeyPatch):
94
+ monkeypatch.delenv("JAX_PLATFORMS", raising=False)
95
+ monkeypatch.delenv("PREFILL_SLICES", raising=False)
96
+ monkeypatch.delenv("DECODE_SLICES", raising=False)
97
+
98
+ assert envs.JAX_PLATFORMS == ""
99
+ assert envs.PREFILL_SLICES == ""
100
+ assert envs.DECODE_SLICES == ""
101
+ assert envs.PHASED_PROFILING_DIR == ""
102
+
103
+
104
+ def test_none_default_env_vars(monkeypatch: pytest.MonkeyPatch):
105
+ monkeypatch.delenv("TPU_ACCELERATOR_TYPE", raising=False)
106
+ monkeypatch.delenv("TPU_NAME", raising=False)
107
+ monkeypatch.delenv("TPU_WORKER_ID", raising=False)
108
+
109
+ assert envs.TPU_ACCELERATOR_TYPE is None
110
+ assert envs.TPU_NAME is None
111
+ assert envs.TPU_WORKER_ID is None
112
+
113
+
114
+ def test_ray_env_vars(monkeypatch: pytest.MonkeyPatch):
115
+ assert envs.RAY_USAGE_STATS_ENABLED == "0"
116
+ monkeypatch.setenv("RAY_USAGE_STATS_ENABLED", "1")
117
+ assert envs.RAY_USAGE_STATS_ENABLED == "1"
118
+
119
+ assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "shm"
120
+ monkeypatch.setenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "nccl")
121
+ assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"
122
+
123
+
124
+ def test_invalid_attribute_raises_error():
125
+ with pytest.raises(AttributeError,
126
+ match="has no attribute 'NONEXISTENT_VAR'"):
127
+ _ = envs.NONEXISTENT_VAR
128
+
129
+
130
+ def test_dir_returns_all_env_vars():
131
+ env_vars = envs.__dir__()
132
+ assert isinstance(env_vars, list)
133
+ assert len(env_vars) == len(environment_variables)
134
+ assert "JAX_PLATFORMS" in env_vars
135
+ assert "TPU_NAME" in env_vars
136
+ assert "SKIP_JAX_PRECOMPILE" in env_vars
137
+ assert "MODEL_IMPL_TYPE" in env_vars
138
+
139
+
140
+ def test_tpu_multihost_env_vars(monkeypatch: pytest.MonkeyPatch):
141
+ monkeypatch.setenv("TPU_WORKER_ID", "0")
142
+ assert envs.TPU_WORKER_ID == "0"
143
+
144
+ monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "grpc")
145
+ assert envs.TPU_MULTIHOST_BACKEND == "grpc"
146
+
147
+ monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "xla")
148
+ assert envs.TPU_MULTIHOST_BACKEND == "xla"
149
+
150
+
151
+ def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
152
+ monkeypatch.setenv("PREFILL_SLICES", "0,1,2,3")
153
+ assert envs.PREFILL_SLICES == "0,1,2,3"
154
+
155
+ monkeypatch.setenv("DECODE_SLICES", "4,5,6,7")
156
+ assert envs.DECODE_SLICES == "4,5,6,7"
157
+
158
+
159
+ def test_model_impl_type_default(monkeypatch: pytest.MonkeyPatch):
160
+ monkeypatch.delenv("MODEL_IMPL_TYPE", raising=False)
161
+ assert envs.MODEL_IMPL_TYPE == "flax_nnx"
162
+
163
+
164
+ def test_cache_preserves_values_across_env_changes(
165
+ monkeypatch: pytest.MonkeyPatch):
166
+ monkeypatch.setenv("JAX_PLATFORMS", "tpu")
167
+
168
+ enable_envs_cache()
169
+
170
+ assert envs.JAX_PLATFORMS == "tpu"
171
+
172
+ # Change environment variable
173
+ monkeypatch.setenv("JAX_PLATFORMS", "cpu")
174
+
175
+ # Cached value should still be "tpu"
176
+ assert envs.JAX_PLATFORMS == "tpu"
177
+
178
+ # Reset envs.__getattr__ back to non-cached version
179
+ envs.__getattr__ = envs.__getattr__.__wrapped__
180
+
181
+ # Now it should reflect the new value
182
+ assert envs.JAX_PLATFORMS == "cpu"
tests/test_utils.py CHANGED
@@ -75,25 +75,34 @@ def test_hbm_usage_bytes_pathways_enabled(mock_devices, mock_live_arrays):
75
75
  mock_device2 = MagicMock()
76
76
  devices = [mock_device1, mock_device2]
77
77
 
78
- # Create mock device buffers
79
- mock_buffer1_dev1 = MagicMock()
80
- mock_buffer1_dev1.device = mock_device1
81
- mock_buffer1_dev1.nbytes = 2000 # 2000 bytes on device1
78
+ # Create mock addressable shards with data property
79
+ mock_data1_dev1 = MagicMock()
80
+ mock_data1_dev1.device = mock_device1
81
+ mock_data1_dev1.nbytes = 2000 # 2000 bytes on device1
82
82
 
83
- mock_buffer1_dev2 = MagicMock()
84
- mock_buffer1_dev2.device = mock_device2
85
- mock_buffer1_dev2.nbytes = 2000 # 2000 bytes on device2
83
+ mock_data1_dev2 = MagicMock()
84
+ mock_data1_dev2.device = mock_device2
85
+ mock_data1_dev2.nbytes = 2000 # 2000 bytes on device2
86
86
 
87
- mock_buffer2_dev1 = MagicMock()
88
- mock_buffer2_dev1.device = mock_device1
89
- mock_buffer2_dev1.nbytes = 1000 # 1000 bytes on device1
87
+ mock_data2_dev1 = MagicMock()
88
+ mock_data2_dev1.device = mock_device1
89
+ mock_data2_dev1.nbytes = 1000 # 1000 bytes on device1
90
90
 
91
- # Create mock arrays with device buffers
91
+ mock_shard1_dev1 = MagicMock()
92
+ mock_shard1_dev1.data = mock_data1_dev1
93
+
94
+ mock_shard1_dev2 = MagicMock()
95
+ mock_shard1_dev2.data = mock_data1_dev2
96
+
97
+ mock_shard2_dev1 = MagicMock()
98
+ mock_shard2_dev1.data = mock_data2_dev1
99
+
100
+ # Create mock arrays with addressable_shards
92
101
  mock_array1 = MagicMock()
93
- mock_array1.device_buffers = [mock_buffer1_dev1, mock_buffer1_dev2]
102
+ mock_array1.addressable_shards = [mock_shard1_dev1, mock_shard1_dev2]
94
103
 
95
104
  mock_array2 = MagicMock()
96
- mock_array2.device_buffers = [mock_buffer2_dev1]
105
+ mock_array2.addressable_shards = [mock_shard2_dev1]
97
106
 
98
107
  mock_live_arrays.return_value = [mock_array1, mock_array2]
99
108
 
@@ -159,7 +168,7 @@ def test_hbm_usage_bytes_pathways_no_arrays(mock_devices, mock_live_arrays):
159
168
  "head_dim, expected_padded_head_dim",
160
169
  [
161
170
  (1, 128),
162
- (64, 128),
171
+ (64, 64),
163
172
  (127, 128),
164
173
  (128, 128),
165
174
  (129, 256),
tpu_inference/__init__.py CHANGED
@@ -1,21 +1,40 @@
1
- import os
2
-
3
1
  # The environment variables override should be imported before any other
4
2
  # modules to ensure that the environment variables are set before any
5
3
  # other modules are imported.
6
4
  import tpu_inference.env_override # noqa: F401
5
+ from tpu_inference import envs
7
6
  from tpu_inference import tpu_info as ti
8
7
  from tpu_inference.logger import init_logger
9
8
 
10
9
  logger = init_logger(__name__)
11
10
 
12
- if "proxy" in os.environ.get('JAX_PLATFORMS', '').lower():
11
+ if "proxy" in envs.JAX_PLATFORMS:
13
12
  logger.info("Running vLLM on TPU via Pathways proxy.")
14
13
  # Must run pathwaysutils.initialize() before any JAX operations
15
14
  try:
15
+ import traceback
16
+
16
17
  import pathwaysutils
18
+ import vllm
19
+ from vllm.platforms import (resolve_current_platform_cls_qualname,
20
+ resolve_obj_by_qualname)
17
21
  pathwaysutils.initialize()
18
22
  logger.info("Module pathwaysutils is imported.")
23
+
24
+ # Pathways requires eager resolution of vllm.current_platform instead of
25
+ # lazy resolution in the normal code path. Since this part involves
26
+ # global topology discovery across multiple hosts, the platform
27
+ # resolution must happen before other components are loaded.
28
+ logger.info("Eagerly resolving vLLM current_platform for Pathways.")
29
+ platform_cls_qualname = resolve_current_platform_cls_qualname()
30
+ resolved_platform_instance = resolve_obj_by_qualname(
31
+ platform_cls_qualname)()
32
+ vllm.platforms._current_platform = resolved_platform_instance
33
+ vllm.platforms._init_trace = "".join(traceback.format_stack())
34
+ logger.info(
35
+ f"vLLM platform resolved to: {resolved_platform_instance.__class__.__name__}"
36
+ )
37
+
19
38
  except Exception as e:
20
39
  logger.error(
21
40
  f"Error occurred while importing pathwaysutils or logging TPU info: {e}"
@@ -29,6 +29,7 @@ from vllm.v1.request import Request, RequestStatus
29
29
 
30
30
  from tpu_inference import utils as common_utils
31
31
  from tpu_inference.core import disagg_executor, disagg_utils
32
+ from tpu_inference.runner.tpu_runner import AsyncTPUModelRunnerOutput
32
33
  # ======================================================================================
33
34
  # Imports for _DisaggOrchestrator (decoupled from vLLM)
34
35
  # ======================================================================================
@@ -186,6 +187,8 @@ class _DisaggOrchestrator:
186
187
  if model_output is None:
187
188
  model_output = prefill_engine.model_executor.sample_tokens(
188
189
  grammar_output)
190
+ if isinstance(model_output, AsyncTPUModelRunnerOutput):
191
+ model_output = model_output.get_output()
189
192
 
190
193
  if scheduler_output.total_num_scheduled_tokens > 0:
191
194
  logger.debug(f"Prefill result: {model_output}")
@@ -218,15 +221,16 @@ class _DisaggOrchestrator:
218
221
  f"request-{req_id}: tokens={request.all_token_ids} after prefill"
219
222
  )
220
223
  # Remove request from the prefill engine.
224
+ if req_id in prefill_engine.scheduler.requests:
225
+ request = prefill_engine.scheduler.requests[req_id]
226
+ prefill_engine.scheduler.running.remove(request)
227
+ prefill_engine.scheduler.encoder_cache_manager.free(
228
+ request)
221
229
 
222
- request = prefill_engine.scheduler.requests[req_id]
223
- prefill_engine.scheduler.running.remove(request)
224
- prefill_engine.scheduler.encoder_cache_manager.free(
225
- request)
230
+ prefill_engine.scheduler.kv_cache_manager.free(
231
+ request)
226
232
 
227
- prefill_engine.scheduler.kv_cache_manager.free(request)
228
-
229
- prefill_engine.scheduler.requests.pop(req_id)
233
+ prefill_engine.scheduler.requests.pop(req_id)
230
234
 
231
235
  for output in (engine_core_outputs.items()
232
236
  if engine_core_outputs else ()):
@@ -335,8 +339,10 @@ class _DisaggOrchestrator:
335
339
  new_block_ids = kv_cache_manager.get_block_ids(req_id)
336
340
  logger.debug(
337
341
  f"inserting {req_id} new_block_ids {new_block_ids}")
338
- assert (len(new_block_ids[0]) == math.ceil(
339
- prompt_tokens / self._config.cache_config.block_size))
342
+ if len(new_block_ids[0]) != math.ceil(
343
+ prompt_tokens / self._config.cache_config.block_size):
344
+ logger.warning("Running out of blocks in decode engine! ")
345
+ break
340
346
 
341
347
  decode_engine.model_executor.driver_worker.model_runner.insert_request_with_kv_cache(
342
348
  vllm_request, kv_cache, new_block_ids)
@@ -366,6 +372,8 @@ class _DisaggOrchestrator:
366
372
  if model_output is None:
367
373
  model_output = decode_engine.model_executor.sample_tokens(
368
374
  grammar_output)
375
+ if isinstance(model_output, AsyncTPUModelRunnerOutput):
376
+ model_output = model_output.get_output()
369
377
 
370
378
  if scheduler_output.total_num_scheduled_tokens > 0:
371
379
  logger.debug(f"Decode result: {model_output}")
@@ -1,17 +1,15 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
- import os
4
3
  from typing import Tuple
5
4
 
6
- PREFILL_SLICES = 'PREFILL_SLICES'
7
- DECODE_SLICES = 'DECODE_SLICES'
5
+ from tpu_inference import envs
8
6
 
9
7
 
10
8
  def is_disagg_enabled() -> bool:
11
9
  # We triggrer our code path as long as prefill slices are set. This
12
10
  # allows us to test interleave mode effectively with the code path
13
11
  # for comparison purposes.
14
- return PREFILL_SLICES in os.environ
12
+ return bool(envs.PREFILL_SLICES)
15
13
 
16
14
 
17
15
  def _parse_slices(slices_str: str) -> Tuple[int, ...]:
@@ -40,12 +38,12 @@ def _parse_slices(slices_str: str) -> Tuple[int, ...]:
40
38
 
41
39
 
42
40
  def get_prefill_slices() -> Tuple[int, ...]:
43
- if PREFILL_SLICES not in os.environ:
41
+ if not envs.PREFILL_SLICES:
44
42
  return ()
45
- return _parse_slices(os.environ[PREFILL_SLICES])
43
+ return _parse_slices(envs.PREFILL_SLICES)
46
44
 
47
45
 
48
46
  def get_decode_slices() -> Tuple[int, ...]:
49
- if DECODE_SLICES not in os.environ:
47
+ if not envs.DECODE_SLICES:
50
48
  return ()
51
- return _parse_slices(os.environ[DECODE_SLICES])
49
+ return _parse_slices(envs.DECODE_SLICES)
@@ -60,7 +60,6 @@ D workflow:
60
60
 
61
61
  import copy
62
62
  import functools
63
- import os
64
63
  import threading
65
64
  import time
66
65
  from concurrent.futures import Future, ThreadPoolExecutor
@@ -86,6 +85,7 @@ if TYPE_CHECKING:
86
85
  from vllm.v1.core.kv_cache_manager import KVCacheBlocks
87
86
  from vllm.v1.request import Request
88
87
 
88
+ from tpu_inference import envs
89
89
  from tpu_inference.distributed.utils import (get_host_ip, get_kv_ips,
90
90
  get_kv_ports,
91
91
  get_kv_transfer_port, get_node_id,
@@ -441,8 +441,7 @@ class TPUConnectorWorker:
441
441
 
442
442
  self.runner: TPUModelRunner = None
443
443
  self.mesh: Mesh = None
444
- self.multi_host = os.getenv("TPU_MULTIHOST_BACKEND",
445
- "").lower() == "ray"
444
+ self.multi_host = envs.TPU_MULTIHOST_BACKEND == "ray"
446
445
  # NOTE(xiang): This can not be the worker rank set in RayDistributedExecutor.
447
446
  # The worker rank is assigned with vLLM's sorting logic, which does not work
448
447
  # for TPU host topology.
@@ -2,6 +2,7 @@ import os
2
2
 
3
3
  from vllm.utils.network_utils import get_ip
4
4
 
5
+ from tpu_inference import envs
5
6
  from tpu_inference.logger import init_logger
6
7
 
7
8
  logger = init_logger(__name__)
@@ -17,7 +18,7 @@ def set_node_kv_ip_port(ip_port: tuple[int, str, int]):
17
18
 
18
19
 
19
20
  def get_kv_ips() -> str:
20
- if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
21
+ if envs.TPU_MULTIHOST_BACKEND == "ray":
21
22
  num_nodes = len(_NODES_KV_IP_PORT)
22
23
  ips = []
23
24
  for node_id in range(num_nodes):
@@ -28,7 +29,7 @@ def get_kv_ips() -> str:
28
29
 
29
30
 
30
31
  def get_kv_ports() -> str:
31
- if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
32
+ if envs.TPU_MULTIHOST_BACKEND == "ray":
32
33
  num_nodes = len(_NODES_KV_IP_PORT)
33
34
  ports = []
34
35
  for node_id in range(num_nodes):
tpu_inference/envs.py CHANGED
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
26
26
  environment_variables: dict[str, Callable[[], Any]] = {
27
27
  # JAX platform selection (e.g., "tpu", "cpu", "proxy")
28
28
  "JAX_PLATFORMS":
29
- lambda: os.getenv("JAX_PLATFORMS", ""),
29
+ lambda: os.getenv("JAX_PLATFORMS", "").lower(),
30
30
  # TPU accelerator type (e.g., "v5litepod-16", "v4-8")
31
31
  "TPU_ACCELERATOR_TYPE":
32
32
  lambda: os.getenv("TPU_ACCELERATOR_TYPE", None),
@@ -108,6 +108,9 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
108
108
  ip_port = self.collective_rpc("get_node_kv_ip_port")
109
109
  for item in ip_port:
110
110
  set_node_kv_ip_port(item)
111
+ self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
112
+ self.vllm_config.ec_transfer_config is None
113
+ or not self.vllm_config.ec_transfer_config.is_ec_producer)
111
114
 
112
115
  def _initialize_ray_cluster(self) -> None:
113
116
  """Initialize the distributed cluster with Ray.
@@ -131,10 +134,17 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
131
134
  f"current platform {current_platform.device_name} does not "
132
135
  "support ray.")
133
136
 
134
- placement_group_specs: List[Dict[str, float]] = [{
135
- device_str:
136
- node['Resources'][device_str]
137
- } for node in ray.nodes()]
137
+ pp_size = self.parallel_config.pipeline_parallel_size
138
+ placement_group_specs: List[Dict[str, float]] = []
139
+ if pp_size == 1:
140
+ placement_group_specs = [{
141
+ device_str: node['Resources'][device_str]
142
+ } for node in ray.nodes()]
143
+ else:
144
+ num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
145
+ placement_group_specs = [{
146
+ device_str: num_devices_per_pp_rank
147
+ } for _ in range(pp_size)]
138
148
 
139
149
  # vLLM engine is also a worker to execute model with an accelerator,
140
150
  # so it requires to have the device in a current node. Check if
@@ -329,6 +339,8 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
329
339
  all_kwargs = []
330
340
  for rank, (node_id, _) in enumerate(worker_node_and_tpu_ids):
331
341
  local_rank = node_workers[node_id].index(rank)
342
+ ip = sorted_worker_metadata[rank].ip
343
+ prev_ip = sorted_worker_metadata[rank - 1].ip if rank > 0 else ""
332
344
  kwargs = dict(
333
345
  vllm_config=self.vllm_config,
334
346
  local_rank=local_rank,
@@ -336,22 +348,26 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
336
348
  distributed_init_method=distributed_init_method,
337
349
  is_driver_worker=(not self.parallel_config)
338
350
  or (rank % self.parallel_config.tensor_parallel_size == 0),
351
+ ip=ip,
352
+ prev_worker_ip=prev_ip,
339
353
  )
340
354
  all_kwargs.append(kwargs)
341
355
  self.collective_rpc("init_worker", args=(all_kwargs, ))
342
356
  self.collective_rpc("init_device")
357
+ if self.parallel_config.pipeline_parallel_size > 1:
358
+ self.collective_rpc("initialize_pp_transfer_connect")
343
359
  self.collective_rpc("load_model")
344
360
 
345
361
  if self.use_ray_spmd_worker:
346
362
  for pp_rank in range(self.parallel_config.pipeline_parallel_size):
347
363
  self.pp_tp_workers.append([])
348
- for tp_rank in range(
349
- int(self.parallel_config.tensor_parallel_size //
350
- num_tpu_per_worker)):
351
- # PP=2, TP=4
352
- # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
353
- rank = (pp_rank * self.parallel_config.tensor_parallel_size
354
- ) + tp_rank
364
+ num_tp_workers = int(
365
+ self.parallel_config.tensor_parallel_size //
366
+ num_tpu_per_worker)
367
+ for tp_rank in range(num_tp_workers):
368
+ # PP=2, TP=4, num_tpu_per_worker=2
369
+ # pp_tp_workers = [[0, 1], [2, 3]]
370
+ rank = (pp_rank * num_tp_workers) + tp_rank
355
371
  assert len(self.pp_tp_workers[pp_rank]) == tp_rank
356
372
  assert pp_rank < len(self.pp_tp_workers)
357
373
  self.pp_tp_workers[pp_rank].append(self.workers[rank])