tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,771 @@
1
+ import functools
2
+ import os
3
+ import random
4
+ from contextlib import nullcontext
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, cast
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import jaxtyping
10
+ import numpy as np
11
+ import torch
12
+ import vllm.envs as envs
13
+ from flax import nnx
14
+ from torchax.ops.mappings import j2t_dtype
15
+ from vllm.config import VllmConfig
16
+ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
17
+ has_kv_transfer_group)
18
+ from vllm.forward_context import set_forward_context
19
+ from vllm.sequence import IntermediateTensors
20
+ from vllm.tasks import SupportedTask
21
+ from vllm.utils import cdiv
22
+ from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
23
+ from vllm.v1.kv_cache_interface import KVCacheConfig
24
+ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
25
+ ModelRunnerOutput)
26
+ from vllm.v1.request import Request
27
+ from vllm.v1.spec_decode.ngram_proposer import NgramProposer
28
+ from vllm.v1.worker.kv_connector_model_runner_mixin import \
29
+ KVConnectorModelRunnerMixin
30
+ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
31
+
32
+ from tpu_inference import utils as common_utils
33
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
34
+ from tpu_inference.layers.jax.sample.rejection_sampler import RejectionSampler
35
+ from tpu_inference.layers.jax.sample.sampling import (compute_logprobs,
36
+ gather_logprobs, sample)
37
+ from tpu_inference.layers.jax.sample.sampling_metadata import \
38
+ TPUSupportedSamplingMetadata
39
+ from tpu_inference.layers.jax.sharding import build_mesh
40
+ from tpu_inference.logger import init_logger
41
+ from tpu_inference.models.common.model_loader import get_model
42
+ from tpu_inference.models.jax.utils.weight_utils import (
43
+ shard_put, transfer_state_with_mappings)
44
+ from tpu_inference.runner import utils as runner_utils
45
+ from tpu_inference.runner.compilation_manager import CompilationManager
46
+ from tpu_inference.runner.input_batch_jax import CachedRequestState, InputBatch
47
+ from tpu_inference.runner.kv_cache_manager import KVCacheManager
48
+ from tpu_inference.runner.lora_utils import LoraUtils
49
+ from tpu_inference.runner.multimodal_manager import MultiModalManager
50
+ from tpu_inference.runner.persistent_batch_manager import \
51
+ PersistentBatchManager
52
+ from tpu_inference.runner.speculative_decoding_manager import \
53
+ SpeculativeDecodingManager
54
+ from tpu_inference.runner.structured_decoding_manager import \
55
+ StructuredDecodingManager
56
+ from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
57
+ from tpu_inference.utils import device_array, make_optimized_mesh
58
+
59
+ logger = init_logger(__name__)
60
+
61
+ INVALID_TOKEN_ID = -1
62
+ # Smallest output size
63
+ MIN_NUM_SEQS = 8
64
+
65
+ DUMMY_METADATA = AttentionMetadata(
66
+ input_positions=[],
67
+ block_tables=[],
68
+ request_distribution=[0, 0, 0],
69
+ )
70
+
71
+ TPU_STR_DTYPE_TO_TORCH_DTYPE = {
72
+ "half": torch.half,
73
+ "bfloat16": torch.bfloat16,
74
+ "float": torch.float,
75
+ "fp8": torch.float8_e4m3fn,
76
+ "fp8_e4m3": torch.float8_e4m3fn,
77
+ "fp8_e5m2": torch.float8_e5m2,
78
+ "int8": torch.int8,
79
+ "uint8": torch.uint8,
80
+ }
81
+
82
+
83
+ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
84
+
85
+ def __init__(
86
+ self,
87
+ vllm_config: VllmConfig,
88
+ devices: List[Any],
89
+ ):
90
+ self.vllm_config = vllm_config
91
+ self.model_config = vllm_config.model_config
92
+ # TODO(jevinjiang): override block size based on RPA v3.
93
+ self.cache_config = vllm_config.cache_config
94
+ self.lora_config = vllm_config.lora_config
95
+ self.load_config = vllm_config.load_config
96
+ self.parallel_config = vllm_config.parallel_config
97
+ self.scheduler_config = vllm_config.scheduler_config
98
+ self.speculative_config = vllm_config.speculative_config
99
+ self.observability_config = vllm_config.observability_config
100
+ self.device_config = vllm_config.device_config
101
+
102
+ self.devices = devices
103
+ self.dtype = self.model_config.dtype
104
+ self.maybe_forbid_compile = runner_utils.ForbidCompile(
105
+ ) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext()
106
+
107
+ self._init_random()
108
+ self._init_mesh()
109
+ self._init_phased_profiling()
110
+ self._init_mm()
111
+ self._init_inputs()
112
+ self._init_speculative_decoding()
113
+
114
+ # Delegate functions to specific manager classes.
115
+ self.compilation_manager = CompilationManager(self)
116
+ self.speculative_decoding_manager = SpeculativeDecodingManager(self)
117
+ self.structured_decoding_manager = StructuredDecodingManager(self)
118
+ self.kv_cache_manager = KVCacheManager(self)
119
+ self.mm_manager = MultiModalManager(self)
120
+ self.persistent_batch_manager = PersistentBatchManager(
121
+ self.requests, self.input_batch, self.encoder_cache,
122
+ self.uses_mrope, self.model_config)
123
+ self.lora_utils = LoraUtils(self)
124
+
125
+ cache_config = self.cache_config
126
+ if cache_config.cache_dtype == "auto":
127
+ model_dtype = self.dtype
128
+ if isinstance(model_dtype, str):
129
+ self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
130
+ elif isinstance(getattr(model_dtype, 'dtype', None), jnp.dtype):
131
+ self.kv_cache_dtype = j2t_dtype(model_dtype.dtype)
132
+ elif isinstance(model_dtype, torch.dtype):
133
+ self.kv_cache_dtype = model_dtype
134
+ else:
135
+ raise ValueError(
136
+ "KV cache is unsupported for model_dtype of %s",
137
+ model_dtype)
138
+ else:
139
+ self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
140
+ cache_config.cache_dtype]
141
+
142
+ def _init_random(self):
143
+ if self.model_config.seed is None:
144
+ self.model_config.seed = 0
145
+ random.seed(self.model_config.seed)
146
+ np.random.seed(self.model_config.seed)
147
+ self.rng_key = jax.random.key(self.model_config.seed)
148
+
149
+ def _init_mesh(self) -> None:
150
+ try:
151
+ # TODO: Update override steps.
152
+ sharding_strategy = \
153
+ self.vllm_config.additional_config["sharding"]["sharding_strategy"]
154
+ except KeyError:
155
+ sharding_strategy = {"tensor_parallelism": len(self.devices)}
156
+
157
+ if os.getenv("NEW_MODEL_DESIGN", False):
158
+ self.mesh = build_mesh(self.devices, sharding_strategy)
159
+ else:
160
+ try:
161
+ dp = sharding_strategy["data_parallelism"]
162
+ except KeyError:
163
+ dp = 1
164
+ try:
165
+ tp = sharding_strategy["tensor_parallelism"]
166
+ except KeyError:
167
+ tp = len(self.devices)
168
+
169
+ axis_names = ("data", "model")
170
+ mesh_shape = (dp, tp)
171
+
172
+ self.mesh = make_optimized_mesh(mesh_shape,
173
+ axis_names,
174
+ devices=self.devices)
175
+ logger.info(f"Init mesh | mesh={self.mesh}")
176
+
177
+ def _init_phased_profiling(self) -> None:
178
+ self.phased_profiling_dir = os.getenv("PHASED_PROFILING_DIR", "")
179
+ self.phase_based_profiler = None
180
+ if self.phased_profiling_dir:
181
+ self.phase_based_profiler = runner_utils.PhasedBasedProfiler(
182
+ self.phased_profiling_dir)
183
+
184
+ def _init_mm(self) -> None:
185
+ self.is_multimodal_model = None
186
+ self.uses_mrope = self.model_config.uses_mrope
187
+
188
+ def _init_speculative_decoding(self) -> None:
189
+ self.drafter = None
190
+ if self.speculative_config:
191
+ if self.speculative_config.method == "ngram":
192
+ self.drafter = NgramProposer(self.vllm_config)
193
+ elif self.speculative_config.method == "eagle3":
194
+ self.drafter = Eagle3Proposer(self.vllm_config, self)
195
+ else:
196
+ raise NotImplementedError(
197
+ "Unsupported speculative decoding method: "
198
+ f"{self.speculative_config.method}")
199
+ self.rejection_sampler = RejectionSampler()
200
+
201
+ def _init_inputs(self) -> None:
202
+ model_config = self.model_config
203
+ cache_config = self.cache_config
204
+ scheduler_config = self.scheduler_config
205
+
206
+ self.sliding_window = model_config.get_sliding_window()
207
+ self.block_size = cache_config.block_size
208
+ self.max_model_len = model_config.max_model_len
209
+ self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
210
+ # InputBatch needs to work with sampling tensors greater than padding
211
+ # to avoid dynamic shapes. Also, avoid suboptimal alignment.
212
+ self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)
213
+ # [16, 32, 64, 128, 256, 512, 1024, 2048]
214
+ self.num_tokens_paddings = runner_utils.get_token_paddings(
215
+ min_token_size=16,
216
+ max_token_size=scheduler_config.max_num_batched_tokens,
217
+ padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
218
+ # In case `max_num_tokens < max(num_tokens_paddings)` use the actual
219
+ # padded max value to pre-allocate data structures and pre-compile.
220
+ self.max_num_tokens = self.num_tokens_paddings[-1]
221
+
222
+ # Request states.
223
+ self.requests: dict[str, CachedRequestState] = {}
224
+ # mm_hash -> encoder_output
225
+ self.encoder_cache: dict[str, jax.Array] = {}
226
+ self.input_batch = InputBatch(
227
+ max_num_reqs=self.max_num_reqs,
228
+ max_model_len=self.max_model_len,
229
+ max_num_batched_tokens=self.max_num_tokens,
230
+ pin_memory=False,
231
+ vocab_size=self.model_config.get_vocab_size(),
232
+ block_sizes=[self.block_size],
233
+ is_spec_decode=bool(self.vllm_config.speculative_config),
234
+ )
235
+
236
+ self.input_ids_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
237
+ self.positions_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
238
+ self.block_table_cpu = np.zeros(
239
+ (self.max_num_reqs, self.max_num_blocks_per_req), dtype=np.int32)
240
+ self.query_start_loc_cpu = np.zeros(self.max_num_tokens + 1,
241
+ dtype=np.int32)
242
+ self.seq_lens_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
243
+ # Range tensor with values [0 .. self.max_num_tokens - 1].
244
+ # Used to initialize positions / context_lens / seq_lens
245
+ # Keep in int64 to avoid overflow with long context
246
+ self.arange_cpu = np.arange(self.max_num_tokens, dtype=np.int64)
247
+ self.num_reqs_paddings = runner_utils.get_req_paddings(
248
+ min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs)
249
+
250
+ # Padding for logits. Without speculative decoding, each request has one position to select from.
251
+ # With speculative decoding, each request has multiple positions to select from.
252
+ max_logits_per_req = 1
253
+ if self.speculative_config:
254
+ max_logits_per_req = self.speculative_config.num_speculative_tokens + 1 # Including bonus token
255
+ self.num_logits_paddings = runner_utils.get_token_paddings(
256
+ min_token_size=MIN_NUM_SEQS,
257
+ max_token_size=self.max_num_reqs * max_logits_per_req,
258
+ padding_gap=0)
259
+ else:
260
+ self.num_logits_paddings = None
261
+
262
+ self.temperatures_cpu = np.zeros(self.max_num_tokens, dtype=np.float32)
263
+ self.top_ps_cpu = np.zeros(self.max_num_tokens, dtype=np.float32)
264
+ self.top_ks_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
265
+
266
+ # tensors for structured decoding
267
+ self.vocab_size = self.model_config.get_vocab_size()
268
+ if self.lora_config is not None:
269
+ # lora_config.lora_extra_vocab_size is the "Maximum size of extra vocabulary that can be present in a LoRA adapter" per https://github.com/vanbasten23/vllm/blob/7f4a8b6705622fde952a2e633e86716f902d6e1b/vllm/config.py#L3040
270
+ self.vocab_size += self.lora_config.lora_extra_vocab_size
271
+ self.grammar_bitmask_cpu = np.zeros(
272
+ (self.max_num_reqs, cdiv(self.vocab_size, 32)),
273
+ dtype=np.int32,
274
+ )
275
+ self.require_structured_out_cpu = np.zeros(
276
+ (self.max_num_reqs, 1),
277
+ dtype=np.bool_,
278
+ )
279
+ self.structured_decode_arange = np.arange(0, 32, dtype=np.int32)
280
+
281
+ # multi-modal support
282
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
283
+
284
+ # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
285
+ # the modality of inputs. For text-only inputs, each dimension has
286
+ # identical position IDs, making M-RoPE functionally equivalent to
287
+ # 1D-RoPE.
288
+ # See page 5 of https://arxiv.org/abs/2409.12191
289
+ self.mrope_positions_cpu = np.zeros((3, self.max_num_tokens),
290
+ dtype=np.int64)
291
+
292
+ def load_model(self):
293
+ self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, self.get_multimodal_embeddings_fn, self.get_input_embeddings_fn, self.get_mrope_input_positions_fn, self.state, self.lora_manager, self.model = get_model(
294
+ self.vllm_config,
295
+ self.rng_key,
296
+ self.mesh,
297
+ )
298
+
299
+ if self.drafter is not None:
300
+ logger.info("Loading drafter model...")
301
+ self.drafter.load_model(self.state)
302
+
303
+ self.rng_params_for_sampling = nnx.Rngs(
304
+ jax.random.key(self.model_config.seed)).params()
305
+ self.is_multimodal_model = (self.model_config.is_multimodal_model
306
+ and self.get_multimodal_embeddings_fn
307
+ is not None)
308
+
309
+ logger.info(f"Init model | "
310
+ f"hbm={common_utils.hbm_usage_gb(self.devices)}GiB")
311
+
312
+ def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
313
+ return ("generate", )
314
+
315
+ def get_kv_cache_spec(self):
316
+ return self.kv_cache_manager.get_kv_cache_spec()
317
+
318
+ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
319
+ self.kv_cache_config = kv_cache_config
320
+ self.kv_caches = []
321
+ self.kv_cache_manager.initialize_kv_cache(kv_cache_config)
322
+ if has_kv_transfer_group():
323
+ get_kv_transfer_group().register_runner(self)
324
+
325
+ def capture_model(self) -> None:
326
+ self.compilation_manager.capture_model()
327
+
328
+ def execute_model(
329
+ self,
330
+ scheduler_output: "VllmSchedulerOutput",
331
+ intermediate_tensors: Optional[IntermediateTensors] = None,
332
+ ) -> ModelRunnerOutput:
333
+ return self._execute_model(scheduler_output)[1]
334
+
335
+ def _execute_model(
336
+ self,
337
+ scheduler_output: "VllmSchedulerOutput",
338
+ ) -> tuple[AttentionMetadata, ModelRunnerOutput]:
339
+ self.persistent_batch_manager.update_states(
340
+ scheduler_output, self.get_mrope_input_positions_fn)
341
+ if not scheduler_output.total_num_scheduled_tokens:
342
+ if has_kv_transfer_group():
343
+ return DUMMY_METADATA, self.kv_connector_no_forward(
344
+ scheduler_output, self.vllm_config)
345
+
346
+ # Return empty ModelRunnerOutput if there's no work to do.
347
+ # TODO(fhzhang): We rely on empty cycles to remove requests in input batch. Fix it to reduce overhead.
348
+ logger.debug(f"Nothing scheduled: {scheduler_output}!")
349
+ # NOTE(pooyam): There is no guarantee that scheduler is not sending empty output: https://github.com/vllm-project/vllm/blob/7cfea0df390c154c1026f77d3682e2733ca4aca8/vllm/v1/engine/core.py#L275
350
+ # Why they are not preventing that is not clear to me.
351
+ if len(scheduler_output.finished_req_ids) == 0:
352
+ logger.warning(
353
+ "Should not schedule a request that does nothing!")
354
+ # raise Exception(
355
+ # "Should not schedule a request that does nothing!")
356
+ return DUMMY_METADATA, EMPTY_MODEL_RUNNER_OUTPUT,
357
+
358
+ (input_ids, attn_metadata, sampling_metadata, logits_indices,
359
+ spec_decode_metadata) = self._prepare_inputs(scheduler_output)
360
+
361
+ # multi-modal support
362
+ if self.is_multimodal_model:
363
+ # Run the multimodal encoder if any.
364
+ # We have the modality embeds at this time.
365
+ self.mm_manager.execute_mm_encoder(scheduler_output)
366
+ mm_embeds = self.mm_manager.gather_mm_embeddings(scheduler_output)
367
+ else:
368
+ mm_embeds = []
369
+
370
+ # NOTE(Wenlong): For multi-modal model,
371
+ # it will embed the text tokens and merge with the existing modality embeds
372
+ # Later, the multi-modality model will take the embedding as the input.
373
+ # For text-only model, this does nothing. It will input the input_ids and
374
+ # leave the mebedding job inside the forward pass
375
+ input_ids, inputs_embeds = self._get_input_ids_embeds(
376
+ input_ids, mm_embeds)
377
+
378
+ lora_metadata = self.lora_utils.extract_lora_metadata()
379
+ # TODO: make _get_input_ids_embeds within this context
380
+ # NOTE: right now, mm model will use embeddings as the input,
381
+ # but text-only model will use input_ids
382
+ with self.maybe_forbid_compile:
383
+
384
+ with set_forward_context(
385
+ None,
386
+ self.vllm_config,
387
+ ), self.maybe_get_kv_connector_output(
388
+ scheduler_output) as kv_connector_output:
389
+ # NOTE(Wenlong): It takes both `input_ids` and `inputs_embeds`,
390
+ # but one of them would be `None`
391
+
392
+ (self.kv_caches, hidden_states,
393
+ aux_hidden_states) = self.model_fn(
394
+ self.state,
395
+ self.kv_caches,
396
+ input_ids,
397
+ attn_metadata,
398
+ inputs_embeds,
399
+ tuple(self.layer_name_to_kvcache_index.items()),
400
+ lora_metadata,
401
+ )
402
+
403
+ hidden_states = self._select_from_array_fn(hidden_states,
404
+ logits_indices)
405
+ logits = self.compute_logits_fn(
406
+ self.state,
407
+ hidden_states,
408
+ lora_metadata,
409
+ )
410
+ if scheduler_output.grammar_bitmask is not None:
411
+ (
412
+ require_struct_decoding, grammar_bitmask_padded, arange
413
+ ) = self.structured_decoding_manager.prepare_structured_decoding_input(
414
+ logits, scheduler_output)
415
+ logits = self.structured_decoding_manager.structured_decode_fn(
416
+ require_struct_decoding,
417
+ grammar_bitmask_padded,
418
+ logits,
419
+ arange,
420
+ )
421
+ tpu_sampling_metadata = sampling_metadata
422
+ if spec_decode_metadata is None:
423
+ next_tokens = sample(
424
+ self.rng_params_for_sampling,
425
+ self.mesh,
426
+ logits,
427
+ tpu_sampling_metadata,
428
+ )
429
+ else:
430
+ bonus_logits = self._select_from_array_fn(
431
+ logits, spec_decode_metadata.bonus_logits_indices)
432
+ bonus_token_ids = sample(
433
+ self.rng_params_for_sampling,
434
+ self.mesh,
435
+ bonus_logits,
436
+ tpu_sampling_metadata,
437
+ )
438
+ target_logits = self._select_from_array_fn(
439
+ logits, spec_decode_metadata.target_logits_indices)
440
+ next_tokens = self.rejection_sampler(
441
+ draft_token_ids=spec_decode_metadata.draft_token_ids,
442
+ num_draft_tokens=spec_decode_metadata.draft_lengths,
443
+ draft_probs=None,
444
+ target_logits=target_logits,
445
+ bonus_token_ids=bonus_token_ids,
446
+ sampling_metadata=tpu_sampling_metadata,
447
+ key=self.rng_params_for_sampling,
448
+ )
449
+
450
+ if tpu_sampling_metadata.logprobs:
451
+ logprobs = self._compute_and_gather_logprobs(
452
+ logits, next_tokens, self.model_config.max_logprobs)
453
+ else:
454
+ logprobs = None
455
+
456
+ num_reqs = self.input_batch.num_reqs
457
+
458
+ # Update the cache state concurrently. Code above will not block until
459
+ # we use `selected_token_ids`. Add mark_step if post-processing changes
460
+ request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
461
+ discard_sampled_tokens_req_indices = []
462
+ for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
463
+ assert req_id is not None
464
+ req_state = self.requests[req_id]
465
+ seq_len = (req_state.num_computed_tokens +
466
+ scheduler_output.num_scheduled_tokens[req_id])
467
+ if seq_len >= req_state.num_tokens:
468
+ request_seq_lens.append((i, req_state, seq_len))
469
+ else:
470
+ # Ignore the sampled token from the partial request.
471
+ # Rewind the generator state as if the token was not sampled.
472
+ generator = self.input_batch.generators.get(i)
473
+ if generator is not None:
474
+ # This relies on cuda-specific torch-internal impl details
475
+ generator.set_offset(generator.get_offset() - 4)
476
+
477
+ # Record the index of the request that should not be sampled,
478
+ # so that we could clear the sampled tokens before returning.
479
+ discard_sampled_tokens_req_indices.append(i)
480
+
481
+ assert all(
482
+ req_id is not None for req_id in
483
+ self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
484
+ req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs])
485
+
486
+ prompt_logprobs_dict = {}
487
+ for req_id in self.input_batch.req_ids[:num_reqs]:
488
+ prompt_logprobs_dict[req_id] = None
489
+
490
+ if spec_decode_metadata is None:
491
+ next_tokens = np.asarray(jax.device_get(next_tokens))
492
+ selected_token_ids = np.expand_dims(next_tokens[:num_reqs], 1)
493
+ valid_sampled_token_ids = selected_token_ids.tolist()
494
+ else:
495
+ valid_sampled_token_ids = self.rejection_sampler.parse_output(
496
+ next_tokens, self.input_batch.vocab_size,
497
+ spec_decode_metadata.draft_lengths_cpu, num_reqs,
498
+ spec_decode_metadata.draft_token_ids.shape[0])
499
+
500
+ # Mask out the sampled tokens that should not be sampled.
501
+ for i in discard_sampled_tokens_req_indices:
502
+ valid_sampled_token_ids[i].clear()
503
+ # Append sampled tokens
504
+ for req_idx, req_state, _ in request_seq_lens:
505
+ sampled_ids = valid_sampled_token_ids[req_idx]
506
+ if not sampled_ids:
507
+ continue
508
+
509
+ start_idx = self.input_batch.num_tokens_no_spec[req_idx]
510
+ end_idx = start_idx + len(sampled_ids)
511
+ assert end_idx <= self.max_model_len, (
512
+ "Sampled token IDs exceed the max model length. "
513
+ f"Total number of tokens: {end_idx} > max_model_len: "
514
+ f"{self.max_model_len}")
515
+
516
+ self.input_batch.token_ids_cpu[req_idx,
517
+ start_idx:end_idx] = sampled_ids
518
+ self.input_batch.num_tokens_no_spec[req_idx] = end_idx
519
+ self.input_batch.num_tokens[req_idx] = end_idx
520
+ req_state.output_token_ids.extend(sampled_ids)
521
+
522
+ if logprobs is not None:
523
+ logprobs_lists = logprobs.tolists()
524
+ else:
525
+ logprobs_lists = None
526
+
527
+ if self.speculative_config:
528
+ with self.maybe_forbid_compile:
529
+ self.speculative_decoding_manager.propose_draft_token_ids(
530
+ valid_sampled_token_ids,
531
+ aux_hidden_states,
532
+ attn_metadata,
533
+ spec_decode_metadata,
534
+ scheduler_output,
535
+ input_ids,
536
+ )
537
+
538
+ model_runner_output = ModelRunnerOutput(
539
+ req_ids=req_ids,
540
+ req_id_to_index=self.input_batch.req_id_to_index,
541
+ sampled_token_ids=valid_sampled_token_ids,
542
+ logprobs=logprobs_lists,
543
+ prompt_logprobs_dict=prompt_logprobs_dict,
544
+ pooler_output=[],
545
+ kv_connector_output=kv_connector_output,
546
+ )
547
+ return attn_metadata, model_runner_output
548
+
549
+ @functools.partial(jax.jit, static_argnums=(0, ))
550
+ def _select_from_array_fn(self, array, indices_to_select):
551
+ return array[indices_to_select]
552
+
553
+ @staticmethod
554
+ @functools.partial(jax.jit, static_argnames=("max_logprobs", ))
555
+ def _compute_and_gather_logprobs(logits, next_tokens, max_logprobs):
556
+ logprobs = compute_logprobs(logits)
557
+ return gather_logprobs(logprobs, next_tokens, max_logprobs)
558
+
559
+ def _prepare_inputs(self, scheduler_output: "VllmSchedulerOutput"):
560
+ total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
561
+ assert total_num_scheduled_tokens > 0
562
+ num_reqs = self.input_batch.num_reqs
563
+ assert num_reqs > 0
564
+
565
+ # Get the number of scheduled tokens for each request.
566
+ num_scheduled_tokens_per_req = []
567
+ max_num_scheduled_tokens_all_reqs = 0
568
+ for req_id in self.input_batch.req_ids[:num_reqs]:
569
+ assert req_id is not None
570
+ num_tokens = scheduler_output.num_scheduled_tokens[req_id]
571
+ num_scheduled_tokens_per_req.append(num_tokens)
572
+ max_num_scheduled_tokens_all_reqs = max(
573
+ max_num_scheduled_tokens_all_reqs, num_tokens)
574
+ num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req,
575
+ dtype=np.int32)
576
+ assert max_num_scheduled_tokens_all_reqs > 0
577
+ padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
578
+ num_reqs, self.max_num_reqs)
579
+
580
+ # Get request indices.
581
+ # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
582
+ # For each scheduled token, what are the corresponding req index.
583
+ req_indices = np.repeat(self.arange_cpu[:num_reqs],
584
+ num_scheduled_tokens_per_req)
585
+
586
+ # Get batched arange.
587
+ # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
588
+ # For each scheduled token, what is its position in corresponding req.
589
+ arange = np.concatenate(
590
+ [self.arange_cpu[:n] for n in num_scheduled_tokens_per_req])
591
+
592
+ # Get positions.
593
+ positions_np = self.positions_cpu[:total_num_scheduled_tokens]
594
+ np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
595
+ arange,
596
+ out=positions_np)
597
+
598
+ # Multi-modal support
599
+ # Calculate M-RoPE positions.
600
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
601
+ if self.uses_mrope:
602
+ self.mm_manager.calc_mrope_positions(scheduler_output)
603
+
604
+ # Get token indices.
605
+ # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
606
+ # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
607
+ # where M is the max_model_len.
608
+ token_indices = (positions_np +
609
+ req_indices * self.input_batch.token_ids_cpu.shape[1])
610
+
611
+ # NOTE(woosuk): We use torch.index_select instead of np.take here
612
+ # because torch.index_select is much faster than np.take for large
613
+ # tensors.
614
+ np.take(self.input_batch.token_ids_cpu.flatten(),
615
+ token_indices,
616
+ out=self.input_ids_cpu[:total_num_scheduled_tokens])
617
+
618
+ # Prepare the attention metadata.
619
+ self.query_start_loc_cpu[0] = 0
620
+ np.cumsum(num_scheduled_tokens_per_req,
621
+ out=self.query_start_loc_cpu[1:num_reqs + 1])
622
+ self.query_start_loc_cpu[num_reqs + 1:] = 1
623
+
624
+ self.seq_lens_cpu[:num_reqs] = (
625
+ self.input_batch.num_computed_tokens_cpu[:num_reqs] +
626
+ num_scheduled_tokens_per_req)
627
+
628
+ # Do the padding and copy the tensors to the TPU.
629
+ padded_total_num_scheduled_tokens = runner_utils.get_padded_token_len(
630
+ self.num_tokens_paddings, total_num_scheduled_tokens)
631
+ # Zero out to avoid spurious values from prev iteration (last cp chunk)
632
+ self.input_ids_cpu[
633
+ total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0
634
+
635
+ # Please see runner_utils.PhasedBasedProfiler for details
636
+ if self.phase_based_profiler:
637
+ batch_composition_stats = runner_utils.get_batch_composition_stats(
638
+ self.input_batch, total_num_scheduled_tokens, num_reqs,
639
+ padded_total_num_scheduled_tokens, scheduler_output)
640
+
641
+ self.phase_based_profiler.step(batch_composition_stats)
642
+
643
+ # Inputs
644
+ input_ids = self.input_ids_cpu[:padded_total_num_scheduled_tokens]
645
+ positions = self.positions_cpu[:padded_total_num_scheduled_tokens]
646
+ mrope_positions = self.mrope_positions_cpu[:, :
647
+ padded_total_num_scheduled_tokens]
648
+ block_tables = self.block_table_cpu[:self.max_num_reqs]
649
+ block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
650
+ self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
651
+
652
+ # TODO(pooyam): Some paddings are up to `num_reqs_paddings` (spec decoding, select hidden states, etc) and some other are to `max_num_reqs` (block table, seq_lens). We should stick to one of them maybe?
653
+ query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1]
654
+ seq_lens = self.seq_lens_cpu[:self.max_num_reqs]
655
+ request_distribution = np.array(self.input_batch.request_distribution)
656
+ use_spec_decode = len(
657
+ scheduler_output.scheduled_spec_decode_tokens) > 0
658
+ if not use_spec_decode:
659
+ logits_indices = self.query_start_loc_cpu[1:padded_num_reqs +
660
+ 1] - 1
661
+ spec_decode_metadata = None
662
+ else:
663
+ num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
664
+ for req_id, draft_token_ids in (
665
+ scheduler_output.scheduled_spec_decode_tokens.items()):
666
+ req_idx = self.input_batch.req_id_to_index[req_id]
667
+ num_draft_tokens[req_idx] = len(draft_token_ids)
668
+
669
+ spec_decode_metadata = self.speculative_decoding_manager.get_spec_decode_metadata(
670
+ num_draft_tokens, self.query_start_loc_cpu[1:num_reqs + 1],
671
+ padded_num_reqs)
672
+ logits_indices = spec_decode_metadata.final_logits_indices
673
+
674
+ # Put to device
675
+ sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
676
+ self.mesh, self.input_batch, padded_num_reqs)
677
+ if self.uses_mrope:
678
+ positions = mrope_positions
679
+
680
+ # Convert block_tables to 1D on cpu.
681
+ block_tables = block_tables.reshape(-1)
682
+
683
+ query_start_loc_cpu = query_start_loc
684
+ seq_lens_cpu = seq_lens
685
+ (input_ids, positions, block_tables, query_start_loc, seq_lens,
686
+ logits_indices, request_distribution) = device_array(
687
+ self.mesh, (input_ids, positions, block_tables, query_start_loc,
688
+ seq_lens, logits_indices, request_distribution))
689
+
690
+ if self.lora_config is not None:
691
+ self.lora_utils.set_active_loras(
692
+ num_scheduled_tokens_per_req, total_num_scheduled_tokens,
693
+ padded_total_num_scheduled_tokens)
694
+
695
+ attention_metadata = AttentionMetadata(
696
+ input_positions=positions,
697
+ block_tables=block_tables,
698
+ seq_lens=seq_lens,
699
+ query_start_loc=query_start_loc,
700
+ request_distribution=request_distribution)
701
+
702
+ # This is for making these cpu buffers hidden during tracing
703
+ attention_metadata.query_start_loc_cpu = query_start_loc_cpu
704
+ attention_metadata.seq_lens_cpu = seq_lens_cpu
705
+
706
+ return (
707
+ input_ids,
708
+ attention_metadata,
709
+ sampling_metadata,
710
+ logits_indices,
711
+ spec_decode_metadata,
712
+ )
713
+
714
+ def _get_input_ids_embeds(self, input_ids: jax.Array,
715
+ mm_embeds: list[jax.Array]):
716
+ if self.is_multimodal_model:
717
+ inputs_embeds = self.get_input_embeddings_fn(
718
+ self.state,
719
+ input_ids=input_ids,
720
+ multimodal_embeddings=mm_embeds,
721
+ )
722
+ return None, inputs_embeds
723
+ else:
724
+ return input_ids, None
725
+
726
+ def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
727
+ return self.speculative_decoding_manager.take_draft_token_ids()
728
+
729
+ ###### Local disagg utilities ######
730
+
731
+ def get_kv_cache_for_block_ids(
732
+ self,
733
+ block_ids: List[int],
734
+ ) -> List[jax.Array]:
735
+ return self.kv_cache_manager.get_kv_cache_for_block_ids(block_ids)
736
+
737
+ def transfer_kv_cache(self,
738
+ kv_cache_slices: List[jax.Array]) -> List[jax.Array]:
739
+ return self.kv_cache_manager.transfer_kv_cache(kv_cache_slices)
740
+
741
+ def insert_request_with_kv_cache(
742
+ self,
743
+ request: "Request",
744
+ kv_cache_slices: List[jax.Array],
745
+ block_ids: List[List[int]],
746
+ ):
747
+ return self.kv_cache_manager.insert_request_with_kv_cache(
748
+ request, kv_cache_slices, block_ids)
749
+
750
+ ###### RL framework integration ######
751
+
752
+ def _sync_weights(
753
+ self,
754
+ updated_weights: jaxtyping.PyTree,
755
+ mappings: Dict[str, Tuple[str, Tuple[str]]],
756
+ transpose_keys: Dict[str, Tuple[int]],
757
+ reshard_fn: Callable[[jaxtyping.PyTree, jaxtyping.PyTree],
758
+ jaxtyping.PyTree] = None
759
+ ) -> None:
760
+ """For RL framework integration."""
761
+ if reshard_fn is not None:
762
+ updated_weights = reshard_fn(updated_weights, self.state)
763
+ shard = None
764
+ else:
765
+ shard = functools.partial(shard_put, mesh=self.mesh)
766
+ self.state = transfer_state_with_mappings(
767
+ src_state=updated_weights,
768
+ tgt_state=self.state,
769
+ mappings=mappings,
770
+ transpose_keys=transpose_keys,
771
+ shard=shard)