tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,59 @@
1
+ import os
2
+
3
+ from vllm.utils import get_ip
4
+
5
+ from tpu_inference.logger import init_logger
6
+
7
+ logger = init_logger(__name__)
8
+
9
+ # For multi-host usage only, to collect IP and port for all nodes.
10
+ _NODES_KV_IP_PORT = dict()
11
+
12
+
13
+ def set_node_kv_ip_port(ip_port: tuple[int, str, int]):
14
+ global _NODES_KV_IP_PORT
15
+ node_id, ip, port = ip_port
16
+ _NODES_KV_IP_PORT[node_id] = (ip, port)
17
+
18
+
19
+ def get_kv_ips() -> str:
20
+ if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
21
+ num_nodes = len(_NODES_KV_IP_PORT)
22
+ ips = []
23
+ for node_id in range(num_nodes):
24
+ ips.append(_NODES_KV_IP_PORT[node_id][0])
25
+ return ips
26
+ else:
27
+ return get_host_ip()
28
+
29
+
30
+ def get_kv_ports() -> str:
31
+ if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
32
+ num_nodes = len(_NODES_KV_IP_PORT)
33
+ ports = []
34
+ for node_id in range(num_nodes):
35
+ ports.append(_NODES_KV_IP_PORT[node_id][1])
36
+ return ports
37
+ else:
38
+ return get_kv_transfer_port()
39
+
40
+
41
+ def get_host_ip() -> str:
42
+ """Use `VLLM_HOST_IP` if set, otherwise use default network interface IP."""
43
+ return get_ip()
44
+
45
+
46
+ def get_kv_transfer_port() -> str:
47
+ port = os.getenv("TPU_KV_TRANSFER_PORT", "9100")
48
+ return port
49
+
50
+
51
+ def get_side_channel_port() -> str:
52
+ port = os.getenv("TPU_SIDE_CHANNEL_PORT", "9600")
53
+ return port
54
+
55
+
56
+ def get_node_id() -> int:
57
+ # TODO(xiang): Is it possible to get this from a pre-defiend env?
58
+ id = os.getenv("TPU_NODE_ID", 0)
59
+ return id
File without changes
@@ -0,0 +1,346 @@
1
+ import os
2
+ from typing import Dict, List, Optional
3
+
4
+ import ray
5
+ import vllm.envs as envs
6
+ from ray.util.placement_group import PlacementGroup
7
+ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
8
+ from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
9
+ from vllm.executor.ray_distributed_executor import RayWorkerMetaData
10
+ from vllm.executor.ray_utils import RayWorkerWrapper, _wait_until_pg_ready
11
+ from vllm.platforms import current_platform
12
+ from vllm.ray.ray_env import get_env_vars_to_copy
13
+ from vllm.utils import get_distributed_init_method, get_ip, get_open_port
14
+ from vllm.v1.executor.ray_distributed_executor import \
15
+ RayDistributedExecutor as RayDistributedExecutorV1
16
+
17
+ from tpu_inference.logger import init_logger
18
+
19
+ try:
20
+ from ray._private.state import available_resources_per_node
21
+ except ImportError:
22
+ # Ray 2.9.x doesn't expose `available_resources_per_node`
23
+ from ray._private.state import state as _state
24
+ available_resources_per_node = _state._available_resources_per_node
25
+
26
+ import asyncio
27
+ from collections import defaultdict
28
+
29
+ import msgspec
30
+ from vllm.executor.msgspec_utils import encode_hook
31
+ from vllm.v1.outputs import SamplerOutput
32
+
33
+ from tpu_inference.distributed.utils import set_node_kv_ip_port
34
+
35
+ logger = init_logger(__name__)
36
+
37
+
38
+ class RayDistributedExecutor(RayDistributedExecutorV1):
39
+ """Ray-based distributed executor for TPU.
40
+
41
+ The implementation is similar to vllm/executor/ray_distributed_executor.py
42
+ with these major differences:
43
+
44
+ 1. self._init_executor():
45
+ VLLM_USE_RAY_SPMD_WORKER=1, in which the driver worker is the same as other workers.
46
+ 2. self._initialize_ray_cluster():
47
+ This sets placement_group_specs for TPU.
48
+ In vLLM one GPU maps to one placement group.
49
+ While here one TPU node with all chips maps to one placement group.
50
+ 3. self._init_workers_ray():
51
+ This set TPU resources when create each worker.
52
+ And we omit the driver worker related logic.
53
+ """
54
+
55
+ def _init_executor(self) -> None:
56
+ self.forward_dag: Optional[ray.dag.CompiledDAG] = None
57
+ # V1 uses SPMD worker and compiled DAG
58
+ os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
59
+ os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
60
+ os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
61
+
62
+ # If the env var is set, it uses the Ray's compiled DAG API
63
+ # which optimizes the control plane overhead.
64
+ # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
65
+ # Currently, this requires USE_RAY_SPMD_WORKER=True.
66
+ self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
67
+ # If the env var is set, then we do not distinguish between the
68
+ # "driver worker" vs other workers. Also, the rank 0 worker will
69
+ # be executed in a remote Ray worker. Currently this requires
70
+ # USE_RAY_COMPILED_DAG=True.
71
+ self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
72
+
73
+ assert self.uses_ray
74
+ self._initialize_ray_cluster()
75
+ placement_group = self.parallel_config.placement_group
76
+
77
+ # Disable Ray usage stats collection.
78
+ ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
79
+ if ray_usage != "1":
80
+ os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
81
+
82
+ # Create the parallel GPU workers.
83
+ self._init_workers_ray(placement_group)
84
+
85
+ self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
86
+ self.output_decoder = msgspec.msgpack.Decoder(
87
+ Optional[List[SamplerOutput]])
88
+ self.use_v1 = envs.VLLM_USE_V1
89
+
90
+ self.pp_locks: Optional[List[asyncio.Lock]] = None
91
+
92
+ # KV connector setup
93
+ self.has_connector = self.vllm_config.kv_transfer_config is not None
94
+ self.kv_output_aggregator = KVOutputAggregator(
95
+ self.parallel_config.world_size)
96
+ if self.has_connector:
97
+ ip_port = self._run_workers("get_node_kv_ip_port")
98
+ for item in ip_port:
99
+ set_node_kv_ip_port(item)
100
+
101
+ def _initialize_ray_cluster(self) -> None:
102
+ """Initialize the distributed cluster with Ray.
103
+
104
+ it will connect to the Ray cluster and create a placement group
105
+ for the workers, which includes the specification of the resources
106
+ for each distributed worker.
107
+ """
108
+ from vllm.platforms import current_platform
109
+
110
+ if ray.is_initialized():
111
+ logger.info(
112
+ "Ray is already initialized. Skipping Ray initialization.")
113
+ else:
114
+ logger.warning("Ray is not initialized, this is mainly for test.")
115
+ ray.init()
116
+
117
+ device_str = current_platform.ray_device_key
118
+ if not device_str:
119
+ raise ValueError(
120
+ f"current platform {current_platform.device_name} does not "
121
+ "support ray.")
122
+
123
+ placement_group_specs: List[Dict[str, float]] = [{
124
+ device_str:
125
+ node['Resources'][device_str]
126
+ } for node in ray.nodes()]
127
+
128
+ # vLLM engine is also a worker to execute model with an accelerator,
129
+ # so it requires to have the device in a current node. Check if
130
+ # the current node has at least one device.
131
+ current_ip = get_ip()
132
+ current_node_id = ray.get_runtime_context().get_node_id()
133
+ current_node_resource = available_resources_per_node()[current_node_id]
134
+ if current_node_resource.get(device_str, 0) < 1:
135
+ raise ValueError(
136
+ f"Current node has no {device_str} available. "
137
+ f"{current_node_resource=}. vLLM engine cannot start without "
138
+ f"{device_str}. Make sure you have at least 1 {device_str} "
139
+ f"available in a node {current_node_id=} {current_ip=}.")
140
+ # This way, at least bundle is required to be created in a current
141
+ # node.
142
+ placement_group_specs[0][f"node:{current_ip}"] = 0.001
143
+ logger.info(
144
+ f"RayDistributedExecutor | placement_group_specs={placement_group_specs}"
145
+ )
146
+
147
+ # By default, Ray packs resources as much as possible.
148
+ current_placement_group = ray.util.placement_group(
149
+ placement_group_specs, strategy="PACK")
150
+ _wait_until_pg_ready(current_placement_group)
151
+
152
+ assert current_placement_group is not None
153
+ # Set the placement group in the parallel config
154
+ self.parallel_config.placement_group = current_placement_group
155
+
156
+ def _init_workers_ray(self, placement_group: "PlacementGroup",
157
+ **ray_remote_kwargs):
158
+ # The workers are the actual ray actors.
159
+ self.workers: List[RayWorkerWrapper] = []
160
+
161
+ # Used in ray compiled DAG: indexed first by PP rank,
162
+ # and then TP rank. In other words, the inner list is
163
+ # the TP group of workers for a PP rank.
164
+ self.pp_tp_workers: List[List[RayWorkerWrapper]] = []
165
+
166
+ if self.parallel_config.ray_workers_use_nsight:
167
+ ray_remote_kwargs = self._configure_ray_workers_use_nsight(
168
+ ray_remote_kwargs)
169
+
170
+ # Create the workers.
171
+ bundle_indices: List[int]
172
+ if envs.VLLM_RAY_BUNDLE_INDICES:
173
+ # Use the bundle indices specified by the user.
174
+ bundle_indices = list(
175
+ map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
176
+ assert len(bundle_indices) == self.parallel_config.world_size, \
177
+ ("VLLM_RAY_BUNDLE_INDICES must have the same size"
178
+ f" as the world size, but got {bundle_indices=} "
179
+ f"and {self.parallel_config.world_size=}")
180
+ assert len(set(bundle_indices)) == len(bundle_indices), \
181
+ ("VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
182
+ f" but got {bundle_indices=}")
183
+ else:
184
+ bundle_indices = []
185
+ for bundle_id, bundle in enumerate(placement_group.bundle_specs):
186
+ if bundle.get(current_platform.ray_device_key, 0):
187
+ bundle_indices.append(bundle_id)
188
+
189
+ worker_metadata: List[RayWorkerMetaData] = []
190
+ driver_ip = get_ip()
191
+ num_tpu_per_worker = placement_group.bundle_specs[0].get(
192
+ current_platform.ray_device_key, 0)
193
+ for rank, bundle_id in enumerate(bundle_indices):
194
+ scheduling_strategy = PlacementGroupSchedulingStrategy(
195
+ placement_group=placement_group,
196
+ placement_group_capture_child_tasks=True,
197
+ placement_group_bundle_index=bundle_id,
198
+ )
199
+ worker = ray.remote(
200
+ num_cpus=0,
201
+ num_gpus=0,
202
+ resources={
203
+ current_platform.ray_device_key: num_tpu_per_worker
204
+ },
205
+ scheduling_strategy=scheduling_strategy,
206
+ **ray_remote_kwargs,
207
+ )(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
208
+ rpc_rank=rank)
209
+ worker_metadata.append(
210
+ RayWorkerMetaData(worker=worker, created_rank=rank))
211
+
212
+ worker_ips = ray.get([
213
+ each.worker.get_node_ip.remote() # type: ignore[attr-defined]
214
+ for each in worker_metadata
215
+ ])
216
+
217
+ for each, ip in zip(worker_metadata, worker_ips):
218
+ each.ip = ip
219
+
220
+ logger.debug("workers: %s", worker_metadata)
221
+
222
+ ip_counts: Dict[str, int] = {}
223
+ for ip in worker_ips:
224
+ ip_counts[ip] = ip_counts.get(ip, 0) + 1
225
+
226
+ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
227
+ """
228
+ Sort the workers based on 3 properties:
229
+ 1. If the worker is on the same node as the driver (vllm engine),
230
+ it should be placed first.
231
+ 2. Then, if the worker is on a node with fewer workers, it should
232
+ be placed first.
233
+ 3. Finally, if the work is on a node with smaller IP address, it
234
+ should be placed first.
235
+ """
236
+ ip = item.ip
237
+ return (0 if ip == driver_ip else 1, ip_counts[ip], ip)
238
+
239
+ # After sorting, the workers on the same node will be
240
+ # close to each other, and the workers on the driver
241
+ # node will be placed first.
242
+ sorted_worker_metadata = sorted(worker_metadata,
243
+ key=sort_by_driver_then_worker_ip)
244
+ start_rank = 0
245
+ for i, item in enumerate(sorted_worker_metadata):
246
+ item.adjusted_rank = i + start_rank
247
+ self.workers = [item.worker for item in sorted_worker_metadata]
248
+ rerank_mapping = {
249
+ item.created_rank: item.adjusted_rank
250
+ for item in sorted_worker_metadata
251
+ }
252
+ self._run_workers("adjust_rank", rerank_mapping)
253
+
254
+ # Get the set of TPU IDs used on each node.
255
+ worker_node_and_tpu_ids = []
256
+ for worker in self.workers:
257
+ worker_node_and_tpu_ids.append(
258
+ ray.get(worker.get_node_and_gpu_ids.remote()) \
259
+ ) # type: ignore
260
+
261
+ node_workers = defaultdict(list) # node id -> list of worker ranks
262
+ node_tpus = defaultdict(list) # node id -> list of tpu ids
263
+
264
+ for i, (node_id, tpu_ids) in enumerate(worker_node_and_tpu_ids):
265
+ node_workers[node_id].append(i)
266
+ # `tpu_ids` can be a list of strings or integers.
267
+ # convert them to integers for consistency.
268
+ tpu_ids = [int(x) for x in tpu_ids]
269
+ node_tpus[node_id].extend(tpu_ids)
270
+ for node_id, tpu_ids in node_tpus.items():
271
+ node_tpus[node_id] = sorted(tpu_ids)
272
+ logger.info(
273
+ f"RayDistributedExecutor | node_workers={node_workers} | node_tpus={node_tpus}"
274
+ )
275
+
276
+ all_ips = set(worker_ips + [driver_ip])
277
+ n_ips = len(all_ips)
278
+ n_nodes = len(node_workers)
279
+
280
+ if n_nodes != n_ips:
281
+ logger.warning(
282
+ f"Got {n_nodes} nodes but with {n_ips} IP addresses. "
283
+ "This is not a typical production setup whose "
284
+ "number of nodes and IPs is euqal. This setup may "
285
+ "lead to unexpected behaviors.")
286
+
287
+ # Set environment variables for the driver and workers.
288
+ all_args_to_update_environment_variables = [{
289
+ current_platform.device_control_env_var:
290
+ ",".join(map(str, node_tpus[node_id])),
291
+ } for (node_id, _) in worker_node_and_tpu_ids]
292
+
293
+ # Environment variables to copy from driver to workers
294
+ env_vars_to_copy = get_env_vars_to_copy(
295
+ exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
296
+ additional_vars=set(current_platform.additional_env_vars),
297
+ destination="workers")
298
+
299
+ # Copy existing env vars to each worker's args
300
+ for args in all_args_to_update_environment_variables:
301
+ for name in env_vars_to_copy:
302
+ if name in os.environ:
303
+ args[name] = os.environ[name]
304
+
305
+ self._env_vars_for_all_workers = (
306
+ all_args_to_update_environment_variables)
307
+
308
+ self._run_workers("update_environment_variables",
309
+ self._get_env_vars_to_be_updated())
310
+
311
+ distributed_init_method = get_distributed_init_method(
312
+ driver_ip, get_open_port())
313
+
314
+ # Initialize the actual workers inside worker wrapper.
315
+ all_kwargs = []
316
+ for rank, (node_id, _) in enumerate(worker_node_and_tpu_ids):
317
+ local_rank = node_workers[node_id].index(rank)
318
+ kwargs = dict(
319
+ vllm_config=self.vllm_config,
320
+ local_rank=local_rank,
321
+ rank=rank,
322
+ distributed_init_method=distributed_init_method,
323
+ is_driver_worker=(not self.parallel_config)
324
+ or (rank % self.parallel_config.tensor_parallel_size == 0),
325
+ )
326
+ all_kwargs.append(kwargs)
327
+ self._run_workers("init_worker", all_kwargs)
328
+
329
+ self._run_workers("init_device")
330
+ self._run_workers("load_model",
331
+ max_concurrent_workers=self.parallel_config.
332
+ max_parallel_loading_workers)
333
+
334
+ if self.use_ray_spmd_worker:
335
+ for pp_rank in range(self.parallel_config.pipeline_parallel_size):
336
+ self.pp_tp_workers.append([])
337
+ for tp_rank in range(
338
+ int(self.parallel_config.tensor_parallel_size //
339
+ num_tpu_per_worker)):
340
+ # PP=2, TP=4
341
+ # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
342
+ rank = (pp_rank * self.parallel_config.tensor_parallel_size
343
+ ) + tp_rank
344
+ assert len(self.pp_tp_workers[pp_rank]) == tp_rank
345
+ assert pp_rank < len(self.pp_tp_workers)
346
+ self.pp_tp_workers[pp_rank].append(self.workers[rank])
File without changes
@@ -0,0 +1,258 @@
1
+ # TODO: Update documentation
2
+
3
+ from typing import List, Optional, Tuple
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from flax import nnx
8
+ from jax.sharding import Mesh
9
+ from jax.sharding import PartitionSpec as P
10
+ from vllm.config import VllmConfig
11
+
12
+ from tpu_inference.layers.jax.attention.attention import (Attention,
13
+ AttentionMetadata)
14
+ from tpu_inference.layers.jax.constants import KVCacheType
15
+ from tpu_inference.layers.jax.layers import DenseFFW, Embedder, LMhead, RMSNorm
16
+ from tpu_inference.layers.jax.transformer_block import TransformerBlock
17
+ from tpu_inference.logger import init_logger
18
+ from tpu_inference.models.jax.utils.weight_utils import (MetadataMap,
19
+ load_hf_weights)
20
+
21
+ logger = init_logger(__name__)
22
+
23
+
24
+ class LlamaForCausalLM(nnx.Module):
25
+
26
+ def __init__(self,
27
+ vllm_config: VllmConfig,
28
+ rng: jax.Array,
29
+ mesh: Mesh,
30
+ force_random_weights: bool = False):
31
+ assert mesh is not None
32
+
33
+ self.vllm_config = vllm_config
34
+ self.rng = nnx.Rngs(rng)
35
+ self.mesh = mesh
36
+
37
+ model_name = self.vllm_config.model_config.model.lower()
38
+ if "70b" in model_name:
39
+ logger.info("Initializing Llama3 70B model variant.")
40
+ self.hidden_size = 8192
41
+ num_layers = 80
42
+ self.num_attention_heads = 64
43
+ self.num_key_value_heads = 8
44
+ intermediate_size = 28672
45
+ elif "8b" in model_name:
46
+ logger.info("Initializing Llama3 8B model variant.")
47
+ self.hidden_size = 4096
48
+ num_layers = 32
49
+ self.num_attention_heads = 32
50
+ self.num_key_value_heads = 8
51
+ intermediate_size = 14336
52
+ else:
53
+ raise ValueError(
54
+ f"Could not determine Llama3 variant (8B or 70B) from model name: '{model_name}'. "
55
+ "Please ensure '8b' or '70b' is in the model path.")
56
+
57
+ dtype = jnp.bfloat16
58
+ self.head_dim = 128
59
+ rope_theta = 500000.0
60
+ vocab_size = 128256
61
+ rms_norm_eps = 1e-5
62
+
63
+ self.embedder = Embedder(vocab_size=vocab_size,
64
+ hidden_size=self.hidden_size,
65
+ dtype=dtype,
66
+ rngs=self.rng,
67
+ random_init=force_random_weights,
68
+ vd_sharding=("model", None))
69
+
70
+ self.layers = []
71
+ kv_cache_dtype = self.vllm_config.cache_config.cache_dtype
72
+ for _ in range(num_layers):
73
+ self.layers.append(
74
+ TransformerBlock(
75
+ pre_attention_norm=RMSNorm(
76
+ dims=self.hidden_size,
77
+ random_init=force_random_weights,
78
+ epsilon=rms_norm_eps,
79
+ rngs=self.rng,
80
+ with_scale=True,
81
+ dtype=dtype,
82
+ ),
83
+ pre_mlp_norm=RMSNorm(
84
+ dims=self.hidden_size,
85
+ rngs=self.rng,
86
+ random_init=force_random_weights,
87
+ epsilon=rms_norm_eps,
88
+ with_scale=True,
89
+ dtype=dtype,
90
+ ),
91
+ attn=Attention(
92
+ hidden_size=self.hidden_size,
93
+ num_attention_heads=self.num_attention_heads,
94
+ num_key_value_heads=self.num_key_value_heads,
95
+ head_dim=self.head_dim,
96
+ rope_theta=rope_theta,
97
+ rope_scaling={},
98
+ rngs=self.rng,
99
+ dtype=dtype,
100
+ # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
101
+ kv_cache_dtype=kv_cache_dtype,
102
+ mesh=self.mesh,
103
+ random_init=force_random_weights,
104
+ dnh_sharding=(None, "model", None),
105
+ dkh_sharding=(None, "model", None),
106
+ nhd_sharding=("model", None, None),
107
+ query_tnh=P(None, "model", None),
108
+ keyvalue_skh=P(None, "model", None),
109
+ attn_o_tnh=P(None, "model", None),
110
+ ),
111
+ custom_module=DenseFFW(dtype=dtype,
112
+ hidden_act="silu",
113
+ hidden_size=self.hidden_size,
114
+ intermediate_size=intermediate_size,
115
+ rngs=self.rng,
116
+ df_sharding=(None, "model"),
117
+ fd_sharding=("model", None),
118
+ random_init=force_random_weights),
119
+ ))
120
+
121
+ self.final_norm = RMSNorm(
122
+ dims=self.hidden_size,
123
+ rngs=self.rng,
124
+ random_init=force_random_weights,
125
+ epsilon=rms_norm_eps,
126
+ with_scale=True,
127
+ dtype=dtype,
128
+ )
129
+
130
+ self.lm_head = LMhead(vocab_size=vocab_size,
131
+ hidden_size=self.hidden_size,
132
+ dtype=dtype,
133
+ rngs=self.rng,
134
+ dv_sharding=(None, 'model'),
135
+ random_init=force_random_weights)
136
+
137
+ def load_weights(self, rng: jax.Array, cache_dir: Optional[str] = None):
138
+ # NOTE: Since we are using nnx.eval_shape to init the model,
139
+ # we have to pass dynamic arrays here for __call__'s usage.
140
+ self.rng = nnx.Rngs(rng)
141
+ weight_loader = Llama3WeightLoader(
142
+ vllm_config=self.vllm_config,
143
+ hidden_size=self.hidden_size,
144
+ attn_heads=self.num_attention_heads,
145
+ num_key_value_heads=self.num_key_value_heads,
146
+ attn_head_dim=self.head_dim)
147
+
148
+ weight_loader.load_weights(self)
149
+
150
+ def __call__(
151
+ self,
152
+ kv_caches: List[jax.Array],
153
+ input_ids: jax.Array,
154
+ attention_metadata: AttentionMetadata,
155
+ *args,
156
+ ) -> Tuple[List[KVCacheType], jax.Array]:
157
+ is_prefill = False
158
+ with jax.named_scope("llama_embed_input"): #Embedding
159
+ x_TD = self.embedder.encode(input_ids)
160
+
161
+ with jax.named_scope("llama_model_transformer_blocks"):
162
+ for (i, layer) in enumerate(self.layers):
163
+ kv_cache = kv_caches[i]
164
+
165
+ # The first layer is unscoped to avoid JAX tracing issues.
166
+ # JAX's profiler may incorrectly apply the scope name from the first
167
+ # layer's kernel compilation to all subsequent layers. Skipping the
168
+ # first layer ensures distinct scope names for the remaining layers.
169
+ if i == 0:
170
+ new_kv_cache, x_TD = layer(x_TD, is_prefill, kv_cache,
171
+ attention_metadata)
172
+ else:
173
+ with jax.named_scope(f'layer_{i}'):
174
+ new_kv_cache, x_TD = layer(x_TD, is_prefill, kv_cache,
175
+ attention_metadata)
176
+
177
+ kv_caches[i] = new_kv_cache
178
+
179
+ with jax.named_scope(
180
+ "llama_final_norm"): #Norm after last transformer block
181
+ final_activation_TD = self.final_norm(x_TD)
182
+
183
+ return kv_caches, final_activation_TD, []
184
+
185
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
186
+ with jax.named_scope("llama_lm_head_projection"
187
+ ): #LM head projection to produce logits
188
+ logits_TV = jnp.dot(hidden_states,
189
+ self.lm_head.input_embedding_table_DV.value)
190
+
191
+ return logits_TV
192
+
193
+
194
+ class Llama3WeightLoader:
195
+
196
+ def __init__(self, vllm_config: VllmConfig, hidden_size, attn_heads,
197
+ num_key_value_heads, attn_head_dim):
198
+ self._transpose_map = {
199
+ "lm_head": (1, 0),
200
+ "gate_proj": (1, 0),
201
+ "up_proj": (1, 0),
202
+ "down_proj": (1, 0),
203
+ "q_proj": (2, 0, 1),
204
+ "k_proj": (2, 0, 1),
205
+ "v_proj": (2, 0, 1),
206
+ "o_proj": (1, 2, 0),
207
+ }
208
+ self._weight_shape_map = {
209
+ "q_proj": (attn_heads, -1, hidden_size),
210
+ "k_proj": (num_key_value_heads, -1, hidden_size),
211
+ "v_proj": (num_key_value_heads, -1, hidden_size),
212
+ "o_proj": (hidden_size, attn_heads, -1),
213
+ }
214
+ self._bias_shape_map = {
215
+ "q_proj.bias": (attn_heads, attn_head_dim),
216
+ "k_proj.bias": (num_key_value_heads, attn_head_dim),
217
+ "v_proj.bias": (num_key_value_heads, attn_head_dim)
218
+ }
219
+
220
+ # Set the mappings from loaded parameter keys to standardized names.
221
+ self._loaded_to_standardized_keys = {
222
+ "model.embed_tokens": "embedder.input_embedding_table_VD",
223
+ "model.layers.*.input_layernorm":
224
+ "layers.*.pre_attention_norm.scale",
225
+ "model.layers.*.mlp.down_proj":
226
+ "layers.*.custom_module.kernel_down_proj_FD",
227
+ "model.layers.*.mlp.gate_proj":
228
+ "layers.*.custom_module.kernel_gating_DF",
229
+ "model.layers.*.mlp.up_proj":
230
+ "layers.*.custom_module.kernel_up_proj_DF",
231
+ "model.layers.*.post_attention_layernorm":
232
+ "layers.*.pre_mlp_norm.scale",
233
+ "model.layers.*.self_attn.k_proj":
234
+ "layers.*.attn.kernel_k_proj_DKH",
235
+ "model.layers.*.self_attn.o_proj":
236
+ "layers.*.attn.kernel_o_proj_NHD",
237
+ "model.layers.*.self_attn.q_proj":
238
+ "layers.*.attn.kernel_q_proj_DNH",
239
+ "model.layers.*.self_attn.v_proj":
240
+ "layers.*.attn.kernel_v_proj_DKH",
241
+ "model.norm": "final_norm.scale",
242
+ "lm_head": "lm_head.input_embedding_table_DV"
243
+ }
244
+ self.vllm_config = vllm_config
245
+
246
+ def load_weights(self, model_for_loading: nnx.Module):
247
+ model_params = nnx.state(model_for_loading)
248
+ metadata_map = MetadataMap(name_map=self._loaded_to_standardized_keys,
249
+ reshape_map=self._weight_shape_map,
250
+ bias_reshape_map=self._bias_shape_map,
251
+ transpose_map=self._transpose_map)
252
+ load_hf_weights(vllm_config=self.vllm_config,
253
+ model=model_for_loading,
254
+ metadata_map=metadata_map,
255
+ mesh=model_for_loading.mesh)
256
+
257
+ # TODO: validate that all of the model_params were accounted for as well.
258
+ nnx.update(model_for_loading, model_params)
File without changes