tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,101 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Current Used Abbreviation for Tensor Dimensions:
16
+ B: Batch size
17
+ T: Sequence Length (for Query tensors)
18
+ S: Sequence Length (for Key/Value tensors)
19
+ D: d_model, the embedding dimension of the model
20
+ F: d_ff, the hidden dimension of the feed-forward MLP layers
21
+ V: Vocab Size
22
+ H: Dimension of each attention head
23
+ N: Number of query heads in Attention
24
+ Q: Number of query heads (synonymous with N)
25
+ K: Number of Key/Value heads in Attention
26
+ C: Expert capacity in Mixture-of-Experts models
27
+ X: Number of activated experts per token in MoE
28
+ G: Number of groups in Grouped-Query Attention
29
+ E: Total number of experts in MoE
30
+ """
31
+
32
+ import enum
33
+ from typing import Tuple, TypeAlias
34
+
35
+ import jax
36
+
37
+ KVCacheType: TypeAlias = Tuple[jax.Array, jax.Array]
38
+
39
+
40
+ class RouterType(enum.Enum):
41
+ """Enum for router types."""
42
+ TOP_K = 'top_k'
43
+
44
+
45
+ class OPERATION_MODE(enum.Enum):
46
+ PREFILL = 1
47
+ DECODE = 2
48
+
49
+
50
+ class HuggingFaceArgNames(enum.Enum):
51
+ ## Modeling params
52
+ HIDDEN_ACT: str = "hidden_act"
53
+ HIDDEN_SIZE: str = "hidden_size"
54
+ NUM_HIDDEN_LAYERS: str = "num_hidden_layers"
55
+ RMS_NORM_EPS: str = "rms_norm_eps"
56
+ ROPE_SCALING: str = "rope_scaling"
57
+ ROPE_THETA: str = "rope_theta"
58
+ VOCAB_SIZE: str = "vocab_size"
59
+
60
+ # Block parameters
61
+ SHARED_EXPERTS: str = "shared_experts"
62
+
63
+ # FFW params
64
+ INTERMEDIATE_SIZE: str = "intermediate_size"
65
+
66
+ # Attention params
67
+ HEAD_DIM: str = "head_dim"
68
+ NUM_ATTENTION_HEADS: str = "num_attention_heads"
69
+ NUM_KEY_VALUE_HEADS: str = "num_key_value_heads"
70
+ ATTENTION_DROPOUT: str = "attention_dropout"
71
+ ATTENTION_BIAS: str = "attention_bias"
72
+ ATTENTION_CHUNK_SIZE: str = "attention_chunk_size"
73
+
74
+ ## Llama4 Attention Params
75
+ USE_QK_NORM: str = "use_qk_norm"
76
+ TEMPERATURE_TUNING: str = "temperature_tuning"
77
+ TEMPERATURE_TUNING_SCALE: str = "temperature_tuning_scale"
78
+ TEMPERATURE_TUNING_FLOOR_SCALE: str = "temperature_tuning_floor_scale"
79
+
80
+ # MLA params
81
+ KV_LORA_RANK: str = "kv_lora_rank"
82
+ Q_LORA_RANK: str = "q_lora_rank"
83
+ QK_NOPE_HEAD_DIM: str = "qk_nope_head_dim"
84
+ QK_ROPE_HEAD_DIM: str = "qk_rope_head_dim"
85
+ V_HEAD_DIM: str = "v_head_dim"
86
+
87
+ # MoE
88
+ INTERMEDIATE_SIZE_MOE: str = "intermediate_size_moe"
89
+ NUM_LOCAL_EXPERTS: str = "num_local_experts" # Llama moe
90
+ NUM_EXPERTS_PER_TOKEN: str = "num_experts_per_token"
91
+ NUM_ROUTED_EXPERTS: str = "n_routed_experts" # Deepseek moe
92
+ NUM_SHARED_ROUTED_EXPERTS: str = "n_shared_experts"
93
+ NUM_GROUPS: str = "n_group"
94
+ ROUTED_SCALING_FACTOR: str = "routed_scaling_factor"
95
+ TOPK_GROUP: str = "topk_group"
96
+ NORM_TOPK_PROB: str = "norm_topk_prob"
97
+ SCORING_FUNCTION: str = "scoring_func"
98
+
99
+ ## Sampling params
100
+ BOS_TOKEN_ID: str = "bos_token_id"
101
+ EOS_TOKEN_ID: str = "eos_token_id"
@@ -0,0 +1,315 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import InitVar, dataclass
16
+ from typing import Any
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ from flax import nnx
21
+ from flax.typing import Sharding
22
+ from jaxtyping import Float, Int
23
+
24
+ from tpu_inference.layers.jax.base import create_param
25
+
26
+
27
+ # A dummy for modeling_flax_utils which might contain activation functions
28
+ class FlaxUtils:
29
+ """A dummy class to namespace activation functions, mimicking external utilities."""
30
+ ACT2FN = {
31
+ 'silu': nnx.silu,
32
+ 'gelu': nnx.gelu,
33
+ 'relu': nnx.relu,
34
+ 'sigmoid': nnx.sigmoid,
35
+ 'softmax': nnx.softmax
36
+ }
37
+
38
+
39
+ modeling_flax_utils = FlaxUtils()
40
+
41
+
42
+ @dataclass
43
+ class RuntimeParams:
44
+ """A container for runtime parameters needed by neural network blocks.
45
+
46
+ This dataclass acts as a flexible container to pass objects that are only
47
+ available at runtime (like a pre-allocated KV cache or dynamic sharding
48
+ configurations) into the initialization of stateful modules. This avoids
49
+ having to update the constructor signature of every module when a new
50
+ runtime dependency is introduced.
51
+
52
+ Attributes:
53
+ kv_cache: The key-value cache object for attention layers.
54
+ sharding_cfg: The configuration for tensor sharding.
55
+ quantization: Configuration for quantization schemes.
56
+ """
57
+ kv_cache: Any = None
58
+ sharding_cfg: Any = None
59
+ quantization: Any = None
60
+
61
+
62
+ @dataclass(kw_only=True)
63
+ class RMSNorm(nnx.Module):
64
+ """An implementation of Root Mean Square Layer Normalization.
65
+
66
+ Attributes:
67
+ dims: The feature dimension to normalize over.
68
+ epsilon: A small float added to the variance to avoid division by zero.
69
+ with_scale: If True, learns a multiplicative scale parameter.
70
+ dtype: The data type for computations.
71
+ """
72
+ dims: int
73
+ activation_ffw_td: Sharding = ()
74
+ random_init: bool = False
75
+ epsilon: float = 1e-6
76
+ with_scale: bool = True
77
+ dtype: Any = jnp.float32
78
+
79
+ rngs: InitVar[nnx.Rngs]
80
+
81
+ def __call__(self, x_TD: Float, op_mode='generate') -> Float:
82
+ """Applies RMS Normalization to the input tensor.
83
+
84
+ Args:
85
+ x_TD: The input tensor. The normalization is applied over the last dimension.
86
+
87
+ Returns:
88
+ The normalized tensor with the same shape as the input.
89
+ """
90
+ x_TD = jnp.asarray(x_TD, self.dtype)
91
+ x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
92
+
93
+ with jax.named_scope("rms_norm_variance"):
94
+ var_T1 = jnp.mean(jnp.square(x_TD), axis=-1, keepdims=True)
95
+ with jax.named_scope("rms_norm_rsqrt"):
96
+ normed_x_TD = x_TD * jax.lax.rsqrt(var_T1 + self.epsilon)
97
+
98
+ with jax.named_scope("rms_norm_scale_apply"):
99
+ normed_x_TD *= self.scale.value
100
+ normed_x_TD = nnx.with_sharding_constraint(normed_x_TD,
101
+ self.activation_ffw_td)
102
+ return normed_x_TD.astype(self.dtype)
103
+
104
+ def __post_init__(self, rngs: nnx.Rngs):
105
+ self.scale = create_param(rngs,
106
+ shape=(self.dims, ),
107
+ dtype=self.dtype,
108
+ random_init=self.random_init)
109
+
110
+
111
+ @dataclass(kw_only=True)
112
+ class DenseFFW(nnx.Module):
113
+ """A Gated Feed-Forward Network (FFN) layer.
114
+
115
+ This module consists of two linear projections (gating and up-projection),
116
+ an element-wise multiplication of the activated gating projection and the
117
+ up-projection, followed by a final downward projection.
118
+
119
+ Attributes:
120
+ sharding_cfg: The configuration for tensor sharding.
121
+ """
122
+ dtype: jnp.dtype
123
+ hidden_act: str
124
+ hidden_size: int
125
+ intermediate_size: int
126
+ df_sharding: Sharding = ()
127
+ fd_sharding: Sharding = ()
128
+ activation_ffw_td: Sharding = ()
129
+ random_init: bool = False
130
+
131
+ rngs: InitVar[nnx.Rngs]
132
+
133
+ def __call__(self, x_TD):
134
+ """Performs the forward pass of the FFW layer.
135
+
136
+ Args:
137
+ x_TD: The input tensor of shape either `(sequence, d_model)`
138
+
139
+ Returns:
140
+ The output tensor of shape `(batch, sequence, d_model)`.
141
+ """
142
+ # TODO consider to create factories for einsum(?)
143
+ x_TD = jnp.asarray(x_TD, self.dtype)
144
+ x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
145
+ with jax.named_scope("wi_0"):
146
+ gating_TF = jnp.einsum('TD,DF -> TF', x_TD,
147
+ self.kernel_gating_DF.value)
148
+ activated_gating_TF = modeling_flax_utils.ACT2FN[self.hidden_act](
149
+ gating_TF)
150
+ with jax.named_scope("wi_1"):
151
+ up_proj_TF = jnp.einsum('TD,DF -> TF', x_TD,
152
+ self.kernel_up_proj_DF.value)
153
+ fuse_TF = activated_gating_TF * up_proj_TF
154
+ with jax.named_scope("wo"):
155
+ output_TD = jnp.einsum('TF,FD -> TD', fuse_TF,
156
+ self.kernel_down_proj_FD.value)
157
+
158
+ return output_TD
159
+
160
+ def __post_init__(self, rngs: nnx.Rngs):
161
+ D = self.hidden_size
162
+ F = self.intermediate_size
163
+
164
+ self.kernel_gating_DF = create_param(rngs,
165
+ shape=(D, F),
166
+ dtype=self.dtype,
167
+ sharding=self.df_sharding,
168
+ random_init=self.random_init)
169
+ self.kernel_up_proj_DF = create_param(rngs,
170
+ shape=(D, F),
171
+ dtype=self.dtype,
172
+ sharding=self.df_sharding,
173
+ random_init=self.random_init)
174
+ self.kernel_down_proj_FD = create_param(rngs,
175
+ shape=(F, D),
176
+ dtype=self.dtype,
177
+ sharding=self.fd_sharding,
178
+ random_init=self.random_init)
179
+
180
+
181
+ @dataclass(kw_only=True)
182
+ class Embedder(nnx.Module):
183
+ """A module for token embedding and, optionally, decoding (tied embeddings).
184
+
185
+ This class handles both the "encoding" step of converting token IDs to dense
186
+ vectors and the "decoding" step of projecting model outputs back to logits
187
+ over the vocabulary.
188
+
189
+ """
190
+ vocab_size: int
191
+ hidden_size: int
192
+ dtype: jnp.dtype
193
+ prelogit_td: Sharding = ()
194
+ vd_sharding: Sharding = ()
195
+ random_init: bool = False
196
+ normalize_embeddings: bool = False
197
+
198
+ rngs: InitVar[nnx.Rngs]
199
+
200
+ def __post_init__(self, rngs: nnx.Rngs):
201
+ self.input_embedding_table_VD = create_param(
202
+ rngs,
203
+ shape=(self.vocab_size, self.hidden_size),
204
+ sharding=self.vd_sharding,
205
+ dtype=self.dtype,
206
+ random_init=self.random_init)
207
+
208
+ def __call__(self, x, decode=False):
209
+ """Dispatches to either the encode or decode method.
210
+
211
+ Args:
212
+ x: The input tensor. Either token IDs for encoding or hidden states
213
+ for decoding.
214
+ decode: A boolean flag. If False (default), performs encoding. If
215
+ True, performs decoding.
216
+
217
+ Returns:
218
+ Either embedding vectors or logit scores.
219
+ """
220
+ if decode:
221
+ return self.decode(x)
222
+ else:
223
+ return self.encode(x)
224
+
225
+ def decode(self, x_TD: Float) -> Float:
226
+ """Projects hidden states to vocabulary logits.
227
+
228
+ Args:
229
+ x_TD: The input tensor of hidden states from the model backbone, with
230
+ shape `(sequence, d_model)`.
231
+
232
+ Returns:
233
+ The output logits over the vocabulary, with shape
234
+ `(sequence, vocab_size)`.
235
+ """
236
+ x_TD = jnp.asarray(x_TD, self.dtype)
237
+ x_TD = nnx.with_sharding_constraint(x_TD, self.prelogit_td)
238
+
239
+ with jax.named_scope("embedder_decode_projection"):
240
+ logits_TV = jnp.einsum('VD,TD -> TV',
241
+ self.input_embedding_table_VD.value, x_TD)
242
+ return logits_TV
243
+
244
+ def encode(self, x_T: Int) -> Float:
245
+ """Converts integer token IDs to dense embedding vectors.
246
+
247
+ Args:
248
+ x_T: The input tensor of token IDs, with shape `(sequence, )`.
249
+
250
+ Returns:
251
+ The corresponding embedding vectors, with shape
252
+ `(batch, sequence, d_model)`.
253
+ """
254
+ with jax.named_scope("embedder_encode_lookup"):
255
+ embedding_TD = jnp.take(self.input_embedding_table_VD.value,
256
+ x_T,
257
+ axis=0)
258
+
259
+ if self.normalize_embeddings:
260
+ with jax.named_scope("embedder_normalize_embeddings"):
261
+ embedding_TD *= jnp.sqrt(self.hidden_size).astype(self.dtype)
262
+ return embedding_TD
263
+
264
+
265
+ @dataclass(kw_only=True)
266
+ class LMhead(Embedder):
267
+ """
268
+ An Embedder that uses a (D, V) shaped embedding table, inheriting from
269
+ the base Embedder class.
270
+
271
+ This implementation overrides the kernel generation, encoding, and decoding
272
+ methods to work with the transposed embedding matrix layout.
273
+ """
274
+ dv_sharding: Sharding
275
+
276
+ def __post_init__(self, rngs: nnx.Rngs):
277
+ self.input_embedding_table_DV = create_param(
278
+ rngs,
279
+ shape=(self.hidden_size, self.vocab_size),
280
+ sharding=self.dv_sharding,
281
+ dtype=self.dtype,
282
+ random_init=self.random_init)
283
+
284
+ def __call__(self, x):
285
+ """Dispatches to decode method.
286
+
287
+ Args:
288
+ x: The input tensor. Either token IDs for encoding or hidden states
289
+ for decoding.
290
+ decode: A boolean flag. If False (default), performs encoding. If
291
+ True, performs decoding.
292
+
293
+ Returns:
294
+ Either embedding vectors or logit scores.
295
+ """
296
+ return self.decode(x)
297
+
298
+ def decode(self, x_TD: Float) -> Float:
299
+ """Projects hidden states to vocabulary logits.
300
+
301
+ Args:
302
+ x_TD: The input tensor of hidden states from the model backbone, with
303
+ shape `(sequence, d_model)`.
304
+
305
+ Returns:
306
+ The output logits over the vocabulary, with shape
307
+ `(sequence, vocab_size)`.
308
+ """
309
+ x_TD = jnp.asarray(x_TD, self.dtype)
310
+ x_TD = nnx.with_sharding_constraint(x_TD, self.prelogit_td)
311
+
312
+ with jax.named_scope("lmhead_decode_projection"):
313
+ logits_TV = jnp.einsum('DV,TD -> TV',
314
+ self.input_embedding_table_DV.value, x_TD)
315
+ return logits_TV
@@ -0,0 +1,30 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Tuple
17
+
18
+ import jax
19
+ from jax.sharding import NamedSharding
20
+ from jax.sharding import PartitionSpec as P
21
+
22
+
23
+ # TODO(xiang): move this to weight_utils.py
24
+ def shard_put(x: jax.Array, sharding_names: Tuple[str, ...] | P,
25
+ mesh: jax.sharding.Mesh) -> jax.Array:
26
+ # Single device sharding requires this special handling
27
+ # to avoid the recursive jit error.
28
+ if math.prod(mesh.axis_sizes) == 1:
29
+ return jax.device_put(x, mesh.devices.flatten()[0])
30
+ return jax.device_put(x, NamedSharding(mesh, P(*sharding_names)))
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.