sglang 0.4.0__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_offline_throughput.py +18 -6
- sglang/bench_one_batch.py +13 -0
- sglang/bench_serving.py +8 -1
- sglang/check_env.py +140 -48
- sglang/lang/backend/runtime_endpoint.py +1 -0
- sglang/lang/chat_template.py +32 -0
- sglang/llama3_eval.py +316 -0
- sglang/srt/constrained/outlines_backend.py +5 -0
- sglang/srt/constrained/xgrammar_backend.py +9 -6
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +22 -5
- sglang/srt/layers/attention/torch_native_backend.py +22 -8
- sglang/srt/layers/attention/triton_backend.py +38 -33
- sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +665 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
- sglang/srt/layers/fused_moe_triton/layer.py +1 -1
- sglang/srt/layers/logits_processor.py +133 -95
- sglang/srt/layers/quantization/__init__.py +2 -47
- sglang/srt/layers/quantization/fp8.py +607 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +11 -2
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/torchao_utils.py +58 -45
- sglang/srt/managers/detokenizer_manager.py +37 -17
- sglang/srt/managers/io_struct.py +39 -10
- sglang/srt/managers/schedule_batch.py +39 -24
- sglang/srt/managers/schedule_policy.py +64 -5
- sglang/srt/managers/scheduler.py +236 -197
- sglang/srt/managers/tokenizer_manager.py +99 -58
- sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +2 -2
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/mem_cache/radix_cache.py +12 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -11
- sglang/srt/model_executor/model_runner.py +24 -9
- sglang/srt/model_parallel.py +67 -10
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/deepseek_v2.py +87 -7
- sglang/srt/models/gemma2.py +34 -0
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/granite.py +517 -0
- sglang/srt/models/grok.py +72 -13
- sglang/srt/models/llama.py +22 -5
- sglang/srt/models/llama_classification.py +11 -23
- sglang/srt/models/llama_reward.py +0 -2
- sglang/srt/models/llava.py +37 -14
- sglang/srt/models/mixtral.py +12 -9
- sglang/srt/models/phi3_small.py +0 -5
- sglang/srt/models/qwen2.py +20 -0
- sglang/srt/models/qwen2_moe.py +0 -5
- sglang/srt/models/torch_native_llama.py +0 -5
- sglang/srt/openai_api/adapter.py +4 -0
- sglang/srt/openai_api/protocol.py +9 -4
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +4 -4
- sglang/srt/server_args.py +62 -13
- sglang/srt/utils.py +57 -10
- sglang/test/test_utils.py +3 -2
- sglang/utils.py +10 -3
- sglang/version.py +1 -1
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +15 -9
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +72 -65
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
sglang/srt/model_parallel.py
CHANGED
@@ -2,18 +2,18 @@
|
|
2
2
|
Common utilities for torch model parallelism.
|
3
3
|
"""
|
4
4
|
|
5
|
-
from typing import Optional
|
5
|
+
from typing import Optional, Sequence
|
6
6
|
|
7
7
|
import torch
|
8
|
+
import torch.nn as nn
|
8
9
|
from torch.distributed.device_mesh import DeviceMesh
|
9
10
|
|
10
11
|
try:
|
11
|
-
|
12
|
+
import torch.distributed.tensor as dt
|
12
13
|
except ImportError:
|
13
14
|
# torch 2.4 or older
|
14
|
-
|
15
|
+
import torch.distributed._tensor as dt
|
15
16
|
|
16
|
-
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
17
17
|
from torch.distributed.tensor.parallel import (
|
18
18
|
ColwiseParallel,
|
19
19
|
RowwiseParallel,
|
@@ -21,6 +21,50 @@ from torch.distributed.tensor.parallel import (
|
|
21
21
|
)
|
22
22
|
|
23
23
|
|
24
|
+
def _shard_tensor(
|
25
|
+
full_tensor: torch.Tensor,
|
26
|
+
device_mesh: DeviceMesh,
|
27
|
+
placements: Sequence[dt.Shard],
|
28
|
+
) -> "dt.DTensor":
|
29
|
+
"""
|
30
|
+
Locally shards a full tensor based on indicated sharding arrangement, and
|
31
|
+
returns a DTensor containing the local shard.
|
32
|
+
|
33
|
+
.. warning:: This is a private API that is subject to change. It skips the
|
34
|
+
communication otherwise required by `distribute_tensor`. It is only
|
35
|
+
applicable to cases where all ranks have the same `full_tensor`. For
|
36
|
+
example, in distributed inference all ranks load from the same
|
37
|
+
checkpoint. This API will not check for data equality between ranks, it
|
38
|
+
is thus user's responsibility to ensure the `full_tensor` is the same
|
39
|
+
across ranks.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
full_tensor (torch.Tensor): the full tensor to be sharded.
|
43
|
+
device_mesh (:class:`DeviceMesh`): DeviceMesh to place the
|
44
|
+
DTensor. Must have same dimension as the number of placements.
|
45
|
+
placements (Sequence[:class:`Shard`]): the placements that
|
46
|
+
describes how to place the local tensor on DeviceMesh.
|
47
|
+
|
48
|
+
Returns:
|
49
|
+
A :class:`DTensor` object with the shard as its local tensor.
|
50
|
+
|
51
|
+
Examples:
|
52
|
+
>>> # xdoctest: +SKIP("need world_size and rank")
|
53
|
+
>>> device_mesh = dist.init_device_mesh("cuda", (world_size,))
|
54
|
+
>>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}")
|
55
|
+
>>> dtensor = _shard_tensor(full_tensor, device_mesh, [Shard(1)])
|
56
|
+
"""
|
57
|
+
shape, offset = dt._utils.compute_local_shape_and_global_offset(
|
58
|
+
full_tensor.shape, device_mesh, placements
|
59
|
+
)
|
60
|
+
slices = [
|
61
|
+
slice(cur_offset, cur_offset + cur_shape)
|
62
|
+
for cur_shape, cur_offset in zip(shape, offset)
|
63
|
+
]
|
64
|
+
local_tensor = full_tensor[slices]
|
65
|
+
return dt.DTensor.from_local(local_tensor, device_mesh, placements)
|
66
|
+
|
67
|
+
|
24
68
|
class ColwiseParallelSharded(ColwiseParallel):
|
25
69
|
"""
|
26
70
|
A version of ColwiseParallel where the local weight has been already
|
@@ -34,7 +78,7 @@ class ColwiseParallelSharded(ColwiseParallel):
|
|
34
78
|
# means Colwise as Linear is input * weight^T + bias, where
|
35
79
|
# weight would become Shard(1)
|
36
80
|
for name, param in module.named_parameters():
|
37
|
-
dtensor = DTensor.from_local(param, device_mesh, [Shard(0)])
|
81
|
+
dtensor = dt.DTensor.from_local(param, device_mesh, [dt.Shard(0)])
|
38
82
|
dist_param = torch.nn.Parameter(dtensor, requires_grad=False)
|
39
83
|
module.register_parameter(name, dist_param)
|
40
84
|
|
@@ -47,6 +91,23 @@ class RowwiseParallelMaybeWait(RowwiseParallel):
|
|
47
91
|
AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`.
|
48
92
|
"""
|
49
93
|
|
94
|
+
def _partition_linear_fn(self, name, module, device_mesh):
|
95
|
+
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
|
96
|
+
# means Rowwise as nn.Linear is input * weight^T + bias, where
|
97
|
+
# weight would become Shard(0)
|
98
|
+
module.register_parameter(
|
99
|
+
"weight",
|
100
|
+
nn.Parameter(_shard_tensor(module.weight, device_mesh, [dt.Shard(1)])),
|
101
|
+
)
|
102
|
+
if getattr(module, "bias", None) is not None:
|
103
|
+
# The Linear module has bias
|
104
|
+
module.register_parameter(
|
105
|
+
"bias",
|
106
|
+
nn.Parameter(
|
107
|
+
dt.distribute_tensor(module.bias, device_mesh, [dt.Replicate()])
|
108
|
+
),
|
109
|
+
)
|
110
|
+
|
50
111
|
@staticmethod
|
51
112
|
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
52
113
|
outputs = super(
|
@@ -54,11 +115,7 @@ class RowwiseParallelMaybeWait(RowwiseParallel):
|
|
54
115
|
)._prepare_output_fn(
|
55
116
|
output_layouts, use_local_output, mod, outputs, device_mesh
|
56
117
|
)
|
57
|
-
|
58
|
-
if isinstance(outputs, AsyncCollectiveTensor):
|
59
|
-
return outputs.wait()
|
60
|
-
else:
|
61
|
-
return outputs
|
118
|
+
return torch.distributed._functional_collectives.wait_tensor(outputs)
|
62
119
|
|
63
120
|
|
64
121
|
def tensor_parallel(
|
sglang/srt/models/commandr.py
CHANGED
@@ -62,10 +62,10 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
|
62
62
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
63
63
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
64
64
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
65
|
-
from sglang.srt.utils import set_weight_attrs
|
65
|
+
from sglang.srt.utils import get_compiler_backend, set_weight_attrs
|
66
66
|
|
67
67
|
|
68
|
-
@torch.compile
|
68
|
+
@torch.compile(backend=get_compiler_backend())
|
69
69
|
def layer_norm_func(hidden_states, weight, variance_epsilon):
|
70
70
|
input_dtype = hidden_states.dtype
|
71
71
|
hidden_states = hidden_states.to(torch.float32)
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -21,6 +21,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from transformers import PretrainedConfig
|
24
|
+
from vllm import _custom_ops as ops
|
24
25
|
from vllm.distributed import (
|
25
26
|
get_tensor_model_parallel_rank,
|
26
27
|
get_tensor_model_parallel_world_size,
|
@@ -30,6 +31,7 @@ from vllm.distributed import (
|
|
30
31
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
31
32
|
|
32
33
|
from sglang.srt.layers.activation import SiluAndMul
|
34
|
+
from sglang.srt.layers.ep_moe.layer import EPMoE
|
33
35
|
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
34
36
|
from sglang.srt.layers.layernorm import RMSNorm
|
35
37
|
from sglang.srt.layers.linear import (
|
@@ -112,12 +114,12 @@ class DeepseekV2MoE(nn.Module):
|
|
112
114
|
"Only silu is supported for now."
|
113
115
|
)
|
114
116
|
|
115
|
-
|
117
|
+
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
118
|
+
self.experts = MoEImpl(
|
116
119
|
num_experts=config.n_routed_experts,
|
117
120
|
top_k=config.num_experts_per_tok,
|
118
121
|
hidden_size=config.hidden_size,
|
119
122
|
intermediate_size=config.moe_intermediate_size,
|
120
|
-
reduce_results=False,
|
121
123
|
renormalize=config.norm_topk_prob,
|
122
124
|
quant_config=quant_config,
|
123
125
|
use_grouped_topk=True,
|
@@ -453,7 +455,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
453
455
|
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
454
456
|
self.scaling = self.scaling * mscale * mscale
|
455
457
|
|
456
|
-
self.
|
458
|
+
self.attn_mqa = RadixAttention(
|
457
459
|
self.num_local_heads,
|
458
460
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
459
461
|
self.scaling,
|
@@ -462,6 +464,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
462
464
|
v_head_dim=self.kv_lora_rank,
|
463
465
|
)
|
464
466
|
|
467
|
+
self.attn_mha = RadixAttention(
|
468
|
+
self.num_local_heads,
|
469
|
+
self.qk_nope_head_dim + self.qk_rope_head_dim,
|
470
|
+
self.scaling,
|
471
|
+
num_kv_heads=self.num_local_heads,
|
472
|
+
layer_id=layer_id,
|
473
|
+
v_head_dim=self.v_head_dim,
|
474
|
+
)
|
475
|
+
|
465
476
|
self.w_kc = None
|
466
477
|
self.w_vc = None
|
467
478
|
self.w_scale = None
|
@@ -471,6 +482,63 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
471
482
|
positions: torch.Tensor,
|
472
483
|
hidden_states: torch.Tensor,
|
473
484
|
forward_batch: ForwardBatch,
|
485
|
+
) -> torch.Tensor:
|
486
|
+
# Use normal computation for prefill and use weight absorption for extend/decode
|
487
|
+
if (
|
488
|
+
forward_batch.forward_mode.is_extend()
|
489
|
+
and forward_batch.extend_prefix_lens.sum() == 0
|
490
|
+
):
|
491
|
+
return self.forward_normal(positions, hidden_states, forward_batch)
|
492
|
+
else:
|
493
|
+
return self.forward_absorb(positions, hidden_states, forward_batch)
|
494
|
+
|
495
|
+
def forward_normal(
|
496
|
+
self,
|
497
|
+
positions: torch.Tensor,
|
498
|
+
hidden_states: torch.Tensor,
|
499
|
+
forward_batch: ForwardBatch,
|
500
|
+
) -> torch.Tensor:
|
501
|
+
if self.q_lora_rank is not None:
|
502
|
+
q = self.q_a_proj(hidden_states)[0]
|
503
|
+
q = self.q_a_layernorm(q)
|
504
|
+
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
505
|
+
else:
|
506
|
+
q = self.q_proj(hidden_states)[0].view(
|
507
|
+
-1, self.num_local_heads, self.qk_head_dim
|
508
|
+
)
|
509
|
+
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
510
|
+
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
511
|
+
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
512
|
+
latent_cache = latent_cache.unsqueeze(1)
|
513
|
+
kv_a = self.kv_a_layernorm(kv_a.contiguous())
|
514
|
+
kv = self.kv_b_proj(kv_a)[0]
|
515
|
+
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
516
|
+
k_nope = kv[..., : self.qk_nope_head_dim]
|
517
|
+
v = kv[..., self.qk_nope_head_dim :]
|
518
|
+
k_pe = latent_cache[:, :, self.kv_lora_rank :]
|
519
|
+
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
520
|
+
q[..., self.qk_nope_head_dim :] = q_pe
|
521
|
+
k = torch.empty_like(q)
|
522
|
+
k[..., : self.qk_nope_head_dim] = k_nope
|
523
|
+
k[..., self.qk_nope_head_dim :] = k_pe
|
524
|
+
|
525
|
+
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
526
|
+
latent_cache[:, :, self.kv_lora_rank :] = k_pe
|
527
|
+
|
528
|
+
# Save latent cache
|
529
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
530
|
+
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
531
|
+
)
|
532
|
+
attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
533
|
+
attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
|
534
|
+
output, _ = self.o_proj(attn_output)
|
535
|
+
return output
|
536
|
+
|
537
|
+
def forward_absorb(
|
538
|
+
self,
|
539
|
+
positions: torch.Tensor,
|
540
|
+
hidden_states: torch.Tensor,
|
541
|
+
forward_batch: ForwardBatch,
|
474
542
|
) -> torch.Tensor:
|
475
543
|
q_len = hidden_states.shape[0]
|
476
544
|
q_input = hidden_states.new_empty(
|
@@ -508,7 +576,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
508
576
|
q_input[..., self.kv_lora_rank :] = q_pe
|
509
577
|
k_input[..., self.kv_lora_rank :] = k_pe
|
510
578
|
|
511
|
-
attn_output = self.
|
579
|
+
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
512
580
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
513
581
|
|
514
582
|
if self.w_vc.dtype == torch.float8_e4m3fn:
|
@@ -767,7 +835,8 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
767
835
|
|
768
836
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
769
837
|
# (param_name, weight_name, expert_id, shard_id)
|
770
|
-
|
838
|
+
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
839
|
+
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
771
840
|
ckpt_gate_proj_name="gate_proj",
|
772
841
|
ckpt_down_proj_name="down_proj",
|
773
842
|
ckpt_up_proj_name="up_proj",
|
@@ -828,14 +897,25 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
828
897
|
if not global_server_args_dict["disable_mla"]:
|
829
898
|
for layer_id in range(self.config.num_hidden_layers):
|
830
899
|
self_attn = self.model.layers[layer_id].self_attn
|
831
|
-
|
900
|
+
if hasattr(self_attn.kv_b_proj, "qweight"):
|
901
|
+
# AWQ compatible
|
902
|
+
w = ops.awq_dequantize(
|
903
|
+
self_attn.kv_b_proj.qweight,
|
904
|
+
self_attn.kv_b_proj.scales,
|
905
|
+
self_attn.kv_b_proj.qzeros,
|
906
|
+
0,
|
907
|
+
0,
|
908
|
+
0,
|
909
|
+
).T
|
910
|
+
else:
|
911
|
+
w = self_attn.kv_b_proj.weight
|
912
|
+
w_kc, w_vc = w.unflatten(
|
832
913
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
833
914
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
834
915
|
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
835
916
|
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
836
917
|
if hasattr(self_attn.kv_b_proj, "weight_scale"):
|
837
918
|
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
838
|
-
del self_attn.kv_b_proj
|
839
919
|
|
840
920
|
|
841
921
|
EntryClass = DeepseekV2ForCausalLM
|
sglang/srt/models/gemma2.py
CHANGED
@@ -355,6 +355,40 @@ class Gemma2ForCausalLM(nn.Module):
|
|
355
355
|
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
356
356
|
)
|
357
357
|
|
358
|
+
def get_hidden_dim(self, module_name):
|
359
|
+
# return input_dim, output_dim
|
360
|
+
if module_name in ["q_proj", "qkv_proj"]:
|
361
|
+
return (
|
362
|
+
self.config.hidden_size,
|
363
|
+
self.config.head_dim * self.config.num_attention_heads,
|
364
|
+
)
|
365
|
+
elif module_name in ["o_proj"]:
|
366
|
+
return (
|
367
|
+
self.config.head_dim * self.config.num_attention_heads,
|
368
|
+
self.config.hidden_size,
|
369
|
+
)
|
370
|
+
elif module_name in ["kv_proj"]:
|
371
|
+
return (
|
372
|
+
self.config.hidden_size,
|
373
|
+
self.config.head_dim * self.config.num_key_value_heads,
|
374
|
+
)
|
375
|
+
elif module_name == "gate_up_proj":
|
376
|
+
return self.config.hidden_size, self.config.intermediate_size
|
377
|
+
elif module_name == "down_proj":
|
378
|
+
return self.config.intermediate_size, self.config.hidden_size
|
379
|
+
else:
|
380
|
+
raise NotImplementedError()
|
381
|
+
|
382
|
+
def get_module_name(self, name):
|
383
|
+
params_mapping = {
|
384
|
+
"q_proj": "qkv_proj",
|
385
|
+
"k_proj": "qkv_proj",
|
386
|
+
"v_proj": "qkv_proj",
|
387
|
+
"gate_proj": "gate_up_proj",
|
388
|
+
"up_proj": "gate_up_proj",
|
389
|
+
}
|
390
|
+
return params_mapping.get(name, name)
|
391
|
+
|
358
392
|
def get_attention_sliding_window_size(self):
|
359
393
|
return get_attention_sliding_window_size(self.config)
|
360
394
|
|
@@ -32,7 +32,6 @@ class Gemma2ForSequenceClassification(nn.Module):
|
|
32
32
|
) -> None:
|
33
33
|
super().__init__()
|
34
34
|
self.config = config
|
35
|
-
self.torchao_config = None
|
36
35
|
self.quant_config = quant_config
|
37
36
|
self.num_labels = config.num_labels
|
38
37
|
self.model = Gemma2Model(config, quant_config=quant_config)
|