tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (76) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -7
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/lora/utils.py +0 -8
  9. tests/test_envs.py +110 -12
  10. tests/test_quantization.py +3 -0
  11. tests/test_utils.py +1 -2
  12. tpu_inference/__init__.py +22 -3
  13. tpu_inference/core/disagg_utils.py +6 -8
  14. tpu_inference/distributed/tpu_connector.py +3 -4
  15. tpu_inference/distributed/utils.py +3 -2
  16. tpu_inference/envs.py +93 -9
  17. tpu_inference/executors/ray_distributed_executor.py +9 -2
  18. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  19. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  20. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  21. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  22. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  23. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  25. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
  26. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
  27. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  28. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  29. tpu_inference/layers/common/attention_interface.py +7 -1
  30. tpu_inference/layers/common/sharding.py +11 -7
  31. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  32. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  33. tpu_inference/layers/vllm/fused_moe.py +170 -208
  34. tpu_inference/layers/vllm/linear_common.py +43 -21
  35. tpu_inference/layers/vllm/quantization/common.py +11 -6
  36. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  38. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  39. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  40. tpu_inference/layers/vllm/sharding.py +2 -2
  41. tpu_inference/lora/torch_punica_tpu.py +1 -2
  42. tpu_inference/models/common/model_loader.py +84 -28
  43. tpu_inference/models/jax/deepseek_v3.py +185 -64
  44. tpu_inference/models/jax/gpt_oss.py +3 -3
  45. tpu_inference/models/jax/llama3.py +2 -1
  46. tpu_inference/models/jax/llama_eagle3.py +8 -5
  47. tpu_inference/models/jax/llama_guard_4.py +361 -0
  48. tpu_inference/models/jax/qwen2.py +2 -1
  49. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  50. tpu_inference/models/jax/qwen3.py +2 -1
  51. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  52. tpu_inference/models/jax/utils/weight_utils.py +205 -144
  53. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
  54. tpu_inference/platforms/tpu_platform.py +34 -50
  55. tpu_inference/runner/compilation_manager.py +144 -60
  56. tpu_inference/runner/kv_cache.py +40 -20
  57. tpu_inference/runner/kv_cache_manager.py +48 -33
  58. tpu_inference/runner/persistent_batch_manager.py +40 -2
  59. tpu_inference/runner/structured_decoding_manager.py +2 -3
  60. tpu_inference/runner/tpu_runner.py +280 -149
  61. tpu_inference/runner/utils.py +2 -2
  62. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  63. tpu_inference/tpu_info.py +4 -3
  64. tpu_inference/utils.py +46 -18
  65. tpu_inference/worker/tpu_worker.py +197 -63
  66. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
  67. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
  68. tpu_inference/mock/__init__.py +0 -0
  69. tpu_inference/mock/vllm_config_utils.py +0 -28
  70. tpu_inference/mock/vllm_envs.py +0 -1219
  71. tpu_inference/mock/vllm_logger.py +0 -212
  72. tpu_inference/mock/vllm_logging_utils.py +0 -15
  73. tpu_inference/models/jax/phi3.py +0 -376
  74. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  75. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  76. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -1,212 +0,0 @@
1
- # SPDX-License-Identifier: Apache-2.0
2
- # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
- """Logging configuration for vLLM."""
4
- import datetime
5
- import json
6
- import logging
7
- import sys
8
- from collections.abc import Hashable
9
- from functools import lru_cache, partial
10
- from logging import Logger
11
- from logging.config import dictConfig
12
- from os import path
13
- from types import MethodType
14
- from typing import Any, cast
15
-
16
- import tpu_inference.mock.vllm_envs as envs
17
-
18
- VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING
19
- VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH
20
- VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL
21
- VLLM_LOGGING_PREFIX = envs.VLLM_LOGGING_PREFIX
22
-
23
- _FORMAT = (f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s "
24
- "[%(filename)s:%(lineno)d] %(message)s")
25
- _DATE_FORMAT = "%m-%d %H:%M:%S"
26
-
27
- DEFAULT_LOGGING_CONFIG = {
28
- "formatters": {
29
- "vllm": {
30
- "class": "tpu_inference.vllm_logging_utils.NewLineFormatter",
31
- "datefmt": _DATE_FORMAT,
32
- "format": _FORMAT,
33
- },
34
- },
35
- "handlers": {
36
- "vllm": {
37
- "class": "logging.StreamHandler",
38
- "formatter": "vllm",
39
- "level": VLLM_LOGGING_LEVEL,
40
- "stream": "ext://sys.stdout",
41
- },
42
- },
43
- "loggers": {
44
- "vllm": {
45
- "handlers": ["vllm"],
46
- "level": "DEBUG",
47
- "propagate": False,
48
- },
49
- },
50
- "version": 1,
51
- "disable_existing_loggers": False
52
- }
53
-
54
-
55
- @lru_cache
56
- def _print_debug_once(logger: Logger, msg: str, *args: Hashable) -> None:
57
- # Set the stacklevel to 2 to print the original caller's line info
58
- logger.debug(msg, *args, stacklevel=2)
59
-
60
-
61
- @lru_cache
62
- def _print_info_once(logger: Logger, msg: str, *args: Hashable) -> None:
63
- # Set the stacklevel to 2 to print the original caller's line info
64
- logger.info(msg, *args, stacklevel=2)
65
-
66
-
67
- @lru_cache
68
- def _print_warning_once(logger: Logger, msg: str, *args: Hashable) -> None:
69
- # Set the stacklevel to 2 to print the original caller's line info
70
- logger.warning(msg, *args, stacklevel=2)
71
-
72
-
73
- class _VllmLogger(Logger):
74
- """
75
- Note:
76
- This class is just to provide type information.
77
- We actually patch the methods directly on the [`logging.Logger`][]
78
- instance to avoid conflicting with other libraries such as
79
- `intel_extension_for_pytorch.utils._logger`.
80
- """
81
-
82
- def debug_once(self, msg: str, *args: Hashable) -> None:
83
- """
84
- As [`debug`][logging.Logger.debug], but subsequent calls with
85
- the same message are silently dropped.
86
- """
87
- _print_debug_once(self, msg, *args)
88
-
89
- def info_once(self, msg: str, *args: Hashable) -> None:
90
- """
91
- As [`info`][logging.Logger.info], but subsequent calls with
92
- the same message are silently dropped.
93
- """
94
- _print_info_once(self, msg, *args)
95
-
96
- def warning_once(self, msg: str, *args: Hashable) -> None:
97
- """
98
- As [`warning`][logging.Logger.warning], but subsequent calls with
99
- the same message are silently dropped.
100
- """
101
- _print_warning_once(self, msg, *args)
102
-
103
-
104
- # Pre-defined methods mapping to avoid repeated dictionary creation
105
- _METHODS_TO_PATCH = {
106
- "debug_once": _print_debug_once,
107
- "info_once": _print_info_once,
108
- "warning_once": _print_warning_once,
109
- }
110
-
111
-
112
- def _configure_vllm_root_logger() -> None:
113
- logging_config = dict[str, Any]()
114
-
115
- if not VLLM_CONFIGURE_LOGGING and VLLM_LOGGING_CONFIG_PATH:
116
- raise RuntimeError(
117
- "VLLM_CONFIGURE_LOGGING evaluated to false, but "
118
- "VLLM_LOGGING_CONFIG_PATH was given. VLLM_LOGGING_CONFIG_PATH "
119
- "implies VLLM_CONFIGURE_LOGGING. Please enable "
120
- "VLLM_CONFIGURE_LOGGING or unset VLLM_LOGGING_CONFIG_PATH.")
121
-
122
- if VLLM_CONFIGURE_LOGGING:
123
- logging_config = DEFAULT_LOGGING_CONFIG
124
-
125
- if VLLM_LOGGING_CONFIG_PATH:
126
- if not path.exists(VLLM_LOGGING_CONFIG_PATH):
127
- raise RuntimeError(
128
- "Could not load logging config. File does not exist: %s",
129
- VLLM_LOGGING_CONFIG_PATH)
130
- with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
131
- custom_config = json.loads(file.read())
132
-
133
- if not isinstance(custom_config, dict):
134
- raise ValueError("Invalid logging config. Expected dict, got %s.",
135
- type(custom_config).__name__)
136
- logging_config = custom_config
137
-
138
- for formatter in logging_config.get("formatters", {}).values():
139
- # This provides backwards compatibility after #10134.
140
- if formatter.get(
141
- "class"
142
- ) == "tpu_inference.vllm_logging_utils.NewLineFormatter":
143
- formatter[
144
- "class"] = "tpu_inference.mock.vllm_logging_utils.NewLineFormatter"
145
-
146
- if logging_config:
147
- dictConfig(logging_config)
148
-
149
-
150
- # The root logger is initialized when the module is imported.
151
- # This is thread-safe as the module is only imported once,
152
- # guaranteed by the Python GIL.
153
- _configure_vllm_root_logger()
154
-
155
-
156
- def init_logger(name: str) -> _VllmLogger:
157
- """The main purpose of this function is to ensure that loggers are
158
- retrieved in such a way that we can be sure the root vllm logger has
159
- already been configured."""
160
-
161
- logger = logging.getLogger(name)
162
-
163
- for method_name, method in _METHODS_TO_PATCH.items():
164
- setattr(logger, method_name, MethodType(method, logger))
165
-
166
- return cast(_VllmLogger, logger)
167
-
168
-
169
- logger = init_logger(__name__)
170
-
171
-
172
- def _trace_calls(log_path, root_dir, frame, event, arg=None):
173
- if event in ['call', 'return']:
174
- # Extract the filename, line number, function name, and the code object
175
- filename = frame.f_code.co_filename
176
- lineno = frame.f_lineno
177
- func_name = frame.f_code.co_name
178
- if not filename.startswith(root_dir):
179
- # only log the functions in the vllm root_dir
180
- return
181
- # Log every function call or return
182
- try:
183
- last_frame = frame.f_back
184
- if last_frame is not None:
185
- last_filename = last_frame.f_code.co_filename
186
- last_lineno = last_frame.f_lineno
187
- last_func_name = last_frame.f_code.co_name
188
- else:
189
- # initial frame
190
- last_filename = ""
191
- last_lineno = 0
192
- last_func_name = ""
193
- with open(log_path, 'a') as f:
194
- ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
195
- if event == 'call':
196
- f.write(f"{ts} Call to"
197
- f" {func_name} in {filename}:{lineno}"
198
- f" from {last_func_name} in {last_filename}:"
199
- f"{last_lineno}\n")
200
- else:
201
- f.write(f"{ts} Return from"
202
- f" {func_name} in {filename}:{lineno}"
203
- f" to {last_func_name} in {last_filename}:"
204
- f"{last_lineno}\n")
205
- except NameError:
206
- # modules are deleted during shutdown
207
- pass
208
- return partial(_trace_calls, log_path, root_dir)
209
-
210
-
211
- def enable_trace_function_call(log_file_path: str, root_dir: str):
212
- sys.settrace(partial(_trace_calls, log_file_path, root_dir))
@@ -1,15 +0,0 @@
1
- import logging
2
-
3
-
4
- class NewLineFormatter(logging.Formatter):
5
- """Adds logging prefix to newlines to align multi-line messages."""
6
-
7
- def __init__(self, fmt, datefmt=None, style="%"):
8
- logging.Formatter.__init__(self, fmt, datefmt, style)
9
-
10
- def format(self, record):
11
- msg = logging.Formatter.format(self, record)
12
- if record.message != "":
13
- parts = msg.split(record.message)
14
- msg = msg.replace("\n", "\r\n" + parts[0])
15
- return msg
@@ -1,376 +0,0 @@
1
- from typing import List, Optional, Tuple
2
-
3
- import jax
4
- import jax.numpy as jnp
5
- from flax import nnx
6
- from jax.sharding import Mesh
7
- from transformers import Phi3Config, modeling_flax_utils
8
- from vllm.config import VllmConfig
9
-
10
- from tpu_inference import utils
11
- from tpu_inference.layers.common.attention_interface import attention
12
- from tpu_inference.layers.common.attention_metadata import AttentionMetadata
13
- from tpu_inference.layers.jax.rope_interface import apply_longrope, apply_rope
14
- from tpu_inference.logger import init_logger
15
- from tpu_inference.models.jax.utils.weight_utils import (MetadataMap,
16
- load_hf_weights)
17
-
18
- logger = init_logger(__name__)
19
-
20
- init_fn = nnx.initializers.uniform()
21
-
22
-
23
- class Phi3MLP(nnx.Module):
24
-
25
- def __init__(self, config: Phi3Config, dtype: jnp.dtype, rng: nnx.Rngs):
26
- hidden_size = config.hidden_size
27
- intermediate_size = config.intermediate_size
28
- act = config.hidden_act
29
-
30
- self.gate_up_proj = nnx.Linear(
31
- hidden_size,
32
- 2 * intermediate_size,
33
- use_bias=False,
34
- param_dtype=dtype,
35
- kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
36
- rngs=rng,
37
- )
38
- self.down_proj = nnx.Linear(
39
- intermediate_size,
40
- hidden_size,
41
- use_bias=False,
42
- param_dtype=dtype,
43
- kernel_init=nnx.with_partitioning(init_fn, ("model", None)),
44
- rngs=rng,
45
- )
46
- self.act_fn = modeling_flax_utils.ACT2FN[act]
47
-
48
- def __call__(self, x: jax.Array) -> jax.Array:
49
- gate_up = self.gate_up_proj(x)
50
- gate, up = jnp.split(gate_up, 2, axis=-1)
51
- fuse = up * self.act_fn(gate)
52
- result = self.down_proj(fuse)
53
- return result
54
-
55
-
56
- class Phi3Attention(nnx.Module):
57
-
58
- def __init__(self, config: Phi3Config, dtype: jnp.dtype, rng: nnx.Rngs,
59
- mesh: Mesh, kv_cache_dtype: str):
60
- self.hidden_size = config.hidden_size
61
- self.num_heads = config.num_attention_heads
62
- self.num_kv_heads = config.num_key_value_heads
63
- self.rope_theta = config.rope_theta
64
- self.rope_scaling = config.rope_scaling
65
- self.original_max_position_embeddings = config.original_max_position_embeddings
66
- self.max_position_embeddings = config.max_position_embeddings
67
-
68
- self.head_dim_original = getattr(config, "head_dim",
69
- self.hidden_size // self.num_heads)
70
- self.head_dim = utils.get_padded_head_dim(self.head_dim_original)
71
-
72
- sharding_size = mesh.shape["model"]
73
- self.num_heads = utils.get_padded_num_heads(self.num_heads,
74
- sharding_size)
75
- self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads,
76
- sharding_size)
77
-
78
- self.mesh = mesh
79
-
80
- self.qkv_proj = nnx.Einsum(
81
- "TD,DNH->TNH",
82
- (self.hidden_size, self.num_heads + self.num_kv_heads * 2,
83
- self.head_dim),
84
- param_dtype=dtype,
85
- kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
86
- rngs=rng,
87
- )
88
- self.o_proj = nnx.Einsum(
89
- "TNH,NHD->TD",
90
- (self.num_heads, self.head_dim, self.hidden_size),
91
- param_dtype=dtype,
92
- kernel_init=nnx.with_partitioning(init_fn, ("model", None, None)),
93
- rngs=rng,
94
- )
95
-
96
- self._q_scale = 1.0
97
- self._k_scale = 1.0
98
- self._v_scale = 1.0
99
- self.kv_cache_quantized_dtype = None
100
- if kv_cache_dtype != "auto":
101
- self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
102
- kv_cache_dtype)
103
-
104
- def __call__(
105
- self,
106
- kv_cache: Optional[jax.Array],
107
- x: jax.Array,
108
- attention_metadata: AttentionMetadata,
109
- ) -> Tuple[jax.Array, jax.Array]:
110
- md = attention_metadata
111
- # qkv: (T, N + K * 2, H)
112
- qkv = self.qkv_proj(x)
113
- q, k, v = jnp.split(
114
- qkv, [self.num_heads, self.num_heads + self.num_kv_heads], axis=1)
115
- if self.rope_scaling:
116
- q = apply_longrope(q, md.input_positions, self.head_dim_original,
117
- self.rope_scaling,
118
- self.original_max_position_embeddings,
119
- self.max_position_embeddings, self.rope_theta)
120
- k = apply_longrope(k, md.input_positions, self.head_dim_original,
121
- self.rope_scaling,
122
- self.original_max_position_embeddings,
123
- self.max_position_embeddings, self.rope_theta)
124
- else:
125
- q = apply_rope(q, md.input_positions, self.head_dim_original,
126
- self.rope_theta, self.rope_scaling)
127
- k = apply_rope(k, md.input_positions, self.head_dim_original,
128
- self.rope_theta, self.rope_scaling)
129
- # o: (T, N, H)
130
- q_scale = k_scale = v_scale = None
131
- if self.kv_cache_quantized_dtype:
132
- # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
133
- # q_scale = self._q_scale
134
- k_scale = self._k_scale
135
- v_scale = self._v_scale
136
- k, v = utils.quantize_kv(k, v, self.kv_cache_quantized_dtype,
137
- k_scale, v_scale)
138
- new_kv_cache, outputs = attention(
139
- kv_cache,
140
- q,
141
- k,
142
- v,
143
- attention_metadata,
144
- self.mesh,
145
- self.head_dim_original,
146
- q_scale=q_scale,
147
- k_scale=k_scale,
148
- v_scale=v_scale,
149
- )
150
- # (T, D)
151
- o = self.o_proj(outputs)
152
- return new_kv_cache, o
153
-
154
-
155
- class Phi3DecoderLayer(nnx.Module):
156
-
157
- def __init__(self, config: Phi3Config, dtype: jnp.dtype, rng: nnx.Rngs,
158
- mesh: Mesh, kv_cache_dtype: str):
159
- rms_norm_eps = config.rms_norm_eps
160
- hidden_size = config.hidden_size
161
-
162
- self.input_layernorm = nnx.RMSNorm(
163
- hidden_size,
164
- epsilon=rms_norm_eps,
165
- param_dtype=dtype,
166
- scale_init=nnx.with_partitioning(init_fn, (None, )),
167
- rngs=rng,
168
- )
169
- self.self_attn = Phi3Attention(config=config,
170
- dtype=dtype,
171
- rng=rng,
172
- mesh=mesh,
173
- kv_cache_dtype=kv_cache_dtype)
174
- self.post_attention_layernorm = nnx.RMSNorm(
175
- hidden_size,
176
- epsilon=rms_norm_eps,
177
- param_dtype=dtype,
178
- scale_init=nnx.with_partitioning(init_fn, (None, )),
179
- rngs=rng,
180
- )
181
- self.mlp = Phi3MLP(
182
- config=config,
183
- dtype=dtype,
184
- rng=rng,
185
- )
186
-
187
- def __call__(
188
- self,
189
- kv_cache: jax.Array,
190
- x: jax.Array,
191
- attention_metadata: AttentionMetadata,
192
- ) -> Tuple[jax.Array, jax.Array]:
193
- hidden_states = self.input_layernorm(x)
194
- kv_cache, attn_output = self.self_attn(
195
- kv_cache,
196
- hidden_states,
197
- attention_metadata,
198
- )
199
- attn_output += x
200
-
201
- residual = attn_output
202
- attn_output = self.post_attention_layernorm(attn_output)
203
- outputs = self.mlp(attn_output)
204
- outputs = residual + outputs
205
- return kv_cache, outputs
206
-
207
-
208
- class Phi3Model(nnx.Module):
209
-
210
- def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs,
211
- mesh: Mesh) -> None:
212
- model_config = vllm_config.model_config
213
- hf_config = model_config.hf_config
214
- vocab_size = model_config.get_vocab_size()
215
- dtype = model_config.dtype
216
- rms_norm_eps = hf_config.rms_norm_eps
217
- hidden_size = hf_config.hidden_size
218
-
219
- self.embed = nnx.Embed(
220
- num_embeddings=vocab_size,
221
- features=hidden_size,
222
- param_dtype=dtype,
223
- embedding_init=nnx.with_partitioning(init_fn, ("model", None)),
224
- rngs=rng,
225
- )
226
- self.layers = [
227
- Phi3DecoderLayer(
228
- config=hf_config,
229
- dtype=dtype,
230
- rng=rng,
231
- mesh=mesh,
232
- # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
233
- kv_cache_dtype=vllm_config.cache_config.cache_dtype)
234
- for _ in range(hf_config.num_hidden_layers)
235
- ]
236
- self.norm = nnx.RMSNorm(
237
- hidden_size,
238
- epsilon=rms_norm_eps,
239
- param_dtype=dtype,
240
- scale_init=nnx.with_partitioning(init_fn, (None, )),
241
- rngs=rng,
242
- )
243
- if model_config.hf_config.tie_word_embeddings:
244
- self.lm_head = self.embed.embedding
245
- else:
246
- self.lm_head = nnx.Param(
247
- init_fn(rng.params(), (hidden_size, vocab_size), dtype),
248
- sharding=(None, "model"),
249
- )
250
-
251
- def __call__(
252
- self,
253
- kv_caches: List[jax.Array],
254
- input_ids: jax.Array,
255
- attention_metadata: AttentionMetadata,
256
- ) -> Tuple[List[jax.Array], jax.Array]:
257
- x = self.embed(input_ids)
258
- for i, layer in enumerate(self.layers):
259
- kv_cache = kv_caches[i]
260
- kv_cache, x = layer(
261
- kv_cache,
262
- x,
263
- attention_metadata,
264
- )
265
- kv_caches[i] = kv_cache
266
- x = self.norm(x)
267
- return kv_caches, x
268
-
269
-
270
- class Phi3ForCausalLM(nnx.Module):
271
-
272
- def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
273
- mesh: Mesh) -> None:
274
- self.vllm_config = vllm_config
275
- self.rng = nnx.Rngs(rng_key)
276
- self.mesh = mesh
277
-
278
- self.model = Phi3Model(
279
- vllm_config=vllm_config,
280
- rng=self.rng,
281
- mesh=mesh,
282
- )
283
-
284
- def __call__(
285
- self,
286
- kv_caches: List[jax.Array],
287
- input_ids: jax.Array,
288
- attention_metadata: AttentionMetadata,
289
- *args,
290
- ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
291
- kv_caches, x = self.model(
292
- kv_caches,
293
- input_ids,
294
- attention_metadata,
295
- )
296
- return kv_caches, x, []
297
-
298
- def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
299
- if self.vllm_config.model_config.hf_config.tie_word_embeddings:
300
- logits = jnp.dot(hidden_states, self.model.lm_head.value.T)
301
- else:
302
- logits = jnp.dot(hidden_states, self.model.lm_head.value)
303
- return logits
304
-
305
- def get_metadata_map(self) -> MetadataMap:
306
- sharding_size = self.mesh.shape["model"]
307
-
308
- model_config = self.vllm_config.model_config
309
- hf_config = model_config.hf_config
310
-
311
- num_heads = hf_config.num_attention_heads
312
- num_kv_heads = hf_config.num_key_value_heads
313
- qkv_heads = num_heads + num_kv_heads * 2
314
- hidden_size = model_config.get_hidden_size()
315
-
316
- # Pad head_dim for kernel performance.
317
- head_dim_original = model_config.get_head_size()
318
-
319
- # Key: path to a HF layer weight
320
- # Value: path to a nnx layer weight
321
- name_map = {
322
- "model.embed_tokens": "model.embed.embedding",
323
- "model.layers.*.input_layernorm":
324
- "model.layers.*.input_layernorm.scale",
325
- "model.layers.*.mlp.down_proj":
326
- "model.layers.*.mlp.down_proj.kernel",
327
- "model.layers.*.mlp.gate_up_proj":
328
- "model.layers.*.mlp.gate_up_proj.kernel",
329
- "model.layers.*.post_attention_layernorm":
330
- "model.layers.*.post_attention_layernorm.scale",
331
- "model.layers.*.self_attn.qkv_proj":
332
- "model.layers.*.self_attn.qkv_proj.kernel",
333
- "model.layers.*.self_attn.o_proj":
334
- "model.layers.*.self_attn.o_proj.kernel",
335
- "model.norm": "model.norm.scale",
336
- }
337
- if not self.vllm_config.model_config.hf_config.tie_word_embeddings:
338
- name_map.update({
339
- "lm_head": "model.lm_head",
340
- })
341
-
342
- reshape_keys: dict[str, tuple[int, ...]] = {
343
- "qkv_proj": (qkv_heads, head_dim_original, hidden_size),
344
- "o_proj": (hidden_size, num_heads, head_dim_original),
345
- }
346
- transpose_keys: dict[str, tuple[int, ...]] = {
347
- "lm_head": (1, 0),
348
- "gate_up_proj": (1, 0),
349
- "down_proj": (1, 0),
350
- "qkv_proj": (2, 0, 1),
351
- "o_proj": (1, 2, 0),
352
- }
353
-
354
- # key: (padding_dim, padding_size)
355
- pad_keys: dict[str, tuple[int, ...]] = {
356
- "qkv_proj": (1, sharding_size // num_heads),
357
- "o_proj": (0, sharding_size // num_heads),
358
- }
359
-
360
- return MetadataMap(name_map=name_map,
361
- reshape_map=reshape_keys,
362
- bias_reshape_map={},
363
- transpose_map=transpose_keys,
364
- pad_map=pad_keys,
365
- bias_pad_map={})
366
-
367
- def load_weights(self, rng_key: jax.Array):
368
- # NOTE: Since we are using nnx.eval_shape to init the model,
369
- # we have to pass dynamic arrays here for __call__'s usage.
370
- self.rng = nnx.Rngs(rng_key)
371
-
372
- metadata_map = self.get_metadata_map()
373
- load_hf_weights(vllm_config=self.vllm_config,
374
- model=self,
375
- metadata_map=metadata_map,
376
- mesh=self.mesh)