tpu-inference 0.0.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (174) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +374 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +648 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +88 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +203 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +235 -0
  27. tpu_inference/__init__.py +53 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +49 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +727 -0
  37. tpu_inference/distributed/utils.py +60 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +160 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +382 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1566 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1501 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1603 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +396 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +469 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +110 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +331 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +368 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +310 -0
  120. tpu_inference/models/__init__.py +0 -0
  121. tpu_inference/models/common/__init__.py +0 -0
  122. tpu_inference/models/common/model_loader.py +478 -0
  123. tpu_inference/models/jax/__init__.py +0 -0
  124. tpu_inference/models/jax/deepseek_v3.py +868 -0
  125. tpu_inference/models/jax/gpt_oss.py +492 -0
  126. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  127. tpu_inference/models/jax/llama3.py +376 -0
  128. tpu_inference/models/jax/llama4.py +629 -0
  129. tpu_inference/models/jax/llama_eagle3.py +336 -0
  130. tpu_inference/models/jax/llama_guard_4.py +361 -0
  131. tpu_inference/models/jax/qwen2.py +376 -0
  132. tpu_inference/models/jax/qwen2_5_vl.py +1218 -0
  133. tpu_inference/models/jax/qwen3.py +303 -0
  134. tpu_inference/models/jax/utils/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/file_utils.py +96 -0
  136. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  137. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  138. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  139. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  140. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  141. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  142. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  143. tpu_inference/models/jax/utils/quantization/quantization_utils.py +650 -0
  144. tpu_inference/models/jax/utils/weight_utils.py +584 -0
  145. tpu_inference/models/vllm/__init__.py +0 -0
  146. tpu_inference/models/vllm/vllm_model_wrapper.py +293 -0
  147. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  148. tpu_inference/platforms/__init__.py +2 -0
  149. tpu_inference/platforms/tpu_platform.py +275 -0
  150. tpu_inference/runner/__init__.py +0 -0
  151. tpu_inference/runner/block_table.py +122 -0
  152. tpu_inference/runner/compilation_manager.py +865 -0
  153. tpu_inference/runner/input_batch.py +435 -0
  154. tpu_inference/runner/kv_cache.py +132 -0
  155. tpu_inference/runner/kv_cache_manager.py +478 -0
  156. tpu_inference/runner/lora_utils.py +92 -0
  157. tpu_inference/runner/multimodal_manager.py +217 -0
  158. tpu_inference/runner/persistent_batch_manager.py +282 -0
  159. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  160. tpu_inference/runner/structured_decoding_manager.py +87 -0
  161. tpu_inference/runner/tpu_runner.py +1744 -0
  162. tpu_inference/runner/utils.py +426 -0
  163. tpu_inference/spec_decode/__init__.py +0 -0
  164. tpu_inference/spec_decode/jax/__init__.py +0 -0
  165. tpu_inference/spec_decode/jax/eagle3.py +417 -0
  166. tpu_inference/tpu_info.py +78 -0
  167. tpu_inference/utils.py +340 -0
  168. tpu_inference/worker/__init__.py +0 -0
  169. tpu_inference/worker/tpu_worker.py +458 -0
  170. tpu_inference-0.0.1rc1.dist-info/METADATA +108 -0
  171. tpu_inference-0.0.1rc1.dist-info/RECORD +174 -0
  172. tpu_inference-0.0.1rc1.dist-info/WHEEL +5 -0
  173. tpu_inference-0.0.1rc1.dist-info/licenses/LICENSE +201 -0
  174. tpu_inference-0.0.1rc1.dist-info/top_level.txt +2 -0
tpu_inference/utils.py ADDED
@@ -0,0 +1,340 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ import time
3
+ from collections import defaultdict
4
+ from collections.abc import Sequence
5
+ from functools import wraps
6
+ from typing import Any, Callable, List, Tuple
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import numpy as np
11
+ import torch
12
+ from jax._src import dtypes
13
+ from jax._src import mesh as mesh_lib
14
+ from jax._src import xla_bridge as xb
15
+ from jax._src.lib import xla_client as xc
16
+ from jax._src.numpy.scalar_types import _ScalarMeta
17
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
18
+ from torchax.ops.mappings import j2t_dtype, t2j_dtype
19
+ from vllm import envs as vllm_envs
20
+ from vllm import utils
21
+
22
+ from tpu_inference import envs
23
+ from tpu_inference.logger import init_logger
24
+
25
+ GBYTES = 1024 * 1024 * 1024
26
+ TPU_HEAD_SIZE_ALIGNMENT = 128
27
+ TPU_SECOND_LAST_MINOR = 8
28
+
29
+ # Map vllm dtype string that doesn't exactly match jax dtype string name.
30
+ _VLLM_DTYPE_STR_TO_JAX_DTYPE = {
31
+ "fp8": jnp.float8_e4m3fn,
32
+ "fp8_e4m3": jnp.float8_e4m3fn,
33
+ "fp8_e5m2": jnp.float8_e5m2,
34
+ }
35
+
36
+
37
+ def to_jax_dtype(dtype: str | jnp.dtype | torch.dtype) -> jnp.dtype:
38
+ if isinstance(dtype, str):
39
+ if dict_dtype := _VLLM_DTYPE_STR_TO_JAX_DTYPE.get(dtype, None):
40
+ return dict_dtype
41
+ return jnp.dtype(dtype)
42
+ elif isinstance(dtype, torch.dtype):
43
+ return t2j_dtype(dtype)
44
+ elif isinstance(dtype, jnp.dtype):
45
+ return dtype
46
+ elif isinstance(dtype, _ScalarMeta):
47
+ return dtype.dtype
48
+ else:
49
+ raise ValueError(f"Argument is unsupported data type {type(dtype)}")
50
+
51
+
52
+ def to_torch_dtype(dtype: str | jnp.dtype | torch.dtype) -> torch.dtype:
53
+ # Use jax dtype as an intermediate dtype which we'll be used to convert it
54
+ # into torch dtype.
55
+ dtype = to_jax_dtype(dtype)
56
+ return j2t_dtype(dtype)
57
+
58
+
59
+ _megacore = False
60
+ logger = init_logger(__name__)
61
+
62
+
63
+ def enable_megacore() -> None:
64
+ global _megacore
65
+ _megacore = True
66
+
67
+
68
+ def get_megacore() -> bool:
69
+ return _megacore
70
+
71
+
72
+ def get_num_kv_heads_by_tp(num_kv_heads: int, tp_size: int) -> int:
73
+ if tp_size <= num_kv_heads:
74
+ assert num_kv_heads % tp_size == 0
75
+ return num_kv_heads
76
+ else:
77
+ assert tp_size % num_kv_heads == 0
78
+ return tp_size
79
+
80
+
81
+ def hbm_usage_bytes(devices: Any) -> List[Tuple[int, int]]:
82
+ usage = []
83
+ if vllm_envs.VLLM_TPU_USING_PATHWAYS:
84
+ return pathways_hbm_usage_gb(devices)
85
+
86
+ multihost_backend = envs.TPU_MULTIHOST_BACKEND
87
+ if multihost_backend == "ray":
88
+ # MemoryStats is only supported for addressable PjRt devices.
89
+ # Assume all the devices have similar memory usage for now.
90
+ # TODO(ranlihao): find a proper way to get the memory usage of each device.
91
+ for device in devices:
92
+ try:
93
+ hbm_used = device.memory_stats()["bytes_in_use"]
94
+ hbm_limit = device.memory_stats()["bytes_limit"]
95
+ logger.info(
96
+ "Get memory stats for device %s. Assuming all devices have the same usage.",
97
+ device)
98
+ usage.extend([(hbm_used, hbm_limit)] * len(devices))
99
+ break
100
+ except Exception as e:
101
+ logger.warning(
102
+ "Failed to get memory stats for device %s: %s. ", device,
103
+ e)
104
+ else:
105
+ for device in devices:
106
+ hbm_used = device.memory_stats()["bytes_in_use"]
107
+ hbm_limit = device.memory_stats()["bytes_limit"]
108
+ usage.append((hbm_used, hbm_limit))
109
+
110
+ return usage
111
+
112
+
113
+ def get_device_name(num_devices: int | None = None):
114
+ kind = jax.devices()[0].device_kind
115
+ if 'TPU' not in kind:
116
+ raise RuntimeError('Expected TPU devices')
117
+ suffix = ''
118
+ if kind.endswith(' lite'):
119
+ kind = kind[:-len(' lite')]
120
+ suffix = 'e'
121
+ elif kind.endswith('e'):
122
+ kind = kind[:-1]
123
+ suffix = 'e'
124
+ elif kind.endswith('p'):
125
+ kind = kind[:-1]
126
+ suffix = 'p'
127
+ elif kind == 'TPU7x':
128
+ kind = 'TPU v7'
129
+ assert kind[:-1] == 'TPU v', kind
130
+ kind += suffix
131
+ if num_devices is not None:
132
+ kind += f'-{num_devices}'
133
+ return kind
134
+
135
+
136
+ def get_device_hbm_limit() -> int:
137
+
138
+ device_kind = get_device_name()
139
+ if device_kind == "TPU v5p" or device_kind == "TPU v5":
140
+ return 95 * GBYTES
141
+ elif device_kind == "TPU v5e":
142
+ return 16 * GBYTES
143
+ elif device_kind == "TPU v6e" or device_kind == "TPU v4":
144
+ return 32 * GBYTES
145
+ elif device_kind == "TPU v7":
146
+ # 192 * GBYTES / 2 because each JAX device (v7x core) has
147
+ # 1/2 of the total chip HBM
148
+ return 96 * GBYTES
149
+ else:
150
+ raise ValueError(f"Unknown device kind: {device_kind}")
151
+
152
+
153
+ def pathways_hbm_usage_gb(devices: Any) -> List[Tuple[float, float]]:
154
+ live_arrays = jax.live_arrays()
155
+ hbm_used = defaultdict(int)
156
+ hbm_limit = get_device_hbm_limit()
157
+ for array in live_arrays:
158
+ for buffer in array.addressable_shards:
159
+ hbm_used[buffer.data.device] += buffer.data.nbytes
160
+ return [(hbm_used[device], hbm_limit) for device in devices]
161
+
162
+
163
+ def hbm_usage_gb(devices: Any) -> List[Tuple[float, float]]:
164
+ usage = hbm_usage_bytes(devices)
165
+ usage = [(round(used / GBYTES, 2), round(limit / GBYTES, 2))
166
+ for used, limit in usage]
167
+ return usage
168
+
169
+
170
+ def get_padded_head_dim(head_dim: int) -> int:
171
+ """Pads head_dim up to the nearest multiple of 128 for kernel performance."""
172
+ # When head_dim == 64, we use kernel specificly optimized for it which does
173
+ # not require any padding.
174
+ if head_dim == 64:
175
+ return 64
176
+ return (head_dim + 127) // 128 * 128
177
+
178
+
179
+ def get_padded_num_heads(num_heads: int, sharding_size: int) -> int:
180
+ if num_heads >= sharding_size:
181
+ assert num_heads % sharding_size == 0
182
+ else:
183
+ assert sharding_size % num_heads == 0
184
+ num_heads = sharding_size
185
+ return num_heads
186
+
187
+
188
+ def get_dtype_packing(dtype):
189
+ bits = dtypes.bit_width(dtype)
190
+ return 32 // bits
191
+
192
+
193
+ def make_optimized_mesh(axis_shapes: Sequence[int],
194
+ axis_names: Sequence[str],
195
+ *,
196
+ devices: Sequence[xc.Device] | None = None):
197
+ if devices is None:
198
+ devices = xb.devices()
199
+ # Sort the devices in case it's passed in an arbitary order
200
+ devices = sorted(devices, key=lambda x: x.coords)
201
+
202
+ def _is_1D(axis_shapes):
203
+ return sum(x > 1 for x in axis_shapes) == 1
204
+
205
+ if _is_1D(axis_shapes):
206
+ dev_kind = devices[0].device_kind
207
+ device_num = len(devices)
208
+ if dev_kind == "TPU v6 lite":
209
+ ordered_devices = None
210
+ # NOTE(chengjiyao):
211
+ # The coords of v6e-8 are
212
+ # (0,0,0)
213
+ # (1,0,0)
214
+ # (0,1,0)
215
+ # (1,1,0)
216
+ # (0,2,0)
217
+ # (1,2,0)
218
+ # (0,3,0)
219
+ # (1,3,0)
220
+ if device_num == 8:
221
+ ordered_devices = np.array([
222
+ devices[0],
223
+ devices[1],
224
+ devices[2],
225
+ devices[3],
226
+ devices[7],
227
+ devices[6],
228
+ devices[5],
229
+ devices[4],
230
+ ])
231
+ # NOTE(chengjiyao):
232
+ # The coords of v6e-4 are
233
+ # (0,0,0)
234
+ # (1,0,0)
235
+ # (0,1,0)
236
+ # (1,1,0)
237
+ elif device_num == 4:
238
+ ordered_devices = np.array([
239
+ devices[0],
240
+ devices[1],
241
+ devices[3],
242
+ devices[2],
243
+ ])
244
+ if ordered_devices is not None:
245
+ ordered_devices = np.array(ordered_devices)
246
+ ordered_devices = ordered_devices.reshape(axis_shapes)
247
+ mesh = mesh_lib.Mesh(ordered_devices, axis_names)
248
+ logger.info("Use customized mesh: %s", mesh)
249
+ return mesh
250
+
251
+ return jax.make_mesh(axis_shapes, axis_names, devices=devices)
252
+
253
+
254
+ def device_array(mesh: Mesh, *args, sharding=None, **kwargs) -> jax.Array:
255
+ """
256
+ Create a device array with the specified mesh and sharding.
257
+
258
+ Args:
259
+ mesh: The JAX mesh to use for device placement
260
+ *args: Positional arguments to pass to jax.device_put
261
+ sharding: Optional sharding specification. If None, uses PartitionSpec(None)
262
+ **kwargs: Keyword arguments to pass to jax.device_put
263
+
264
+ Returns:
265
+ A JAX array placed on the specified devices
266
+ """
267
+ if sharding is None:
268
+ sharding = NamedSharding(mesh, PartitionSpec(None))
269
+ return jax.device_put(*args, device=sharding, **kwargs)
270
+
271
+
272
+ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
273
+ """
274
+ A wrapper function of vllm.utils.get_hash_fn_by_name to support builtin
275
+ """
276
+ if hash_fn_name == "builtin":
277
+ return hash
278
+ return utils.get_hash_fn_by_name(hash_fn_name)
279
+
280
+
281
+ def quantize_kv(key: jax.Array, value: jax.Array,
282
+ kv_cache_quantized_dtype: jnp.dtype, k_scale: float,
283
+ v_scale: float) -> Tuple[jax.Array, jax.Array]:
284
+ """
285
+ Quantize the key and value tensors.
286
+
287
+ Args:
288
+ key: The key tensor to quantize.
289
+ value: The value tensor to quantize.
290
+ kv_cache_quantized_dtype: The dtype to quantize the key and value tensors to.
291
+ q_scale: The scale to quantize the key and value tensors by.
292
+ k_scale: The scale to quantize the key tensor by.
293
+ v_scale: The scale to quantize the value tensor by.
294
+
295
+ Returns:
296
+ Tuple[jax.Array, jax.Array]: The quantized key and value tensors.
297
+ """
298
+ dtype_info = jnp.finfo(kv_cache_quantized_dtype)
299
+ minval, maxval = float(dtype_info.min), float(dtype_info.max)
300
+ key = key.astype(jnp.float32) / k_scale
301
+ key = jnp.clip(key, minval, maxval)
302
+ key = key.astype(kv_cache_quantized_dtype)
303
+ value = value.astype(jnp.float32) / v_scale
304
+ value = jnp.clip(value, minval, maxval)
305
+ value = value.astype(kv_cache_quantized_dtype)
306
+
307
+ return key, value
308
+
309
+
310
+ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
311
+ """
312
+ Get the JAX dtype from a string dtype.
313
+
314
+ Args:
315
+ str_dtype: The string dtype to get the JAX dtype from.
316
+
317
+ Returns:
318
+ jnp.dtype: The JAX dtype.
319
+ """
320
+ # TODO(kyuyeunk): Replace all reference of this function into TpuDtype.
321
+ return to_jax_dtype(str_dtype)
322
+
323
+
324
+ def time_function(func):
325
+ """
326
+ A decorator to measure the execution time of a function.
327
+ """
328
+
329
+ @wraps(func)
330
+ def wrapper(*args, **kwargs):
331
+ start_time = time.perf_counter()
332
+ result = func(*args, **kwargs)
333
+ end_time = time.perf_counter()
334
+ execution_time = end_time - start_time
335
+ logger.debug(
336
+ f"Function '{func.__name__}' executed in {execution_time:.4f} seconds."
337
+ )
338
+ return result
339
+
340
+ return wrapper
File without changes