tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,406 @@
1
+ import json
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+ from jax.sharding import Mesh
6
+ from vllm.config import VllmConfig
7
+
8
+ BATCH_AXIS_NAME = 'data'
9
+ SEQUENCE_AXIS_NAME = 'data'
10
+ DATA_AXIS_NAME = 'data'
11
+ ATTN_HEAD_AXIS_NAME = 'model'
12
+ ATTN_TENSOR_AXIS_NAME = None
13
+ MLP_TENSOR_AXIS_NAME = ('model', 'expert')
14
+ MOE_TENSOR_AXIS_NAME = 'model'
15
+ EXPERT_AXIS_NAME = 'expert'
16
+ VOCAB_AXIS_NAME = ('data', 'expert', 'model')
17
+
18
+
19
+ @dataclass
20
+ class ShardingStrategy:
21
+ """Defines the high-level parallelism strategy.
22
+
23
+ This class specifies how many ways each type of parallelism (tensor, expert,
24
+ sequence, data) should be distributed across the available devices.
25
+
26
+ Attributes:
27
+ tensor_parallelism: The degree of tensor parallelism (e.g., splitting
28
+ weights of a single layer).
29
+ expert_parallelism: The degree of expert parallelism for MoE models.
30
+ sequence_parallelism: The degree of sequence parallelism (splitting
31
+ activations along the sequence length dimension).
32
+ data_parallelism: The degree of data parallelism (splitting the batch
33
+ across devices).
34
+ """
35
+ tensor_parallelism: int = 1
36
+ expert_parallelism: int = 1
37
+ sequence_parallelism: int = 1
38
+ data_parallelism: int = 1
39
+
40
+
41
+ #TODO split this into block unique sharding config, i.e. attentionShardingConfig, MoEShardingConfig
42
+ @dataclass
43
+ class ShardingRulesConfig:
44
+ """Holds detailed sharding configurations for individual tensors, namely logical rules.
45
+
46
+ Each attribute in this class corresponds to a specific weight or activation
47
+ tensor within a transformer model. The value of each attribute is a
48
+ tuple of logical mesh axis names (e.g., 'dp', 'sp', 'tp'), which defines
49
+ how the corresponding tensor's dimensions are partitioned across the device mesh.
50
+ The dimension order in the attribute name (e.g., `btd` for batch, sequence,
51
+ d_model) maps directly to the sharding tuple.
52
+
53
+ TODO: update the mesh axis names to be clear and reduce confusion between prefill & generate
54
+ """
55
+
56
+ # Activation for attn input: (Batch * Sequence, Dim)
57
+ activation_attention_td: tuple = (None, None)
58
+ # Activation for attn out: (Batch * Sequence, Dim)
59
+ activation_attention_out_td: tuple = (None, None)
60
+ # Activation for q projection input: (Batch * Sequence, Dim)
61
+ activation_q_td: tuple = (None, None)
62
+ # Attention Out activation after projection: (Batch * Sequence, NumHeads, HeadDim)
63
+ attn_o_tnh: tuple = (None, None, None)
64
+ # Q vector: (Batch * Sequence, NumHeads, HeadDim)
65
+ query_tnh: tuple = (None, None, None)
66
+ # K/V vector: (Batch * Sequence, NumKVHeads, HeadDim)
67
+ keyvalue_skh: tuple = (None, None, None)
68
+
69
+ # Attention Q weight: (Dim, NumHeads, HeadDim)
70
+ attn_q_weight_dnh: tuple = (None, None, None)
71
+ # Attention K weight: (Dim, NumKVHeads, HeadDim)
72
+ attn_k_weight_dkh: tuple = (None, None, None)
73
+ # Attention V weight: (Dim, NumKVHeads, HeadDim)
74
+ attn_v_weight_dkh: tuple = (None, None, None)
75
+ # Attention Out weight: (NumHeads, HeadDim, Dim)
76
+ attn_o_weight_nhd: tuple = (None, None, None)
77
+
78
+ # Activation for ffw input: (Batch * Sequence, Dim)
79
+ activation_ffw_td: tuple = (None, None)
80
+
81
+ # Activation for ffw input: (Batch * Sequence, Expert, Dim)
82
+ activation_ffw_ted: tuple = (None, None, None)
83
+
84
+ # FFW hidden activation: (Batch * Sequence, FfwDim)
85
+ ffw_hidden_tf: tuple = (None, None)
86
+
87
+ # FFW up/gate weight: (Dim, FfwDim)
88
+ ffw_weight_df: tuple = (None, None)
89
+ # FFW down weight: (FfwDim, Dim)
90
+ ffw_weight_fd: tuple = (None, None)
91
+ # MoE gate/up weights: (NumExperts, Dim, FfwDim)
92
+ moe_weights_edf: tuple = (None, None, None)
93
+ # MoE down weights: (NumExperts, FfwDim, Dim)
94
+ moe_weights_efd: tuple = (None, None, None)
95
+ # MoE router weights: (Dim, NumExperts)
96
+ moe_router_de: tuple = (None, None)
97
+ # MoE router bias weights: (NumExperts,)
98
+ moe_router_bias_e: tuple = (None, )
99
+
100
+ # Embedding weight: (VocabSize, Dim)
101
+ emb_weight_vd: tuple = (None, None)
102
+ # Activation between layers: (Batch * Sequence, Dim)
103
+ activation_td: tuple = (None, None)
104
+ # Final activation before logits: (Batch * Sequence, Dim)
105
+ prelogit_td: tuple = (None, None)
106
+ # Logit activation: (Batch * Sequence, VocabSize)
107
+ logits_tv: tuple = (None, None)
108
+ # RMS norm scale weight: (Dim,)
109
+ norm_scale: tuple = (None)
110
+ # Vocab projection weight (tied embeddings): (Dim, VocabSize)
111
+ vocab_vd: tuple = (None, None)
112
+ vocab_dv: tuple = (None, None)
113
+
114
+
115
+ class ShardingConfig:
116
+ """Container for operation-specific sharding configurations.
117
+
118
+ This class holds two separate `ShardingRulesConfig` objects, one for the
119
+ 'prefill' phase and one for the 'generate' (or decode) phase of model
120
+ execution. This allows tailoring sharding strategies to the different
121
+ computational patterns of each phase.
122
+
123
+ Example Sharding Strategy and Configuration:
124
+
125
+ Sharding Strategy defines the high-level parallelism dimensions.
126
+ For a device mesh like `Mesh((2, 4, 4, 4), ('data', 'seq', 'expert', 'tensor'))` on 128 devices:
127
+ - data: Data Parallelism (2-way)
128
+ - seq: Sequence Parallelism (4-way)
129
+ - expert: Expert Parallelism (4-way)
130
+ - tensor: Tensor Parallelism (4-way)
131
+
132
+ ShardingConfig then maps tensor dimensions to these logical mesh axes.
133
+ For example, a tensor with shape (Batch, Sequence, Dimension) could be sharded
134
+ differently for prefill and decode/generate operations:
135
+
136
+ - Prefill (long sequences, small batch):
137
+ Sharding sequence dim on the 'sp' axis is often efficient.
138
+ `prefill_rules.activation_attention_btd = (None, 'seq', 'tensor')`
139
+
140
+ - Generate (short sequences, large batch):
141
+ Sharding batch dim on the 'dp' axis is often efficient.
142
+ `generate_rules.activation_attention_btd = ('data', None, 'tensor')`
143
+ """
144
+
145
+ def __init__(self,
146
+ prefill_rules=None,
147
+ generate_rules=None,
148
+ default_rules_cls=ShardingRulesConfig):
149
+ """Initializes the ShardingConfig.
150
+
151
+ Args:
152
+ prefill_rules: An `ShardingRulesConfig` for the prefill phase.
153
+ If None, a default config is created.
154
+ generate_rules: An `ShardingRulesConfig` for the generate phase.
155
+ If None, a default config is created.
156
+ default_rules_cls: The default sharding rules (class) to use.
157
+ """
158
+ # Use a factory pattern to avoid mutable default arguments
159
+ self.default_rules_cls = default_rules_cls
160
+ self.prefill_rules = prefill_rules if prefill_rules is not None else default_rules_cls(
161
+ )
162
+ self.generate_rules = generate_rules if generate_rules is not None else default_rules_cls(
163
+ )
164
+
165
+
166
+ def build_mesh(devices, strategy: dict[str, int]) -> Mesh:
167
+ """Constructs a JAX device mesh from a sharding strategy.
168
+
169
+ This method creates a logical grid of devices based on the parallelism
170
+ degrees defined in the strategy. The logical axis names ('dp', 'ep',
171
+ 'sp', 'tp') are used to map tensor dimensions to the physical device grid.
172
+
173
+ Args:
174
+ strategy: A dictionary from upper level config.
175
+
176
+ Returns:
177
+ A JAX `Mesh` object.
178
+ """
179
+
180
+ axis_order = {
181
+ "data": strategy.get("data_parallelism", 1),
182
+ "expert": strategy.get("expert_parallelism", 1),
183
+ "seq": strategy.get("sequence_parallelism", 1),
184
+ "model": strategy.get("tensor_parallelism", 1),
185
+ }
186
+ # TODO: add logic to infer axis when the degree is -1
187
+ mesh_axis_names = []
188
+ mesh_shape = []
189
+ for axis, dim in axis_order.items():
190
+ mesh_axis_names.append(axis)
191
+ mesh_shape.append(dim)
192
+
193
+ if not mesh_shape:
194
+ mesh_shape = [1]
195
+ mesh_axis_names = [
196
+ 'data'
197
+ ] # default to data parallelism if no other strategy is specified
198
+
199
+ devices = np.asarray(devices).reshape(mesh_shape)
200
+ return Mesh(devices, axis_names=tuple(mesh_axis_names))
201
+
202
+
203
+ class Sharding:
204
+ """Generates and manages sharding configurations based on a high-level strategy.
205
+
206
+ This class populates a `ShardingConfig` with detailed tensor sharding
207
+ rules for both prefill and generation phases. It also allows for runtime
208
+ overrides of these rules.
209
+
210
+ Attributes:
211
+ sharding_cfg: The generated `ShardingConfig` with detailed rules.
212
+ """
213
+
214
+ def __init__(self,
215
+ prefill_rules: dict | None = None,
216
+ generate_rules: dict | None = None,
217
+ default_rules_cls=ShardingRulesConfig,
218
+ vllm_config: VllmConfig = None):
219
+ """Initializes the Sharding manager.
220
+
221
+ Args:
222
+ prefill_rules: A dictionary of overrides for the prefill
223
+ sharding config. Keys are attribute names in `ShardingRulesConfig`,
224
+ and values are the new sharding tuples.
225
+ generate_rules: A dictionary of overrides for the generate
226
+ sharding config.
227
+ """
228
+ self.vllm_config = vllm_config
229
+ self.default_rules_cls = default_rules_cls
230
+ self.sharding_cfg = self.make_sharding_config(
231
+ default_rules_cls=default_rules_cls,
232
+ prefill_overrides=prefill_rules,
233
+ generate_overrides=generate_rules)
234
+
235
+ def _get_overrides(self, sharding_phase: str):
236
+ """Return the overrides from the vLLM config for the given sharding phase."""
237
+ overrides = {}
238
+ try:
239
+ overrides = self.vllm_config.additional_config["sharding"][
240
+ "logical_rules"]["all"]
241
+ except KeyError:
242
+ pass
243
+
244
+ try:
245
+ additional_overrides = self.vllm_config.additional_config[
246
+ "sharding"]["logical_rules"][f"{sharding_phase}"]
247
+ overrides.update(additional_overrides)
248
+ except KeyError:
249
+ pass
250
+ return overrides
251
+
252
+ def __str__(self):
253
+ """Succinct representation of relevant Sharding settings and overrides."""
254
+ output_str = f" Using {self.default_rules_cls.__name__} logical rules.\n"
255
+ output_str += f" {self.__class__.__name__:} overrides:\n"
256
+ output_str += f" prefill logical_rule overrides:\n {json.dumps(self._get_overrides('prefill'), indent=4, default=str)}\n\n"
257
+ output_str += f" generate logical_rule overrides:\n {json.dumps(self._get_overrides('generate'), indent=4, default=str)}\n\n"
258
+ return output_str
259
+
260
+ def validate_sharding_strategy(self, ):
261
+ """Validates if the sharding strategy is compatible with the environment.
262
+
263
+ This method is a placeholder now, and will check if the product of parallelism degrees
264
+ matches the number of available devices.
265
+ """
266
+ #TODO: check num_devices % parallelism == 0
267
+ #TODO: check num_devices == multiply(parallelism(with inferred))
268
+ return
269
+
270
+ def get_sharding_cfg(self) -> ShardingConfig:
271
+ """Returns the generated sharding configuration."""
272
+ return self.sharding_cfg
273
+
274
+ def _apply_overrides(self, config_obj: ShardingRulesConfig,
275
+ overrides: dict | None):
276
+ """Applies runtime overrides to a sharding configuration object.
277
+
278
+ Args:
279
+ config_obj: The sharding configuration object (e.g., prefill_rules)
280
+ to be updated.
281
+ overrides: A dictionary where keys are attribute names of the config
282
+ object and values are the new sharding tuples.
283
+
284
+ Raises:
285
+ AttributeError: If a key in the overrides dictionary is not a valid
286
+ attribute of the configuration object.
287
+ """
288
+ for key, value in overrides.items():
289
+ if hasattr(config_obj, key):
290
+ setattr(config_obj, key, value)
291
+ else:
292
+ # Raise an error for invalid keys to prevent silent failures
293
+ raise AttributeError(
294
+ f"'{key}' is not a valid attribute of {type(config_obj).__name__}"
295
+ )
296
+
297
+ def _make_default_sharding_config(self, prefill_rules, generate_rules):
298
+
299
+ # Populate Prefill Config
300
+ # During prefill, sequence length is long, so we shard along the sequence axis.
301
+ prefill_rules.activation_attention_td = (DATA_AXIS_NAME,
302
+ ATTN_TENSOR_AXIS_NAME)
303
+ prefill_rules.activation_attention_out_td = (DATA_AXIS_NAME,
304
+ ATTN_TENSOR_AXIS_NAME)
305
+ prefill_rules.activation_q_td = (DATA_AXIS_NAME, ATTN_TENSOR_AXIS_NAME)
306
+ #TODO: the default qkv and kvcache is sharded on head dim
307
+ # We may change it after we finalize the KVCache design
308
+ prefill_rules.attn_o_tnh = (DATA_AXIS_NAME, ATTN_HEAD_AXIS_NAME, None)
309
+ prefill_rules.query_tnh = (DATA_AXIS_NAME, ATTN_HEAD_AXIS_NAME, None)
310
+ prefill_rules.keyvalue_skh = (DATA_AXIS_NAME, ATTN_HEAD_AXIS_NAME,
311
+ None)
312
+
313
+ # Populate Generate (Decode) Config
314
+ # During decode, batch size is the large dimension, so we shard along the batch axis.
315
+ generate_rules.activation_attention_td = (DATA_AXIS_NAME,
316
+ ATTN_TENSOR_AXIS_NAME)
317
+ generate_rules.activation_attention_out_td = (DATA_AXIS_NAME,
318
+ ATTN_TENSOR_AXIS_NAME)
319
+ generate_rules.activation_q_td = (DATA_AXIS_NAME,
320
+ ATTN_TENSOR_AXIS_NAME)
321
+ #TODO: the default qkv and kvcache is sharded on head dim
322
+ # We may change it after we finalize the KVCache design
323
+ generate_rules.attn_o_tnh = (DATA_AXIS_NAME, ATTN_HEAD_AXIS_NAME, None)
324
+ generate_rules.query_tnh = (DATA_AXIS_NAME, ATTN_HEAD_AXIS_NAME, None)
325
+ generate_rules.keyvalue_skh = (DATA_AXIS_NAME, ATTN_HEAD_AXIS_NAME,
326
+ None)
327
+ generate_rules.attn_q_weight_dnh = (None, ATTN_HEAD_AXIS_NAME,
328
+ ATTN_TENSOR_AXIS_NAME)
329
+ generate_rules.attn_k_weight_dkh = (None, ATTN_HEAD_AXIS_NAME,
330
+ ATTN_TENSOR_AXIS_NAME)
331
+ generate_rules.attn_v_weight_dkh = (None, ATTN_HEAD_AXIS_NAME,
332
+ ATTN_TENSOR_AXIS_NAME)
333
+ generate_rules.attn_o_weight_nhd = (ATTN_HEAD_AXIS_NAME, None,
334
+ ATTN_TENSOR_AXIS_NAME)
335
+ generate_rules.activation_ffw_td = (DATA_AXIS_NAME, None)
336
+ generate_rules.activation_ffw_ted = (DATA_AXIS_NAME, EXPERT_AXIS_NAME,
337
+ None)
338
+ generate_rules.ffw_hidden_tf = (DATA_AXIS_NAME, MLP_TENSOR_AXIS_NAME)
339
+ # FFW weights are typically sharded along the hidden dimension (F).
340
+ generate_rules.ffw_weight_df = (None, MLP_TENSOR_AXIS_NAME)
341
+ generate_rules.ffw_weight_fd = (MLP_TENSOR_AXIS_NAME, None)
342
+ # MoE weights are sharded along the expert axis and the hidden dimension.
343
+ generate_rules.moe_weights_edf = (EXPERT_AXIS_NAME, None,
344
+ MOE_TENSOR_AXIS_NAME)
345
+ generate_rules.moe_weights_efd = (EXPERT_AXIS_NAME,
346
+ MOE_TENSOR_AXIS_NAME, None)
347
+ generate_rules.moe_router_de = (None, EXPERT_AXIS_NAME)
348
+
349
+ # Embedding weight: (VocabSize, Dim)
350
+ generate_rules.emb_weight_vd = (MLP_TENSOR_AXIS_NAME, None)
351
+ generate_rules.activation_td = (DATA_AXIS_NAME, ATTN_TENSOR_AXIS_NAME)
352
+ generate_rules.prelogit_td = (DATA_AXIS_NAME, ATTN_TENSOR_AXIS_NAME)
353
+ generate_rules.logits_tv = (DATA_AXIS_NAME, MLP_TENSOR_AXIS_NAME)
354
+ generate_rules.vocab_vd = (VOCAB_AXIS_NAME, None)
355
+ generate_rules.vocab_dv = (None, VOCAB_AXIS_NAME)
356
+
357
+ def make_sharding_config(
358
+ self,
359
+ default_rules_cls: ShardingRulesConfig,
360
+ prefill_overrides: dict | None = None,
361
+ generate_overrides: dict | None = None) -> ShardingConfig:
362
+ """Creates the detailed `ShardingConfig` with specific partitioning rules
363
+ and applies any runtime overrides.
364
+
365
+ This method populates the `prefill_rules` and
366
+ `generate_rules` with hardcoded sharding rules that are generally
367
+ effective for transformer models, and then updates them with any provided
368
+ overrides.
369
+
370
+ Args:
371
+ prefill_overrides: A dictionary with attribute names and their new values
372
+ for the prefill sharding configuration.
373
+ generate_overrides: A dictionary with attribute names and their new values
374
+ for the generate sharding configuration.
375
+
376
+ Returns:
377
+ The populated and overridden `ShardingConfig` object.
378
+ """
379
+ #TODO: organize into update_prefill() and update_decode for each axis
380
+ #TODO: verify the sharding axes
381
+ sharding_cfg = ShardingConfig(default_rules_cls=default_rules_cls)
382
+ prefill_rules = sharding_cfg.prefill_rules
383
+ generate_rules = sharding_cfg.generate_rules
384
+
385
+ # Extract the overrides from the vllm_config if they are not provided programatically.
386
+ if prefill_overrides is None:
387
+ prefill_overrides = self._get_overrides("prefill")
388
+ if generate_overrides is None:
389
+ generate_overrides = self._get_overrides("generate")
390
+
391
+ # Apply default sharding configs
392
+ self._make_default_sharding_config(prefill_rules, generate_rules)
393
+
394
+ # Apply overriding the runtime sharding rules
395
+ self._apply_overrides(prefill_rules, prefill_overrides)
396
+ self._apply_overrides(generate_rules, generate_overrides)
397
+
398
+ return sharding_cfg
399
+
400
+ #TODO: Add __repr__
401
+
402
+
403
+ class ShardingInfo:
404
+ #TODO a sharding info class for visualizing & debugging the sharding performance
405
+ # Will implement it for the next version
406
+ pass
@@ -0,0 +1,76 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Tuple
3
+
4
+ # Flax and JAX sharding imports
5
+ import jax
6
+ from flax import nnx
7
+
8
+ from tpu_inference.layers.jax.attention.attention import (AttentionMetadata,
9
+ KVCache)
10
+ from tpu_inference.layers.jax.layers import DenseFFW
11
+ from tpu_inference.layers.jax.moe.moe import MoE
12
+
13
+
14
+ @dataclass(kw_only=True)
15
+ class TransformerBlock(nnx.Module):
16
+ """
17
+ A heavy weight module which serves as the stateful live blocks in serving
18
+
19
+ custom_module can be either a dense module (i.e., DenseFFW) or MoE.
20
+ """
21
+ pre_attention_norm: nnx.Module
22
+ pre_mlp_norm: nnx.Module
23
+ custom_module: nnx.Module
24
+ attn: nnx.Module
25
+ use_attention_rope: bool = True
26
+ quant: Any | None = None
27
+
28
+ def __call__(
29
+ self, x_TD: jax.Array, is_prefill: bool, kv_cache: KVCache,
30
+ attention_metadata: AttentionMetadata
31
+ ) -> Tuple[KVCache, jax.Array]:
32
+ # Attn Block
33
+ attn_residual_TD = x_TD
34
+ x_TD = self.pre_attention_norm(x_TD)
35
+ new_cache, attn_output_TD = self.attn(x_TD, is_prefill, kv_cache,
36
+ attention_metadata,
37
+ self.use_attention_rope)
38
+ attn_output_TD += attn_residual_TD
39
+
40
+ # FFW Block
41
+ ffw_residual_TD = attn_output_TD
42
+ normed_ffw_input_TD = self.pre_mlp_norm(attn_output_TD)
43
+ logits_TD = self.custom_module(normed_ffw_input_TD)
44
+ logits_TD += ffw_residual_TD
45
+ return new_cache, logits_TD
46
+
47
+
48
+ @dataclass(kw_only=True)
49
+ class SharedExpertsTransformerBlock(TransformerBlock):
50
+ """Create a modified TransformerBlock that sums MoE layer output with shared expert output."""
51
+ shared_experts: nnx.Module
52
+
53
+ def __call__(self, x_TD, is_prefill, kv_cache, attention_metadata):
54
+ # Attn Block
55
+ attn_residual_TD = x_TD
56
+ x_TD = self.pre_attention_norm(x_TD)
57
+ new_cache, attn_output_TD = self.attn(x_TD, is_prefill, kv_cache,
58
+ attention_metadata,
59
+ self.use_attention_rope)
60
+ attn_output_TD += attn_residual_TD
61
+
62
+ # FFW Block
63
+ ffw_residual_TD = attn_output_TD
64
+ normed_ffw_input_TD = self.pre_mlp_norm(attn_output_TD)
65
+ if isinstance(self.custom_module, MoE):
66
+ logits_TD = self.custom_module(normed_ffw_input_TD)
67
+ # Add the shared expert outputs to the MoE outputs.
68
+ shared_expert_output_TD = self.shared_experts(normed_ffw_input_TD)
69
+ logits_TD += shared_expert_output_TD
70
+ elif isinstance(self.custom_module, DenseFFW):
71
+ logits_TD = self.custom_module(normed_ffw_input_TD)
72
+ else:
73
+ raise ValueError(
74
+ f"Invalid custom moduel type: {type(self.custom_module)}")
75
+ logits_TD += ffw_residual_TD
76
+ return new_cache, logits_TD
File without changes
@@ -0,0 +1,184 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ import functools
4
+ from typing import Optional, Tuple
5
+
6
+ import jax
7
+ import torch
8
+ from jax.sharding import Mesh
9
+ from torchax.interop import jax_view, torch_view
10
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
11
+ AttentionLayer, AttentionType)
12
+
13
+ from tpu_inference import utils
14
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
15
+ from tpu_inference.layers.jax.attention_interface import attention
16
+ from tpu_inference.logger import init_logger
17
+ from tpu_inference.models.vllm.vllm_model_wrapper_context import \
18
+ get_vllm_model_wrapper_context
19
+
20
+ logger = init_logger(__name__)
21
+
22
+
23
+ class PallasAttentionBackend(AttentionBackend):
24
+
25
+ @staticmethod
26
+ def get_name() -> str:
27
+ return "PALLAS"
28
+
29
+ @staticmethod
30
+ def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
31
+ return PallasAttentionBackendImpl
32
+
33
+
34
+ class PallasAttentionBackendImpl(AttentionImpl):
35
+
36
+ def __init__(
37
+ self,
38
+ num_heads: int,
39
+ head_size: int,
40
+ scale: float,
41
+ num_kv_heads: int,
42
+ alibi_slopes: Optional[list[float]],
43
+ sliding_window: Optional[int],
44
+ kv_cache_dtype: str,
45
+ logits_soft_cap: Optional[float] = None,
46
+ attn_type: str = AttentionType.DECODER,
47
+ kv_sharing_target_layer_name: Optional[int] = None,
48
+ use_irope: bool = False,
49
+ ) -> None:
50
+ if use_irope:
51
+ logger.warning_once(
52
+ "Using irope in Pallas is not supported yet, it will fall back "
53
+ "to global attention for long context.")
54
+ self.num_heads = num_heads
55
+ self.head_size = head_size
56
+ self.scale = float(scale)
57
+ self.num_kv_heads = num_kv_heads
58
+ self.sliding_window = sliding_window
59
+ self.logits_soft_cap = logits_soft_cap
60
+ self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
61
+
62
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
63
+ if alibi_slopes is not None:
64
+ raise NotImplementedError("Alibi slopes is not supported.")
65
+ self.kv_cache_quantized_dtype = None
66
+ if kv_cache_dtype != "auto":
67
+ self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
68
+ kv_cache_dtype)
69
+
70
+ if attn_type != AttentionType.DECODER:
71
+ raise NotImplementedError("Encoder self-attention and "
72
+ "encoder/decoder cross-attention "
73
+ "are not implemented for "
74
+ "PallasAttentionBackendImpl")
75
+
76
+ def forward(
77
+ self,
78
+ layer: AttentionLayer,
79
+ query: torch.Tensor,
80
+ key: torch.Tensor,
81
+ value: torch.Tensor,
82
+ kv_cache: torch.Tensor,
83
+ attn_metadata: AttentionMetadata,
84
+ output: Optional[torch.Tensor] = None,
85
+ output_scale: Optional[torch.Tensor] = None,
86
+ ) -> torch.Tensor:
87
+ if output_scale is not None:
88
+ raise NotImplementedError(
89
+ "fused output quantization is not yet supported for "
90
+ "PallasAttentionBackendImpl")
91
+
92
+ if kv_cache.numel():
93
+ raise RuntimeError(
94
+ "KV cache from vLLM Attention layer should be empty but has "
95
+ "the size of %s.", kv_cache.numel())
96
+
97
+ del kv_cache # Use kv_cache from vllm wrapper context values instead.
98
+
99
+ vllm_model_wrapper_context = get_vllm_model_wrapper_context()
100
+ kv_cache_index = vllm_model_wrapper_context.layer_name_to_kvcache_index[
101
+ layer.layer_name]
102
+ kv_cache = vllm_model_wrapper_context.kv_caches[kv_cache_index]
103
+
104
+ mesh = vllm_model_wrapper_context.mesh
105
+
106
+ query, key, value = jax_view(query), jax_view(key), jax_view(value)
107
+ q_scale = k_scale = v_scale = None
108
+ if self.kv_cache_quantized_dtype:
109
+ key, value = utils.quantize_kv(key, value,
110
+ self.kv_cache_quantized_dtype,
111
+ layer._k_scale_float,
112
+ layer._v_scale_float)
113
+ # TODO(kyuyeunk): Enable w8a8 when VREG spill issue is resolved.
114
+ # q_scale = layer._q_scale_float
115
+ k_scale = layer._k_scale_float
116
+ v_scale = layer._v_scale_float
117
+
118
+ new_kv_cache, outputs = _jax_attn_func(kv_cache, query, key, value,
119
+ attn_metadata, mesh, self.scale,
120
+ self.head_size, self.num_heads,
121
+ self.num_kv_heads, q_scale,
122
+ k_scale, v_scale)
123
+ vllm_model_wrapper_context.kv_caches[kv_cache_index] = new_kv_cache
124
+
125
+ return torch_view(outputs)
126
+
127
+
128
+ @functools.partial(
129
+ jax.jit,
130
+ static_argnums=(
131
+ 5, 6, 7, 8, 9, 10, 11, 12
132
+ ), # mesh, scale, head_size, num_heads, num_kv_heads, q_scale, k_scale, v_scale
133
+ donate_argnums=(0, ), # donate kv_cache
134
+ )
135
+ def _jax_attn_func(
136
+ kv_cache: jax.Array,
137
+ q: jax.Array,
138
+ k: jax.Array,
139
+ v: jax.Array,
140
+ attention_metadata: AttentionMetadata,
141
+ mesh: Mesh,
142
+ scale: float,
143
+ head_size: int,
144
+ num_heads: int,
145
+ num_kv_heads: int,
146
+ q_scale: Optional[float] = None,
147
+ k_scale: Optional[float] = None,
148
+ v_scale: Optional[float] = None,
149
+ ) -> Tuple[jax.Array, jax.Array]:
150
+ del scale # Unused for now, as the attention function applies a default scale.
151
+
152
+ # Get shapes from vllm
153
+ q_len, q_compute_dim = q.shape
154
+ k_len, k_compute_dim = k.shape
155
+ assert k.shape == v.shape
156
+ assert q_compute_dim == head_size * num_heads
157
+ assert k_compute_dim == head_size * num_kv_heads
158
+
159
+ # Convert the shapes from vLLM's convetion to what the attention function expects
160
+ # bs, num_heads, q_len, head_size
161
+ q = q.reshape(q_len, num_heads, head_size)
162
+ # bs, num_kv_heads, k_len, head_size
163
+ k = k.reshape(k_len, num_kv_heads, head_size)
164
+ v = v.reshape(k_len, num_kv_heads, head_size)
165
+
166
+ new_kv_cache, outputs = attention(
167
+ kv_cache,
168
+ q,
169
+ k,
170
+ v,
171
+ attention_metadata,
172
+ mesh,
173
+ q_scale=q_scale,
174
+ k_scale=k_scale,
175
+ v_scale=v_scale,
176
+ )
177
+
178
+ # Convert the shape back to vLLM's convention
179
+ assert outputs.shape[0] == q_len
180
+ assert outputs.shape[1] == num_heads
181
+ assert outputs.shape[2] == head_size
182
+ outputs = outputs.reshape(q_len, q_compute_dim)
183
+
184
+ return new_kv_cache, outputs