sglang 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/api.py +7 -1
  2. sglang/bench_latency.py +9 -6
  3. sglang/bench_serving.py +46 -22
  4. sglang/global_config.py +1 -1
  5. sglang/lang/backend/runtime_endpoint.py +60 -49
  6. sglang/lang/compiler.py +2 -2
  7. sglang/lang/interpreter.py +4 -2
  8. sglang/lang/ir.py +16 -7
  9. sglang/srt/constrained/base_tool_cache.py +1 -1
  10. sglang/srt/constrained/fsm_cache.py +12 -2
  11. sglang/srt/constrained/jump_forward.py +13 -2
  12. sglang/srt/layers/activation.py +32 -0
  13. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  14. sglang/srt/layers/extend_attention.py +9 -2
  15. sglang/srt/layers/fused_moe/__init__.py +1 -0
  16. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  17. sglang/srt/layers/fused_moe/layer.py +587 -0
  18. sglang/srt/layers/layernorm.py +65 -0
  19. sglang/srt/layers/logits_processor.py +7 -2
  20. sglang/srt/layers/pooler.py +50 -0
  21. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  22. sglang/srt/layers/radix_attention.py +40 -16
  23. sglang/srt/managers/detokenizer_manager.py +31 -9
  24. sglang/srt/managers/io_struct.py +63 -0
  25. sglang/srt/managers/policy_scheduler.py +173 -25
  26. sglang/srt/managers/schedule_batch.py +115 -97
  27. sglang/srt/managers/tokenizer_manager.py +194 -112
  28. sglang/srt/managers/tp_worker.py +290 -359
  29. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  30. sglang/srt/mem_cache/chunk_cache.py +43 -20
  31. sglang/srt/mem_cache/memory_pool.py +2 -2
  32. sglang/srt/mem_cache/radix_cache.py +74 -40
  33. sglang/srt/model_executor/cuda_graph_runner.py +71 -25
  34. sglang/srt/model_executor/forward_batch_info.py +293 -156
  35. sglang/srt/model_executor/model_runner.py +77 -57
  36. sglang/srt/models/chatglm.py +2 -2
  37. sglang/srt/models/commandr.py +1 -1
  38. sglang/srt/models/deepseek.py +2 -2
  39. sglang/srt/models/deepseek_v2.py +7 -6
  40. sglang/srt/models/gemma.py +1 -1
  41. sglang/srt/models/gemma2.py +11 -6
  42. sglang/srt/models/grok.py +50 -396
  43. sglang/srt/models/internlm2.py +2 -7
  44. sglang/srt/models/llama2.py +4 -4
  45. sglang/srt/models/llama_embedding.py +88 -0
  46. sglang/srt/models/minicpm.py +2 -2
  47. sglang/srt/models/mixtral.py +56 -254
  48. sglang/srt/models/mixtral_quant.py +1 -4
  49. sglang/srt/models/qwen.py +2 -2
  50. sglang/srt/models/qwen2.py +2 -2
  51. sglang/srt/models/qwen2_moe.py +2 -13
  52. sglang/srt/models/stablelm.py +1 -1
  53. sglang/srt/openai_api/adapter.py +187 -48
  54. sglang/srt/openai_api/protocol.py +37 -1
  55. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  56. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  57. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  58. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  59. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  60. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  61. sglang/srt/sampling_params.py +31 -8
  62. sglang/srt/server.py +91 -29
  63. sglang/srt/server_args.py +32 -19
  64. sglang/srt/utils.py +32 -15
  65. sglang/test/run_eval.py +10 -1
  66. sglang/test/runners.py +81 -73
  67. sglang/test/simple_eval_humaneval.py +2 -8
  68. sglang/test/simple_eval_mgsm.py +203 -0
  69. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  70. sglang/test/test_layernorm.py +60 -0
  71. sglang/test/test_programs.py +36 -7
  72. sglang/test/test_utils.py +24 -2
  73. sglang/utils.py +0 -1
  74. sglang/version.py +1 -1
  75. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/METADATA +33 -16
  76. sglang-0.2.13.dist-info/RECORD +112 -0
  77. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
  78. sglang/srt/layers/linear.py +0 -884
  79. sglang/srt/layers/quantization/__init__.py +0 -64
  80. sglang/srt/layers/quantization/fp8.py +0 -677
  81. sglang/srt/model_loader/model_loader.py +0 -292
  82. sglang/srt/model_loader/utils.py +0 -275
  83. sglang-0.2.11.dist-info/RECORD +0 -102
  84. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
  85. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
@@ -38,6 +38,7 @@ from vllm.distributed import (
38
38
  init_distributed_environment,
39
39
  initialize_model_parallel,
40
40
  )
41
+ from vllm.model_executor.model_loader import get_model
41
42
  from vllm.model_executor.models import ModelRegistry
42
43
 
43
44
  from sglang.global_config import global_config
@@ -52,7 +53,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
52
53
  from sglang.srt.server_args import ServerArgs
53
54
  from sglang.srt.utils import (
54
55
  get_available_gpu_memory,
55
- is_llama3_405b_fp8,
56
+ is_generation_model,
57
+ is_llama3_405b_fp8_head_16,
56
58
  is_multimodal_model,
57
59
  monkey_patch_vllm_dummy_weight_loader,
58
60
  monkey_patch_vllm_p2p_access_check,
@@ -130,10 +132,12 @@ class ModelRunner:
130
132
  server_args.max_total_tokens,
131
133
  )
132
134
  self.init_cublas()
133
- self.init_flash_infer()
135
+ self.init_flashinfer()
134
136
 
135
- # Capture cuda graphs
136
- self.init_cuda_graphs()
137
+ if self.is_generation:
138
+ # FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
139
+ # Capture cuda graphs
140
+ self.init_cuda_graphs()
137
141
 
138
142
  def load_model(self):
139
143
  logger.info(
@@ -155,7 +159,7 @@ class ModelRunner:
155
159
  skip_tokenizer_init=True,
156
160
  )
157
161
 
158
- if is_llama3_405b_fp8(self.model_config) and self.tp_size <= 8:
162
+ if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
159
163
  # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
160
164
  self.model_config.hf_config.num_key_value_heads = 8
161
165
  vllm_model_config.hf_config.num_key_value_heads = 8
@@ -165,15 +169,6 @@ class ModelRunner:
165
169
  if self.model_config.model_overide_args is not None:
166
170
  vllm_model_config.hf_config.update(self.model_config.model_overide_args)
167
171
 
168
- if (
169
- self.server_args.efficient_weight_load
170
- and "llama" in self.server_args.model_path.lower()
171
- and self.server_args.quantization == "fp8"
172
- ):
173
- from sglang.srt.model_loader.model_loader import get_model
174
- else:
175
- from vllm.model_executor.model_loader import get_model
176
-
177
172
  self.model = get_model(
178
173
  model_config=vllm_model_config,
179
174
  device_config=device_config,
@@ -184,6 +179,15 @@ class ModelRunner:
184
179
  scheduler_config=None,
185
180
  cache_config=None,
186
181
  )
182
+ self.sliding_window_size = (
183
+ self.model.get_window_size()
184
+ if hasattr(self.model, "get_window_size")
185
+ else None
186
+ )
187
+ self.is_generation = is_generation_model(
188
+ self.model_config.hf_config.architectures
189
+ )
190
+
187
191
  logger.info(
188
192
  f"[gpu={self.gpu_id}] Load weight end. "
189
193
  f"type={type(self.model).__name__}, "
@@ -287,8 +291,11 @@ class ModelRunner:
287
291
  c = a @ b
288
292
  return c
289
293
 
290
- def init_flash_infer(self):
294
+ def init_flashinfer(self):
291
295
  if self.server_args.disable_flashinfer:
296
+ assert (
297
+ self.sliding_window_size is None
298
+ ), "turn on flashinfer to support window attention"
292
299
  self.flashinfer_prefill_wrapper_ragged = None
293
300
  self.flashinfer_prefill_wrapper_paged = None
294
301
  self.flashinfer_decode_wrapper = None
@@ -302,20 +309,47 @@ class ModelRunner:
302
309
  else:
303
310
  use_tensor_cores = False
304
311
 
305
- self.flashinfer_workspace_buffers = torch.empty(
306
- 2, global_config.flashinfer_workspace_size, dtype=torch.uint8, device="cuda"
307
- )
308
- self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
309
- self.flashinfer_workspace_buffers[0], "NHD"
310
- )
311
- self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
312
- self.flashinfer_workspace_buffers[1], "NHD"
313
- )
314
- self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
315
- self.flashinfer_workspace_buffers[0],
316
- "NHD",
317
- use_tensor_cores=use_tensor_cores,
318
- )
312
+ if self.sliding_window_size is None:
313
+ self.flashinfer_workspace_buffer = torch.empty(
314
+ global_config.flashinfer_workspace_size,
315
+ dtype=torch.uint8,
316
+ device="cuda",
317
+ )
318
+ self.flashinfer_prefill_wrapper_ragged = (
319
+ BatchPrefillWithRaggedKVCacheWrapper(
320
+ self.flashinfer_workspace_buffer, "NHD"
321
+ )
322
+ )
323
+ self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
324
+ self.flashinfer_workspace_buffer, "NHD"
325
+ )
326
+ self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
327
+ self.flashinfer_workspace_buffer,
328
+ "NHD",
329
+ use_tensor_cores=use_tensor_cores,
330
+ )
331
+ else:
332
+ self.flashinfer_workspace_buffer = torch.empty(
333
+ global_config.flashinfer_workspace_size,
334
+ dtype=torch.uint8,
335
+ device="cuda",
336
+ )
337
+ self.flashinfer_prefill_wrapper_ragged = None
338
+ self.flashinfer_prefill_wrapper_paged = []
339
+ self.flashinfer_decode_wrapper = []
340
+ for i in range(2):
341
+ self.flashinfer_prefill_wrapper_paged.append(
342
+ BatchPrefillWithPagedKVCacheWrapper(
343
+ self.flashinfer_workspace_buffer, "NHD"
344
+ )
345
+ )
346
+ self.flashinfer_decode_wrapper.append(
347
+ BatchDecodeWithPagedKVCacheWrapper(
348
+ self.flashinfer_workspace_buffer,
349
+ "NHD",
350
+ use_tensor_cores=use_tensor_cores,
351
+ )
352
+ )
319
353
 
320
354
  def init_cuda_graphs(self):
321
355
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
@@ -350,33 +384,22 @@ class ModelRunner:
350
384
  if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
351
385
  return self.cuda_graph_runner.replay(batch)
352
386
 
353
- input_metadata = InputMetadata.create(
387
+ input_metadata = InputMetadata.from_schedule_batch(
354
388
  self,
355
- forward_mode=ForwardMode.DECODE,
356
- req_pool_indices=batch.req_pool_indices,
357
- seq_lens=batch.seq_lens,
358
- prefix_lens=batch.prefix_lens,
359
- position_ids_offsets=batch.position_ids_offsets,
360
- out_cache_loc=batch.out_cache_loc,
361
- top_logprobs_nums=batch.top_logprobs_nums,
362
- return_logprob=batch.return_logprob,
389
+ batch,
390
+ ForwardMode.DECODE,
363
391
  )
392
+
364
393
  return self.model.forward(
365
394
  batch.input_ids, input_metadata.positions, input_metadata
366
395
  )
367
396
 
368
397
  @torch.inference_mode()
369
398
  def forward_extend(self, batch: ScheduleBatch):
370
- input_metadata = InputMetadata.create(
399
+ input_metadata = InputMetadata.from_schedule_batch(
371
400
  self,
401
+ batch,
372
402
  forward_mode=ForwardMode.EXTEND,
373
- req_pool_indices=batch.req_pool_indices,
374
- seq_lens=batch.seq_lens,
375
- prefix_lens=batch.prefix_lens,
376
- position_ids_offsets=batch.position_ids_offsets,
377
- out_cache_loc=batch.out_cache_loc,
378
- top_logprobs_nums=batch.top_logprobs_nums,
379
- return_logprob=batch.return_logprob,
380
403
  )
381
404
  return self.model.forward(
382
405
  batch.input_ids, input_metadata.positions, input_metadata
@@ -384,24 +407,18 @@ class ModelRunner:
384
407
 
385
408
  @torch.inference_mode()
386
409
  def forward_extend_multi_modal(self, batch: ScheduleBatch):
387
- input_metadata = InputMetadata.create(
410
+ input_metadata = InputMetadata.from_schedule_batch(
388
411
  self,
412
+ batch,
389
413
  forward_mode=ForwardMode.EXTEND,
390
- req_pool_indices=batch.req_pool_indices,
391
- seq_lens=batch.seq_lens,
392
- prefix_lens=batch.prefix_lens,
393
- position_ids_offsets=batch.position_ids_offsets,
394
- out_cache_loc=batch.out_cache_loc,
395
- return_logprob=batch.return_logprob,
396
- top_logprobs_nums=batch.top_logprobs_nums,
397
414
  )
398
415
  return self.model.forward(
399
416
  batch.input_ids,
400
417
  input_metadata.positions,
401
418
  input_metadata,
402
- batch.pixel_values,
403
- batch.image_sizes,
404
- batch.image_offsets,
419
+ input_metadata.pixel_values,
420
+ input_metadata.image_sizes,
421
+ input_metadata.image_offsets,
405
422
  )
406
423
 
407
424
  def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
@@ -429,8 +446,10 @@ def import_model_classes():
429
446
  entry, list
430
447
  ): # To support multiple model classes in one module
431
448
  for tmp in entry:
449
+ assert tmp.__name__ not in model_arch_name_to_cls
432
450
  model_arch_name_to_cls[tmp.__name__] = tmp
433
451
  else:
452
+ assert entry.__name__ not in model_arch_name_to_cls
434
453
  model_arch_name_to_cls[entry.__name__] = entry
435
454
 
436
455
  # compat: some models such as chatglm has incorrect class set in config.json
@@ -440,6 +459,7 @@ def import_model_classes():
440
459
  ):
441
460
  for remap in module.EntryClassRemapping:
442
461
  if isinstance(remap, tuple) and len(remap) == 2:
462
+ assert remap[0] not in model_arch_name_to_cls
443
463
  model_arch_name_to_cls[remap[0]] = remap[1]
444
464
 
445
465
  return model_arch_name_to_cls
@@ -24,8 +24,6 @@ from torch import nn
24
24
  from torch.nn import LayerNorm
25
25
  from vllm.config import CacheConfig
26
26
  from vllm.distributed import get_tensor_model_parallel_world_size
27
- from vllm.model_executor.layers.activation import SiluAndMul
28
- from vllm.model_executor.layers.layernorm import RMSNorm
29
27
  from vllm.model_executor.layers.linear import (
30
28
  MergedColumnParallelLinear,
31
29
  QKVParallelLinear,
@@ -43,6 +41,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
43
41
  from vllm.sequence import SamplerOutput
44
42
  from vllm.transformers_utils.configs import ChatGLMConfig
45
43
 
44
+ from sglang.srt.layers.activation import SiluAndMul
45
+ from sglang.srt.layers.layernorm import RMSNorm
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
48
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -50,7 +50,6 @@ from vllm.distributed import (
50
50
  get_tensor_model_parallel_rank,
51
51
  get_tensor_model_parallel_world_size,
52
52
  )
53
- from vllm.model_executor.layers.activation import SiluAndMul
54
53
  from vllm.model_executor.layers.linear import (
55
54
  MergedColumnParallelLinear,
56
55
  QKVParallelLinear,
@@ -62,6 +61,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmb
62
61
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
63
62
  from vllm.model_executor.utils import set_weight_attrs
64
63
 
64
+ from sglang.srt.layers.activation import SiluAndMul
65
65
  from sglang.srt.layers.logits_processor import LogitsProcessor
66
66
  from sglang.srt.layers.radix_attention import RadixAttention
67
67
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -27,9 +27,7 @@ from vllm.distributed import (
27
27
  get_tensor_model_parallel_world_size,
28
28
  tensor_model_parallel_all_reduce,
29
29
  )
30
- from vllm.model_executor.layers.activation import SiluAndMul
31
30
  from vllm.model_executor.layers.fused_moe import fused_moe
32
- from vllm.model_executor.layers.layernorm import RMSNorm
33
31
  from vllm.model_executor.layers.linear import (
34
32
  MergedColumnParallelLinear,
35
33
  QKVParallelLinear,
@@ -44,6 +42,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
44
42
  )
45
43
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
46
44
 
45
+ from sglang.srt.layers.activation import SiluAndMul
46
+ from sglang.srt.layers.layernorm import RMSNorm
47
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -26,9 +26,7 @@ from vllm.distributed import (
26
26
  get_tensor_model_parallel_world_size,
27
27
  tensor_model_parallel_all_reduce,
28
28
  )
29
- from vllm.model_executor.layers.activation import SiluAndMul
30
29
  from vllm.model_executor.layers.fused_moe import FusedMoE
31
- from vllm.model_executor.layers.layernorm import RMSNorm
32
30
  from vllm.model_executor.layers.linear import (
33
31
  ColumnParallelLinear,
34
32
  MergedColumnParallelLinear,
@@ -43,6 +41,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
43
41
  )
44
42
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
43
 
44
+ from sglang.srt.layers.activation import SiluAndMul
45
+ from sglang.srt.layers.layernorm import RMSNorm
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
48
  from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -445,11 +445,12 @@ class DeepseekV2AttentionMLA(nn.Module):
445
445
  q_nope_out = q_input[..., : self.kv_lora_rank]
446
446
  torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1))
447
447
 
448
- k_input = self.kv_a_proj_with_mqa(hidden_states)[0].unsqueeze(1)
449
- k_pe = k_input[..., self.kv_lora_rank :]
450
- v_input = k_input[..., : self.kv_lora_rank]
451
- v_input = self.kv_a_layernorm(v_input.contiguous())
448
+ latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
449
+ v_input = latent_cache[..., : self.kv_lora_rank]
450
+ v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
451
+ k_input = latent_cache.unsqueeze(1)
452
452
  k_input[..., : self.kv_lora_rank] = v_input
453
+ k_pe = k_input[..., self.kv_lora_rank :]
453
454
 
454
455
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
455
456
  q_input[..., self.kv_lora_rank :] = q_pe
@@ -24,7 +24,6 @@ from transformers import PretrainedConfig
24
24
  from vllm.config import CacheConfig, LoRAConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
26
  from vllm.model_executor.layers.activation import GeluAndMul
27
- from vllm.model_executor.layers.layernorm import RMSNorm
28
27
  from vllm.model_executor.layers.linear import (
29
28
  MergedColumnParallelLinear,
30
29
  QKVParallelLinear,
@@ -35,6 +34,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
35
34
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
36
35
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
37
36
 
37
+ from sglang.srt.layers.layernorm import RMSNorm
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
40
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -38,13 +38,18 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
38
38
  # from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
39
39
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
40
40
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
- from vllm.model_executor.sampling_metadata import SamplingMetadata
42
41
 
43
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
44
43
  from sglang.srt.layers.radix_attention import RadixAttention
45
44
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
46
45
 
47
46
 
47
+ # Aligned with HF's implementation, using sliding window inclusive with the last token
48
+ # SGLang assumes exclusive
49
+ def get_window_size(config):
50
+ return config.sliding_window - 1
51
+
52
+
48
53
  class GemmaRMSNorm(CustomOp):
49
54
  """RMS normalization for Gemma.
50
55
 
@@ -201,17 +206,14 @@ class Gemma2Attention(nn.Module):
201
206
  dtype=torch.get_default_dtype(),
202
207
  )
203
208
 
204
- # from vLLM: FIXME(woosuk): While Gemma 2 uses sliding window attention for every
205
- # odd layer, vLLM currently ignores it and uses global attention for
206
- # all layers.
207
- use_sliding_window = layer_idx % 2 == 1 and config.sliding_window is not None
208
- del use_sliding_window # Unused.
209
+ use_sliding_window = layer_idx % 2 == 0 and hasattr(config, "sliding_window")
209
210
  self.attn = RadixAttention(
210
211
  self.num_heads,
211
212
  self.head_dim,
212
213
  self.scaling,
213
214
  num_kv_heads=self.num_kv_heads,
214
215
  layer_id=layer_idx,
216
+ sliding_window_size=get_window_size(config) if use_sliding_window else None,
215
217
  logit_cap=self.config.attn_logit_softcapping,
216
218
  )
217
219
 
@@ -404,6 +406,9 @@ class Gemma2ForCausalLM(nn.Module):
404
406
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
405
407
  )
406
408
 
409
+ def get_window_size(self):
410
+ return get_window_size(self.config)
411
+
407
412
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
408
413
  stacked_params_mapping = [
409
414
  # (param_name, shard_name, shard_id)