tpu-inference 0.11.1rc1__py3-none-any.whl → 0.11.1rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (50) hide show
  1. tpu_inference/kernels/collectives/__init__.py +0 -0
  2. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  3. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  4. tpu_inference/kernels/collectives/util.py +47 -0
  5. tpu_inference/layers/__init__.py +0 -0
  6. tpu_inference/layers/common/__init__.py +0 -0
  7. tpu_inference/layers/common/attention_metadata.py +34 -0
  8. tpu_inference/layers/jax/__init__.py +0 -0
  9. tpu_inference/layers/jax/attention/__init__.py +0 -0
  10. tpu_inference/layers/jax/attention/attention.py +254 -0
  11. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  12. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  13. tpu_inference/layers/jax/attention_interface.py +356 -0
  14. tpu_inference/layers/jax/base.py +151 -0
  15. tpu_inference/layers/jax/binary_search.py +295 -0
  16. tpu_inference/layers/jax/constants.py +88 -0
  17. tpu_inference/layers/jax/layers.py +301 -0
  18. tpu_inference/layers/jax/misc.py +16 -0
  19. tpu_inference/layers/jax/moe/__init__.py +0 -0
  20. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  21. tpu_inference/layers/jax/moe/moe.py +209 -0
  22. tpu_inference/layers/jax/rope.py +172 -0
  23. tpu_inference/layers/jax/rope_interface.py +214 -0
  24. tpu_inference/layers/jax/sample/__init__.py +0 -0
  25. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  26. tpu_inference/layers/jax/sample/sampling.py +95 -0
  27. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  28. tpu_inference/layers/jax/sharding.py +406 -0
  29. tpu_inference/layers/jax/transformer_block.py +76 -0
  30. tpu_inference/layers/vllm/__init__.py +0 -0
  31. tpu_inference/layers/vllm/attention.py +184 -0
  32. tpu_inference/layers/vllm/fused_moe.py +399 -0
  33. tpu_inference/layers/vllm/linear_common.py +186 -0
  34. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  35. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  36. tpu_inference/layers/vllm/quantization/common.py +105 -0
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  38. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  39. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  40. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  41. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  42. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  43. tpu_inference/layers/vllm/sharding.py +151 -0
  44. tpu_inference/models/common/__init__.py +0 -0
  45. tpu_inference/models/common/model_loader.py +433 -0
  46. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/METADATA +6 -6
  47. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/RECORD +50 -5
  48. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/WHEEL +0 -0
  49. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/licenses/LICENSE +0 -0
  50. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,356 @@
1
+ import functools
2
+ import math
3
+ from typing import Any, Callable, Optional, Tuple
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from jax.experimental import shard_map
8
+ from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention
9
+ from jax.experimental.pallas.ops.tpu.splash_attention import \
10
+ splash_attention_kernel as splash
11
+ from jax.experimental.pallas.ops.tpu.splash_attention import \
12
+ splash_attention_mask as mask_lib
13
+ from jax.sharding import Mesh
14
+ from jax.sharding import PartitionSpec as P
15
+
16
+ import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
17
+ from tpu_inference.kernels.flash_attention.kernel import flash_attention
18
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
19
+ from tpu_inference.utils import get_megacore
20
+
21
+ MAX_ALLOWED_PAGE_INDICES_N = (
22
+ 128 * 1024
23
+ ) # Based on experiments on v5e, 256x1024 results in smem oom but 128x1024 not. TODO: Adjust this based on TPU version.
24
+
25
+ ragged_paged_attention = rpa.ragged_paged_attention
26
+ get_kv_cache_shape = rpa.get_kv_cache_shape
27
+
28
+
29
+ def sharded_flash_attention(
30
+ mesh: Mesh,
31
+ causal: bool = True,
32
+ sm_scale: Optional[float] = None,
33
+ vmem_limit_bytes: int | None = None,
34
+ ) -> Callable[..., Any]:
35
+ in_specs = (
36
+ P("data", "model", None, None), # q
37
+ P("data", "model", None, None), # k
38
+ P("data", "model", None, None), # v
39
+ P(), # segment_ids
40
+ )
41
+ out_specs = P("data", "model", None, None)
42
+
43
+ def _flash_attention(q, k, v, segment_ids):
44
+ return flash_attention(q,
45
+ k,
46
+ v,
47
+ segment_ids=segment_ids,
48
+ sm_scale=sm_scale,
49
+ causal=causal,
50
+ vmem_limit_bytes=vmem_limit_bytes)
51
+
52
+ return jax.jit(
53
+ shard_map.shard_map(_flash_attention,
54
+ mesh=mesh,
55
+ in_specs=in_specs,
56
+ out_specs=out_specs,
57
+ check_rep=False))
58
+
59
+
60
+ def sharded_paged_attention(
61
+ mesh: Mesh,
62
+ attn_logits_soft_cap: Optional[float] = None,
63
+ ) -> Callable[..., Any]:
64
+ """Shards GQA PagedAttention along KV heads."""
65
+ in_specs = (
66
+ P(None, "model", None), # q
67
+ P("model", None, None, None), # k
68
+ P("model", None, None, None), # v
69
+ P(), # lengths
70
+ P(), # page_indices
71
+ )
72
+ out_specs = P(None, "model", None)
73
+
74
+ def _paged_attention_fn(q, k, v, lengths, page_indices):
75
+ if page_indices.size > MAX_ALLOWED_PAGE_INDICES_N:
76
+ raise ValueError(
77
+ "This will result in smem OOM. Use `paged_attention_with_guarded_smem` to run with minibatches."
78
+ )
79
+ return paged_attention(
80
+ q,
81
+ k,
82
+ v,
83
+ lengths,
84
+ page_indices,
85
+ attn_logits_soft_cap=attn_logits_soft_cap,
86
+ pages_per_compute_block=min(
87
+ 16, page_indices.shape[1]), # 512 / page_size:32,
88
+ megacore_mode="kv_head" if get_megacore() else None,
89
+ )
90
+
91
+ return jax.jit(
92
+ shard_map.shard_map(
93
+ _paged_attention_fn,
94
+ mesh=mesh,
95
+ in_specs=in_specs,
96
+ out_specs=out_specs,
97
+ check_rep=False,
98
+ ))
99
+
100
+
101
+ # TODO(xiangxu): merge this with sharded_paged_attention
102
+ @functools.partial(jax.jit, static_argnums=[0])
103
+ def paged_attention_with_guarded_smem(
104
+ paged_attention_kernel: Callable,
105
+ q: jax.Array,
106
+ k_pages: jax.Array,
107
+ v_pages: jax.Array,
108
+ lengths: jax.Array,
109
+ page_indices: jax.Array,
110
+ ):
111
+ # Addresses b/336316706. Summary:
112
+ # Paged attention kernel stores `lengths` (batch_size * 4 bytes) and `page_indices` (batch_size * num_blocks_per_seq * 4 bytes) in SMEM.
113
+ # Capacity of SMEM is quite limited which is also TPU version dependent. Models with higher context length or higher batch size, can cause OOM in SMEM.
114
+ # There are two solutions:
115
+ # 1. Reduce blocks per seq by increasing page size.
116
+ # 2. Splitting the batch into several minibatches (Higher perf based on my benchmark).
117
+
118
+ batch_size, blocks_per_seq = page_indices.shape
119
+
120
+ if page_indices.size <= MAX_ALLOWED_PAGE_INDICES_N:
121
+ return paged_attention_kernel(q, k_pages, v_pages, lengths,
122
+ page_indices)
123
+
124
+ mini_batch_size = MAX_ALLOWED_PAGE_INDICES_N // blocks_per_seq
125
+
126
+ # If batch_size is not disible by mini_batch_size,
127
+ # we set mini_batch_size to a smaller value, i.e GCD,
128
+ # which will trigger more kernel launches but it's fine.
129
+ # TODO: Fix --decode_seqs_padding with this limitation.
130
+ mini_batch_size = math.gcd(batch_size, mini_batch_size)
131
+
132
+ num_kernel_launches = batch_size // mini_batch_size
133
+
134
+ outputs = jnp.zeros_like(q).reshape(
135
+ (num_kernel_launches, mini_batch_size, *q.shape[1:]))
136
+ q = q.reshape((num_kernel_launches, mini_batch_size, *q.shape[1:]))
137
+ seq_lens = lengths.reshape((num_kernel_launches, mini_batch_size))
138
+ block_indices = page_indices.reshape(
139
+ (num_kernel_launches, mini_batch_size, page_indices.shape[1]))
140
+
141
+ for i in range(num_kernel_launches):
142
+ outputs = outputs.at[i].set(
143
+ paged_attention_kernel(q[i], k_pages, v_pages, seq_lens[i],
144
+ block_indices[i]))
145
+
146
+ outputs = outputs.reshape((batch_size, *outputs.shape[2:]))
147
+
148
+ return outputs
149
+
150
+
151
+ # ruff: noqa: E741
152
+ def update_cache(
153
+ is_prefill,
154
+ cache,
155
+ indices,
156
+ operand,
157
+ prefill_seq_len=None,
158
+ sliding_window=None,
159
+ ) -> jax.Array:
160
+
161
+ # (8, 55640, 32, 128) (1, 8, 256, 128) -> K (8, 8, 32, 128)
162
+ # I = B * T // S
163
+ # k cache, operand
164
+
165
+ B, K, T, H = operand.shape
166
+ K_c, L, S, H = cache.shape
167
+ assert K == K_c
168
+ # NOTE: The cache updating is pretty tricky:
169
+ # 1. The random access updating cache is not as performant as the slice updating.
170
+ # If the random access is necessary, make sure the indexing count is as small as possible.
171
+ # 2. The random access updating may trigger extra tranpose (memory copy) of cache,
172
+ # which is a disaster because the cache is huge. This is a data formatting op inserted by
173
+ # the XLA compiler and not well documented.
174
+ # To mitigate the issues above:
175
+ # For prefill:
176
+ # We reshape the operand so that we can update the cache in block wise, which only requires the block indices.
177
+ # For decode:
178
+ # We reshape the cache so that we can update the cache in token wise, which only requires the token indices (block_id + offset).
179
+ if is_prefill:
180
+ # In the case of sliding window, we should select sliding_window tokens from actual prompt, not from the padded tokens.
181
+ if sliding_window and T > sliding_window:
182
+ assert B == 1
183
+ start_index = jax.lax.max(0, prefill_seq_len - sliding_window)
184
+ operand = jax.lax.dynamic_slice_in_dim(
185
+ operand, start_index, sliding_window,
186
+ axis=2) # TODO: @pooyam Perf check this.
187
+ T = sliding_window
188
+
189
+ I = B * T // S
190
+ # cache: (K, L, S, H)
191
+ # operand: (B, K, T, H) -> (K, I, S, H)
192
+ # indices: (B, T // S) -> (I,)
193
+ operand = jnp.swapaxes(operand, 0, 1).reshape(K, I, S, H)
194
+ indices = indices.reshape(I)
195
+ cache = cache.at[:, indices, :, :].set(operand)
196
+ else:
197
+ # cache: (K, L, S, H) -> (K, L * S, H)
198
+ # operand: (B, K, 1, H) -> (K, B, H)
199
+ # indices: (B,)
200
+ cache = cache.reshape(K, L * S, H)
201
+ operand = jnp.swapaxes(operand, 0, 1).reshape(K, B, H)
202
+ # NOTE: `cache.[:, indices, :].set()` will trigger the extra tranpose of the cache.
203
+ # The `jnp.arange(K)[..., None]` trick is to avoid it. WTF?
204
+ cache = cache.at[jnp.arange(K)[..., None], indices, :].set(operand)
205
+ cache = cache.reshape(K, L, S, H)
206
+ return cache
207
+
208
+
209
+ @functools.partial(
210
+ jax.jit, static_argnames=["window_size", "attn_logits_soft_cap", "is_mqa"])
211
+ def apply_splash(q, k, v, window_size, attn_logits_soft_cap,
212
+ is_mqa) -> jax.Array:
213
+ # q: (batch_size, num_heads, seq_len, head_dim)
214
+ num_heads = q.shape[1]
215
+ q_seq_len = q.shape[2]
216
+ kv_seq_len = k.shape[2]
217
+ assert kv_seq_len >= q_seq_len
218
+
219
+ masks = [
220
+ mask_lib.LocalMask((q_seq_len, kv_seq_len), (window_size, 0),
221
+ kv_seq_len - q_seq_len) for _ in range(num_heads)
222
+ ]
223
+ mask = mask_lib.MultiHeadMask(tuple((m for m in masks)))
224
+ block_sizes = splash.BlockSizes.get_default()
225
+
226
+ if is_mqa:
227
+ attn = splash.make_splash_mqa_single_device(
228
+ mask,
229
+ block_sizes=block_sizes,
230
+ attn_logits_soft_cap=attn_logits_soft_cap)
231
+ else:
232
+ attn = splash.make_splash_mha_single_device(
233
+ mask,
234
+ block_sizes=block_sizes,
235
+ attn_logits_soft_cap=attn_logits_soft_cap)
236
+ attn = jax.vmap(attn)
237
+ outputs = attn(q, k, v, None)
238
+
239
+ return outputs
240
+
241
+
242
+ def sharded_splash_attention(
243
+ mesh: Mesh,
244
+ window_size: Optional[int] = None,
245
+ attn_logits_soft_cap: Optional[float] = None,
246
+ is_mqa: bool = False,
247
+ ) -> Callable[..., Any]:
248
+ in_specs = (
249
+ P("data", "model", None, None), # q
250
+ P("data", "model", None, None), # k
251
+ P("data", "model", None, None), # vx
252
+ )
253
+ out_specs = P("data", "model", None, None)
254
+ return jax.jit(
255
+ shard_map.shard_map(
256
+ functools.partial(
257
+ apply_splash,
258
+ window_size=window_size,
259
+ attn_logits_soft_cap=attn_logits_soft_cap,
260
+ is_mqa=is_mqa,
261
+ ),
262
+ mesh=mesh,
263
+ in_specs=in_specs,
264
+ out_specs=out_specs,
265
+ check_rep=False,
266
+ ))
267
+
268
+
269
+ def sharded_ragged_paged_attention(
270
+ sm_scale: float,
271
+ mesh: Mesh,
272
+ attention_chunk_size: int | None = None,
273
+ q_scale: float | None = None,
274
+ k_scale: float | None = None,
275
+ v_scale: float | None = None,
276
+ ):
277
+ """Shards along KV heads."""
278
+ qkv_spec = P(None, "model", None)
279
+ kv_cache_spec = P(None, None, "model")
280
+ in_specs = (
281
+ qkv_spec, # q
282
+ qkv_spec, # k
283
+ qkv_spec, # v
284
+ kv_cache_spec, # kv cache
285
+ P(), # kv_lens
286
+ P(), # page_indices
287
+ P(), # cu_q_lens
288
+ P(), # distribution
289
+ )
290
+ out_specs = (qkv_spec, kv_cache_spec)
291
+
292
+ def _ragged_paged_attention(*args):
293
+ return ragged_paged_attention(
294
+ *args,
295
+ sm_scale=sm_scale,
296
+ sliding_window=attention_chunk_size,
297
+ q_scale=q_scale,
298
+ k_scale=k_scale,
299
+ v_scale=v_scale,
300
+ )
301
+
302
+ return jax.jit(
303
+ shard_map.shard_map(
304
+ _ragged_paged_attention,
305
+ mesh=mesh,
306
+ in_specs=in_specs,
307
+ out_specs=out_specs,
308
+ check_rep=False,
309
+ ))
310
+
311
+
312
+ def attention(
313
+ kv_cache: jax.Array,
314
+ q: jax.Array,
315
+ k: jax.Array,
316
+ v: jax.Array,
317
+ attention_metadata: AttentionMetadata,
318
+ mesh: Mesh,
319
+ head_dim_original: int | None = None, # before padding,
320
+ attention_chunk_size: int | None = None,
321
+ q_scale: float | None = None,
322
+ k_scale: float | None = None,
323
+ v_scale: float | None = None,
324
+ ) -> Tuple[jax.Array, jax.Array]:
325
+ # T: seq_len
326
+ # N: num_heads
327
+ # K: num_kv_heads
328
+ # D: hidden_size
329
+ # H: head_dim
330
+ # L: num_blocks
331
+ # S: block_size
332
+
333
+ # TODO(jevinjiang, cuiq): transpose q weight offline.
334
+ # q: (T, N, H)
335
+ # k,v: (T, K, H)
336
+
337
+ if head_dim_original is None:
338
+ head_dim_original = q.shape[-1]
339
+
340
+ md = attention_metadata
341
+
342
+ # (T, N, H)
343
+ output, kv_cache = sharded_ragged_paged_attention(
344
+ head_dim_original**-0.5, mesh, attention_chunk_size, q_scale, k_scale,
345
+ v_scale)(
346
+ q,
347
+ k,
348
+ v,
349
+ kv_cache,
350
+ md.seq_lens,
351
+ md.block_tables,
352
+ md.query_start_loc,
353
+ md.request_distribution,
354
+ )
355
+
356
+ return kv_cache, output
@@ -0,0 +1,151 @@
1
+ import dataclasses
2
+ from dataclasses import dataclass, fields
3
+ from typing import Any, Callable, Mapping
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from flax import nnx
8
+ from flax.typing import Sharding
9
+ from jax.sharding import PartitionSpec as P
10
+
11
+ from tpu_inference.logger import init_logger
12
+
13
+ # Type alias for Initializer for cleaner type hints
14
+ Initializer = Callable[..., jax.Array]
15
+ logger = init_logger(__name__)
16
+
17
+ # Define singleton initializers to avoid re-compilation.
18
+ _scale_initializer = nnx.initializers.ones
19
+ _sharded_initializer = nnx.initializers.xavier_normal()
20
+ _init_fn = nnx.initializers.uniform()
21
+
22
+
23
+ @dataclass
24
+ class Config:
25
+ """Base configuration class with a robust factory method.
26
+
27
+ This class provides a `from_cfg` classmethod that allows creating a config
28
+ instance from a dictionary, ensuring that all required fields are present
29
+ and ignoring any extraneous keys.
30
+ """
31
+
32
+ @classmethod
33
+ def from_cfg(cls, cfg: dict[str, Any] | None = None, **kwargs):
34
+ """Creates a config instance from a dictionary and/or keyword arguments.
35
+
36
+ This factory method validates that all fields without default values
37
+ are provided in the input dictionary or keyword arguments.
38
+
39
+ Args:
40
+ cfg: A dictionary of configuration parameters.
41
+ **kwargs: Additional configuration parameters passed as keyword arguments.
42
+
43
+ Returns:
44
+ An instance of the configuration class.
45
+
46
+ Raises:
47
+ ValueError: If any required parameters are missing.
48
+ """
49
+ if cfg is None:
50
+ cfg = {}
51
+ cfg.update(kwargs)
52
+
53
+ required_params = {
54
+ f.name
55
+ for f in fields(cls) if f.default is dataclasses.MISSING
56
+ and f.default_factory is dataclasses.MISSING
57
+ }
58
+
59
+ # Check if any of the truly required parameters are missing from the provided config.
60
+ missing_params = required_params - set(cfg.keys())
61
+ if missing_params:
62
+ raise ValueError(
63
+ f"Missing required parameters for {cls.__name__}: {', '.join(sorted(list(missing_params)))}"
64
+ )
65
+
66
+ known_params = {f.name for f in fields(cls)}
67
+ filtered_cfg = {k: v for k, v in cfg.items() if k in known_params}
68
+
69
+ return cls(**filtered_cfg)
70
+
71
+ # TODO: check logic with some unit tests.
72
+ def maybe_apply_overrides(self):
73
+ """Update the args with additional_configs, hf_overrides, and override_generation_config settings.
74
+ If there is overlap in overrides between the configs, then print a warning declaring which
75
+ overrides will take precedent."""
76
+
77
+ if not getattr(self, "vllm_config"):
78
+ return
79
+
80
+ def _overrides_str(original: str, original_val: Any,
81
+ new_val: Any) -> str:
82
+ return f"{original}: {original_val} ---> {new_val}"
83
+
84
+ def _get_overrides_dict(self) -> Mapping[str, Any]:
85
+ """Return the overrides from all of the possible vllm sections."""
86
+ overrides_dict = {}
87
+ vllm_model_config = self.vllm_config.model_config
88
+
89
+ for override_type in ordered_override_types:
90
+ if override_type == "additional_config":
91
+ overrides_dict[
92
+ override_type] = self.vllm_config.additional_config
93
+ else:
94
+ overrides_dict[override_type] = getattr(
95
+ vllm_model_config, override_type)
96
+ return overrides_dict
97
+
98
+ ordered_override_types = [
99
+ "additional_config", "hf_overrides", "override_generation_config"
100
+ ]
101
+
102
+ overrides_dict = _get_overrides_dict(self)
103
+
104
+ # Override the config values using the vLLM sections with highest
105
+ # precedence first.
106
+ for field in fields(self):
107
+ selected_type = None
108
+ for override_type in reversed(ordered_override_types):
109
+ if field.name in overrides_dict[override_type]:
110
+ setattr(self, field.name,
111
+ overrides_dict[override_type][field.name])
112
+ selected_type = override_type
113
+ break
114
+ if selected_type is None:
115
+ continue
116
+
117
+ # If multiple vLLM sections contain overrides, print a warning.
118
+ for override_type in ordered_override_types:
119
+ if override_type == selected_type:
120
+ break
121
+ else:
122
+ if field.name in overrides_dict[override_type]:
123
+ overriden_keys_str = _overrides_str(
124
+ field.name,
125
+ overrides_dict[override_type][field.name],
126
+ overrides_dict[selected_type][field.name])
127
+ logger.warning(
128
+ f"Overriding {override_type} arguments with the following {selected_type} args: {overriden_keys_str}"
129
+ )
130
+
131
+ def __post_init__(self):
132
+ self.maybe_apply_overrides()
133
+
134
+
135
+ def create_param(rngs: nnx.Rngs,
136
+ shape: tuple[int, ...],
137
+ sharding: Sharding = (),
138
+ dtype: Any = jnp.float32,
139
+ random_init=False) -> nnx.Param:
140
+ key = rngs.params()
141
+ if random_init:
142
+ initializer = _scale_initializer if len(
143
+ shape) == 1 else _sharded_initializer
144
+
145
+ jitted_initializer = jax.jit(initializer,
146
+ static_argnames=('shape', 'dtype'),
147
+ out_shardings=P(*sharding))
148
+ param_data = jitted_initializer(key, shape, dtype)
149
+ return nnx.Param(param_data, sharding=sharding)
150
+ else:
151
+ return nnx.Param(_init_fn(key, shape, dtype), sharding=sharding)