tpu-inference 0.0.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (174) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +374 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +648 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +88 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +203 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +235 -0
  27. tpu_inference/__init__.py +53 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +49 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +727 -0
  37. tpu_inference/distributed/utils.py +60 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +160 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +382 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1566 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1501 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1603 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +396 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +469 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +110 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +331 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +368 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +310 -0
  120. tpu_inference/models/__init__.py +0 -0
  121. tpu_inference/models/common/__init__.py +0 -0
  122. tpu_inference/models/common/model_loader.py +478 -0
  123. tpu_inference/models/jax/__init__.py +0 -0
  124. tpu_inference/models/jax/deepseek_v3.py +868 -0
  125. tpu_inference/models/jax/gpt_oss.py +492 -0
  126. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  127. tpu_inference/models/jax/llama3.py +376 -0
  128. tpu_inference/models/jax/llama4.py +629 -0
  129. tpu_inference/models/jax/llama_eagle3.py +336 -0
  130. tpu_inference/models/jax/llama_guard_4.py +361 -0
  131. tpu_inference/models/jax/qwen2.py +376 -0
  132. tpu_inference/models/jax/qwen2_5_vl.py +1218 -0
  133. tpu_inference/models/jax/qwen3.py +303 -0
  134. tpu_inference/models/jax/utils/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/file_utils.py +96 -0
  136. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  137. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  138. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  139. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  140. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  141. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  142. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  143. tpu_inference/models/jax/utils/quantization/quantization_utils.py +650 -0
  144. tpu_inference/models/jax/utils/weight_utils.py +584 -0
  145. tpu_inference/models/vllm/__init__.py +0 -0
  146. tpu_inference/models/vllm/vllm_model_wrapper.py +293 -0
  147. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  148. tpu_inference/platforms/__init__.py +2 -0
  149. tpu_inference/platforms/tpu_platform.py +275 -0
  150. tpu_inference/runner/__init__.py +0 -0
  151. tpu_inference/runner/block_table.py +122 -0
  152. tpu_inference/runner/compilation_manager.py +865 -0
  153. tpu_inference/runner/input_batch.py +435 -0
  154. tpu_inference/runner/kv_cache.py +132 -0
  155. tpu_inference/runner/kv_cache_manager.py +478 -0
  156. tpu_inference/runner/lora_utils.py +92 -0
  157. tpu_inference/runner/multimodal_manager.py +217 -0
  158. tpu_inference/runner/persistent_batch_manager.py +282 -0
  159. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  160. tpu_inference/runner/structured_decoding_manager.py +87 -0
  161. tpu_inference/runner/tpu_runner.py +1744 -0
  162. tpu_inference/runner/utils.py +426 -0
  163. tpu_inference/spec_decode/__init__.py +0 -0
  164. tpu_inference/spec_decode/jax/__init__.py +0 -0
  165. tpu_inference/spec_decode/jax/eagle3.py +417 -0
  166. tpu_inference/tpu_info.py +78 -0
  167. tpu_inference/utils.py +340 -0
  168. tpu_inference/worker/__init__.py +0 -0
  169. tpu_inference/worker/tpu_worker.py +458 -0
  170. tpu_inference-0.0.1rc1.dist-info/METADATA +108 -0
  171. tpu_inference-0.0.1rc1.dist-info/RECORD +174 -0
  172. tpu_inference-0.0.1rc1.dist-info/WHEEL +5 -0
  173. tpu_inference-0.0.1rc1.dist-info/licenses/LICENSE +201 -0
  174. tpu_inference-0.0.1rc1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,60 @@
1
+ import os
2
+
3
+ from vllm.utils.network_utils import get_ip
4
+
5
+ from tpu_inference import envs
6
+ from tpu_inference.logger import init_logger
7
+
8
+ logger = init_logger(__name__)
9
+
10
+ # For multi-host usage only, to collect IP and port for all nodes.
11
+ _NODES_KV_IP_PORT = dict()
12
+
13
+
14
+ def set_node_kv_ip_port(ip_port: tuple[int, str, int]):
15
+ global _NODES_KV_IP_PORT
16
+ node_id, ip, port = ip_port
17
+ _NODES_KV_IP_PORT[node_id] = (ip, port)
18
+
19
+
20
+ def get_kv_ips() -> str:
21
+ if envs.TPU_MULTIHOST_BACKEND == "ray":
22
+ num_nodes = len(_NODES_KV_IP_PORT)
23
+ ips = []
24
+ for node_id in range(num_nodes):
25
+ ips.append(_NODES_KV_IP_PORT[node_id][0])
26
+ return ips
27
+ else:
28
+ return get_host_ip()
29
+
30
+
31
+ def get_kv_ports() -> str:
32
+ if envs.TPU_MULTIHOST_BACKEND == "ray":
33
+ num_nodes = len(_NODES_KV_IP_PORT)
34
+ ports = []
35
+ for node_id in range(num_nodes):
36
+ ports.append(_NODES_KV_IP_PORT[node_id][1])
37
+ return ports
38
+ else:
39
+ return get_kv_transfer_port()
40
+
41
+
42
+ def get_host_ip() -> str:
43
+ """Use `VLLM_HOST_IP` if set, otherwise use default network interface IP."""
44
+ return get_ip()
45
+
46
+
47
+ def get_kv_transfer_port() -> str:
48
+ port = os.getenv("TPU_KV_TRANSFER_PORT", "9100")
49
+ return port
50
+
51
+
52
+ def get_side_channel_port() -> str:
53
+ port = os.getenv("TPU_SIDE_CHANNEL_PORT", "9600")
54
+ return port
55
+
56
+
57
+ def get_node_id() -> int:
58
+ # TODO(xiang): Is it possible to get this from a pre-defiend env?
59
+ id = os.getenv("TPU_NODE_ID", 0)
60
+ return int(id)
@@ -0,0 +1,9 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the tpu-inference project
3
+
4
+ import os
5
+
6
+ # Disable CUDA-specific shared experts stream for TPU
7
+ # This prevents errors when trying to create CUDA streams on TPU hardware
8
+ # The issue was introduced by vllm-project/vllm#26440
9
+ os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1"
tpu_inference/envs.py ADDED
@@ -0,0 +1,160 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the tpu-inference project
3
+
4
+ import functools
5
+ import os
6
+ from collections.abc import Callable
7
+ from typing import TYPE_CHECKING, Any
8
+
9
+ if TYPE_CHECKING:
10
+ JAX_PLATFORMS: str = ""
11
+ TPU_ACCELERATOR_TYPE: str | None = None
12
+ TPU_NAME: str | None = None
13
+ TPU_WORKER_ID: str | None = None
14
+ TPU_MULTIHOST_BACKEND: str = ""
15
+ PREFILL_SLICES: str = ""
16
+ DECODE_SLICES: str = ""
17
+ SKIP_JAX_PRECOMPILE: bool = False
18
+ VLLM_XLA_CHECK_RECOMPILATION: bool = False
19
+ MODEL_IMPL_TYPE: str = "flax_nnx"
20
+ NEW_MODEL_DESIGN: bool = False
21
+ PHASED_PROFILING_DIR: str = ""
22
+ PYTHON_TRACER_LEVEL: int = 1
23
+ USE_MOE_EP_KERNEL: bool = False
24
+ NUM_SLICES: int = 1
25
+ RAY_USAGE_STATS_ENABLED: str = "0"
26
+ VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
27
+
28
+
29
+ def env_with_choices(
30
+ env_name: str,
31
+ default: str | None,
32
+ choices: list[str] | Callable[[], list[str]],
33
+ case_sensitive: bool = True,
34
+ ) -> Callable[[], str | None]:
35
+ """
36
+ Create a lambda that validates environment variable against allowed choices
37
+
38
+ Args:
39
+ env_name: Name of the environment variable
40
+ default: Default value if not set (can be None)
41
+ choices: List of valid string options or callable that returns list
42
+ case_sensitive: Whether validation should be case sensitive
43
+
44
+ Returns:
45
+ Lambda function for environment_variables dict
46
+ """
47
+
48
+ def _get_validated_env() -> str | None:
49
+ value = os.getenv(env_name)
50
+ if value is None:
51
+ return default
52
+
53
+ # Resolve choices if it's a callable (for lazy loading)
54
+ actual_choices = choices() if callable(choices) else choices
55
+
56
+ if not case_sensitive:
57
+ check_value = value.lower()
58
+ check_choices = [choice.lower() for choice in actual_choices]
59
+ else:
60
+ check_value = value
61
+ check_choices = actual_choices
62
+
63
+ if check_value not in check_choices:
64
+ raise ValueError(f"Invalid value '{value}' for {env_name}. "
65
+ f"Valid options: {actual_choices}.")
66
+
67
+ return value
68
+
69
+ return _get_validated_env
70
+
71
+
72
+ environment_variables: dict[str, Callable[[], Any]] = {
73
+ # JAX platform selection (e.g., "tpu", "cpu", "proxy")
74
+ "JAX_PLATFORMS":
75
+ lambda: os.getenv("JAX_PLATFORMS", "").lower(),
76
+ # TPU accelerator type (e.g., "v5litepod-16", "v4-8")
77
+ "TPU_ACCELERATOR_TYPE":
78
+ lambda: os.getenv("TPU_ACCELERATOR_TYPE", None),
79
+ # Name of the TPU resource
80
+ "TPU_NAME":
81
+ lambda: os.getenv("TPU_NAME", None),
82
+ # Worker ID for multi-host TPU setups
83
+ "TPU_WORKER_ID":
84
+ lambda: os.getenv("TPU_WORKER_ID", None),
85
+ # Backend for multi-host communication on TPU
86
+ "TPU_MULTIHOST_BACKEND":
87
+ env_with_choices("TPU_MULTIHOST_BACKEND", "", ["ray"]),
88
+ # Slice configuration for disaggregated prefill workers
89
+ "PREFILL_SLICES":
90
+ lambda: os.getenv("PREFILL_SLICES", ""),
91
+ # Slice configuration for disaggregated decode workers
92
+ "DECODE_SLICES":
93
+ lambda: os.getenv("DECODE_SLICES", ""),
94
+ # Skip JAX precompilation step during initialization
95
+ "SKIP_JAX_PRECOMPILE":
96
+ lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE") or "0")),
97
+ # Check for XLA recompilation during execution
98
+ "VLLM_XLA_CHECK_RECOMPILATION":
99
+ lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION") or "0")),
100
+ # Model implementation type (e.g., "flax_nnx")
101
+ "MODEL_IMPL_TYPE":
102
+ env_with_choices("MODEL_IMPL_TYPE", "flax_nnx",
103
+ ["vllm", "flax_nnx", "jetpack"]),
104
+ # Enable new experimental model design
105
+ "NEW_MODEL_DESIGN":
106
+ lambda: bool(int(os.getenv("NEW_MODEL_DESIGN") or "0")),
107
+ # Directory to store phased profiling output
108
+ "PHASED_PROFILING_DIR":
109
+ lambda: os.getenv("PHASED_PROFILING_DIR", ""),
110
+ # Python tracer level for profiling
111
+ "PYTHON_TRACER_LEVEL":
112
+ lambda: int(os.getenv("PYTHON_TRACER_LEVEL") or "1"),
113
+ # Use custom expert-parallel kernel for MoE (Mixture of Experts)
114
+ "USE_MOE_EP_KERNEL":
115
+ lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL") or "0")),
116
+ # Number of TPU slices for multi-slice mesh
117
+ "NUM_SLICES":
118
+ lambda: int(os.getenv("NUM_SLICES") or "1"),
119
+ # Enable/disable Ray usage statistics collection
120
+ "RAY_USAGE_STATS_ENABLED":
121
+ lambda: os.getenv("RAY_USAGE_STATS_ENABLED", "0"),
122
+ # Ray compiled DAG channel type for TPU
123
+ "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
124
+ env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]),
125
+ }
126
+
127
+
128
+ def __getattr__(name: str) -> Any:
129
+ """
130
+ Gets environment variables lazily.
131
+
132
+ NOTE: After enable_envs_cache() invocation (which triggered after service
133
+ initialization), all environment variables will be cached.
134
+ """
135
+ if name in environment_variables:
136
+ return environment_variables[name]()
137
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
138
+
139
+
140
+ def enable_envs_cache() -> None:
141
+ """
142
+ Enables caching of environment variables by wrapping the module's __getattr__
143
+ function with functools.cache(). This improves performance by avoiding
144
+ repeated re-evaluation of environment variables.
145
+
146
+ NOTE: This should be called after service initialization. Once enabled,
147
+ environment variable values are cached and will not reflect changes to
148
+ os.environ until the process is restarted.
149
+ """
150
+ # Tag __getattr__ with functools.cache
151
+ global __getattr__
152
+ __getattr__ = functools.cache(__getattr__)
153
+
154
+ # Cache all environment variables
155
+ for key in environment_variables:
156
+ __getattr__(key)
157
+
158
+
159
+ def __dir__() -> list[str]:
160
+ return list(environment_variables.keys())
File without changes
@@ -0,0 +1,382 @@
1
+ import os
2
+ from array import array
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import ray
6
+ import vllm.envs as envs
7
+ from ray.util.placement_group import PlacementGroup
8
+ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
9
+ from vllm.multimodal.inputs import MultiModalKwargs
10
+ from vllm.platforms import current_platform
11
+ from vllm.ray.ray_env import get_env_vars_to_copy
12
+ from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
13
+ from vllm.utils.network_utils import (get_distributed_init_method, get_ip,
14
+ get_open_port)
15
+ from vllm.v1.core.sched.output import SchedulerOutput
16
+ from vllm.v1.executor.ray_distributed_executor import \
17
+ RayDistributedExecutor as RayDistributedExecutorV1
18
+ from vllm.v1.executor.ray_executor import RayWorkerMetaData
19
+ from vllm.v1.executor.ray_utils import RayWorkerWrapper, _wait_until_pg_ready
20
+
21
+ from tpu_inference.logger import init_logger
22
+
23
+ try:
24
+ from ray._private.state import available_resources_per_node
25
+ except ImportError:
26
+ # Ray 2.9.x doesn't expose `available_resources_per_node`
27
+ from ray._private.state import state as _state
28
+ available_resources_per_node = _state._available_resources_per_node
29
+
30
+ import asyncio
31
+ from collections import defaultdict
32
+
33
+ import msgspec
34
+ from vllm.v1.outputs import SamplerOutput
35
+
36
+ from tpu_inference.distributed.utils import set_node_kv_ip_port
37
+
38
+ logger = init_logger(__name__)
39
+
40
+
41
+ def _encode_hook(obj: Any) -> Any:
42
+ """Custom msgspec enc hook that supports array types and MultiModalKwargs.
43
+
44
+ See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
45
+ """
46
+ if isinstance(obj, array):
47
+ assert obj.typecode == VLLM_TOKEN_ID_ARRAY_TYPE, (
48
+ f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. "
49
+ f"Given array has a type code of {obj.typecode}.")
50
+ return obj.tobytes()
51
+ if isinstance(obj, MultiModalKwargs):
52
+ return dict(obj)
53
+
54
+
55
+ class RayDistributedExecutor(RayDistributedExecutorV1):
56
+ """Ray-based distributed executor for TPU.
57
+
58
+ The implementation is similar to vllm/executor/ray_distributed_executor.py
59
+ with these major differences:
60
+
61
+ 1. self._init_executor():
62
+ VLLM_USE_RAY_SPMD_WORKER=1, in which the driver worker is the same as other workers.
63
+ 2. self._initialize_ray_cluster():
64
+ This sets placement_group_specs for TPU.
65
+ In vLLM one GPU maps to one placement group.
66
+ While here one TPU node with all chips maps to one placement group.
67
+ 3. self._init_workers_ray():
68
+ This set TPU resources when create each worker.
69
+ And we omit the driver worker related logic.
70
+ """
71
+
72
+ def _init_executor(self) -> None:
73
+ self.forward_dag: Optional[ray.dag.CompiledDAG] = None
74
+
75
+ os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
76
+
77
+ # Currently, this requires USE_RAY_SPMD_WORKER=True.
78
+ self.use_ray_compiled_dag = True
79
+ # If it is true, then we do not distinguish between the
80
+ # "driver worker" vs other workers. Also, the rank 0 worker will
81
+ # be executed in a remote Ray worker. Currently this requires
82
+ # USE_RAY_COMPILED_DAG=True.
83
+ self.use_ray_spmd_worker = True
84
+
85
+ assert self.uses_ray
86
+ self._initialize_ray_cluster()
87
+ placement_group = self.parallel_config.placement_group
88
+
89
+ # Disable Ray usage stats collection.
90
+ ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
91
+ if ray_usage != "1":
92
+ os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
93
+
94
+ # Create the parallel GPU workers.
95
+ self._init_workers_ray(placement_group)
96
+
97
+ self.input_encoder = msgspec.msgpack.Encoder(enc_hook=_encode_hook)
98
+ self.output_decoder = msgspec.msgpack.Decoder(
99
+ Optional[List[SamplerOutput]])
100
+
101
+ self.pp_locks: Optional[List[asyncio.Lock]] = None
102
+
103
+ self.scheduler_output: SchedulerOutput | None = None
104
+
105
+ # KV connector setup
106
+ self.has_connector = self.vllm_config.kv_transfer_config is not None
107
+ if self.has_connector:
108
+ ip_port = self.collective_rpc("get_node_kv_ip_port")
109
+ for item in ip_port:
110
+ set_node_kv_ip_port(item)
111
+ self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
112
+ self.vllm_config.ec_transfer_config is None
113
+ or not self.vllm_config.ec_transfer_config.is_ec_producer)
114
+
115
+ def _initialize_ray_cluster(self) -> None:
116
+ """Initialize the distributed cluster with Ray.
117
+
118
+ it will connect to the Ray cluster and create a placement group
119
+ for the workers, which includes the specification of the resources
120
+ for each distributed worker.
121
+ """
122
+ from vllm.platforms import current_platform
123
+
124
+ if ray.is_initialized():
125
+ logger.info(
126
+ "Ray is already initialized. Skipping Ray initialization.")
127
+ else:
128
+ logger.warning("Ray is not initialized, this is mainly for test.")
129
+ ray.init()
130
+
131
+ device_str = current_platform.ray_device_key
132
+ if not device_str:
133
+ raise ValueError(
134
+ f"current platform {current_platform.device_name} does not "
135
+ "support ray.")
136
+
137
+ pp_size = self.parallel_config.pipeline_parallel_size
138
+ placement_group_specs: List[Dict[str, float]] = []
139
+
140
+ ray_nodes = ray.nodes()
141
+ logger.info(f"RayDistributedExecutor | ray_nodes={ray_nodes}")
142
+
143
+ if pp_size == 1:
144
+ placement_group_specs = [{
145
+ device_str: node['Resources'][device_str]
146
+ } for node in ray_nodes]
147
+ else:
148
+ num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
149
+ placement_group_specs = [{
150
+ device_str: num_devices_per_pp_rank
151
+ } for _ in range(pp_size)]
152
+
153
+ # vLLM engine is also a worker to execute model with an accelerator,
154
+ # so it requires to have the device in a current node. Check if
155
+ # the current node has at least one device.
156
+ current_ip = get_ip()
157
+ current_node_id = ray.get_runtime_context().get_node_id()
158
+ current_node_resource = available_resources_per_node()[current_node_id]
159
+ if current_node_resource.get(device_str, 0) < 1:
160
+ raise ValueError(
161
+ f"Current node has no {device_str} available. "
162
+ f"{current_node_resource=}. vLLM engine cannot start without "
163
+ f"{device_str}. Make sure you have at least 1 {device_str} "
164
+ f"available in a node {current_node_id=} {current_ip=}.")
165
+ # This way, at least bundle is required to be created in a current
166
+ # node.
167
+ placement_group_specs[0][f"node:{current_ip}"] = 0.001
168
+ logger.info(
169
+ f"RayDistributedExecutor | placement_group_specs={placement_group_specs}"
170
+ )
171
+
172
+ # By default, Ray packs resources as much as possible.
173
+ current_placement_group = ray.util.placement_group(
174
+ placement_group_specs, strategy="PACK")
175
+ _wait_until_pg_ready(current_placement_group)
176
+
177
+ assert current_placement_group is not None
178
+ # Set the placement group in the parallel config
179
+ self.parallel_config.placement_group = current_placement_group
180
+
181
+ def _init_workers_ray(self, placement_group: "PlacementGroup",
182
+ **ray_remote_kwargs):
183
+ # The workers are the actual ray actors.
184
+ self.workers: List[RayWorkerWrapper] = []
185
+
186
+ # Used in ray compiled DAG: indexed first by PP rank,
187
+ # and then TP rank. In other words, the inner list is
188
+ # the TP group of workers for a PP rank.
189
+ self.pp_tp_workers: List[List[RayWorkerWrapper]] = []
190
+
191
+ if self.parallel_config.ray_workers_use_nsight:
192
+ ray_remote_kwargs = self._configure_ray_workers_use_nsight(
193
+ ray_remote_kwargs)
194
+
195
+ # Create the workers.
196
+ bundle_indices: List[int]
197
+ if envs.VLLM_RAY_BUNDLE_INDICES:
198
+ # Use the bundle indices specified by the user.
199
+ bundle_indices = list(
200
+ map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
201
+ assert len(bundle_indices) == self.parallel_config.world_size, \
202
+ ("VLLM_RAY_BUNDLE_INDICES must have the same size"
203
+ f" as the world size, but got {bundle_indices=} "
204
+ f"and {self.parallel_config.world_size=}")
205
+ assert len(set(bundle_indices)) == len(bundle_indices), \
206
+ ("VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
207
+ f" but got {bundle_indices=}")
208
+ else:
209
+ bundle_indices = []
210
+ for bundle_id, bundle in enumerate(placement_group.bundle_specs):
211
+ if bundle.get(current_platform.ray_device_key, 0):
212
+ bundle_indices.append(bundle_id)
213
+
214
+ worker_metadata: List[RayWorkerMetaData] = []
215
+ driver_ip = get_ip()
216
+ num_tpu_per_worker = placement_group.bundle_specs[0].get(
217
+ current_platform.ray_device_key, 0)
218
+ for rank, bundle_id in enumerate(bundle_indices):
219
+ scheduling_strategy = PlacementGroupSchedulingStrategy(
220
+ placement_group=placement_group,
221
+ placement_group_capture_child_tasks=True,
222
+ placement_group_bundle_index=bundle_id,
223
+ )
224
+ worker = ray.remote(
225
+ num_cpus=0,
226
+ num_gpus=0,
227
+ resources={
228
+ current_platform.ray_device_key: num_tpu_per_worker
229
+ },
230
+ scheduling_strategy=scheduling_strategy,
231
+ **ray_remote_kwargs,
232
+ )(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
233
+ rpc_rank=rank)
234
+ worker_metadata.append(
235
+ RayWorkerMetaData(worker=worker, created_rank=rank))
236
+
237
+ worker_ips = ray.get([
238
+ each.worker.get_node_ip.remote() # type: ignore[attr-defined]
239
+ for each in worker_metadata
240
+ ])
241
+
242
+ for each, ip in zip(worker_metadata, worker_ips):
243
+ each.ip = ip
244
+
245
+ logger.debug(f"Initialized worker_metadata: {worker_metadata}")
246
+
247
+ ip_counts: Dict[str, int] = {}
248
+ for ip in worker_ips:
249
+ ip_counts[ip] = ip_counts.get(ip, 0) + 1
250
+
251
+ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
252
+ """
253
+ Sort the workers based on 3 properties:
254
+ 1. If the worker is on the same node as the driver (vllm engine),
255
+ it should be placed first.
256
+ 2. Then, if the worker is on a node with fewer workers, it should
257
+ be placed first.
258
+ 3. Finally, if the work is on a node with smaller IP address, it
259
+ should be placed first.
260
+ """
261
+ ip = item.ip
262
+ return (0 if ip == driver_ip else 1, ip_counts[ip], ip)
263
+
264
+ # After sorting, the workers on the same node will be
265
+ # close to each other, and the workers on the driver
266
+ # node will be placed first.
267
+ sorted_worker_metadata = sorted(worker_metadata,
268
+ key=sort_by_driver_then_worker_ip)
269
+ start_rank = 0
270
+ for i, item in enumerate(sorted_worker_metadata):
271
+ item.adjusted_rank = i + start_rank
272
+ logger.info(
273
+ f"Initialized sorted worker_metadata: {sorted_worker_metadata}")
274
+
275
+ self.workers = [item.worker for item in sorted_worker_metadata]
276
+ rerank_mapping = {
277
+ item.created_rank: item.adjusted_rank
278
+ for item in sorted_worker_metadata
279
+ }
280
+ self.collective_rpc("adjust_rank", args=(rerank_mapping, ))
281
+
282
+ # Get the set of TPU IDs used on each node.
283
+ worker_node_and_tpu_ids = []
284
+ for worker in self.workers:
285
+ worker_node_and_tpu_ids.append(
286
+ ray.get(worker.get_node_and_gpu_ids.remote()) \
287
+ ) # type: ignore
288
+
289
+ node_workers = defaultdict(list) # node id -> list of worker ranks
290
+ node_tpus = defaultdict(list) # node id -> list of tpu ids
291
+
292
+ for i, (node_id, tpu_ids) in enumerate(worker_node_and_tpu_ids):
293
+ node_workers[node_id].append(i)
294
+ # `tpu_ids` can be a list of strings or integers.
295
+ # convert them to integers for consistency.
296
+ tpu_ids = [int(x) for x in tpu_ids]
297
+ node_tpus[node_id].extend(tpu_ids)
298
+ for node_id, tpu_ids in node_tpus.items():
299
+ node_tpus[node_id] = sorted(tpu_ids)
300
+ logger.info(
301
+ f"RayDistributedExecutor | node_workers={node_workers} | node_tpus={node_tpus}"
302
+ )
303
+
304
+ all_ips = set(worker_ips + [driver_ip])
305
+ n_ips = len(all_ips)
306
+ n_nodes = len(node_workers)
307
+
308
+ if n_nodes != n_ips:
309
+ logger.warning(
310
+ f"Got {n_nodes} nodes but with {n_ips} IP addresses. "
311
+ "This is not a typical production setup whose "
312
+ "number of nodes and IPs is euqal. This setup may "
313
+ "lead to unexpected behaviors.")
314
+
315
+ # Set environment variables for the driver and workers.
316
+ all_args_to_update_environment_variables = [{
317
+ current_platform.device_control_env_var:
318
+ ",".join(map(str, node_tpus[node_id])),
319
+ } for (node_id, _) in worker_node_and_tpu_ids]
320
+
321
+ # Environment variables to copy from driver to workers
322
+ env_vars_to_copy = get_env_vars_to_copy(
323
+ exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
324
+ additional_vars=set(current_platform.additional_env_vars),
325
+ destination="workers")
326
+
327
+ # Copy existing env vars to each worker's args
328
+ for args in all_args_to_update_environment_variables:
329
+ for name in env_vars_to_copy:
330
+ if name in os.environ:
331
+ args[name] = os.environ[name]
332
+
333
+ self._env_vars_for_all_workers = (
334
+ all_args_to_update_environment_variables)
335
+
336
+ self.collective_rpc("update_environment_variables",
337
+ args=(self._get_env_vars_to_be_updated(), ))
338
+
339
+ distributed_init_method = get_distributed_init_method(
340
+ driver_ip, get_open_port())
341
+
342
+ # Initialize the actual workers inside worker wrapper.
343
+ all_kwargs = []
344
+ for rank, (node_id, _) in enumerate(worker_node_and_tpu_ids):
345
+ local_rank = node_workers[node_id].index(rank)
346
+ ip = sorted_worker_metadata[rank].ip
347
+ prev_ip = sorted_worker_metadata[rank - 1].ip if rank > 0 else ""
348
+ kwargs = dict(
349
+ vllm_config=self.vllm_config,
350
+ local_rank=local_rank,
351
+ rank=rank,
352
+ distributed_init_method=distributed_init_method,
353
+ is_driver_worker=(not self.parallel_config)
354
+ or (rank % self.parallel_config.tensor_parallel_size == 0),
355
+ ip=ip,
356
+ prev_worker_ip=prev_ip,
357
+ )
358
+ all_kwargs.append(kwargs)
359
+ self.collective_rpc("init_worker", args=(all_kwargs, ))
360
+ self.collective_rpc("init_device")
361
+ if self.parallel_config.pipeline_parallel_size > 1:
362
+ self.collective_rpc("initialize_pp_transfer_connect")
363
+ self.collective_rpc("load_model")
364
+
365
+ if self.use_ray_spmd_worker:
366
+ for pp_rank in range(self.parallel_config.pipeline_parallel_size):
367
+ self.pp_tp_workers.append([])
368
+ num_tp_workers = int(
369
+ self.parallel_config.tensor_parallel_size //
370
+ num_tpu_per_worker)
371
+ for tp_rank in range(num_tp_workers):
372
+ # PP=2, TP=4, num_tpu_per_worker=2
373
+ # pp_tp_workers = [[0, 1], [2, 3]]
374
+ rank = (pp_rank * num_tp_workers) + tp_rank
375
+ assert len(self.pp_tp_workers[pp_rank]) == tp_rank
376
+ assert pp_rank < len(self.pp_tp_workers)
377
+ self.pp_tp_workers[pp_rank].append(self.workers[rank])
378
+
379
+ # Ray executor do not need handshake metadata
380
+ # as we pass the kv_parameters through proxy server
381
+ def get_kv_connector_handshake_metadata(self) -> None:
382
+ pass
File without changes