tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,294 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass, field
17
+ from typing import Optional, Tuple
18
+
19
+ import jax
20
+ from flax import nnx
21
+ from jax import numpy as jnp
22
+ from jax.experimental.layout import Layout, with_layout_constraint
23
+ from jax.sharding import NamedSharding, PartitionSpec
24
+
25
+
26
+ @dataclass(kw_only=True)
27
+ class RotaryEmbedding(nnx.Module):
28
+ """
29
+ An implementation of the original rotary positional embedding.
30
+ """
31
+ rotary_dim: int
32
+ rope_theta: float
33
+ original_max_position_embeddings: int
34
+ dtype: jnp.dtype
35
+ sin_cos_cache: Optional[jax.Array] = field(init=False, default=None)
36
+
37
+ def initialize_cache(self):
38
+ """Computes and caches the sin/cos embeddings."""
39
+ if self.sin_cos_cache is None:
40
+ self.sin_cos_cache = self._compute_sin_cos()
41
+
42
+ def _compute_inv_freq(self):
43
+ fractions_H = jnp.arange(0, self.rotary_dim, 2,
44
+ dtype=jnp.float32) / self.rotary_dim
45
+ inv_freq_H = 1.0 / (self.rope_theta**fractions_H)
46
+ return inv_freq_H
47
+
48
+ def _compute_sin_cos(self):
49
+ inv_freq_H = self._compute_inv_freq()
50
+ t = jnp.arange(self.original_max_position_embeddings,
51
+ dtype=jnp.float32)
52
+
53
+ freqs = jnp.einsum("...T,k->...Tk",
54
+ t,
55
+ inv_freq_H,
56
+ precision=jax.lax.Precision.HIGHEST)
57
+ sin, cos = jnp.sin(freqs), jnp.cos(freqs)
58
+ cache = jnp.concatenate((cos, sin), axis=-1)
59
+ return cache
60
+
61
+ def apply_rope(self, positions: jax.Array, x_TNH: jax.Array):
62
+ assert x_TNH.ndim == 3
63
+ assert self.sin_cos_cache is not None, "RoPE cache not initialized."
64
+ cos_sin_TH = self.sin_cos_cache[positions]
65
+ # cos, sin: (T, H/2)
66
+ cos_TH, sin_TH = jnp.split(cos_sin_TH, 2, axis=-1)
67
+ assert sin_TH.ndim == 2 and cos_TH.ndim == 2
68
+ # cos, sin: (T, 1, H/2)
69
+ cos_T1H, sin_T1H = cos_TH[:, None, :], sin_TH[:, None, :]
70
+ # first_half, second_half: (T, N, H/2)
71
+ first_half_TNH, second_half_TNH = jnp.split(x_TNH, 2, axis=-1)
72
+ combined = jnp.concatenate([
73
+ first_half_TNH * cos_T1H - second_half_TNH * sin_T1H,
74
+ second_half_TNH * cos_T1H + first_half_TNH * sin_T1H
75
+ ],
76
+ axis=-1)
77
+ return combined.astype(self.dtype)
78
+
79
+
80
+ @dataclass(kw_only=True)
81
+ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
82
+ """
83
+ Rotary Embedding for deepseek, with scaling and YaRN method.
84
+ """
85
+ scaling_factor: float
86
+ beta_fast: int = 32
87
+ beta_slow: int = 1
88
+ mscale_value: float = 1
89
+ mscale_all_dim: float = 0
90
+
91
+ def initialize_cache(self, mesh: jax.sharding.Mesh):
92
+ """Computes and caches the sin/cos embeddings."""
93
+ # The second condition is for the Qwix case, where we need to call `initialize_cache` on
94
+ # the abstract model. Thus, when we go to call `initialize_cache` on the concrete model,
95
+ # this method will have been called already, but we need to recompute the cache so that
96
+ # it's concrete (otherwise, it'll still be a jax.ShapeDtypeStruct).
97
+ if self.sin_cos_cache is not None and not isinstance(
98
+ self.sin_cos_cache, jax.ShapeDtypeStruct):
99
+ return
100
+ mscale_val = _yarn_get_mscale(
101
+ self.scaling_factor, self.mscale_value) / _yarn_get_mscale(
102
+ self.scaling_factor, self.mscale_all_dim)
103
+ replicated_sharding = NamedSharding(mesh, PartitionSpec())
104
+ self.mscale = jax.device_put(mscale_val, replicated_sharding)
105
+ self.sin_cos_cache = self._compute_sin_cos()
106
+
107
+ def _compute_inv_freq(self):
108
+ fractions = jnp.arange(0, self.rotary_dim, 2,
109
+ dtype=jnp.float32) / self.rotary_dim
110
+ inv_freq_extrapolation = 1.0 / (self.rope_theta**fractions)
111
+ inv_freq_interpolation = 1.0 / (self.scaling_factor *
112
+ self.rope_theta**fractions)
113
+ low, high = _yarn_find_correction_range(
114
+ self.beta_fast, self.beta_slow, self.rotary_dim, self.rope_theta,
115
+ self.original_max_position_embeddings)
116
+
117
+ # Get n-d rotational scaling corrected for extrapolation
118
+ inv_freq_mask = 1 - _yarn_linear_ramp_mask(
119
+ low, high, self.rotary_dim // 2).astype(jnp.float32)
120
+ inv_freq = inv_freq_interpolation * (
121
+ 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
122
+ return inv_freq
123
+
124
+ @jax.jit
125
+ def _compute_sin_cos(self):
126
+ inv_freq_H = self._compute_inv_freq()
127
+ t = jnp.arange(self.original_max_position_embeddings *
128
+ self.scaling_factor,
129
+ dtype=jnp.float32)
130
+ freqs = jnp.einsum("...T,k->...Tk", t, inv_freq_H)
131
+ sin, cos = jnp.sin(freqs) * self.mscale, jnp.cos(freqs) * self.mscale
132
+ cache = jnp.concatenate((cos, sin), axis=-1)
133
+ H = cache.shape[1]
134
+ target_dim = ((H - 1) // 128 + 1) * 128
135
+ padding_amount = target_dim - self.rotary_dim
136
+ pad_width = ((0, 0), (0, padding_amount))
137
+ cache_padded = jnp.pad(cache, pad_width, mode='constant')
138
+ desired_layout = Layout(major_to_minor=(1, 0))
139
+ cache_padded = with_layout_constraint(cache_padded, desired_layout)
140
+ return cache_padded
141
+
142
+ def apply_rope(self, positions: jax.Array, x_TNH: jax.Array):
143
+ assert x_TNH.ndim == 3
144
+ assert self.sin_cos_cache is not None, "RoPE cache not initialized."
145
+ cos_sin_padded = self.sin_cos_cache[positions]
146
+ cos_sin_TH = cos_sin_padded[:, :self.rotary_dim]
147
+ # cos, sin: (T, H/2)
148
+ cos_TH, sin_TH = jnp.split(cos_sin_TH, 2, axis=-1)
149
+ assert sin_TH.ndim == 2 and cos_TH.ndim == 2
150
+ # cos, sin: (T, 1, H/2)
151
+ cos_T1H, sin_T1H = cos_TH[:, None, :], sin_TH[:, None, :]
152
+ # even, odd: (T, N, H/2)
153
+ even_TNH, odd_TNH = x_TNH[..., ::2], x_TNH[..., 1::2]
154
+ combined_TNH = jnp.stack([
155
+ even_TNH * cos_T1H - odd_TNH * sin_T1H,
156
+ odd_TNH * cos_T1H + even_TNH * sin_T1H
157
+ ],
158
+ axis=-1).reshape(x_TNH.shape)
159
+ return combined_TNH.astype(self.dtype)
160
+
161
+
162
+ # Calculates the temperature scaling factor for YaRN to adjust
163
+ # RoPE embedding magnitudes.
164
+ def _yarn_get_mscale(scale, mscale):
165
+ return jnp.where(scale <= 1, 1.0, 0.1 * mscale * jnp.log(scale) + 1.0)
166
+
167
+
168
+ # Inverses dim formula to find dim based on number of rotations.
169
+ def _yarn_find_correction_dim(num_rotations,
170
+ dim,
171
+ base=10000,
172
+ max_position_embeddings=2048):
173
+ return (dim * math.log(max_position_embeddings /
174
+ (num_rotations * 2 * math.pi))) / (2 *
175
+ math.log(base))
176
+
177
+
178
+ # Finds dim range bounds based on rotations.
179
+ def _yarn_find_correction_range(low_rot,
180
+ high_rot,
181
+ dim,
182
+ base=10000,
183
+ max_position_embeddings=2048):
184
+ low = math.floor(
185
+ _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
186
+ high = math.ceil(
187
+ _yarn_find_correction_dim(high_rot, dim, base,
188
+ max_position_embeddings))
189
+ return max(low, 0), min(high, dim - 1) # Clamp values just in case
190
+
191
+
192
+ # Creates a 1D mask that ramps linearly from 0 to 1 between min and max indices.
193
+ def _yarn_linear_ramp_mask(min, max, dim):
194
+ if min == max:
195
+ max += 0.001 # Prevent singularity
196
+
197
+ linear_func = (jnp.arange(dim, dtype=jnp.float32) - min) / (max - min)
198
+ ramp_func = jnp.clip(linear_func, 0, 1)
199
+ return ramp_func
200
+
201
+
202
+ @dataclass(kw_only=True)
203
+ class GptOssRotaryEmbedding(nnx.Module):
204
+ """
205
+ JAX implementation of the Rotary Positional Embedding with YaRN scaling.
206
+ """
207
+ head_dim: int
208
+ rope_theta: float
209
+ dtype: jnp.dtype
210
+ initial_context_length: int = 4096
211
+ rope_scaling_factor: float = 1.0
212
+ rope_ntk_alpha: float = 1.0
213
+ rope_ntk_beta: float = 32.0
214
+
215
+ def _compute_concentration_and_inv_freq(self) -> Tuple[float, jax.Array]:
216
+ """
217
+ Computes the inverse frequencies and concentration factor for YaRN.
218
+ See YaRN paper: https://arxiv.org/abs/2309.00071
219
+ """
220
+ freq = self.rope_theta**(
221
+ jnp.arange(0, self.head_dim, 2, dtype=jnp.float32) / self.head_dim)
222
+
223
+ if self.rope_scaling_factor > 1.0:
224
+ concentration = 0.1 * jnp.log(self.rope_scaling_factor) + 1.0
225
+
226
+ d_half = self.head_dim / 2
227
+ # NTK by parts
228
+ low = (d_half * jnp.log(self.initial_context_length /
229
+ (self.rope_ntk_beta * 2 * jnp.pi)) /
230
+ jnp.log(self.rope_theta))
231
+ high = (d_half * jnp.log(self.initial_context_length /
232
+ (self.rope_ntk_alpha * 2 * jnp.pi)) /
233
+ jnp.log(self.rope_theta))
234
+
235
+ interpolation = 1.0 / (self.rope_scaling_factor * freq)
236
+ extrapolation = 1.0 / freq
237
+
238
+ ramp = (jnp.arange(d_half, dtype=jnp.float32) - low) / (high - low)
239
+ mask = 1 - jnp.clip(ramp, 0, 1)
240
+
241
+ inv_freq = interpolation * (1 - mask) + extrapolation * mask
242
+ else:
243
+ concentration = 1.0
244
+ inv_freq = 1.0 / freq
245
+
246
+ return concentration, inv_freq
247
+
248
+ def _compute_cos_sin(self,
249
+ positions: jax.Array) -> Tuple[jax.Array, jax.Array]:
250
+ """Computes cosine and sine embeddings for given positions."""
251
+ concentration, inv_freq_H = self._compute_concentration_and_inv_freq()
252
+
253
+ # freqs: (T, H/2)
254
+ freqs = jnp.einsum("T,H->TH",
255
+ positions.astype(jnp.float32),
256
+ inv_freq_H,
257
+ precision=jax.lax.Precision.HIGHEST)
258
+
259
+ cos = jnp.cos(freqs) * concentration
260
+ sin = jnp.sin(freqs) * concentration
261
+ return cos, sin
262
+
263
+ def __call__(self, query_TNH: jax.Array, key_TNH: jax.Array,
264
+ positions: jax.Array) -> Tuple[jax.Array, jax.Array]:
265
+ """
266
+ Applies rotary embeddings to query and key tensors.
267
+ Args:
268
+ query_TNH: Query tensor with shape (num_tokens, num_heads, head_dim)
269
+ key_TNH: Key tensor with shape (num_tokens, num_kv_heads, head_dim)
270
+ positions: A 1D array of token positions.
271
+ """
272
+ # cos, sin: (T, H/2)
273
+ cos_TH, sin_TH = self._compute_cos_sin(positions)
274
+
275
+ # Reshape for broadcasting: (T, 1, H/2)
276
+ cos_T1H = cos_TH[:, None, :]
277
+ sin_T1H = sin_TH[:, None, :]
278
+
279
+ def _apply_rotation(x_TNH: jax.Array) -> jax.Array:
280
+ # Split the last dimension
281
+ first_half, second_half = jnp.split(x_TNH, 2, axis=-1)
282
+
283
+ # Apply rotation
284
+ rotated_x = jnp.concatenate([
285
+ first_half * cos_T1H - second_half * sin_T1H,
286
+ second_half * cos_T1H + first_half * sin_T1H
287
+ ],
288
+ axis=-1)
289
+ return rotated_x.astype(self.dtype)
290
+
291
+ rotated_query = _apply_rotation(query_TNH)
292
+ rotated_key = _apply_rotation(key_TNH)
293
+
294
+ return rotated_query, rotated_key
@@ -0,0 +1,228 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Any, Dict
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+
21
+
22
+ def apply_rope(
23
+ # (seq_len, num_heads, head_dim)
24
+ inputs: jax.Array,
25
+ # (3, seq_len) for M-RoPE, otherwise (seq_len,)
26
+ positions: jax.Array,
27
+ head_dim: int,
28
+ rope_theta: float = 10000,
29
+ rope_scaling: Dict[str, Any] = None,
30
+ rope_input_ordering: str = "split",
31
+ ) -> jax.Array:
32
+ """
33
+ Applies Rotary Positional Embedding using the sine and cosine strategy.
34
+
35
+ This implementation assumes the input tensor has a shape that might include
36
+ padding on the last dimension (head_dim).
37
+ RoPE is applied only to the first `head_dim` features, and the result is
38
+ padded back to the original dimension if necessary.
39
+ If rope_input_ordering is "split", then the input pairs for rotation are taken one from the
40
+ first and one from the second half of the head_dim. If it is "interleaved" then
41
+ adjacent values are used as inputs for rotation.
42
+ """
43
+
44
+ # M-RoPE support for Qwen2.5-VL
45
+ if positions.ndim == 2 and positions.shape[0] == 3:
46
+ mrope_section = rope_scaling.get("mrope_section",
47
+ None) if rope_scaling else None
48
+ # NOTE: We assume mrope_section is always available
49
+ # as Qwen2.5-VL is the only model using mrope
50
+ assert mrope_section is not None
51
+
52
+ split_indices = [mrope_section[0], mrope_section[0] + mrope_section[1]]
53
+
54
+ # Indices for the features to be rotated (first half of head_dim)
55
+ all_freq_indices = jnp.arange(head_dim // 2)
56
+
57
+ # Split the indices according to mrope_section. This is valid because split_indices are static.
58
+ freq_indices_split = jnp.split(all_freq_indices, split_indices)
59
+ # freq_indices_split is a list of 3 JAX arrays.
60
+
61
+ cos_list = []
62
+ sin_list = []
63
+
64
+ for i in range(3): # For each of the 3 position dimensions
65
+ current_indices = freq_indices_split[i]
66
+
67
+ if current_indices.size == 0:
68
+ # This section is empty, skip.
69
+ continue
70
+
71
+ # inv_freq shape: (mrope_section[i],)
72
+ inv_freq = 1.0 / (rope_theta**(current_indices * 2.0 / head_dim))
73
+
74
+ # positions[i]: (seq_len,)
75
+ # freqs shape: (seq_len, mrope_section[i])
76
+ freqs = jnp.outer(positions[i], inv_freq)
77
+
78
+ cos_list.append(jnp.cos(freqs))
79
+ sin_list.append(jnp.sin(freqs))
80
+
81
+ # Concatenate along the feature dimension
82
+ # cos, sin shape: (seq_len, head_dim//2)
83
+ cos = jnp.concatenate(cos_list, axis=1)
84
+ sin = jnp.concatenate(sin_list, axis=1)
85
+
86
+ # Add num_heads dimension for broadcasting
87
+ cos = cos[:, jnp.newaxis, :] # Shape: (seq_len, 1, head_dim//2)
88
+ sin = sin[:, jnp.newaxis, :] # Shape: (seq_len, 1, head_dim//2)
89
+
90
+ # Apply rotation
91
+ inputs_real = inputs[..., :head_dim // 2]
92
+ inputs_imag = inputs[..., head_dim // 2:head_dim]
93
+
94
+ outputs_real = inputs_real * cos - inputs_imag * sin
95
+ outputs_imag = inputs_real * sin + inputs_imag * cos
96
+
97
+ out = jnp.concatenate([outputs_real, outputs_imag], axis=-1)
98
+
99
+ # Standard RoPE
100
+ else:
101
+ # Calculate inverse frequencies (timescale)
102
+ fraction = 2 * jnp.arange(0, head_dim // 2) / head_dim
103
+ timescale = 1.0 / (rope_theta**fraction)
104
+
105
+ # Apply scaling if provided
106
+ if rope_scaling:
107
+ timescale = apply_rope_scaling(timescale, rope_scaling)
108
+
109
+ # Prepare for rotation by calculating sin and cos values
110
+ # `sinusoid_inp` gets shape (batch * seq_len, head_dim/2)
111
+ sinusoid_inp = positions[..., jnp.newaxis] * timescale[jnp.newaxis, :]
112
+
113
+ # Broadcast over the 'heads' dimension, assuming shape (batch*seq, heads, head_dim)
114
+ sinusoid_inp = sinusoid_inp[:, jnp.newaxis, ...]
115
+ sin = jnp.sin(sinusoid_inp)
116
+ cos = jnp.cos(sinusoid_inp)
117
+
118
+ if rope_input_ordering == "interleaved":
119
+ # Reshape to group adjacent features for rotation, matching new_apply_rope
120
+ rotary_inputs = inputs[
121
+ ..., :head_dim] # Take just the non-padded amount.
122
+ reshaped_inputs = rotary_inputs.reshape(*rotary_inputs.shape[:-1],
123
+ -1, 2)
124
+
125
+ # Apply the rotation
126
+ first_half = reshaped_inputs[..., 0]
127
+ second_half = reshaped_inputs[..., 1]
128
+ else:
129
+ first_half = inputs[..., :head_dim // 2]
130
+ second_half = inputs[..., head_dim // 2:head_dim]
131
+
132
+ first_part = first_half * cos - second_half * sin
133
+ second_part = second_half * cos + first_half * sin
134
+
135
+ # Combine the rotated parts and reshape back
136
+ if rope_input_ordering == "interleaved":
137
+ out_stacked = jnp.stack([first_part, second_part], axis=-1)
138
+ out = out_stacked.reshape(rotary_inputs.shape)
139
+ else:
140
+ out = jnp.concatenate([first_part, second_part], axis=-1)
141
+
142
+ # If the original input was padded, pad the output with zeros to match.
143
+ padded_head_dim = inputs.shape[-1]
144
+ if padded_head_dim > head_dim:
145
+ pad_width = padded_head_dim - head_dim
146
+ pad_config = [(0, 0)] * (out.ndim - 1) + [(0, pad_width)]
147
+ out = jnp.pad(out, pad_config)
148
+
149
+ return out.astype(inputs.dtype)
150
+
151
+
152
+ def apply_longrope(
153
+ inputs: jax.Array,
154
+ positions: jax.Array,
155
+ head_dim: int,
156
+ rope_scaling: Dict[str, Any],
157
+ original_max_position_embeddings: int,
158
+ max_position_embeddings: int,
159
+ rope_theta: float = 10000,
160
+ ) -> jax.Array:
161
+ # LongRoPE implementation specific to Phi-3
162
+ # Implementation based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi3/modeling_phi3.py#L197-L235
163
+
164
+ scale = max_position_embeddings / original_max_position_embeddings
165
+ if scale <= 1.0:
166
+ mscale = 1.0
167
+ else:
168
+ mscale = jnp.sqrt(1 + (jnp.log(scale) /
169
+ jnp.log(original_max_position_embeddings)))
170
+
171
+ seq_len = inputs.shape[0]
172
+ if seq_len > original_max_position_embeddings:
173
+ long_factor = jnp.array(rope_scaling.get("long_factor"))
174
+ timescale = 1.0 / (long_factor * (rope_theta**(
175
+ (2 * jnp.arange(0, head_dim // 2)) / head_dim)))
176
+ else:
177
+ short_factor = jnp.array(rope_scaling.get("short_factor"))
178
+ timescale = 1.0 / (short_factor * (rope_theta**(
179
+ (2 * jnp.arange(0, head_dim // 2)) / head_dim)))
180
+
181
+ # Calculate RoPE positions
182
+ sinusoid_inp = positions[..., jnp.newaxis] * timescale[jnp.newaxis, :]
183
+ sinusoid_inp = sinusoid_inp[:, jnp.newaxis, ...]
184
+ sin = jnp.sin(sinusoid_inp) * mscale
185
+ cos = jnp.cos(sinusoid_inp) * mscale
186
+
187
+ # Padding logic
188
+ padded_head_dim = inputs.shape[-1]
189
+
190
+ # Apply RoPE mechanism
191
+ first_half = inputs[..., :head_dim // 2]
192
+ second_half = inputs[..., head_dim // 2:head_dim]
193
+
194
+ first_part = first_half * cos - second_half * sin
195
+ second_part = second_half * cos + first_half * sin
196
+ out = jnp.concatenate([first_part, second_part], axis=-1)
197
+
198
+ if padded_head_dim > head_dim:
199
+ out = jnp.pad(out, ((0, 0), (0, 0), (0, padded_head_dim - head_dim)))
200
+
201
+ return out.astype(inputs.dtype)
202
+
203
+
204
+ def apply_rope_scaling(freqs: jax.Array, rope_scaling: Dict[str,
205
+ Any]) -> jax.Array:
206
+ # Values obtained from grid search
207
+ scale_factor = rope_scaling.get("scale_factor", 8.0)
208
+ low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
209
+ high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
210
+ old_context_len = rope_scaling.get("original_max_position_embeddings",
211
+ 8192)
212
+
213
+ low_freq_wavelen = old_context_len / low_freq_factor
214
+ high_freq_wavelen = old_context_len / high_freq_factor
215
+
216
+ wavelen = 2 * math.pi / freqs
217
+ smooth = (old_context_len / wavelen -
218
+ low_freq_factor) / (high_freq_factor - low_freq_factor)
219
+
220
+ high_freqs = jnp.where(wavelen < high_freq_wavelen, freqs, 0)
221
+ low_freqs = jnp.where(wavelen > low_freq_wavelen, freqs / scale_factor, 0)
222
+ mid_freqs = jnp.where(
223
+ (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
224
+ (1 - smooth) * freqs / scale_factor + smooth * freqs,
225
+ 0,
226
+ )
227
+ new_freqs = high_freqs + low_freqs + mid_freqs
228
+ return new_freqs
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.