tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,214 @@
1
+ import math
2
+ from typing import Any, Dict
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+
8
+ def apply_rope(
9
+ # (seq_len, num_heads, head_dim)
10
+ inputs: jax.Array,
11
+ # (3, seq_len) for M-RoPE, otherwise (seq_len,)
12
+ positions: jax.Array,
13
+ head_dim: int,
14
+ rope_theta: float = 10000,
15
+ rope_scaling: Dict[str, Any] = None,
16
+ rope_input_ordering: str = "split",
17
+ ) -> jax.Array:
18
+ """
19
+ Applies Rotary Positional Embedding using the sine and cosine strategy.
20
+
21
+ This implementation assumes the input tensor has a shape that might include
22
+ padding on the last dimension (head_dim).
23
+ RoPE is applied only to the first `head_dim` features, and the result is
24
+ padded back to the original dimension if necessary.
25
+ If rope_input_ordering is "split", then the input pairs for rotation are taken one from the
26
+ first and one from the second half of the head_dim. If it is "interleaved" then
27
+ adjacent values are used as inputs for rotation.
28
+ """
29
+
30
+ # M-RoPE support for Qwen2.5-VL
31
+ if positions.ndim == 2 and positions.shape[0] == 3:
32
+ mrope_section = rope_scaling.get("mrope_section",
33
+ None) if rope_scaling else None
34
+ # NOTE: We assume mrope_section is always available
35
+ # as Qwen2.5-VL is the only model using mrope
36
+ assert mrope_section is not None
37
+
38
+ split_indices = [mrope_section[0], mrope_section[0] + mrope_section[1]]
39
+
40
+ # Indices for the features to be rotated (first half of head_dim)
41
+ all_freq_indices = jnp.arange(head_dim // 2)
42
+
43
+ # Split the indices according to mrope_section. This is valid because split_indices are static.
44
+ freq_indices_split = jnp.split(all_freq_indices, split_indices)
45
+ # freq_indices_split is a list of 3 JAX arrays.
46
+
47
+ cos_list = []
48
+ sin_list = []
49
+
50
+ for i in range(3): # For each of the 3 position dimensions
51
+ current_indices = freq_indices_split[i]
52
+
53
+ if current_indices.size == 0:
54
+ # This section is empty, skip.
55
+ continue
56
+
57
+ # inv_freq shape: (mrope_section[i],)
58
+ inv_freq = 1.0 / (rope_theta**(current_indices * 2.0 / head_dim))
59
+
60
+ # positions[i]: (seq_len,)
61
+ # freqs shape: (seq_len, mrope_section[i])
62
+ freqs = jnp.outer(positions[i], inv_freq)
63
+
64
+ cos_list.append(jnp.cos(freqs))
65
+ sin_list.append(jnp.sin(freqs))
66
+
67
+ # Concatenate along the feature dimension
68
+ # cos, sin shape: (seq_len, head_dim//2)
69
+ cos = jnp.concatenate(cos_list, axis=1)
70
+ sin = jnp.concatenate(sin_list, axis=1)
71
+
72
+ # Add num_heads dimension for broadcasting
73
+ cos = cos[:, jnp.newaxis, :] # Shape: (seq_len, 1, head_dim//2)
74
+ sin = sin[:, jnp.newaxis, :] # Shape: (seq_len, 1, head_dim//2)
75
+
76
+ # Apply rotation
77
+ inputs_real = inputs[..., :head_dim // 2]
78
+ inputs_imag = inputs[..., head_dim // 2:head_dim]
79
+
80
+ outputs_real = inputs_real * cos - inputs_imag * sin
81
+ outputs_imag = inputs_real * sin + inputs_imag * cos
82
+
83
+ out = jnp.concatenate([outputs_real, outputs_imag], axis=-1)
84
+
85
+ # Standard RoPE
86
+ else:
87
+ # Calculate inverse frequencies (timescale)
88
+ fraction = 2 * jnp.arange(0, head_dim // 2) / head_dim
89
+ timescale = 1.0 / (rope_theta**fraction)
90
+
91
+ # Apply scaling if provided
92
+ if rope_scaling:
93
+ timescale = apply_rope_scaling(timescale, rope_scaling)
94
+
95
+ # Prepare for rotation by calculating sin and cos values
96
+ # `sinusoid_inp` gets shape (batch * seq_len, head_dim/2)
97
+ sinusoid_inp = positions[..., jnp.newaxis] * timescale[jnp.newaxis, :]
98
+
99
+ # Broadcast over the 'heads' dimension, assuming shape (batch*seq, heads, head_dim)
100
+ sinusoid_inp = sinusoid_inp[:, jnp.newaxis, ...]
101
+ sin = jnp.sin(sinusoid_inp)
102
+ cos = jnp.cos(sinusoid_inp)
103
+
104
+ if rope_input_ordering == "interleaved":
105
+ # Reshape to group adjacent features for rotation, matching new_apply_rope
106
+ rotary_inputs = inputs[
107
+ ..., :head_dim] # Take just the non-padded amount.
108
+ reshaped_inputs = rotary_inputs.reshape(*rotary_inputs.shape[:-1],
109
+ -1, 2)
110
+
111
+ # Apply the rotation
112
+ first_half = reshaped_inputs[..., 0]
113
+ second_half = reshaped_inputs[..., 1]
114
+ else:
115
+ first_half = inputs[..., :head_dim // 2]
116
+ second_half = inputs[..., head_dim // 2:head_dim]
117
+
118
+ first_part = first_half * cos - second_half * sin
119
+ second_part = second_half * cos + first_half * sin
120
+
121
+ # Combine the rotated parts and reshape back
122
+ if rope_input_ordering == "interleaved":
123
+ out_stacked = jnp.stack([first_part, second_part], axis=-1)
124
+ out = out_stacked.reshape(rotary_inputs.shape)
125
+ else:
126
+ out = jnp.concatenate([first_part, second_part], axis=-1)
127
+
128
+ # If the original input was padded, pad the output with zeros to match.
129
+ padded_head_dim = inputs.shape[-1]
130
+ if padded_head_dim > head_dim:
131
+ pad_width = padded_head_dim - head_dim
132
+ pad_config = [(0, 0)] * (out.ndim - 1) + [(0, pad_width)]
133
+ out = jnp.pad(out, pad_config)
134
+
135
+ return out.astype(inputs.dtype)
136
+
137
+
138
+ def apply_longrope(
139
+ inputs: jax.Array,
140
+ positions: jax.Array,
141
+ head_dim: int,
142
+ rope_scaling: Dict[str, Any],
143
+ original_max_position_embeddings: int,
144
+ max_position_embeddings: int,
145
+ rope_theta: float = 10000,
146
+ ) -> jax.Array:
147
+ # LongRoPE implementation specific to Phi-3
148
+ # Implementation based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi3/modeling_phi3.py#L197-L235
149
+
150
+ scale = max_position_embeddings / original_max_position_embeddings
151
+ if scale <= 1.0:
152
+ mscale = 1.0
153
+ else:
154
+ mscale = jnp.sqrt(1 + (jnp.log(scale) /
155
+ jnp.log(original_max_position_embeddings)))
156
+
157
+ seq_len = inputs.shape[0]
158
+ if seq_len > original_max_position_embeddings:
159
+ long_factor = jnp.array(rope_scaling.get("long_factor"))
160
+ timescale = 1.0 / (long_factor * (rope_theta**(
161
+ (2 * jnp.arange(0, head_dim // 2)) / head_dim)))
162
+ else:
163
+ short_factor = jnp.array(rope_scaling.get("short_factor"))
164
+ timescale = 1.0 / (short_factor * (rope_theta**(
165
+ (2 * jnp.arange(0, head_dim // 2)) / head_dim)))
166
+
167
+ # Calculate RoPE positions
168
+ sinusoid_inp = positions[..., jnp.newaxis] * timescale[jnp.newaxis, :]
169
+ sinusoid_inp = sinusoid_inp[:, jnp.newaxis, ...]
170
+ sin = jnp.sin(sinusoid_inp) * mscale
171
+ cos = jnp.cos(sinusoid_inp) * mscale
172
+
173
+ # Padding logic
174
+ padded_head_dim = inputs.shape[-1]
175
+
176
+ # Apply RoPE mechanism
177
+ first_half = inputs[..., :head_dim // 2]
178
+ second_half = inputs[..., head_dim // 2:head_dim]
179
+
180
+ first_part = first_half * cos - second_half * sin
181
+ second_part = second_half * cos + first_half * sin
182
+ out = jnp.concatenate([first_part, second_part], axis=-1)
183
+
184
+ if padded_head_dim > head_dim:
185
+ out = jnp.pad(out, ((0, 0), (0, 0), (0, padded_head_dim - head_dim)))
186
+
187
+ return out.astype(inputs.dtype)
188
+
189
+
190
+ def apply_rope_scaling(freqs: jax.Array, rope_scaling: Dict[str,
191
+ Any]) -> jax.Array:
192
+ # Values obtained from grid search
193
+ scale_factor = rope_scaling.get("scale_factor", 8.0)
194
+ low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
195
+ high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
196
+ old_context_len = rope_scaling.get("original_max_position_embeddings",
197
+ 8192)
198
+
199
+ low_freq_wavelen = old_context_len / low_freq_factor
200
+ high_freq_wavelen = old_context_len / high_freq_factor
201
+
202
+ wavelen = 2 * math.pi / freqs
203
+ smooth = (old_context_len / wavelen -
204
+ low_freq_factor) / (high_freq_factor - low_freq_factor)
205
+
206
+ high_freqs = jnp.where(wavelen < high_freq_wavelen, freqs, 0)
207
+ low_freqs = jnp.where(wavelen > low_freq_wavelen, freqs / scale_factor, 0)
208
+ mid_freqs = jnp.where(
209
+ (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
210
+ (1 - smooth) * freqs / scale_factor + smooth * freqs,
211
+ 0,
212
+ )
213
+ new_freqs = high_freqs + low_freqs + mid_freqs
214
+ return new_freqs
File without changes