tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511130813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +34 -303
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
- tests/lora/test_layers.py +6 -0
- tests/lora/utils.py +8 -0
- tests/test_utils.py +16 -24
- tpu_inference/__init__.py +3 -22
- tpu_inference/core/core_tpu.py +9 -17
- tpu_inference/core/disagg_utils.py +8 -6
- tpu_inference/distributed/tpu_connector.py +4 -3
- tpu_inference/distributed/utils.py +2 -3
- tpu_inference/envs.py +8 -61
- tpu_inference/executors/ray_distributed_executor.py +11 -31
- tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +143 -287
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -7
- tpu_inference/layers/jax/attention/attention.py +1 -1
- tpu_inference/layers/{common → jax}/attention_interface.py +2 -8
- tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
- tpu_inference/layers/jax/sample/sampling.py +2 -2
- tpu_inference/layers/{common → jax}/sharding.py +5 -5
- tpu_inference/layers/vllm/attention.py +1 -1
- tpu_inference/layers/vllm/fused_moe.py +208 -170
- tpu_inference/layers/vllm/quantization/__init__.py +3 -7
- tpu_inference/layers/vllm/quantization/awq.py +3 -4
- tpu_inference/layers/vllm/quantization/common.py +1 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +2 -4
- tpu_inference/layers/vllm/quantization/unquantized.py +67 -62
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +2 -1
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/common/model_loader.py +12 -46
- tpu_inference/models/jax/llama3.py +3 -4
- tpu_inference/models/jax/llama_eagle3.py +5 -8
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +2 -3
- tpu_inference/models/jax/qwen2_5_vl.py +50 -165
- tpu_inference/models/jax/qwen3.py +2 -3
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
- tpu_inference/models/jax/utils/weight_utils.py +143 -198
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -32
- tpu_inference/platforms/tpu_platform.py +34 -47
- tpu_inference/runner/compilation_manager.py +60 -145
- tpu_inference/runner/kv_cache.py +2 -2
- tpu_inference/runner/kv_cache_manager.py +18 -17
- tpu_inference/runner/persistent_batch_manager.py +2 -40
- tpu_inference/runner/structured_decoding_manager.py +3 -2
- tpu_inference/runner/tpu_runner.py +135 -283
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +21 -71
- tpu_inference/tpu_info.py +3 -4
- tpu_inference/utils.py +15 -38
- tpu_inference/worker/tpu_worker.py +26 -163
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/METADATA +3 -4
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/RECORD +63 -61
- tests/test_envs.py +0 -203
- tpu_inference/layers/common/quant_methods.py +0 -8
- tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
- tpu_inference/models/jax/llama_guard_4.py +0 -361
- /tpu_inference/layers/{common → jax}/binary_search.py +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/WHEEL +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/top_level.txt +0 -0
|
@@ -18,7 +18,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|
|
18
18
|
is_layer_skipped, unpack_quantized_values_into_int32)
|
|
19
19
|
from vllm.scalar_type import scalar_types
|
|
20
20
|
|
|
21
|
-
from tpu_inference.layers.common.quant_methods import AWQ, get_tpu_quant_method
|
|
22
21
|
from tpu_inference.layers.vllm.linear_common import (
|
|
23
22
|
slice_sharded_tensor_for_concatenation, torch_to_jax_param)
|
|
24
23
|
from tpu_inference.layers.vllm.quantization.common import (
|
|
@@ -30,12 +29,12 @@ P = PartitionSpec
|
|
|
30
29
|
logger = init_logger(__name__)
|
|
31
30
|
|
|
32
31
|
|
|
33
|
-
@register_quantization_config(
|
|
32
|
+
@register_quantization_config("jax-awq")
|
|
34
33
|
class VllmAWQConfig(AWQConfig, JaxCommonConfig):
|
|
35
34
|
|
|
36
35
|
@classmethod
|
|
37
|
-
def get_name(cls):
|
|
38
|
-
return
|
|
36
|
+
def get_name(cls) -> str:
|
|
37
|
+
return "jax-awq"
|
|
39
38
|
|
|
40
39
|
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
|
41
40
|
# NOTE: AWQ checkpoint was quantized with float16. But on TPUs, using
|
|
@@ -61,12 +61,7 @@ class JaxCommonLinearConfig:
|
|
|
61
61
|
" bad performance.", type(layer))
|
|
62
62
|
|
|
63
63
|
self.bias_sharding = P(self.weight_sharding[0])
|
|
64
|
-
|
|
65
|
-
self.n_shards = 1
|
|
66
|
-
for axis in self.weight_sharding[0]:
|
|
67
|
-
self.n_shards *= self.mesh.shape.get(axis, 1)
|
|
68
|
-
else:
|
|
69
|
-
self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
|
|
64
|
+
self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
|
|
70
65
|
|
|
71
66
|
def get_input_sharding(self, x: torchax.tensor.Tensor):
|
|
72
67
|
if self.enable_sequence_parallelism:
|
|
@@ -16,8 +16,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
|
|
16
16
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|
17
17
|
find_matched_target, should_ignore_layer)
|
|
18
18
|
|
|
19
|
-
from tpu_inference.layers.common.quant_methods import (COMPRESSED_TENSORS,
|
|
20
|
-
get_tpu_quant_method)
|
|
21
19
|
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
22
20
|
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
|
|
23
21
|
VllmCompressedTensorsW8A8Fp8MoEMethod
|
|
@@ -32,12 +30,12 @@ P = PartitionSpec
|
|
|
32
30
|
logger = init_logger(__name__)
|
|
33
31
|
|
|
34
32
|
|
|
35
|
-
@register_quantization_config(
|
|
33
|
+
@register_quantization_config("jax-compressed-tensors")
|
|
36
34
|
class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
|
|
37
35
|
|
|
38
36
|
@classmethod
|
|
39
37
|
def get_name(cls) -> str:
|
|
40
|
-
return
|
|
38
|
+
return "jax-compressed-tensors"
|
|
41
39
|
|
|
42
40
|
def get_scheme(self,
|
|
43
41
|
layer: torch.nn.Module,
|
|
@@ -23,9 +23,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|
|
23
23
|
|
|
24
24
|
from tpu_inference import envs
|
|
25
25
|
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
|
|
26
|
-
from tpu_inference.layers.
|
|
27
|
-
get_tpu_quant_method)
|
|
28
|
-
from tpu_inference.layers.vllm.fused_moe import fused_moe_func
|
|
26
|
+
from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
|
|
29
27
|
from tpu_inference.layers.vllm.linear_common import (
|
|
30
28
|
reorder_concatenated_tensor_for_sharding,
|
|
31
29
|
slice_sharded_tensor_for_concatenation, torch_to_jax_param)
|
|
@@ -36,12 +34,12 @@ P = PartitionSpec
|
|
|
36
34
|
logger = init_logger(__name__)
|
|
37
35
|
|
|
38
36
|
|
|
39
|
-
@register_quantization_config(
|
|
37
|
+
@register_quantization_config("jax-unquantized")
|
|
40
38
|
class VllmUnquantizedConfig(QuantizationConfig, JaxCommonConfig):
|
|
41
39
|
|
|
42
40
|
@classmethod
|
|
43
41
|
def get_name(cls) -> str:
|
|
44
|
-
return
|
|
42
|
+
return "jax-unquantized"
|
|
45
43
|
|
|
46
44
|
@classmethod
|
|
47
45
|
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
|
@@ -108,8 +106,6 @@ class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
|
|
|
108
106
|
layer: torch.nn.Module,
|
|
109
107
|
x: torch.Tensor,
|
|
110
108
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
111
|
-
assert isinstance(layer, LinearBase)
|
|
112
|
-
|
|
113
109
|
with jax.named_scope(layer._get_name()):
|
|
114
110
|
if in_sharding := self.jax_config.get_input_sharding(x):
|
|
115
111
|
x.shard_(NamedSharding(self.jax_config.mesh, in_sharding))
|
|
@@ -168,18 +164,18 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
168
164
|
ep_axis_name: str = 'model'):
|
|
169
165
|
super().__init__(moe)
|
|
170
166
|
self.mesh = mesh
|
|
171
|
-
self.use_kernel = envs.USE_MOE_EP_KERNEL
|
|
167
|
+
self.use_kernel = envs.USE_MOE_EP_KERNEL
|
|
172
168
|
self.ep_axis_name = ep_axis_name
|
|
173
169
|
# TODO: Use autotune table once we have it.
|
|
174
170
|
self.block_size = {
|
|
175
|
-
"bt":
|
|
176
|
-
"bf":
|
|
177
|
-
"bd1":
|
|
178
|
-
"bd2":
|
|
179
|
-
"btc":
|
|
180
|
-
"bfc":
|
|
181
|
-
"bd1c":
|
|
182
|
-
"bd2c":
|
|
171
|
+
"bt": 16,
|
|
172
|
+
"bf": 384,
|
|
173
|
+
"bd1": 512,
|
|
174
|
+
"bd2": 512,
|
|
175
|
+
"btc": 16,
|
|
176
|
+
"bfc": 384,
|
|
177
|
+
"bd1c": 256,
|
|
178
|
+
"bd2c": 256,
|
|
183
179
|
}
|
|
184
180
|
|
|
185
181
|
def select_gemm_impl(
|
|
@@ -193,11 +189,10 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
193
189
|
|
|
194
190
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
195
191
|
assert isinstance(layer, FusedMoE)
|
|
192
|
+
|
|
196
193
|
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
|
|
197
194
|
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
|
|
198
195
|
|
|
199
|
-
num_experts, hidden_size, intermediate_size = w2_weight.shape
|
|
200
|
-
|
|
201
196
|
if self.moe.has_bias:
|
|
202
197
|
w13_bias = t2j(layer.w13_bias, use_dlpack=False)
|
|
203
198
|
w2_bias = t2j(layer.w2_bias, use_dlpack=False)
|
|
@@ -216,56 +211,76 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
216
211
|
w3_bias = w13_bias[:, 1::2]
|
|
217
212
|
w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
|
|
218
213
|
|
|
219
|
-
if self.use_kernel:
|
|
214
|
+
if self.use_kernel and layer.use_ep:
|
|
220
215
|
# Kernel expects:
|
|
221
216
|
# w13: (num_experts, 2, hidden_size, intermediate_size)
|
|
222
217
|
# w2: (num_experts, intermediate_size, hidden_size)
|
|
223
218
|
# Current format:
|
|
224
219
|
# w13_weight: (num_experts, 2*intermediate_size, hidden_size)
|
|
225
220
|
# w2_weight: (num_experts, hidden_size, intermediate_size)
|
|
221
|
+
num_experts = w13_weight.shape[0]
|
|
222
|
+
intermediate_size = w13_weight.shape[1] // 2
|
|
223
|
+
hidden_size = w13_weight.shape[2]
|
|
226
224
|
|
|
225
|
+
# Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
|
|
227
226
|
w13_reshaped = w13_weight.reshape(num_experts, 2,
|
|
228
227
|
intermediate_size, hidden_size)
|
|
228
|
+
w13_weight_transposed = jnp.transpose(w13_reshaped, (0, 1, 3, 2))
|
|
229
229
|
|
|
230
|
-
# Transpose
|
|
231
|
-
|
|
232
|
-
w2_weight_transposed = jnp.swapaxes(w2_weight, 1, 2)
|
|
230
|
+
# Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
|
|
231
|
+
w2_weight_transposed = jnp.transpose(w2_weight, (0, 2, 1))
|
|
233
232
|
|
|
234
233
|
# Apply EP sharding
|
|
235
|
-
ep_sharding = NamedSharding(self.mesh, P("model"))
|
|
236
|
-
|
|
237
234
|
w13_weight = jax.device_put(
|
|
238
|
-
w13_weight_transposed,
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
235
|
+
w13_weight_transposed,
|
|
236
|
+
Format(Layout((0, 1, 2, 3)),
|
|
237
|
+
NamedSharding(self.mesh, P("model", None, None, None))))
|
|
238
|
+
w2_weight = jax.device_put(
|
|
239
|
+
w2_weight_transposed,
|
|
240
|
+
Format(Layout((0, 1, 2)),
|
|
241
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
242
242
|
|
|
243
243
|
if self.moe.has_bias:
|
|
244
244
|
w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
|
|
245
|
+
|
|
246
|
+
# Apply EP sharding
|
|
245
247
|
w13_bias = jax.device_put(
|
|
246
|
-
w13_bias,
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
248
|
+
w13_bias,
|
|
249
|
+
Format(Layout((0, 1, 2)),
|
|
250
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
251
|
+
w2_bias = jax.device_put(
|
|
252
|
+
w2_bias,
|
|
253
|
+
Format(Layout((0, 1)),
|
|
254
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
250
255
|
|
|
256
|
+
else:
|
|
257
|
+
# Original logic for non-kernel path
|
|
251
258
|
if layer.use_ep:
|
|
252
|
-
ep_sharding = NamedSharding(self.mesh, P("model"))
|
|
253
259
|
w13_weight = jax.device_put(
|
|
254
|
-
w13_weight,
|
|
260
|
+
w13_weight,
|
|
261
|
+
Format(Layout((0, 1, 2)),
|
|
262
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
255
263
|
w2_weight = jax.device_put(
|
|
256
|
-
w2_weight,
|
|
264
|
+
w2_weight,
|
|
265
|
+
Format(Layout((0, 1, 2)),
|
|
266
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
257
267
|
|
|
258
268
|
if self.moe.has_bias:
|
|
259
269
|
w13_bias = jax.device_put(
|
|
260
|
-
w13_bias,
|
|
270
|
+
w13_bias,
|
|
271
|
+
Format(Layout((0, 1)),
|
|
272
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
261
273
|
w2_bias = jax.device_put(
|
|
262
|
-
w2_bias,
|
|
274
|
+
w2_bias,
|
|
275
|
+
Format(Layout((0, 1)),
|
|
276
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
263
277
|
|
|
264
278
|
else:
|
|
279
|
+
intermediate_size = w13_weight.shape[1] // 2
|
|
280
|
+
assert intermediate_size == w2_weight.shape[-1]
|
|
265
281
|
output_sizes = [intermediate_size, intermediate_size]
|
|
266
282
|
n_shards = self.mesh.shape["model"]
|
|
267
283
|
assert intermediate_size % n_shards == 0
|
|
268
|
-
|
|
269
284
|
w13_weight = reorder_concatenated_tensor_for_sharding(
|
|
270
285
|
w13_weight, output_sizes, n_shards, dim=1)
|
|
271
286
|
w13_weight = jax.device_put(
|
|
@@ -326,40 +341,30 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
326
341
|
raise NotImplementedError(
|
|
327
342
|
"Only softmax is supported for scoring_func")
|
|
328
343
|
|
|
329
|
-
x = jax_view(x)
|
|
330
|
-
w13_weight = jax_view(layer.w13_weight)
|
|
331
|
-
w2_weight = jax_view(layer.w2_weight)
|
|
332
|
-
w13_bias = w2_bias = None
|
|
333
|
-
if self.moe.has_bias:
|
|
334
|
-
w13_bias = jax_view(layer.w13_bias)
|
|
335
|
-
w2_bias = jax_view(layer.w2_bias)
|
|
336
|
-
gating_output = jax_view(router_logits)
|
|
337
|
-
|
|
338
344
|
if self.use_kernel and layer.use_ep:
|
|
339
345
|
output = fused_ep_moe(
|
|
340
346
|
mesh=self.mesh,
|
|
341
|
-
tokens=x,
|
|
342
|
-
w1=w13_weight,
|
|
343
|
-
w2=w2_weight,
|
|
344
|
-
|
|
345
|
-
b2=w2_bias,
|
|
346
|
-
gating_output=gating_output,
|
|
347
|
+
tokens=jax_view(x),
|
|
348
|
+
w1=jax_view(layer.w13_weight),
|
|
349
|
+
w2=jax_view(layer.w2_weight),
|
|
350
|
+
gating_output=jax_view(router_logits),
|
|
347
351
|
top_k=top_k,
|
|
348
352
|
ep_axis_name=self.ep_axis_name,
|
|
349
|
-
renormalize_topk_logits=renormalize,
|
|
350
|
-
act_fn=activation,
|
|
351
353
|
**self.block_size,
|
|
352
354
|
)
|
|
353
355
|
else:
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
356
|
+
# Use the original implementation
|
|
357
|
+
output = fused_moe_func_padded(
|
|
358
|
+
jax_view(x),
|
|
359
|
+
jax_view(layer.w13_weight),
|
|
360
|
+
jax_view(layer.w2_weight),
|
|
361
|
+
jax_view(layer.w13_bias) if self.moe.has_bias else None,
|
|
362
|
+
jax_view(layer.w2_bias) if self.moe.has_bias else None,
|
|
363
|
+
jax_view(router_logits),
|
|
361
364
|
topk=top_k,
|
|
365
|
+
global_num_experts=global_num_experts,
|
|
362
366
|
renormalize=renormalize,
|
|
367
|
+
reduce_results=layer.reduce_results,
|
|
363
368
|
mesh=self.mesh,
|
|
364
369
|
use_ep=layer.use_ep,
|
|
365
370
|
activation=activation,
|
|
@@ -19,7 +19,6 @@ from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
|
|
|
19
19
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
20
20
|
ParallelLMHead, VocabParallelEmbedding)
|
|
21
21
|
|
|
22
|
-
from tpu_inference import envs
|
|
23
22
|
from tpu_inference.logger import init_logger
|
|
24
23
|
|
|
25
24
|
P = PartitionSpec
|
|
@@ -212,7 +211,8 @@ def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None:
|
|
|
212
211
|
def _sharded_device_put(tensor: jax.Array, sharding) -> jax.Array:
|
|
213
212
|
if isinstance(tensor, tuple):
|
|
214
213
|
return tuple(_sharded_device_put(t, sharding) for t in tensor)
|
|
215
|
-
|
|
214
|
+
import os
|
|
215
|
+
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
|
|
216
216
|
if multihost_backend != "ray":
|
|
217
217
|
return jax.device_put(tensor, sharding)
|
|
218
218
|
|
|
@@ -239,6 +239,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|
|
239
239
|
lora_index_to_id: list[Optional[int]],
|
|
240
240
|
max_loras: int,
|
|
241
241
|
vocab_size: int,
|
|
242
|
+
extra_vocab_size: int,
|
|
242
243
|
):
|
|
243
244
|
# Pad the prompt mapping to avoid running into recompiles on the TPU
|
|
244
245
|
# TODO: Should this happen inside mapping internally? If so how can we
|
|
@@ -257,7 +258,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|
|
257
258
|
lora_index_to_id,
|
|
258
259
|
max_loras,
|
|
259
260
|
vocab_size,
|
|
260
|
-
|
|
261
|
+
extra_vocab_size,
|
|
261
262
|
"cpu",
|
|
262
263
|
)
|
|
263
264
|
with torchax.default_env():
|
|
File without changes
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import Any, List, Mapping
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@dataclass
|
|
6
|
+
class ModelConfig():
|
|
7
|
+
max_model_len: int = 2048
|
|
8
|
+
max_prefill_len: int = 1024
|
|
9
|
+
prefill_batch_size: int = 1
|
|
10
|
+
decode_batch_size: int = 1
|
|
11
|
+
block_size: int = 16
|
|
12
|
+
num_layers: int = 32
|
|
13
|
+
num_kv_heads: int = 32
|
|
14
|
+
head_dim: int = 128
|
|
15
|
+
vocab_size: int = 32000
|
|
16
|
+
model: str = "llama3"
|
|
17
|
+
hf_config: str = ""
|
|
18
|
+
architectures: List[str] = field(default_factory=list)
|
|
19
|
+
override_generation_config: dict[str, Any] = field(default_factory=dict)
|
|
20
|
+
hf_overrides: dict[str, Any] = field(default_factory=dict)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class VllmConfig():
|
|
25
|
+
additional_config: Mapping[str, Any] = field(default_factory=dict)
|
|
26
|
+
# Set default max_model_len to turn off warnings.
|
|
27
|
+
model_config: ModelConfig = field(
|
|
28
|
+
default_factory=lambda: ModelConfig(max_model_len=1024))
|