tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
from dataclasses import InitVar, dataclass
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
from flax import nnx
|
|
7
|
+
from flax.typing import Sharding
|
|
8
|
+
from jaxtyping import Float, Int
|
|
9
|
+
|
|
10
|
+
from tpu_inference.layers.jax.base import create_param
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# A dummy for modeling_flax_utils which might contain activation functions
|
|
14
|
+
class FlaxUtils:
|
|
15
|
+
"""A dummy class to namespace activation functions, mimicking external utilities."""
|
|
16
|
+
ACT2FN = {
|
|
17
|
+
'silu': nnx.silu,
|
|
18
|
+
'gelu': nnx.gelu,
|
|
19
|
+
'relu': nnx.relu,
|
|
20
|
+
'sigmoid': nnx.sigmoid,
|
|
21
|
+
'softmax': nnx.softmax
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
modeling_flax_utils = FlaxUtils()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class RuntimeParams:
|
|
30
|
+
"""A container for runtime parameters needed by neural network blocks.
|
|
31
|
+
|
|
32
|
+
This dataclass acts as a flexible container to pass objects that are only
|
|
33
|
+
available at runtime (like a pre-allocated KV cache or dynamic sharding
|
|
34
|
+
configurations) into the initialization of stateful modules. This avoids
|
|
35
|
+
having to update the constructor signature of every module when a new
|
|
36
|
+
runtime dependency is introduced.
|
|
37
|
+
|
|
38
|
+
Attributes:
|
|
39
|
+
kv_cache: The key-value cache object for attention layers.
|
|
40
|
+
sharding_cfg: The configuration for tensor sharding.
|
|
41
|
+
quantization: Configuration for quantization schemes.
|
|
42
|
+
"""
|
|
43
|
+
kv_cache: Any = None
|
|
44
|
+
sharding_cfg: Any = None
|
|
45
|
+
quantization: Any = None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass(kw_only=True)
|
|
49
|
+
class RMSNorm(nnx.Module):
|
|
50
|
+
"""An implementation of Root Mean Square Layer Normalization.
|
|
51
|
+
|
|
52
|
+
Attributes:
|
|
53
|
+
dims: The feature dimension to normalize over.
|
|
54
|
+
epsilon: A small float added to the variance to avoid division by zero.
|
|
55
|
+
with_scale: If True, learns a multiplicative scale parameter.
|
|
56
|
+
dtype: The data type for computations.
|
|
57
|
+
"""
|
|
58
|
+
dims: int
|
|
59
|
+
activation_ffw_td: Sharding = ()
|
|
60
|
+
random_init: bool = False
|
|
61
|
+
epsilon: float = 1e-6
|
|
62
|
+
with_scale: bool = True
|
|
63
|
+
dtype: Any = jnp.float32
|
|
64
|
+
|
|
65
|
+
rngs: InitVar[nnx.Rngs]
|
|
66
|
+
|
|
67
|
+
def __call__(self, x_TD: Float, op_mode='generate') -> Float:
|
|
68
|
+
"""Applies RMS Normalization to the input tensor.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
x_TD: The input tensor. The normalization is applied over the last dimension.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
The normalized tensor with the same shape as the input.
|
|
75
|
+
"""
|
|
76
|
+
x_TD = jnp.asarray(x_TD, self.dtype)
|
|
77
|
+
x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
|
|
78
|
+
|
|
79
|
+
with jax.named_scope("rms_norm_variance"):
|
|
80
|
+
var_T1 = jnp.mean(jnp.square(x_TD), axis=-1, keepdims=True)
|
|
81
|
+
with jax.named_scope("rms_norm_rsqrt"):
|
|
82
|
+
normed_x_TD = x_TD * jax.lax.rsqrt(var_T1 + self.epsilon)
|
|
83
|
+
|
|
84
|
+
with jax.named_scope("rms_norm_scale_apply"):
|
|
85
|
+
normed_x_TD *= self.scale.value
|
|
86
|
+
normed_x_TD = nnx.with_sharding_constraint(normed_x_TD,
|
|
87
|
+
self.activation_ffw_td)
|
|
88
|
+
return normed_x_TD.astype(self.dtype)
|
|
89
|
+
|
|
90
|
+
def __post_init__(self, rngs: nnx.Rngs):
|
|
91
|
+
self.scale = create_param(rngs,
|
|
92
|
+
shape=(self.dims, ),
|
|
93
|
+
dtype=self.dtype,
|
|
94
|
+
random_init=self.random_init)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@dataclass(kw_only=True)
|
|
98
|
+
class DenseFFW(nnx.Module):
|
|
99
|
+
"""A Gated Feed-Forward Network (FFN) layer.
|
|
100
|
+
|
|
101
|
+
This module consists of two linear projections (gating and up-projection),
|
|
102
|
+
an element-wise multiplication of the activated gating projection and the
|
|
103
|
+
up-projection, followed by a final downward projection.
|
|
104
|
+
|
|
105
|
+
Attributes:
|
|
106
|
+
sharding_cfg: The configuration for tensor sharding.
|
|
107
|
+
"""
|
|
108
|
+
dtype: jnp.dtype
|
|
109
|
+
hidden_act: str
|
|
110
|
+
hidden_size: int
|
|
111
|
+
intermediate_size: int
|
|
112
|
+
df_sharding: Sharding = ()
|
|
113
|
+
fd_sharding: Sharding = ()
|
|
114
|
+
activation_ffw_td: Sharding = ()
|
|
115
|
+
random_init: bool = False
|
|
116
|
+
|
|
117
|
+
rngs: InitVar[nnx.Rngs]
|
|
118
|
+
|
|
119
|
+
def __call__(self, x_TD):
|
|
120
|
+
"""Performs the forward pass of the FFW layer.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
x_TD: The input tensor of shape either `(sequence, d_model)`
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
The output tensor of shape `(batch, sequence, d_model)`.
|
|
127
|
+
"""
|
|
128
|
+
# TODO consider to create factories for einsum(?)
|
|
129
|
+
x_TD = jnp.asarray(x_TD, self.dtype)
|
|
130
|
+
x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
|
|
131
|
+
with jax.named_scope("wi_0"):
|
|
132
|
+
gating_TF = jnp.einsum('TD,DF -> TF', x_TD,
|
|
133
|
+
self.kernel_gating_DF.value)
|
|
134
|
+
activated_gating_TF = modeling_flax_utils.ACT2FN[self.hidden_act](
|
|
135
|
+
gating_TF)
|
|
136
|
+
with jax.named_scope("wi_1"):
|
|
137
|
+
up_proj_TF = jnp.einsum('TD,DF -> TF', x_TD,
|
|
138
|
+
self.kernel_up_proj_DF.value)
|
|
139
|
+
fuse_TF = activated_gating_TF * up_proj_TF
|
|
140
|
+
with jax.named_scope("wo"):
|
|
141
|
+
output_TD = jnp.einsum('TF,FD -> TD', fuse_TF,
|
|
142
|
+
self.kernel_down_proj_FD.value)
|
|
143
|
+
|
|
144
|
+
return output_TD
|
|
145
|
+
|
|
146
|
+
def __post_init__(self, rngs: nnx.Rngs):
|
|
147
|
+
D = self.hidden_size
|
|
148
|
+
F = self.intermediate_size
|
|
149
|
+
|
|
150
|
+
self.kernel_gating_DF = create_param(rngs,
|
|
151
|
+
shape=(D, F),
|
|
152
|
+
dtype=self.dtype,
|
|
153
|
+
sharding=self.df_sharding,
|
|
154
|
+
random_init=self.random_init)
|
|
155
|
+
self.kernel_up_proj_DF = create_param(rngs,
|
|
156
|
+
shape=(D, F),
|
|
157
|
+
dtype=self.dtype,
|
|
158
|
+
sharding=self.df_sharding,
|
|
159
|
+
random_init=self.random_init)
|
|
160
|
+
self.kernel_down_proj_FD = create_param(rngs,
|
|
161
|
+
shape=(F, D),
|
|
162
|
+
dtype=self.dtype,
|
|
163
|
+
sharding=self.fd_sharding,
|
|
164
|
+
random_init=self.random_init)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@dataclass(kw_only=True)
|
|
168
|
+
class Embedder(nnx.Module):
|
|
169
|
+
"""A module for token embedding and, optionally, decoding (tied embeddings).
|
|
170
|
+
|
|
171
|
+
This class handles both the "encoding" step of converting token IDs to dense
|
|
172
|
+
vectors and the "decoding" step of projecting model outputs back to logits
|
|
173
|
+
over the vocabulary.
|
|
174
|
+
|
|
175
|
+
"""
|
|
176
|
+
vocab_size: int
|
|
177
|
+
hidden_size: int
|
|
178
|
+
dtype: jnp.dtype
|
|
179
|
+
prelogit_td: Sharding = ()
|
|
180
|
+
vd_sharding: Sharding = ()
|
|
181
|
+
random_init: bool = False
|
|
182
|
+
normalize_embeddings: bool = False
|
|
183
|
+
|
|
184
|
+
rngs: InitVar[nnx.Rngs]
|
|
185
|
+
|
|
186
|
+
def __post_init__(self, rngs: nnx.Rngs):
|
|
187
|
+
self.input_embedding_table_VD = create_param(
|
|
188
|
+
rngs,
|
|
189
|
+
shape=(self.vocab_size, self.hidden_size),
|
|
190
|
+
sharding=self.vd_sharding,
|
|
191
|
+
dtype=self.dtype,
|
|
192
|
+
random_init=self.random_init)
|
|
193
|
+
|
|
194
|
+
def __call__(self, x, decode=False):
|
|
195
|
+
"""Dispatches to either the encode or decode method.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
x: The input tensor. Either token IDs for encoding or hidden states
|
|
199
|
+
for decoding.
|
|
200
|
+
decode: A boolean flag. If False (default), performs encoding. If
|
|
201
|
+
True, performs decoding.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
Either embedding vectors or logit scores.
|
|
205
|
+
"""
|
|
206
|
+
if decode:
|
|
207
|
+
return self.decode(x)
|
|
208
|
+
else:
|
|
209
|
+
return self.encode(x)
|
|
210
|
+
|
|
211
|
+
def decode(self, x_TD: Float) -> Float:
|
|
212
|
+
"""Projects hidden states to vocabulary logits.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
x_TD: The input tensor of hidden states from the model backbone, with
|
|
216
|
+
shape `(sequence, d_model)`.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
The output logits over the vocabulary, with shape
|
|
220
|
+
`(sequence, vocab_size)`.
|
|
221
|
+
"""
|
|
222
|
+
x_TD = jnp.asarray(x_TD, self.dtype)
|
|
223
|
+
x_TD = nnx.with_sharding_constraint(x_TD, self.prelogit_td)
|
|
224
|
+
|
|
225
|
+
with jax.named_scope("embedder_decode_projection"):
|
|
226
|
+
logits_TV = jnp.einsum('VD,TD -> TV',
|
|
227
|
+
self.input_embedding_table_VD.value, x_TD)
|
|
228
|
+
return logits_TV
|
|
229
|
+
|
|
230
|
+
def encode(self, x_T: Int) -> Float:
|
|
231
|
+
"""Converts integer token IDs to dense embedding vectors.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
x_T: The input tensor of token IDs, with shape `(sequence, )`.
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
The corresponding embedding vectors, with shape
|
|
238
|
+
`(batch, sequence, d_model)`.
|
|
239
|
+
"""
|
|
240
|
+
with jax.named_scope("embedder_encode_lookup"):
|
|
241
|
+
embedding_TD = jnp.take(self.input_embedding_table_VD.value,
|
|
242
|
+
x_T,
|
|
243
|
+
axis=0)
|
|
244
|
+
|
|
245
|
+
if self.normalize_embeddings:
|
|
246
|
+
with jax.named_scope("embedder_normalize_embeddings"):
|
|
247
|
+
embedding_TD *= jnp.sqrt(self.hidden_size).astype(self.dtype)
|
|
248
|
+
return embedding_TD
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
@dataclass(kw_only=True)
|
|
252
|
+
class LMhead(Embedder):
|
|
253
|
+
"""
|
|
254
|
+
An Embedder that uses a (D, V) shaped embedding table, inheriting from
|
|
255
|
+
the base Embedder class.
|
|
256
|
+
|
|
257
|
+
This implementation overrides the kernel generation, encoding, and decoding
|
|
258
|
+
methods to work with the transposed embedding matrix layout.
|
|
259
|
+
"""
|
|
260
|
+
dv_sharding: Sharding
|
|
261
|
+
|
|
262
|
+
def __post_init__(self, rngs: nnx.Rngs):
|
|
263
|
+
self.input_embedding_table_DV = create_param(
|
|
264
|
+
rngs,
|
|
265
|
+
shape=(self.hidden_size, self.vocab_size),
|
|
266
|
+
sharding=self.dv_sharding,
|
|
267
|
+
dtype=self.dtype,
|
|
268
|
+
random_init=self.random_init)
|
|
269
|
+
|
|
270
|
+
def __call__(self, x):
|
|
271
|
+
"""Dispatches to decode method.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
x: The input tensor. Either token IDs for encoding or hidden states
|
|
275
|
+
for decoding.
|
|
276
|
+
decode: A boolean flag. If False (default), performs encoding. If
|
|
277
|
+
True, performs decoding.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
Either embedding vectors or logit scores.
|
|
281
|
+
"""
|
|
282
|
+
return self.decode(x)
|
|
283
|
+
|
|
284
|
+
def decode(self, x_TD: Float) -> Float:
|
|
285
|
+
"""Projects hidden states to vocabulary logits.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
x_TD: The input tensor of hidden states from the model backbone, with
|
|
289
|
+
shape `(sequence, d_model)`.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
The output logits over the vocabulary, with shape
|
|
293
|
+
`(sequence, vocab_size)`.
|
|
294
|
+
"""
|
|
295
|
+
x_TD = jnp.asarray(x_TD, self.dtype)
|
|
296
|
+
x_TD = nnx.with_sharding_constraint(x_TD, self.prelogit_td)
|
|
297
|
+
|
|
298
|
+
with jax.named_scope("lmhead_decode_projection"):
|
|
299
|
+
logits_TV = jnp.einsum('DV,TD -> TV',
|
|
300
|
+
self.input_embedding_table_DV.value, x_TD)
|
|
301
|
+
return logits_TV
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Tuple
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
from jax.sharding import NamedSharding
|
|
6
|
+
from jax.sharding import PartitionSpec as P
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# TODO(xiang): move this to weight_utils.py
|
|
10
|
+
def shard_put(x: jax.Array, sharding_names: Tuple[str, ...] | P,
|
|
11
|
+
mesh: jax.sharding.Mesh) -> jax.Array:
|
|
12
|
+
# Single device sharding requires this special handling
|
|
13
|
+
# to avoid the recursive jit error.
|
|
14
|
+
if math.prod(mesh.axis_sizes) == 1:
|
|
15
|
+
return jax.device_put(x, mesh.devices.flatten()[0])
|
|
16
|
+
return jax.device_put(x, NamedSharding(mesh, P(*sharding_names)))
|
|
File without changes
|