tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (76) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -7
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/lora/utils.py +0 -8
  9. tests/test_envs.py +110 -12
  10. tests/test_quantization.py +3 -0
  11. tests/test_utils.py +1 -2
  12. tpu_inference/__init__.py +22 -3
  13. tpu_inference/core/disagg_utils.py +6 -8
  14. tpu_inference/distributed/tpu_connector.py +3 -4
  15. tpu_inference/distributed/utils.py +3 -2
  16. tpu_inference/envs.py +93 -9
  17. tpu_inference/executors/ray_distributed_executor.py +9 -2
  18. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  19. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  20. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  21. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  22. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  23. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  25. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
  26. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
  27. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  28. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  29. tpu_inference/layers/common/attention_interface.py +7 -1
  30. tpu_inference/layers/common/sharding.py +11 -7
  31. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  32. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  33. tpu_inference/layers/vllm/fused_moe.py +170 -208
  34. tpu_inference/layers/vllm/linear_common.py +43 -21
  35. tpu_inference/layers/vllm/quantization/common.py +11 -6
  36. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  38. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  39. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  40. tpu_inference/layers/vllm/sharding.py +2 -2
  41. tpu_inference/lora/torch_punica_tpu.py +1 -2
  42. tpu_inference/models/common/model_loader.py +84 -28
  43. tpu_inference/models/jax/deepseek_v3.py +185 -64
  44. tpu_inference/models/jax/gpt_oss.py +3 -3
  45. tpu_inference/models/jax/llama3.py +2 -1
  46. tpu_inference/models/jax/llama_eagle3.py +8 -5
  47. tpu_inference/models/jax/llama_guard_4.py +361 -0
  48. tpu_inference/models/jax/qwen2.py +2 -1
  49. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  50. tpu_inference/models/jax/qwen3.py +2 -1
  51. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  52. tpu_inference/models/jax/utils/weight_utils.py +205 -144
  53. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
  54. tpu_inference/platforms/tpu_platform.py +34 -50
  55. tpu_inference/runner/compilation_manager.py +144 -60
  56. tpu_inference/runner/kv_cache.py +40 -20
  57. tpu_inference/runner/kv_cache_manager.py +48 -33
  58. tpu_inference/runner/persistent_batch_manager.py +40 -2
  59. tpu_inference/runner/structured_decoding_manager.py +2 -3
  60. tpu_inference/runner/tpu_runner.py +280 -149
  61. tpu_inference/runner/utils.py +2 -2
  62. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  63. tpu_inference/tpu_info.py +4 -3
  64. tpu_inference/utils.py +46 -18
  65. tpu_inference/worker/tpu_worker.py +197 -63
  66. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
  67. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
  68. tpu_inference/mock/__init__.py +0 -0
  69. tpu_inference/mock/vllm_config_utils.py +0 -28
  70. tpu_inference/mock/vllm_envs.py +0 -1219
  71. tpu_inference/mock/vllm_logger.py +0 -212
  72. tpu_inference/mock/vllm_logging_utils.py +0 -15
  73. tpu_inference/models/jax/phi3.py +0 -376
  74. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  75. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  76. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@ import jax
15
15
  from jax._src.interpreters import pxla
16
16
  from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
17
17
 
18
+ from tpu_inference import envs
18
19
  from tpu_inference.logger import init_logger
19
20
  from tpu_inference.runner.input_batch import InputBatch
20
21
 
@@ -306,8 +307,7 @@ class PhasedBasedProfiler:
306
307
  InferencePhase.BALANCED: False
307
308
  }
308
309
  self.default_profiling_options = jax.profiler.ProfileOptions()
309
- self.default_profiling_options.python_tracer_level = os.getenv(
310
- "PYTHON_TRACER_LEVEL", 0)
310
+ self.default_profiling_options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
311
311
 
312
312
  self.current_phase: str = ""
313
313
 
@@ -6,13 +6,19 @@ from typing import Any, Optional
6
6
  import jax
7
7
  import jax.numpy as jnp
8
8
  import numpy as np
9
+ from flax import nnx
10
+ from jax import lax
11
+ from jax.sharding import NamedSharding, PartitionSpec
9
12
  from vllm.config import VllmConfig
10
13
 
11
14
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
15
+ from tpu_inference.logger import init_logger
12
16
  from tpu_inference.models.common.model_loader import get_model
13
17
  from tpu_inference.runner import utils as runner_utils
14
18
  from tpu_inference.utils import device_array
15
19
 
20
+ logger = init_logger(__name__)
21
+
16
22
 
17
23
  class Eagle3Proposer:
18
24
  """A proposer for speculative decoding using the Eagle3 method.
@@ -51,8 +57,22 @@ class Eagle3Proposer:
51
57
  """Loads the draft model."""
52
58
  self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, _, self.state, _, _ = get_model(
53
59
  self.vllm_config, self.rng_key, self.mesh, is_draft_model=True)
54
- del self.state.model['embed_tokens']
55
- self.state.model.embed_tokens = target_model.model.embed
60
+
61
+ draft_embed_tokens = getattr(self.state.model, 'embed_tokens', None)
62
+ if draft_embed_tokens is None or ~jnp.any(
63
+ draft_embed_tokens.embedding):
64
+ logger.info(
65
+ "Draft model does not have embedding. Setting draft model's embed_tokens to target model's embed"
66
+ )
67
+ self.state.model.embed_tokens = target_model.model.embed
68
+ elif jnp.array_equal(draft_embed_tokens.embedding,
69
+ target_model.model.embed.embedding):
70
+ logger.info(
71
+ "Draft model's embed_tokens is identical to target model's embed. Sharing the embedding."
72
+ )
73
+ self.state.model.embed_tokens = target_model.model.embed
74
+ else:
75
+ logger.info("Draft model has its own embed_tokens.")
56
76
 
57
77
  @functools.partial(jax.jit, static_argnums=(0, ))
58
78
  def _prepare_input_ids(
@@ -110,6 +130,17 @@ class Eagle3Proposer:
110
130
  max_num_blocks_per_req)
111
131
  new_block_tables = jnp.where(expanded_exceeds_mask, -1, block_tables)
112
132
 
133
+ positions = lax.with_sharding_constraint(
134
+ positions, NamedSharding(self.mesh, PartitionSpec(None, )))
135
+ clamped_positions = lax.with_sharding_constraint(
136
+ clamped_positions, NamedSharding(self.mesh, PartitionSpec(None, )))
137
+ new_seq_lens = lax.with_sharding_constraint(
138
+ new_seq_lens, NamedSharding(self.mesh, PartitionSpec(None, )))
139
+ query_start_loc = lax.with_sharding_constraint(
140
+ query_start_loc, NamedSharding(self.mesh, PartitionSpec()))
141
+ new_block_tables = lax.with_sharding_constraint(
142
+ new_block_tables, NamedSharding(self.mesh, PartitionSpec(None, )))
143
+
113
144
  return positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables
114
145
 
115
146
  @functools.partial(jax.jit, static_argnums=(0, ))
@@ -121,6 +152,7 @@ class Eagle3Proposer:
121
152
  @functools.partial(jax.jit, static_argnums=(0, ))
122
153
  def _prepare_hidden_states_and_input_ids(
123
154
  self,
155
+ state: nnx.State,
124
156
  aux_hidden_states: tuple[jax.Array, ...],
125
157
  query_start_loc: jax.Array,
126
158
  target_token_ids: jax.Array,
@@ -129,7 +161,7 @@ class Eagle3Proposer:
129
161
  ) -> tuple[jax.Array, jax.Array, jax.Array]:
130
162
  target_hidden_states = jnp.concatenate(aux_hidden_states, axis=-1)
131
163
  target_hidden_states = self.combine_hidden_states_fn(
132
- self.state, target_hidden_states)
164
+ state, target_hidden_states)
133
165
 
134
166
  input_ids, last_token_indices = self._prepare_input_ids(
135
167
  query_start_loc, target_token_ids, next_token_ids, num_reqs)
@@ -176,8 +208,8 @@ class Eagle3Proposer:
176
208
  block_tables=device_array(
177
209
  self.mesh, block_tables))
178
210
  target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
179
- aux_hidden_states, attn_metadata.query_start_loc, input_ids,
180
- next_token_ids, num_reqs)
211
+ self.state, aux_hidden_states, attn_metadata.query_start_loc,
212
+ input_ids, next_token_ids, num_reqs)
181
213
  return target_hidden_states, input_ids, last_token_indices, attn_metadata
182
214
 
183
215
  # Host copies from the metadata prepared by the runner.
@@ -241,12 +273,13 @@ class Eagle3Proposer:
241
273
 
242
274
  attn_metadata = replace(attn_metadata, block_tables=block_tables)
243
275
  return self._filter_token_and_prepare_initial_inputs(
244
- token_indices, query_start_loc, seq_lens, input_ids,
276
+ self.state, token_indices, query_start_loc, seq_lens, input_ids,
245
277
  aux_hidden_states, attn_metadata, next_token_ids, num_reqs)
246
278
 
247
279
  @functools.partial(jax.jit, static_argnums=(0, ))
248
280
  def _filter_token_and_prepare_initial_inputs(
249
281
  self,
282
+ state: nnx.State,
250
283
  token_indices: jax.Array,
251
284
  query_start_loc: jax.Array,
252
285
  seq_lens: jax.Array,
@@ -274,35 +307,51 @@ class Eagle3Proposer:
274
307
  )
275
308
 
276
309
  target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
277
- [h[token_indices] for h in aux_hidden_states], query_start_loc,
278
- target_token_ids, next_token_ids, num_reqs)
310
+ state, [h[token_indices] for h in aux_hidden_states],
311
+ query_start_loc, target_token_ids, next_token_ids, num_reqs)
279
312
 
280
313
  return target_hidden_states, input_ids, last_token_indices, attn_metadata
281
314
 
282
315
  @functools.partial(jax.jit, static_argnums=(0, ))
283
316
  def _select_draft_token_ids(
284
317
  self,
318
+ state: nnx.State,
285
319
  hidden_states: jax.Array,
286
320
  last_token_indices: jax.Array,
287
321
  ) -> jax.Array:
288
322
  sample_hidden_states = hidden_states[last_token_indices]
289
- return self._get_draft_token_ids(sample_hidden_states)
323
+ sample_hidden_states = lax.with_sharding_constraint(
324
+ sample_hidden_states,
325
+ NamedSharding(self.mesh, PartitionSpec(None, None)))
326
+ return self._get_draft_token_ids(state, sample_hidden_states)
290
327
 
291
328
  @functools.partial(jax.jit, static_argnums=(0, ))
292
- def _get_draft_token_ids(self, hidden_states: jax.Array) -> jax.Array:
329
+ def _get_draft_token_ids(self, state: nnx.State,
330
+ hidden_states: jax.Array) -> jax.Array:
293
331
  lora_metadata = None
294
- logits = self.compute_logits_fn(self.state, hidden_states,
295
- lora_metadata)
296
- return jnp.argmax(logits, axis=-1)
332
+ logits = self.compute_logits_fn(state, hidden_states, lora_metadata)
333
+ draft_token_ids = jnp.argmax(logits, axis=-1)
334
+ return lax.with_sharding_constraint(
335
+ draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
297
336
 
298
337
  @functools.partial(jax.jit, static_argnums=(0, ))
299
338
  def _select_inputs_for_loop_speculation(
300
- self, positions: jax.Array, residual: jax.Array,
339
+ self, state: nnx.State, positions: jax.Array, residual: jax.Array,
301
340
  hidden_states: jax.Array,
302
341
  last_token_indices: jax.Array) -> tuple[jax.Array, jax.Array]:
303
- return positions[last_token_indices], residual[
304
- last_token_indices], self._select_draft_token_ids(
305
- hidden_states, last_token_indices)
342
+ positions = positions[last_token_indices]
343
+ residual = residual[last_token_indices]
344
+ draft_token_ids = self._select_draft_token_ids(state, hidden_states,
345
+ last_token_indices)
346
+
347
+ positions = lax.with_sharding_constraint(
348
+ positions, NamedSharding(self.mesh, PartitionSpec(None, )))
349
+ residual = lax.with_sharding_constraint(
350
+ residual, NamedSharding(self.mesh, PartitionSpec(None, None)))
351
+ draft_token_ids = lax.with_sharding_constraint(
352
+ draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
353
+
354
+ return positions, residual, draft_token_ids
306
355
 
307
356
  def propose(
308
357
  self,
@@ -329,11 +378,11 @@ class Eagle3Proposer:
329
378
 
330
379
  if self.num_speculative_tokens == 1:
331
380
  return kv_caches, self._select_draft_token_ids(
332
- hidden_states, last_token_indices)
381
+ self.state, hidden_states, last_token_indices)
333
382
 
334
383
  positions, hidden_states, draft_token_ids = self._select_inputs_for_loop_speculation(
335
- attn_metadata.input_positions, residual[0], hidden_states,
336
- last_token_indices)
384
+ self.state, attn_metadata.input_positions, residual[0],
385
+ hidden_states, last_token_indices)
337
386
 
338
387
  draft_token_ids_list = [draft_token_ids]
339
388
 
@@ -358,7 +407,8 @@ class Eagle3Proposer:
358
407
  attn_metadata,
359
408
  )
360
409
  hidden_states = residual[0]
361
- draft_token_ids = self._get_draft_token_ids(new_hidden_states)
410
+ draft_token_ids = self._get_draft_token_ids(
411
+ self.state, new_hidden_states)
362
412
  draft_token_ids_list.append(draft_token_ids)
363
413
 
364
414
  # [batch_size, num_speculative_tokens]
tpu_inference/tpu_info.py CHANGED
@@ -3,6 +3,7 @@ import os
3
3
 
4
4
  import requests
5
5
 
6
+ from tpu_inference import envs
6
7
  from tpu_inference.logger import init_logger
7
8
 
8
9
  logger = init_logger(__name__)
@@ -32,14 +33,14 @@ def get_tpu_metadata(key: str = "") -> str:
32
33
 
33
34
 
34
35
  def get_tpu_type() -> str:
35
- tpu_type = os.getenv("TPU_ACCELERATOR_TYPE", None)
36
+ tpu_type = envs.TPU_ACCELERATOR_TYPE
36
37
  if tpu_type is None:
37
38
  tpu_type = get_tpu_metadata(key="accelerator-type")
38
39
  return tpu_type
39
40
 
40
41
 
41
42
  def get_node_name() -> str:
42
- tpu_name = os.getenv("TPU_NAME", None)
43
+ tpu_name = envs.TPU_NAME
43
44
  if not tpu_name:
44
45
  tpu_name = get_tpu_metadata(key="instance-id")
45
46
  return tpu_name
@@ -47,7 +48,7 @@ def get_node_name() -> str:
47
48
 
48
49
  def get_node_worker_id() -> int:
49
50
  """For multi-host TPU VM, this returns the worker id for the current node."""
50
- worker_id = os.getenv("TPU_WORKER_ID", None)
51
+ worker_id = envs.TPU_WORKER_ID
51
52
  if worker_id is None:
52
53
  worker_id = get_tpu_metadata(key="agent-worker-number")
53
54
  if worker_id is None:
tpu_inference/utils.py CHANGED
@@ -1,5 +1,4 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
- import os
3
2
  import time
4
3
  from collections import defaultdict
5
4
  from collections.abc import Sequence
@@ -9,34 +8,62 @@ from typing import Any, Callable, List, Tuple
9
8
  import jax
10
9
  import jax.numpy as jnp
11
10
  import numpy as np
11
+ import torch
12
12
  from jax._src import dtypes
13
13
  from jax._src import mesh as mesh_lib
14
14
  from jax._src import xla_bridge as xb
15
15
  from jax._src.lib import xla_client as xc
16
+ from jax._src.numpy.scalar_types import _ScalarMeta
16
17
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
17
- from vllm import envs, utils
18
+ from torchax.ops.mappings import j2t_dtype, t2j_dtype
19
+ from vllm import envs as vllm_envs
20
+ from vllm import utils
18
21
 
22
+ from tpu_inference import envs
19
23
  from tpu_inference.logger import init_logger
20
24
 
21
25
  GBYTES = 1024 * 1024 * 1024
22
26
  TPU_HEAD_SIZE_ALIGNMENT = 128
23
27
  TPU_SECOND_LAST_MINOR = 8
24
28
 
25
- # This is used to translate from a string name for a dtype
26
- # to formal jax.numpy DType. One use case for this is
27
- # converting the `--kv_cache_dtype` flag to a dtype.
28
- TPU_STR_DTYPE_TO_JAX_DTYPE = {
29
- "bfloat16": jnp.bfloat16,
30
- "fp8": jnp.float8_e4m3fn,
31
- "fp8_e4m3": jnp.float8_e4m3,
32
- "fp8_e5m2": jnp.float8_e5m2,
33
- "int8": jnp.int8,
29
+ # Map vllm dtype string that doesn't exactly match jax dtype string name.
30
+ _VLLM_DTYPE_STR_TO_JAX_DTYPE = {
31
+ "fp8": jnp.float8_e4m3fn.dtype,
32
+ "fp8_e4m3": jnp.float8_e4m3fn.dtype,
33
+ "fp8_e5m2": jnp.float8_e5m2.dtype,
34
34
  }
35
35
 
36
+
37
+ def to_jax_dtype(dtype: str | jnp.dtype | torch.dtype) -> jnp.dtype:
38
+ if isinstance(dtype, str):
39
+ if dict_dtype := _VLLM_DTYPE_STR_TO_JAX_DTYPE.get(dtype, None):
40
+ return dict_dtype
41
+ return jnp.dtype(dtype)
42
+ elif isinstance(dtype, torch.dtype):
43
+ return t2j_dtype(dtype)
44
+ elif isinstance(dtype, jnp.dtype):
45
+ return dtype
46
+ elif isinstance(dtype, _ScalarMeta):
47
+ return dtype.dtype
48
+ else:
49
+ raise ValueError(f"Argument is unsupported data type {type(dtype)}")
50
+
51
+
52
+ def to_torch_dtype(dtype: str | jnp.dtype | torch.dtype) -> torch.dtype:
53
+ # Use jax dtype as an intermediate dtype which we'll be used to convert it
54
+ # into torch dtype.
55
+ dtype = to_jax_dtype(dtype)
56
+ return j2t_dtype(dtype)
57
+
58
+
36
59
  _megacore = False
37
60
  logger = init_logger(__name__)
38
61
 
39
62
 
63
+ def align_to(unpadded_dim, pad_multiple):
64
+ return (unpadded_dim + pad_multiple - 1) // pad_multiple * pad_multiple
65
+
66
+
40
67
  def enable_megacore() -> None:
41
68
  global _megacore
42
69
  _megacore = True
@@ -57,10 +84,10 @@ def get_num_kv_heads_by_tp(num_kv_heads: int, tp_size: int) -> int:
57
84
 
58
85
  def hbm_usage_bytes(devices: Any) -> List[Tuple[int, int]]:
59
86
  usage = []
60
- if envs.VLLM_TPU_USING_PATHWAYS:
87
+ if vllm_envs.VLLM_TPU_USING_PATHWAYS:
61
88
  return pathways_hbm_usage_gb(devices)
62
89
 
63
- multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
90
+ multihost_backend = envs.TPU_MULTIHOST_BACKEND
64
91
  if multihost_backend == "ray":
65
92
  # MemoryStats is only supported for addressable PjRt devices.
66
93
  # Assume all the devices have similar memory usage for now.
@@ -163,7 +190,8 @@ def get_padded_num_heads(num_heads: int, sharding_size: int) -> int:
163
190
 
164
191
 
165
192
  def get_dtype_packing(dtype):
166
- bits = dtypes.bit_width(dtype)
193
+ bits = (dtypes.bit_width(dtype)
194
+ if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
167
195
  return 32 // bits
168
196
 
169
197
 
@@ -248,11 +276,11 @@ def device_array(mesh: Mesh, *args, sharding=None, **kwargs) -> jax.Array:
248
276
 
249
277
  def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
250
278
  """
251
- A wrapper function of vllm.utils.get_hash_fn_by_name to support builtin
279
+ A wrapper function of vllm.utils.hashing.get_hash_fn_by_name to support builtin
252
280
  """
253
281
  if hash_fn_name == "builtin":
254
282
  return hash
255
- return utils.get_hash_fn_by_name(hash_fn_name)
283
+ return utils.hashing.get_hash_fn_by_name(hash_fn_name)
256
284
 
257
285
 
258
286
  def quantize_kv(key: jax.Array, value: jax.Array,
@@ -294,8 +322,8 @@ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
294
322
  Returns:
295
323
  jnp.dtype: The JAX dtype.
296
324
  """
297
- str_dtype = str_dtype.lower().strip()
298
- return TPU_STR_DTYPE_TO_JAX_DTYPE.get(str_dtype)
325
+ # TODO(kyuyeunk): Replace all reference of this function into TpuDtype.
326
+ return to_jax_dtype(str_dtype)
299
327
 
300
328
 
301
329
  def time_function(func):