sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_serving.py +1 -1
  4. sglang/lang/interpreter.py +40 -1
  5. sglang/lang/ir.py +27 -0
  6. sglang/math_utils.py +8 -0
  7. sglang/srt/configs/model_config.py +6 -0
  8. sglang/srt/conversation.py +6 -0
  9. sglang/srt/disaggregation/base/__init__.py +1 -1
  10. sglang/srt/disaggregation/base/conn.py +25 -11
  11. sglang/srt/disaggregation/common/__init__.py +5 -1
  12. sglang/srt/disaggregation/common/utils.py +42 -0
  13. sglang/srt/disaggregation/decode.py +196 -51
  14. sglang/srt/disaggregation/fake/__init__.py +1 -1
  15. sglang/srt/disaggregation/fake/conn.py +15 -9
  16. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +18 -13
  18. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  19. sglang/srt/disaggregation/nixl/conn.py +17 -12
  20. sglang/srt/disaggregation/prefill.py +128 -43
  21. sglang/srt/disaggregation/utils.py +127 -123
  22. sglang/srt/entrypoints/engine.py +15 -1
  23. sglang/srt/entrypoints/http_server.py +13 -2
  24. sglang/srt/eplb_simulator/__init__.py +1 -0
  25. sglang/srt/eplb_simulator/reader.py +51 -0
  26. sglang/srt/layers/activation.py +19 -0
  27. sglang/srt/layers/attention/aiter_backend.py +15 -2
  28. sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
  29. sglang/srt/layers/attention/flashattention_backend.py +53 -64
  30. sglang/srt/layers/attention/flashinfer_backend.py +1 -2
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
  32. sglang/srt/layers/attention/flashmla_backend.py +2 -10
  33. sglang/srt/layers/attention/triton_backend.py +119 -119
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  35. sglang/srt/layers/attention/vision.py +51 -24
  36. sglang/srt/layers/communicator.py +23 -5
  37. sglang/srt/layers/linear.py +0 -4
  38. sglang/srt/layers/logits_processor.py +0 -12
  39. sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
  40. sglang/srt/layers/moe/ep_moe/layer.py +42 -32
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  42. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
  43. sglang/srt/layers/moe/topk.py +16 -8
  44. sglang/srt/layers/pooler.py +56 -0
  45. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  46. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  47. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  48. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  49. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  50. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  51. sglang/srt/layers/radix_attention.py +2 -3
  52. sglang/srt/lora/lora_manager.py +79 -34
  53. sglang/srt/lora/mem_pool.py +4 -5
  54. sglang/srt/managers/cache_controller.py +2 -1
  55. sglang/srt/managers/io_struct.py +28 -4
  56. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  57. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  58. sglang/srt/managers/schedule_batch.py +39 -6
  59. sglang/srt/managers/scheduler.py +73 -17
  60. sglang/srt/managers/tokenizer_manager.py +29 -2
  61. sglang/srt/mem_cache/chunk_cache.py +1 -0
  62. sglang/srt/mem_cache/hiradix_cache.py +4 -2
  63. sglang/srt/mem_cache/memory_pool.py +111 -407
  64. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  65. sglang/srt/mem_cache/radix_cache.py +36 -12
  66. sglang/srt/model_executor/cuda_graph_runner.py +122 -55
  67. sglang/srt/model_executor/forward_batch_info.py +14 -5
  68. sglang/srt/model_executor/model_runner.py +6 -6
  69. sglang/srt/model_loader/loader.py +8 -1
  70. sglang/srt/models/bert.py +113 -13
  71. sglang/srt/models/deepseek_v2.py +113 -155
  72. sglang/srt/models/internvl.py +46 -102
  73. sglang/srt/models/roberta.py +117 -9
  74. sglang/srt/models/vila.py +305 -0
  75. sglang/srt/openai_api/adapter.py +162 -4
  76. sglang/srt/openai_api/protocol.py +37 -1
  77. sglang/srt/sampling/sampling_batch_info.py +24 -0
  78. sglang/srt/sampling/sampling_params.py +2 -0
  79. sglang/srt/server_args.py +318 -233
  80. sglang/srt/speculative/build_eagle_tree.py +1 -1
  81. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
  82. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
  83. sglang/srt/speculative/eagle_utils.py +389 -109
  84. sglang/srt/speculative/eagle_worker.py +134 -43
  85. sglang/srt/two_batch_overlap.py +4 -2
  86. sglang/srt/utils.py +58 -0
  87. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  88. sglang/test/runners.py +38 -3
  89. sglang/test/test_block_fp8.py +1 -0
  90. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  91. sglang/test/test_block_fp8_ep.py +1 -0
  92. sglang/test/test_utils.py +3 -1
  93. sglang/utils.py +9 -0
  94. sglang/version.py +1 -1
  95. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
  96. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
  97. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -11,21 +11,19 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # ==========================582====================================================
14
-
15
- from typing import Iterable, List, Optional, Tuple, Union
14
+ from typing import Iterable, List, Optional, Set, Tuple, Union
16
15
 
17
16
  import torch
18
17
 
19
18
  # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/7f62077af5159c625fe3ad1c812e6c1a2b93ba3b/vllm/model_executor/models/internlm2.py
20
19
  # Adapted from https://raw.githubusercontent.com/hehesangsj/sglang/refs/heads/internvl/python/sglang/srt/models/internvl.py
21
20
  import torch.nn.functional as F
22
- from einops import rearrange, repeat
23
- from sgl_kernel.flash_attn import flash_attn_varlen_func
24
21
  from torch import nn
25
22
  from transformers import PretrainedConfig, PreTrainedModel
26
23
  from transformers.activations import ACT2FN
27
24
  from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
28
25
 
26
+ from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
29
27
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
30
28
  from sglang.srt.managers.mm_utils import (
31
29
  MultiModalityDataPaddingPatternTokenPairs,
@@ -40,75 +38,12 @@ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
40
38
  from sglang.utils import logger
41
39
 
42
40
 
43
- class FlashAttention(nn.Module):
44
- """Implement the scaled dot product attention with softmax.
45
- Arguments
46
- ---------
47
- softmax_scale: The temperature to use for the softmax attention.
48
- (default: 1/sqrt(d_keys) where d_keys is computed at
49
- runtime)
50
- attention_dropout: The dropout rate to apply to the attention
51
- (default: 0.0)
52
- """
53
-
41
+ class InternAttention(nn.Module):
54
42
  def __init__(
55
- self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None
56
- ):
57
- super().__init__()
58
- self.softmax_scale = softmax_scale
59
- self.dropout_p = attention_dropout
60
-
61
- def forward(
62
43
  self,
63
- qkv,
64
- causal=False,
65
- max_s=None,
44
+ config,
45
+ quant_config: QuantizationConfig = None,
66
46
  ):
67
- """Implements the multihead softmax attention.
68
- Arguments
69
- ---------
70
- qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
71
- if unpadded: (nnz, 3, h, d)
72
- """
73
- assert qkv.dtype in [torch.float16, torch.bfloat16]
74
- assert qkv.is_cuda
75
-
76
- batch_size, seqlen, _, nheads, d = qkv.shape
77
- if batch_size == 0 or seqlen == 0:
78
- output_shape = (batch_size, seqlen, nheads, d)
79
- return (
80
- torch.zeros(output_shape, dtype=qkv.dtype, device=qkv.device),
81
- None,
82
- )
83
-
84
- qkv_reshaped = rearrange(qkv, "b s three h d -> (b s) three h d", three=3)
85
- q, k, v = qkv_reshaped.unbind(1)
86
-
87
- max_s = seqlen
88
- cu_seqlens = torch.arange(
89
- 0,
90
- (batch_size + 1) * seqlen,
91
- step=seqlen,
92
- dtype=torch.int32,
93
- device=qkv.device,
94
- )
95
- output_reshaped = flash_attn_varlen_func(
96
- q,
97
- k,
98
- v,
99
- cu_seqlens,
100
- cu_seqlens,
101
- max_s,
102
- max_s,
103
- softmax_scale=self.softmax_scale,
104
- causal=causal,
105
- )
106
- output = rearrange(output_reshaped, "(b s) h d -> b s h d", b=batch_size)
107
- return output, None
108
-
109
-
110
- class InternAttention(nn.Module):
111
- def __init__(self, config):
112
47
  super().__init__()
113
48
  self.config = config
114
49
  self.embed_dim = config.hidden_size
@@ -116,7 +51,19 @@ class InternAttention(nn.Module):
116
51
  self.head_dim = self.embed_dim // self.num_heads
117
52
 
118
53
  self.scale = self.head_dim**-0.5
119
- self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
54
+
55
+ self.attn = VisionAttention(
56
+ qkv_backend="fa3",
57
+ embed_dim=self.embed_dim,
58
+ num_heads=self.num_heads,
59
+ projection_size=self.embed_dim,
60
+ use_qkv_parallel=True,
61
+ quant_config=quant_config,
62
+ dropout=getattr(config, "dropout", 0.0),
63
+ proj_bias=getattr(config, "qkv_bias", True),
64
+ flatten_batch=False,
65
+ )
66
+
120
67
  self.proj_drop = nn.Dropout(config.dropout)
121
68
 
122
69
  self.qk_normalization = config.qk_normalization
@@ -125,36 +72,15 @@ class InternAttention(nn.Module):
125
72
  self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
126
73
  self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
127
74
 
128
- self.inner_attn = FlashAttention(softmax_scale=self.scale)
129
-
130
- self.proj = nn.Linear(self.embed_dim, self.embed_dim)
131
-
132
- def _flash_attn(
75
+ def forward(
133
76
  self,
134
- x,
135
- ):
136
- qkv = self.qkv(x)
137
- qkv = rearrange(
138
- qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads
139
- )
140
-
141
- if self.qk_normalization:
142
- q, k, v = qkv.unbind(2)
143
- q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
144
- k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
145
- qkv = torch.stack([q, k, v], dim=2)
146
-
147
- context, _ = self.inner_attn(
148
- qkv,
149
- )
150
- outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
151
- outs = self.proj_drop(outs)
77
+ hidden_states: torch.Tensor,
78
+ cu_seqlens: torch.Tensor,
79
+ ) -> torch.Tensor:
80
+ out = self.attn(hidden_states, cu_seqlens=cu_seqlens)
81
+ outs = self.proj_drop(out)
152
82
  return outs
153
83
 
154
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
155
- x = self._flash_attn(hidden_states)
156
- return x
157
-
158
84
 
159
85
  class InternVisionEmbeddings(nn.Module):
160
86
  def __init__(self, config: PretrainedConfig):
@@ -286,6 +212,7 @@ class InternVisionEncoderLayer(nn.Module):
286
212
  def forward(
287
213
  self,
288
214
  hidden_states: torch.Tensor,
215
+ cu_seqlens: torch.Tensor,
289
216
  ) -> Tuple[
290
217
  torch.FloatTensor,
291
218
  Optional[torch.FloatTensor],
@@ -295,8 +222,12 @@ class InternVisionEncoderLayer(nn.Module):
295
222
  Args:
296
223
  hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
297
224
  """
225
+
298
226
  hidden_states = hidden_states + self.drop_path1(
299
- self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1
227
+ self.attn(
228
+ self.norm1(hidden_states).to(hidden_states.dtype), cu_seqlens=cu_seqlens
229
+ )
230
+ * self.ls1
300
231
  )
301
232
 
302
233
  hidden_states = hidden_states + self.drop_path2(
@@ -363,12 +294,12 @@ class InternVisionEncoder(nn.Module):
363
294
  encoder_states = () if output_hidden_states else None
364
295
  hidden_states = inputs_embeds
365
296
 
297
+ cu_seqlens = SingletonCache()
298
+
366
299
  for idx, encoder_layer in enumerate(self.layers):
367
300
  if output_hidden_states:
368
301
  encoder_states = encoder_states + (hidden_states,)
369
- layer_outputs = encoder_layer(
370
- hidden_states,
371
- )
302
+ layer_outputs = encoder_layer(hidden_states, cu_seqlens=cu_seqlens)
372
303
  hidden_states = layer_outputs
373
304
 
374
305
  if output_hidden_states:
@@ -625,6 +556,7 @@ class InternVLChatModel(nn.Module):
625
556
  ("gate_up_proj", "up_proj", 1),
626
557
  ]
627
558
  params_dict = dict(self.named_parameters())
559
+ loaded_params: Set[str] = set()
628
560
 
629
561
  for name, loaded_weight in weights:
630
562
  if "rotary_emb.inv_freq" in name:
@@ -641,6 +573,11 @@ class InternVLChatModel(nn.Module):
641
573
  weight_loader(param, loaded_weight, shard_id)
642
574
  break
643
575
  else:
576
+ if "vision_model" in name:
577
+ # adapt to VisionAttention
578
+ name = name.replace(r"attn.", r"attn.attn.")
579
+ name = name.replace(r"qkv.", r"qkv_proj.")
580
+
644
581
  # Skip loading extra bias for GPTQ models.
645
582
  if name.endswith(".bias") and name not in params_dict:
646
583
  continue
@@ -665,6 +602,13 @@ class InternVLChatModel(nn.Module):
665
602
  param, "weight_loader", default_weight_loader
666
603
  )
667
604
  weight_loader(param, loaded_weight)
605
+ loaded_params.add(name)
606
+ unloaded_params = params_dict.keys() - loaded_params
607
+ if unloaded_params:
608
+ raise RuntimeError(
609
+ f"Some weights are not initialized from checkpoints: {unloaded_params}"
610
+ )
611
+ return loaded_params
668
612
 
669
613
 
670
614
  EntryClass = InternVLChatModel
@@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple
6
6
  import torch
7
7
  from torch import nn
8
8
 
9
- from sglang.srt.layers.pooler import Pooler, PoolingType
9
+ from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType
10
10
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
11
11
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
12
12
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -16,6 +16,23 @@ from sglang.srt.models.bert import BertEncoder
16
16
  RobertaConfig = None
17
17
 
18
18
 
19
+ # Adapted from transformers
20
+ class RobertaClassificationHead(nn.Module):
21
+ """Head for sentence-level classification tasks."""
22
+
23
+ def __init__(self, config: RobertaConfig):
24
+ super().__init__()
25
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
26
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
27
+
28
+ def forward(self, features, **kwargs):
29
+ x = features[0, :] # take <s> token (equiv. to [CLS])
30
+ x = self.dense(x)
31
+ x = torch.tanh(x)
32
+ x = self.out_proj(x)
33
+ return x
34
+
35
+
19
36
  class RobertaEmbedding(nn.Module):
20
37
 
21
38
  def __init__(self, config: RobertaConfig):
@@ -51,8 +68,7 @@ class RobertaEmbedding(nn.Module):
51
68
  input_ids: torch.Tensor,
52
69
  seq_lens: torch.Tensor,
53
70
  position_ids: torch.Tensor,
54
- inputs_embeds=None,
55
- token_type_ids: Optional[torch.Tensor] = None,
71
+ forward_batch: ForwardBatch,
56
72
  ) -> torch.Tensor:
57
73
  input_shape = input_ids.size()
58
74
  inputs_embeds = self.word_embeddings(input_ids)
@@ -82,6 +98,8 @@ class RobertaEmbedding(nn.Module):
82
98
 
83
99
  # Position embeddings.
84
100
  position_embeddings = self.position_embeddings(position_ids)
101
+
102
+ token_type_ids = forward_batch.token_type_ids
85
103
  if token_type_ids is None:
86
104
  token_type_ids = torch.zeros(
87
105
  input_shape, dtype=torch.long, device=inputs_embeds.device
@@ -93,20 +111,25 @@ class RobertaEmbedding(nn.Module):
93
111
  return embeddings
94
112
 
95
113
 
96
- class XLMRobertaModel(nn.Module):
114
+ class XLMRobertaBaseModel(nn.Module):
97
115
  def __init__(
98
116
  self,
99
117
  *,
100
118
  config: RobertaConfig,
101
119
  quant_config: Optional[QuantizationConfig] = None,
102
120
  prefix: str = "",
121
+ add_pooling_layer: bool = False,
103
122
  ):
104
123
  super().__init__()
105
124
 
106
125
  self.config = config
107
126
  self.embeddings = RobertaEmbedding(config)
108
127
  self.encoder = BertEncoder(config=config, quant_config=quant_config, prefix="")
109
- self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
128
+ self.pooler = (
129
+ Pooler(pooling_type=PoolingType.CLS, normalize=True)
130
+ if add_pooling_layer
131
+ else None
132
+ )
110
133
 
111
134
  @torch.no_grad()
112
135
  def forward(
@@ -124,11 +147,12 @@ class XLMRobertaModel(nn.Module):
124
147
  input_ids=input_ids,
125
148
  position_ids=positions,
126
149
  seq_lens=forward_batch.seq_lens,
150
+ forward_batch=forward_batch,
127
151
  )
128
152
 
129
153
  hidden_states = self.encoder(hidden_states, forward_batch=forward_batch)
130
- pooler_out = self.pooler(hidden_states, forward_batch)
131
- return pooler_out
154
+
155
+ return hidden_states
132
156
 
133
157
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
134
158
  stacked_params_mapping = [
@@ -141,7 +165,7 @@ class XLMRobertaModel(nn.Module):
141
165
  params_dict = dict(self.named_parameters())
142
166
  for name, loaded_weight in weights:
143
167
  name = name.replace("self", "self_attn")
144
- if "pooler" in name:
168
+ if self.pooler is None and "pooler" in name:
145
169
  continue
146
170
  for param_name, weight_name, shard_id in stacked_params_mapping:
147
171
 
@@ -175,4 +199,88 @@ def create_position_ids_from_input_ids(
175
199
  return incremental_indices.long() + padding_idx
176
200
 
177
201
 
178
- EntryClass = [XLMRobertaModel]
202
+ class XLMRobertaModel(nn.Module):
203
+ def __init__(
204
+ self,
205
+ *,
206
+ config: RobertaConfig,
207
+ quant_config: Optional[QuantizationConfig] = None,
208
+ prefix: str = "",
209
+ ):
210
+ super().__init__()
211
+ self.roberta = XLMRobertaBaseModel(
212
+ config=config, quant_config=quant_config, prefix=prefix
213
+ )
214
+ self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
215
+
216
+ def forward(
217
+ self,
218
+ input_ids: torch.Tensor,
219
+ positions: torch.Tensor,
220
+ forward_batch: ForwardBatch,
221
+ input_embeds: torch.Tensor = None,
222
+ get_embedding: bool = False,
223
+ ) -> torch.Tensor:
224
+ hidden_states = self.roberta(
225
+ input_ids, positions, forward_batch, input_embeds, get_embedding
226
+ )
227
+ return self.pooler(hidden_states, forward_batch)
228
+
229
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
230
+ self.roberta.load_weights(weights)
231
+
232
+
233
+ class XLMRobertaForSequenceClassification(nn.Module):
234
+ def __init__(
235
+ self,
236
+ *,
237
+ config: RobertaConfig,
238
+ quant_config: Optional[QuantizationConfig] = None,
239
+ prefix: str = "",
240
+ ):
241
+ super().__init__()
242
+ self.roberta = XLMRobertaBaseModel(
243
+ config=config, quant_config=quant_config, prefix=prefix
244
+ )
245
+ self.classifier = RobertaClassificationHead(config)
246
+ self.pooler = CrossEncodingPooler(config, self.classifier, self.roberta.pooler)
247
+
248
+ def forward(
249
+ self,
250
+ input_ids: torch.Tensor,
251
+ positions: torch.Tensor,
252
+ forward_batch: ForwardBatch,
253
+ input_embeds: torch.Tensor = None,
254
+ get_embedding: bool = True,
255
+ ) -> torch.Tensor:
256
+ assert (
257
+ get_embedding
258
+ ), "XLMRobertaForSequenceClassification is only used for rerank"
259
+
260
+ hidden_states = self.roberta(
261
+ input_ids, positions, forward_batch, input_embeds, get_embedding
262
+ )
263
+ return self.pooler(hidden_states, forward_batch)
264
+
265
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
266
+ self_weights = []
267
+
268
+ def weight_filter():
269
+ for name, weight in weights:
270
+ if name.startswith("roberta."):
271
+ yield (name[len("roberta.") :], weight)
272
+ else:
273
+ self_weights.append((name, weight))
274
+
275
+ self.roberta.load_weights(weight_filter())
276
+
277
+ params_dict = dict(self.named_parameters())
278
+
279
+ for name, loaded_weight in self_weights:
280
+ if name.startswith("classifier"):
281
+ param = params_dict[name]
282
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
283
+ weight_loader(param, loaded_weight)
284
+
285
+
286
+ EntryClass = [XLMRobertaModel, XLMRobertaForSequenceClassification]