tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from dataclasses import InitVar, dataclass
|
|
16
|
+
from typing import Tuple
|
|
17
|
+
|
|
18
|
+
import jax
|
|
19
|
+
import jax.numpy as jnp
|
|
20
|
+
from flax import nnx
|
|
21
|
+
from flax.typing import Sharding
|
|
22
|
+
from jax.sharding import Mesh
|
|
23
|
+
from jax.sharding import PartitionSpec as P
|
|
24
|
+
|
|
25
|
+
from tpu_inference import utils
|
|
26
|
+
from tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 import \
|
|
27
|
+
ragged_paged_attention_hd64
|
|
28
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
29
|
+
from tpu_inference.layers.jax.base import create_param
|
|
30
|
+
from tpu_inference.layers.jax.rope import GptOssRotaryEmbedding
|
|
31
|
+
|
|
32
|
+
KVCache = Tuple[jax.Array, jax.Array]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass(kw_only=True)
|
|
36
|
+
class GptOssAttention(nnx.Module):
|
|
37
|
+
"""
|
|
38
|
+
JAX implementation of the GPT-OSS Attention block
|
|
39
|
+
"""
|
|
40
|
+
hidden_size: int
|
|
41
|
+
num_attention_heads: int
|
|
42
|
+
num_key_value_heads: int
|
|
43
|
+
head_dim: int
|
|
44
|
+
dtype: jnp.dtype
|
|
45
|
+
rngs: InitVar[nnx.Rngs]
|
|
46
|
+
|
|
47
|
+
rope_theta: float
|
|
48
|
+
initial_context_length: int = 4096
|
|
49
|
+
rope_scaling_factor: float = 32.0
|
|
50
|
+
rope_ntk_alpha: float = 1.0
|
|
51
|
+
rope_ntk_beta: float = 32.0
|
|
52
|
+
kv_cache_dtype: str
|
|
53
|
+
|
|
54
|
+
query_tnh: P = P()
|
|
55
|
+
keyvalue_skh: P = P()
|
|
56
|
+
attn_o_tnh: P = P()
|
|
57
|
+
dnh_sharding: Sharding = ()
|
|
58
|
+
dkh_sharding: Sharding = ()
|
|
59
|
+
nhd_sharding: Sharding = ()
|
|
60
|
+
n_sharding: Sharding = ()
|
|
61
|
+
nh_sharding: Sharding = ()
|
|
62
|
+
kh_sharding: Sharding = ()
|
|
63
|
+
d_sharding: Sharding = ()
|
|
64
|
+
|
|
65
|
+
random_init: bool = False
|
|
66
|
+
mesh: Mesh
|
|
67
|
+
|
|
68
|
+
_q_scale: float = 1.0
|
|
69
|
+
_k_scale: float = 1.0
|
|
70
|
+
_v_scale: float = 1.0
|
|
71
|
+
kv_cache_quantized_dtype = None
|
|
72
|
+
|
|
73
|
+
def __post_init__(self, rngs: nnx.Rngs):
|
|
74
|
+
"""Initializes weights, biases, and RoPE module."""
|
|
75
|
+
|
|
76
|
+
self.sm_scale = 1.0 / (self.head_dim**0.5)
|
|
77
|
+
|
|
78
|
+
self.sinks_N = create_param(
|
|
79
|
+
rngs,
|
|
80
|
+
shape=(self.num_attention_heads, ),
|
|
81
|
+
dtype=jnp.float32,
|
|
82
|
+
sharding=self.n_sharding,
|
|
83
|
+
random_init=self.random_init,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# Q, K, V projection kernels
|
|
87
|
+
self.kernel_q_DNH = create_param(
|
|
88
|
+
rngs,
|
|
89
|
+
shape=(self.hidden_size, self.num_attention_heads, self.head_dim),
|
|
90
|
+
dtype=self.dtype,
|
|
91
|
+
sharding=self.dnh_sharding,
|
|
92
|
+
random_init=self.random_init,
|
|
93
|
+
)
|
|
94
|
+
self.bias_q_NH = create_param(
|
|
95
|
+
rngs,
|
|
96
|
+
shape=(self.num_attention_heads, self.head_dim),
|
|
97
|
+
dtype=self.dtype,
|
|
98
|
+
sharding=self.nh_sharding,
|
|
99
|
+
random_init=self.random_init,
|
|
100
|
+
)
|
|
101
|
+
self.kernel_k_DKH = create_param(
|
|
102
|
+
rngs,
|
|
103
|
+
shape=(self.hidden_size, self.num_key_value_heads, self.head_dim),
|
|
104
|
+
dtype=self.dtype,
|
|
105
|
+
sharding=self.dkh_sharding,
|
|
106
|
+
random_init=self.random_init,
|
|
107
|
+
)
|
|
108
|
+
self.bias_k_KH = create_param(
|
|
109
|
+
rngs,
|
|
110
|
+
shape=(self.num_key_value_heads, self.head_dim),
|
|
111
|
+
dtype=self.dtype,
|
|
112
|
+
sharding=self.kh_sharding,
|
|
113
|
+
random_init=self.random_init,
|
|
114
|
+
)
|
|
115
|
+
self.kernel_v_DKH = create_param(
|
|
116
|
+
rngs,
|
|
117
|
+
shape=(self.hidden_size, self.num_key_value_heads, self.head_dim),
|
|
118
|
+
dtype=self.dtype,
|
|
119
|
+
sharding=self.dkh_sharding,
|
|
120
|
+
random_init=self.random_init,
|
|
121
|
+
)
|
|
122
|
+
self.bias_v_KH = create_param(
|
|
123
|
+
rngs,
|
|
124
|
+
shape=(self.num_key_value_heads, self.head_dim),
|
|
125
|
+
dtype=self.dtype,
|
|
126
|
+
sharding=self.kh_sharding,
|
|
127
|
+
random_init=self.random_init,
|
|
128
|
+
)
|
|
129
|
+
# Output projection kernel
|
|
130
|
+
self.kernel_o_proj_NHD = create_param(
|
|
131
|
+
rngs,
|
|
132
|
+
shape=(self.num_attention_heads, self.head_dim, self.hidden_size),
|
|
133
|
+
dtype=self.dtype,
|
|
134
|
+
sharding=self.nhd_sharding,
|
|
135
|
+
random_init=self.random_init,
|
|
136
|
+
)
|
|
137
|
+
self.bias_o_D = create_param(
|
|
138
|
+
rngs,
|
|
139
|
+
shape=(self.hidden_size, ),
|
|
140
|
+
dtype=self.dtype,
|
|
141
|
+
sharding=self.d_sharding,
|
|
142
|
+
random_init=self.random_init,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# RoPE Module
|
|
146
|
+
self.rope = GptOssRotaryEmbedding(
|
|
147
|
+
head_dim=self.head_dim,
|
|
148
|
+
rope_theta=self.rope_theta,
|
|
149
|
+
dtype=self.dtype,
|
|
150
|
+
initial_context_length=self.initial_context_length,
|
|
151
|
+
rope_scaling_factor=self.rope_scaling_factor,
|
|
152
|
+
rope_ntk_alpha=self.rope_ntk_alpha,
|
|
153
|
+
rope_ntk_beta=self.rope_ntk_beta)
|
|
154
|
+
|
|
155
|
+
if self.kv_cache_dtype != "auto":
|
|
156
|
+
self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
|
|
157
|
+
self.kv_cache_dtype)
|
|
158
|
+
|
|
159
|
+
def attention(
|
|
160
|
+
self,
|
|
161
|
+
kv_cache: KVCache,
|
|
162
|
+
q_TNH: jax.Array,
|
|
163
|
+
k_SKH: jax.Array,
|
|
164
|
+
v_SKH: jax.Array,
|
|
165
|
+
sinks: jax.Array,
|
|
166
|
+
attention_metadata: AttentionMetadata,
|
|
167
|
+
mesh: Mesh,
|
|
168
|
+
q_scale: float | None = None,
|
|
169
|
+
k_scale: float | None = None,
|
|
170
|
+
v_scale: float | None = None,
|
|
171
|
+
) -> Tuple[KVCache, jax.Array]:
|
|
172
|
+
"""Performs scaled dot-product attention by calling the ragged_paged_attention kernel."""
|
|
173
|
+
md = attention_metadata
|
|
174
|
+
kv_cache_spec = P("data", None, "model")
|
|
175
|
+
|
|
176
|
+
in_specs = (
|
|
177
|
+
self.query_tnh, # q
|
|
178
|
+
self.keyvalue_skh, # k
|
|
179
|
+
self.keyvalue_skh, # v
|
|
180
|
+
kv_cache_spec, # kv_cache
|
|
181
|
+
P("data"), # md.seq_lens
|
|
182
|
+
P("data"), # page_indices_flat
|
|
183
|
+
P("data"), # query_start_loc
|
|
184
|
+
P("data"), # distribution
|
|
185
|
+
P(('model')), # sinks
|
|
186
|
+
)
|
|
187
|
+
out_specs = (self.attn_o_tnh, kv_cache_spec)
|
|
188
|
+
|
|
189
|
+
def _ragged_paged_attention_wrapper(*args):
|
|
190
|
+
# Pass the GPT-OSS specific parameters to the kernel
|
|
191
|
+
return ragged_paged_attention_hd64(
|
|
192
|
+
*args,
|
|
193
|
+
sm_scale=self.sm_scale,
|
|
194
|
+
sliding_window=md.sliding_window,
|
|
195
|
+
q_scale=q_scale,
|
|
196
|
+
k_scale=k_scale,
|
|
197
|
+
v_scale=v_scale,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
output_TNH, kv_cache = jax.jit(
|
|
201
|
+
jax.shard_map(
|
|
202
|
+
_ragged_paged_attention_wrapper,
|
|
203
|
+
mesh=mesh,
|
|
204
|
+
in_specs=in_specs,
|
|
205
|
+
out_specs=out_specs,
|
|
206
|
+
check_vma=False,
|
|
207
|
+
))(
|
|
208
|
+
q_TNH,
|
|
209
|
+
k_SKH,
|
|
210
|
+
v_SKH,
|
|
211
|
+
kv_cache,
|
|
212
|
+
md.seq_lens,
|
|
213
|
+
md.block_tables,
|
|
214
|
+
md.query_start_loc,
|
|
215
|
+
md.request_distribution,
|
|
216
|
+
sinks,
|
|
217
|
+
)
|
|
218
|
+
return kv_cache, output_TNH
|
|
219
|
+
|
|
220
|
+
def __call__(self,
|
|
221
|
+
x_TD,
|
|
222
|
+
is_prefill,
|
|
223
|
+
kv_cache: KVCache,
|
|
224
|
+
attention_metadata: AttentionMetadata,
|
|
225
|
+
use_attention_rope: bool = True):
|
|
226
|
+
"""Forward pass for the Attention module using 3D kernels."""
|
|
227
|
+
md = attention_metadata
|
|
228
|
+
x_TD = jnp.asarray(x_TD, self.dtype)
|
|
229
|
+
|
|
230
|
+
with jax.named_scope("q_proj"):
|
|
231
|
+
q_TNH = jnp.einsum("TD,DNH->TNH", x_TD, self.kernel_q_DNH.value)
|
|
232
|
+
q_TNH += self.bias_q_NH.value
|
|
233
|
+
|
|
234
|
+
with jax.named_scope("k_proj"):
|
|
235
|
+
k_TKH = jnp.einsum("TD,DKH->TKH", x_TD, self.kernel_k_DKH.value)
|
|
236
|
+
k_TKH += self.bias_k_KH.value
|
|
237
|
+
|
|
238
|
+
with jax.named_scope("v_proj"):
|
|
239
|
+
v_TKH = jnp.einsum("TD,DKH->TKH", x_TD, self.kernel_v_DKH.value)
|
|
240
|
+
v_TKH += self.bias_v_KH.value
|
|
241
|
+
|
|
242
|
+
if use_attention_rope:
|
|
243
|
+
q_TNH, k_TKH = self.rope(q_TNH, k_TKH, md.input_positions)
|
|
244
|
+
|
|
245
|
+
q_scale = k_scale = v_scale = None
|
|
246
|
+
if self.kv_cache_quantized_dtype:
|
|
247
|
+
# TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
|
|
248
|
+
# q_scale = self._q_scale
|
|
249
|
+
k_scale = self._k_scale
|
|
250
|
+
v_scale = self._v_scale
|
|
251
|
+
k_TKH, v_TKH = utils.quantize_kv(k_TKH, v_TKH,
|
|
252
|
+
self.kv_cache_quantized_dtype,
|
|
253
|
+
k_scale, v_scale)
|
|
254
|
+
|
|
255
|
+
with jax.named_scope("attn_op"):
|
|
256
|
+
new_kv_cache, attn_out_TNH = self.attention(
|
|
257
|
+
kv_cache=kv_cache,
|
|
258
|
+
q_TNH=q_TNH,
|
|
259
|
+
k_SKH=k_TKH,
|
|
260
|
+
v_SKH=v_TKH,
|
|
261
|
+
sinks=self.sinks_N.value,
|
|
262
|
+
attention_metadata=md,
|
|
263
|
+
mesh=self.mesh,
|
|
264
|
+
q_scale=q_scale,
|
|
265
|
+
k_scale=k_scale,
|
|
266
|
+
v_scale=v_scale,
|
|
267
|
+
)
|
|
268
|
+
attn_out_TNH = attn_out_TNH[..., :self.head_dim]
|
|
269
|
+
|
|
270
|
+
with jax.named_scope("o_proj"):
|
|
271
|
+
output_TD = jnp.einsum("TNH,NHD->TD", attn_out_TNH,
|
|
272
|
+
self.kernel_o_proj_NHD.value)
|
|
273
|
+
output_TD += self.bias_o_D.value
|
|
274
|
+
|
|
275
|
+
return new_kv_cache, output_TD
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
from flax import nnx
|
|
20
|
+
from jax.sharding import Sharding
|
|
21
|
+
|
|
22
|
+
from tpu_inference import utils
|
|
23
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
24
|
+
from tpu_inference.layers.jax.attention.attention import Attention, KVCache
|
|
25
|
+
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
26
|
+
from tpu_inference.logger import init_logger
|
|
27
|
+
|
|
28
|
+
logger = init_logger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class L2Norm(nnx.Module):
|
|
32
|
+
"""
|
|
33
|
+
Implementation of L2 Norm in JAX (taken from MaxText repo - maxtext/MaxText/layers/attentions.py).
|
|
34
|
+
|
|
35
|
+
Attributes:
|
|
36
|
+
eps: float, epsilon used for numerical stability (default value should be ok for most cases).
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, eps: float = 1e-6):
|
|
40
|
+
self.eps = eps
|
|
41
|
+
|
|
42
|
+
def __call__(self, x):
|
|
43
|
+
return x * jax.lax.rsqrt(
|
|
44
|
+
jnp.mean(x**2, axis=-1, keepdims=True) + self.eps)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass(kw_only=True)
|
|
48
|
+
class Llama4Attention(Attention):
|
|
49
|
+
use_qk_norm: bool
|
|
50
|
+
temperature_tuning: bool
|
|
51
|
+
temperature_tuning_floor_scale: float
|
|
52
|
+
temperature_tuning_scale: float
|
|
53
|
+
activation_attention_td: Sharding
|
|
54
|
+
activation_attention_out_td: Sharding
|
|
55
|
+
|
|
56
|
+
def __call__(self,
|
|
57
|
+
x,
|
|
58
|
+
is_prefill,
|
|
59
|
+
kv_cache: KVCache,
|
|
60
|
+
attention_metadata: AttentionMetadata,
|
|
61
|
+
use_attention_rope: bool = True):
|
|
62
|
+
"""Performs the forward pass of the attention module.
|
|
63
|
+
|
|
64
|
+
This method computes the attention output by projecting the input `x`
|
|
65
|
+
to queries, keys, and values, applying RoPE and L2Norm if specified,
|
|
66
|
+
performing scaled dot-product attention, and projecting the results
|
|
67
|
+
back to the model dimension.
|
|
68
|
+
If no RoPE (NoPE) is specified, one can also perform temperature tuning
|
|
69
|
+
which is useful to combat dilution of attention scores in long-context attention.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
x: The input tensor of shape `(seq_len, d_model)`.
|
|
73
|
+
is_prefill: Whether the operation mode is prefill (otherwise it is generate).
|
|
74
|
+
kv_cache: The key-value cache for storing past attention states.
|
|
75
|
+
attention_metadata: Metadata for attention, such as input positions.
|
|
76
|
+
use_attention_rope: Whether to use RoPE.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
A tuple containing:
|
|
80
|
+
- The updated KV cache.
|
|
81
|
+
- The attention output tensor of shape
|
|
82
|
+
`(batch_size, seq_len, d_model)`.
|
|
83
|
+
"""
|
|
84
|
+
md = attention_metadata
|
|
85
|
+
x = jnp.asarray(x, self.dtype)
|
|
86
|
+
x_SD = nnx.with_sharding_constraint(x, self.activation_attention_td)
|
|
87
|
+
x_q_TD = nnx.with_sharding_constraint(x, self.activation_q_td)
|
|
88
|
+
rope_scaling = self.rope_scaling
|
|
89
|
+
rope_theta = self.rope_theta
|
|
90
|
+
H = self.head_dim
|
|
91
|
+
l2_norm = L2Norm()
|
|
92
|
+
|
|
93
|
+
with jax.named_scope("q_proj"):
|
|
94
|
+
q_TNH = jnp.einsum('TD,DNH -> TNH', x_q_TD,
|
|
95
|
+
self.kernel_q_proj_DNH.value)
|
|
96
|
+
if use_attention_rope:
|
|
97
|
+
q_TNH = apply_rope(q_TNH, md.input_positions, H, rope_theta,
|
|
98
|
+
rope_scaling, self.rope_input_ordering)
|
|
99
|
+
|
|
100
|
+
# Apply normaliation after RoPE
|
|
101
|
+
if self.use_qk_norm:
|
|
102
|
+
q_TNH = l2_norm(q_TNH)
|
|
103
|
+
else:
|
|
104
|
+
if self.temperature_tuning:
|
|
105
|
+
q_TNH = self.apply_temperature_tuning(md, q_TNH)
|
|
106
|
+
|
|
107
|
+
q_TNH = nnx.with_sharding_constraint(q_TNH, self.query_tnh)
|
|
108
|
+
with jax.named_scope("k_proj"):
|
|
109
|
+
k_SKH = jnp.einsum('SD,DKH -> SKH', x_SD,
|
|
110
|
+
self.kernel_k_proj_DKH.value)
|
|
111
|
+
if use_attention_rope:
|
|
112
|
+
k_SKH = apply_rope(k_SKH, md.input_positions, H, rope_theta,
|
|
113
|
+
rope_scaling, self.rope_input_ordering)
|
|
114
|
+
|
|
115
|
+
# Apply normaliation after RoPE
|
|
116
|
+
if self.use_qk_norm:
|
|
117
|
+
k_SKH = l2_norm(k_SKH)
|
|
118
|
+
k_SKH = nnx.with_sharding_constraint(k_SKH, self.keyvalue_skh)
|
|
119
|
+
|
|
120
|
+
with jax.named_scope("v_proj"):
|
|
121
|
+
v_SKH = jnp.einsum('SD,DKH -> SKH', x_SD,
|
|
122
|
+
self.kernel_v_proj_DKH.value)
|
|
123
|
+
v_SKH = nnx.with_sharding_constraint(v_SKH, self.keyvalue_skh)
|
|
124
|
+
|
|
125
|
+
q_scale = k_scale = v_scale = None
|
|
126
|
+
if self.kv_cache_quantized_dtype:
|
|
127
|
+
# TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
|
|
128
|
+
# q_scale = self._q_scale
|
|
129
|
+
k_scale = self._k_scale
|
|
130
|
+
v_scale = self._v_scale
|
|
131
|
+
k_SKH, v_SKH = utils.quantize_kv(k_SKH, v_SKH,
|
|
132
|
+
self.kv_cache_quantized_dtype,
|
|
133
|
+
k_scale, v_scale)
|
|
134
|
+
|
|
135
|
+
with jax.named_scope("attn_op"):
|
|
136
|
+
new_kv_cache, outputs_TNH = self.attention(
|
|
137
|
+
is_prefill,
|
|
138
|
+
kv_cache,
|
|
139
|
+
q_TNH,
|
|
140
|
+
k_SKH,
|
|
141
|
+
v_SKH,
|
|
142
|
+
attention_metadata,
|
|
143
|
+
self.mesh,
|
|
144
|
+
q_scale=q_scale,
|
|
145
|
+
k_scale=k_scale,
|
|
146
|
+
v_scale=v_scale,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
with jax.named_scope("o_proj"):
|
|
150
|
+
o_TD = jnp.einsum('TNH,NHD -> TD', outputs_TNH,
|
|
151
|
+
self.kernel_o_proj_NHD.value)
|
|
152
|
+
o_TD = nnx.with_sharding_constraint(
|
|
153
|
+
o_TD, self.activation_attention_out_td)
|
|
154
|
+
return new_kv_cache, o_TD
|
|
155
|
+
|
|
156
|
+
def apply_temperature_tuning(self, md: AttentionMetadata,
|
|
157
|
+
input_arr_TNH: jax.Array) -> jax.Array:
|
|
158
|
+
"""Applies temperature tuning to the input array of shape (T, N, H).
|
|
159
|
+
Args:
|
|
160
|
+
md: AttentionMetadata object containing the input positions.
|
|
161
|
+
input_arr_TNH: Input array of shape (T, N, H) which will have scaled temperatures applied.
|
|
162
|
+
"""
|
|
163
|
+
attn_scales = (jnp.log(
|
|
164
|
+
jnp.floor((md.input_positions.astype(self.dtype) + 1.0) /
|
|
165
|
+
self.temperature_tuning_floor_scale) + 1.0) *
|
|
166
|
+
self.temperature_tuning_scale + 1.0)
|
|
167
|
+
return input_arr_TNH * attn_scales[:, None, None]
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import dataclasses
|
|
16
|
+
from dataclasses import dataclass, fields
|
|
17
|
+
from typing import Any, Callable, Mapping
|
|
18
|
+
|
|
19
|
+
import jax
|
|
20
|
+
import jax.numpy as jnp
|
|
21
|
+
from flax import nnx
|
|
22
|
+
from flax.typing import Sharding
|
|
23
|
+
from jax.sharding import PartitionSpec as P
|
|
24
|
+
|
|
25
|
+
from tpu_inference.logger import init_logger
|
|
26
|
+
|
|
27
|
+
# Type alias for Initializer for cleaner type hints
|
|
28
|
+
Initializer = Callable[..., jax.Array]
|
|
29
|
+
logger = init_logger(__name__)
|
|
30
|
+
|
|
31
|
+
# Define singleton initializers to avoid re-compilation.
|
|
32
|
+
_scale_initializer = nnx.initializers.ones
|
|
33
|
+
_sharded_initializer = nnx.initializers.xavier_normal()
|
|
34
|
+
_init_fn = nnx.initializers.uniform()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class Config:
|
|
39
|
+
"""Base configuration class with a robust factory method.
|
|
40
|
+
|
|
41
|
+
This class provides a `from_cfg` classmethod that allows creating a config
|
|
42
|
+
instance from a dictionary, ensuring that all required fields are present
|
|
43
|
+
and ignoring any extraneous keys.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def from_cfg(cls, cfg: dict[str, Any] | None = None, **kwargs):
|
|
48
|
+
"""Creates a config instance from a dictionary and/or keyword arguments.
|
|
49
|
+
|
|
50
|
+
This factory method validates that all fields without default values
|
|
51
|
+
are provided in the input dictionary or keyword arguments.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
cfg: A dictionary of configuration parameters.
|
|
55
|
+
**kwargs: Additional configuration parameters passed as keyword arguments.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
An instance of the configuration class.
|
|
59
|
+
|
|
60
|
+
Raises:
|
|
61
|
+
ValueError: If any required parameters are missing.
|
|
62
|
+
"""
|
|
63
|
+
if cfg is None:
|
|
64
|
+
cfg = {}
|
|
65
|
+
cfg.update(kwargs)
|
|
66
|
+
|
|
67
|
+
required_params = {
|
|
68
|
+
f.name
|
|
69
|
+
for f in fields(cls) if f.default is dataclasses.MISSING
|
|
70
|
+
and f.default_factory is dataclasses.MISSING
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
# Check if any of the truly required parameters are missing from the provided config.
|
|
74
|
+
missing_params = required_params - set(cfg.keys())
|
|
75
|
+
if missing_params:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"Missing required parameters for {cls.__name__}: {', '.join(sorted(list(missing_params)))}"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
known_params = {f.name for f in fields(cls)}
|
|
81
|
+
filtered_cfg = {k: v for k, v in cfg.items() if k in known_params}
|
|
82
|
+
|
|
83
|
+
return cls(**filtered_cfg)
|
|
84
|
+
|
|
85
|
+
# TODO: check logic with some unit tests.
|
|
86
|
+
def maybe_apply_overrides(self):
|
|
87
|
+
"""Update the args with additional_configs, hf_overrides, and override_generation_config settings.
|
|
88
|
+
If there is overlap in overrides between the configs, then print a warning declaring which
|
|
89
|
+
overrides will take precedent."""
|
|
90
|
+
|
|
91
|
+
if not getattr(self, "vllm_config"):
|
|
92
|
+
return
|
|
93
|
+
|
|
94
|
+
def _overrides_str(original: str, original_val: Any,
|
|
95
|
+
new_val: Any) -> str:
|
|
96
|
+
return f"{original}: {original_val} ---> {new_val}"
|
|
97
|
+
|
|
98
|
+
def _get_overrides_dict(self) -> Mapping[str, Any]:
|
|
99
|
+
"""Return the overrides from all of the possible vllm sections."""
|
|
100
|
+
overrides_dict = {}
|
|
101
|
+
vllm_model_config = self.vllm_config.model_config
|
|
102
|
+
|
|
103
|
+
for override_type in ordered_override_types:
|
|
104
|
+
if override_type == "additional_config":
|
|
105
|
+
overrides_dict[
|
|
106
|
+
override_type] = self.vllm_config.additional_config
|
|
107
|
+
else:
|
|
108
|
+
overrides_dict[override_type] = getattr(
|
|
109
|
+
vllm_model_config, override_type)
|
|
110
|
+
return overrides_dict
|
|
111
|
+
|
|
112
|
+
ordered_override_types = [
|
|
113
|
+
"additional_config", "hf_overrides", "override_generation_config"
|
|
114
|
+
]
|
|
115
|
+
|
|
116
|
+
overrides_dict = _get_overrides_dict(self)
|
|
117
|
+
|
|
118
|
+
# Override the config values using the vLLM sections with highest
|
|
119
|
+
# precedence first.
|
|
120
|
+
for field in fields(self):
|
|
121
|
+
selected_type = None
|
|
122
|
+
for override_type in reversed(ordered_override_types):
|
|
123
|
+
if field.name in overrides_dict[override_type]:
|
|
124
|
+
setattr(self, field.name,
|
|
125
|
+
overrides_dict[override_type][field.name])
|
|
126
|
+
selected_type = override_type
|
|
127
|
+
break
|
|
128
|
+
if selected_type is None:
|
|
129
|
+
continue
|
|
130
|
+
|
|
131
|
+
# If multiple vLLM sections contain overrides, print a warning.
|
|
132
|
+
for override_type in ordered_override_types:
|
|
133
|
+
if override_type == selected_type:
|
|
134
|
+
break
|
|
135
|
+
else:
|
|
136
|
+
if field.name in overrides_dict[override_type]:
|
|
137
|
+
overriden_keys_str = _overrides_str(
|
|
138
|
+
field.name,
|
|
139
|
+
overrides_dict[override_type][field.name],
|
|
140
|
+
overrides_dict[selected_type][field.name])
|
|
141
|
+
logger.warning(
|
|
142
|
+
f"Overriding {override_type} arguments with the following {selected_type} args: {overriden_keys_str}"
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
def __post_init__(self):
|
|
146
|
+
self.maybe_apply_overrides()
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def create_param(rngs: nnx.Rngs,
|
|
150
|
+
shape: tuple[int, ...],
|
|
151
|
+
sharding: Sharding = (),
|
|
152
|
+
dtype: Any = jnp.float32,
|
|
153
|
+
random_init=False) -> nnx.Param:
|
|
154
|
+
key = rngs.params()
|
|
155
|
+
if random_init:
|
|
156
|
+
initializer = _scale_initializer if len(
|
|
157
|
+
shape) == 1 else _sharded_initializer
|
|
158
|
+
|
|
159
|
+
jitted_initializer = jax.jit(initializer,
|
|
160
|
+
static_argnames=('shape', 'dtype'),
|
|
161
|
+
out_shardings=P(*sharding))
|
|
162
|
+
param_data = jitted_initializer(key, shape, dtype)
|
|
163
|
+
return nnx.Param(param_data, sharding=sharding)
|
|
164
|
+
else:
|
|
165
|
+
return nnx.Param(_init_fn(key, shape, dtype), sharding=sharding)
|