sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +2 -2
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +9 -7
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +48 -43
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +7 -2
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +227 -120
- sglang/srt/disaggregation/nixl/conn.py +1 -0
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +7 -1
- sglang/srt/entrypoints/engine.py +17 -2
- sglang/srt/entrypoints/http_server.py +17 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +1 -1
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +72 -71
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +3 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +76 -24
- sglang/srt/managers/schedule_policy.py +0 -3
- sglang/srt/managers/scheduler.py +113 -88
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +133 -34
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/memory_pool.py +2 -0
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +19 -14
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +23 -20
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +5 -6
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +30 -4
- sglang/srt/openai_api/protocol.py +0 -8
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +34 -4
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +6 -5
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +89 -14
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +6 -5
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +107 -104
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
sglang/lang/chat_template.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import re
|
1
2
|
from dataclasses import dataclass
|
2
3
|
from enum import Enum, auto
|
3
4
|
from typing import Callable, Dict, List, Tuple
|
@@ -71,9 +72,9 @@ def get_chat_template(name):
|
|
71
72
|
|
72
73
|
def get_chat_template_by_model_path(model_path):
|
73
74
|
for matching_func in matching_function_registry:
|
74
|
-
|
75
|
-
if
|
76
|
-
return
|
75
|
+
template_name = matching_func(model_path)
|
76
|
+
if template_name is not None:
|
77
|
+
return get_chat_template(template_name)
|
77
78
|
return get_chat_template("default")
|
78
79
|
|
79
80
|
|
@@ -193,6 +194,21 @@ register_chat_template(
|
|
193
194
|
)
|
194
195
|
)
|
195
196
|
|
197
|
+
# Reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json
|
198
|
+
register_chat_template(
|
199
|
+
ChatTemplate(
|
200
|
+
name="mistral",
|
201
|
+
default_system_prompt=None,
|
202
|
+
role_prefix_and_suffix={
|
203
|
+
"system": ("[SYSTEM_PROMPT] ", " [/SYSTEM_PROMPT]"),
|
204
|
+
"user": ("[INST] ", " [/INST]"),
|
205
|
+
"assistant": ("", " </s><s>"),
|
206
|
+
},
|
207
|
+
stop_str=("</s>",),
|
208
|
+
image_token="[IMG]",
|
209
|
+
)
|
210
|
+
)
|
211
|
+
|
196
212
|
register_chat_template(
|
197
213
|
ChatTemplate(
|
198
214
|
name="llama-3-instruct",
|
@@ -479,134 +495,118 @@ register_chat_template(
|
|
479
495
|
|
480
496
|
@register_chat_template_matching_function
|
481
497
|
def match_deepseek(model_path: str):
|
482
|
-
if (
|
483
|
-
"
|
484
|
-
)
|
485
|
-
return
|
498
|
+
if re.search(r"deepseek-(v3|r1)", model_path, re.IGNORECASE) and not re.search(
|
499
|
+
r"base", model_path, re.IGNORECASE
|
500
|
+
):
|
501
|
+
return "deepseek-v3"
|
486
502
|
|
487
503
|
|
488
504
|
@register_chat_template_matching_function
|
489
505
|
def match_deepseek_janus_pro(model_path: str):
|
490
|
-
if "janus"
|
491
|
-
return
|
506
|
+
if re.search(r"janus", model_path, re.IGNORECASE):
|
507
|
+
return "janus-pro"
|
492
508
|
|
493
509
|
|
494
510
|
@register_chat_template_matching_function
|
495
511
|
def match_dbrx(model_path: str):
|
496
|
-
if "dbrx"
|
497
|
-
|
512
|
+
if re.search(r"dbrx", model_path, re.IGNORECASE) and re.search(
|
513
|
+
r"instruct", model_path, re.IGNORECASE
|
514
|
+
):
|
515
|
+
return "dbrx-instruct"
|
498
516
|
|
499
517
|
|
500
518
|
@register_chat_template_matching_function
|
501
519
|
def match_vicuna(model_path: str):
|
502
|
-
if "vicuna"
|
503
|
-
return
|
504
|
-
if "llava-v1.5" in model_path.lower():
|
505
|
-
return get_chat_template("vicuna_v1.1")
|
506
|
-
if "llava-next-video-7b" in model_path.lower():
|
507
|
-
return get_chat_template("vicuna_v1.1")
|
520
|
+
if re.search(r"vicuna|llava-v1\.5|llava-next-video-7b", model_path, re.IGNORECASE):
|
521
|
+
return "vicuna_v1.1"
|
508
522
|
|
509
523
|
|
510
524
|
@register_chat_template_matching_function
|
511
525
|
def match_llama2_chat(model_path: str):
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
526
|
+
if re.search(
|
527
|
+
r"llama-2.*chat|codellama.*instruct",
|
528
|
+
model_path,
|
529
|
+
re.IGNORECASE,
|
530
|
+
):
|
531
|
+
return "llama-2-chat"
|
532
|
+
|
533
|
+
|
534
|
+
@register_chat_template_matching_function
|
535
|
+
def match_mistral(model_path: str):
|
536
|
+
if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE):
|
537
|
+
return "mistral"
|
521
538
|
|
522
539
|
|
523
540
|
@register_chat_template_matching_function
|
524
541
|
def match_llama3_instruct(model_path: str):
|
525
|
-
|
526
|
-
|
527
|
-
return get_chat_template("llama-3-instruct")
|
542
|
+
if re.search(r"llama-3.*instruct", model_path, re.IGNORECASE):
|
543
|
+
return "llama-3-instruct"
|
528
544
|
|
529
545
|
|
530
546
|
@register_chat_template_matching_function
|
531
547
|
def match_chat_ml(model_path: str):
|
532
|
-
|
533
|
-
|
534
|
-
if "
|
535
|
-
return
|
536
|
-
|
537
|
-
|
538
|
-
return get_chat_template("qwen2-vl")
|
539
|
-
if "qwen" in model_path:
|
540
|
-
if "vl" in model_path:
|
541
|
-
return get_chat_template("qwen2-vl")
|
542
|
-
if ("chat" in model_path or "instruct" in model_path) and (
|
543
|
-
"llava" not in model_path
|
544
|
-
):
|
545
|
-
return get_chat_template("qwen")
|
546
|
-
if (
|
547
|
-
"llava-v1.6-34b" in model_path
|
548
|
-
or "llava-v1.6-yi-34b" in model_path
|
549
|
-
or "llava-next-video-34b" in model_path
|
550
|
-
or "llava-onevision-qwen2" in model_path
|
548
|
+
if re.search(r"tinyllama", model_path, re.IGNORECASE):
|
549
|
+
return "chatml"
|
550
|
+
if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
|
551
|
+
return "qwen2-vl"
|
552
|
+
if re.search(r"qwen.*(chat|instruct)", model_path, re.IGNORECASE) and not re.search(
|
553
|
+
r"llava", model_path, re.IGNORECASE
|
551
554
|
):
|
552
|
-
return
|
555
|
+
return "qwen"
|
556
|
+
if re.search(
|
557
|
+
r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
|
558
|
+
model_path,
|
559
|
+
re.IGNORECASE,
|
560
|
+
):
|
561
|
+
return "chatml-llava"
|
553
562
|
|
554
563
|
|
555
564
|
@register_chat_template_matching_function
|
556
565
|
def match_chat_yi(model_path: str):
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
566
|
+
if re.search(r"yi-vl", model_path, re.IGNORECASE) and not re.search(
|
567
|
+
r"llava", model_path, re.IGNORECASE
|
568
|
+
):
|
569
|
+
return "yi-vl"
|
570
|
+
elif re.search(r"yi-1\.5.*chat", model_path, re.IGNORECASE):
|
571
|
+
return "yi-1.5"
|
562
572
|
|
563
573
|
|
564
574
|
@register_chat_template_matching_function
|
565
575
|
def match_gemma_it(model_path: str):
|
566
|
-
|
567
|
-
|
568
|
-
return get_chat_template("gemma-it")
|
576
|
+
if re.search(r"gemma.*it", model_path, re.IGNORECASE):
|
577
|
+
return "gemma-it"
|
569
578
|
|
570
579
|
|
571
580
|
@register_chat_template_matching_function
|
572
581
|
def match_openbmb_minicpm(model_path: str):
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
return get_chat_template("minicpmo")
|
582
|
+
if re.search(r"minicpm-v", model_path, re.IGNORECASE):
|
583
|
+
return "minicpmv"
|
584
|
+
elif re.search(r"minicpm-o", model_path, re.IGNORECASE):
|
585
|
+
return "minicpmo"
|
578
586
|
|
579
587
|
|
580
588
|
@register_chat_template_matching_function
|
581
589
|
def match_c4ai_command_r(model_path: str):
|
582
|
-
|
583
|
-
|
584
|
-
return get_chat_template("c4ai-command-r")
|
590
|
+
if re.search(r"c4ai-command-r", model_path, re.IGNORECASE):
|
591
|
+
return "c4ai-command-r"
|
585
592
|
|
586
593
|
|
587
594
|
@register_chat_template_matching_function
|
588
595
|
def match_granite_instruct(model_path: str):
|
589
|
-
|
590
|
-
|
591
|
-
# need to be updated. For now, assume that the Granite 3.0
|
592
|
-
# template works across the board.
|
593
|
-
if "granite" in model_path and "instruct" in model_path:
|
594
|
-
return get_chat_template("granite-3-instruct")
|
596
|
+
if re.search(r"granite.*instruct", model_path, re.IGNORECASE):
|
597
|
+
return "granite-3-instruct"
|
595
598
|
|
596
599
|
|
597
600
|
@register_chat_template_matching_function
|
598
601
|
def match_gemma3_instruct(model_path: str):
|
599
|
-
|
600
|
-
|
601
|
-
# gemma-3-1b-it is completion model
|
602
|
-
return get_chat_template("gemma-it")
|
602
|
+
if re.search(r"gemma-3", model_path, re.IGNORECASE):
|
603
|
+
return "gemma-it"
|
603
604
|
|
604
605
|
|
605
606
|
@register_chat_template_matching_function
|
606
607
|
def match_internvl_chat(model_path: str):
|
607
|
-
|
608
|
-
|
609
|
-
return get_chat_template("internvl-2-5")
|
608
|
+
if re.search(r"internvl2_5", model_path, re.IGNORECASE):
|
609
|
+
return "internvl-2-5"
|
610
610
|
|
611
611
|
|
612
612
|
if __name__ == "__main__":
|
sglang/lang/tracer.py
CHANGED
@@ -38,7 +38,7 @@ def extract_prefix_by_tracing(program, backend):
|
|
38
38
|
with TracingScope(tracer):
|
39
39
|
tracer.ret_value = program.func(tracer, **arguments)
|
40
40
|
except (StopTracing, TypeError, AttributeError):
|
41
|
-
# Some exceptions may not be
|
41
|
+
# Some exceptions may not be caught
|
42
42
|
pass
|
43
43
|
|
44
44
|
# Run and cache prefix
|
@@ -416,9 +416,9 @@ class DeepseekVLV2Processor(ProcessorMixin):
|
|
416
416
|
h = w = math.ceil(
|
417
417
|
(self.image_size // self.patch_size) / self.downsample_ratio
|
418
418
|
)
|
419
|
-
# global views tokens h * (w + 1), 1 is for line
|
419
|
+
# global views tokens h * (w + 1), 1 is for line separator
|
420
420
|
tokenized_image = [self.image_token_id] * h * (w + 1)
|
421
|
-
# add a
|
421
|
+
# add a separator between global and local views
|
422
422
|
tokenized_image += [self.image_token_id]
|
423
423
|
# local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1)
|
424
424
|
tokenized_image += (
|
@@ -14,10 +14,9 @@
|
|
14
14
|
"""The baseclass of a backend for grammar-guided constrained decoding."""
|
15
15
|
|
16
16
|
import logging
|
17
|
-
from
|
18
|
-
from concurrent.futures import Future, ThreadPoolExecutor
|
17
|
+
from concurrent.futures import ThreadPoolExecutor
|
19
18
|
from dataclasses import dataclass
|
20
|
-
from threading import Event
|
19
|
+
from threading import Event
|
21
20
|
from typing import Dict, List, Optional, Tuple
|
22
21
|
|
23
22
|
import torch
|
@@ -27,11 +26,42 @@ from sglang.srt.server_args import ServerArgs
|
|
27
26
|
logger = logging.getLogger(__name__)
|
28
27
|
|
29
28
|
|
30
|
-
class BaseGrammarObject
|
29
|
+
class BaseGrammarObject:
|
31
30
|
|
32
31
|
def __init__(self):
|
33
32
|
self._finished = False
|
34
33
|
|
34
|
+
def accept_token(self, token: int) -> None:
|
35
|
+
"""
|
36
|
+
Accept a token in the grammar.
|
37
|
+
"""
|
38
|
+
raise NotImplementedError()
|
39
|
+
|
40
|
+
def rollback(self, k: int):
|
41
|
+
raise NotImplementedError()
|
42
|
+
|
43
|
+
def is_terminated(self):
|
44
|
+
return False
|
45
|
+
|
46
|
+
def allocate_vocab_mask(
|
47
|
+
self, vocab_size: int, batch_size: int, device
|
48
|
+
) -> torch.Tensor:
|
49
|
+
raise NotImplementedError()
|
50
|
+
|
51
|
+
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
52
|
+
raise NotImplementedError()
|
53
|
+
|
54
|
+
@staticmethod
|
55
|
+
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
56
|
+
raise NotImplementedError()
|
57
|
+
|
58
|
+
@staticmethod
|
59
|
+
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
60
|
+
raise NotImplementedError()
|
61
|
+
|
62
|
+
def copy(self) -> "BaseGrammarObject":
|
63
|
+
raise NotImplementedError()
|
64
|
+
|
35
65
|
@property
|
36
66
|
def finished(self):
|
37
67
|
return self._finished
|
@@ -40,7 +70,6 @@ class BaseGrammarObject(ABC):
|
|
40
70
|
def finished(self, finished):
|
41
71
|
self._finished = finished
|
42
72
|
|
43
|
-
@abstractmethod
|
44
73
|
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
45
74
|
"""
|
46
75
|
Try to jump forward in the grammar.
|
@@ -49,9 +78,8 @@ class BaseGrammarObject(ABC):
|
|
49
78
|
A jump forward helper which may be used in `jump_forward_str_state`.
|
50
79
|
None if the jump forward is not possible.
|
51
80
|
"""
|
52
|
-
raise NotImplementedError
|
81
|
+
raise NotImplementedError()
|
53
82
|
|
54
|
-
@abstractmethod
|
55
83
|
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
56
84
|
"""
|
57
85
|
Jump forward for the grammar.
|
@@ -60,47 +88,15 @@ class BaseGrammarObject(ABC):
|
|
60
88
|
A tuple of the jump forward string and the next state of the grammar
|
61
89
|
(which can be used in `jump_and_retokenize` if needed).
|
62
90
|
"""
|
63
|
-
raise NotImplementedError
|
91
|
+
raise NotImplementedError()
|
64
92
|
|
65
|
-
@abstractmethod
|
66
93
|
def jump_and_retokenize(
|
67
94
|
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
68
95
|
) -> None:
|
69
96
|
"""
|
70
97
|
Jump forward occurs, and update the grammar state if needed.
|
71
98
|
"""
|
72
|
-
raise NotImplementedError
|
73
|
-
|
74
|
-
@abstractmethod
|
75
|
-
def accept_token(self, token: int) -> None:
|
76
|
-
"""
|
77
|
-
Accept a token in the grammar.
|
78
|
-
"""
|
79
|
-
raise NotImplementedError
|
80
|
-
|
81
|
-
@abstractmethod
|
82
|
-
def allocate_vocab_mask(
|
83
|
-
self, vocab_size: int, batch_size: int, device
|
84
|
-
) -> torch.Tensor:
|
85
|
-
raise NotImplementedError
|
86
|
-
|
87
|
-
@abstractmethod
|
88
|
-
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
89
|
-
raise NotImplementedError
|
90
|
-
|
91
|
-
@staticmethod
|
92
|
-
@abstractmethod
|
93
|
-
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
94
|
-
raise NotImplementedError
|
95
|
-
|
96
|
-
@staticmethod
|
97
|
-
@abstractmethod
|
98
|
-
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
99
|
-
raise NotImplementedError
|
100
|
-
|
101
|
-
@abstractmethod
|
102
|
-
def copy(self) -> "BaseGrammarObject":
|
103
|
-
raise NotImplementedError
|
99
|
+
raise NotImplementedError()
|
104
100
|
|
105
101
|
|
106
102
|
@dataclass
|
@@ -113,10 +109,9 @@ class BaseGrammarBackend:
|
|
113
109
|
def __init__(self):
|
114
110
|
self.executor = ThreadPoolExecutor()
|
115
111
|
self.cache: Dict[Tuple[str, str], CacheEntry] = {}
|
116
|
-
self.cache_lock = Lock()
|
117
112
|
|
118
113
|
def _not_supported(self, key_type: str, key_string: str) -> None:
|
119
|
-
logger.warning(f"Skip unsupported {key_type}
|
114
|
+
logger.warning(f"Skip unsupported {key_type=}, {key_string=}")
|
120
115
|
|
121
116
|
def dispatch_fallback(
|
122
117
|
self, key_type: str, key_string: str
|
@@ -148,40 +143,25 @@ class BaseGrammarBackend:
|
|
148
143
|
return self.dispatch_ebnf(key_string)
|
149
144
|
elif key_type == "structural_tag":
|
150
145
|
return self.dispatch_structural_tag(key_string)
|
146
|
+
elif key_type == "structural_pattern":
|
147
|
+
return self.dispatch_structural_pattern(key_string)
|
151
148
|
else:
|
152
149
|
return self.dispatch_fallback(key_type, key_string)
|
153
150
|
|
154
|
-
def
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
self.cache[key] = entry
|
163
|
-
|
164
|
-
if cache_hit:
|
165
|
-
entry.event.wait()
|
166
|
-
else:
|
167
|
-
entry.value = self._init_value_dispatch(key)
|
168
|
-
entry.event.set()
|
169
|
-
return entry.value.copy() if entry.value else None
|
170
|
-
|
171
|
-
def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
172
|
-
with self.cache_lock:
|
173
|
-
entry = self.cache.get(key)
|
174
|
-
if not entry or not entry.event.is_set():
|
175
|
-
return None
|
176
|
-
val = self.cache[key].value
|
177
|
-
return val.copy() if val else None
|
151
|
+
def get_cached_or_future_value(
|
152
|
+
self, key: Tuple[str, str]
|
153
|
+
) -> Optional[BaseGrammarObject]:
|
154
|
+
value = self.cache.get(key)
|
155
|
+
if value:
|
156
|
+
return value.copy(), True
|
157
|
+
value = self.executor.submit(self._init_value_dispatch, key)
|
158
|
+
return value, False
|
178
159
|
|
179
|
-
def
|
180
|
-
|
160
|
+
def set_cache(self, key: Tuple[str, str], value: BaseGrammarObject):
|
161
|
+
self.cache[key] = value
|
181
162
|
|
182
163
|
def reset(self):
|
183
|
-
|
184
|
-
self.cache.clear()
|
164
|
+
self.cache.clear()
|
185
165
|
|
186
166
|
|
187
167
|
def create_grammar_backend(
|
@@ -211,9 +191,12 @@ def create_grammar_backend(
|
|
211
191
|
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
|
212
192
|
|
213
193
|
if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
|
214
|
-
from .reasoner_grammar_backend import
|
194
|
+
from sglang.srt.constrained.reasoner_grammar_backend import (
|
195
|
+
ReasonerGrammarBackend,
|
196
|
+
)
|
215
197
|
|
216
198
|
grammar_backend = ReasonerGrammarBackend(
|
217
199
|
grammar_backend, tokenizer.think_end_id
|
218
200
|
)
|
201
|
+
|
219
202
|
return grammar_backend
|
@@ -50,21 +50,6 @@ class GuidanceGrammar(BaseGrammarObject):
|
|
50
50
|
self.finished = False
|
51
51
|
self.bitmask = None
|
52
52
|
|
53
|
-
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
54
|
-
ff_tokens = self.ll_matcher.compute_ff_tokens()
|
55
|
-
if ff_tokens:
|
56
|
-
return ff_tokens, ""
|
57
|
-
else:
|
58
|
-
return None
|
59
|
-
|
60
|
-
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
61
|
-
return "", -1
|
62
|
-
|
63
|
-
def jump_and_retokenize(
|
64
|
-
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
65
|
-
):
|
66
|
-
pass
|
67
|
-
|
68
53
|
def accept_token(self, token: int):
|
69
54
|
if not self.ll_matcher.consume_token(token):
|
70
55
|
logger.warning(f"matcher error: {self.ll_matcher.get_error()}")
|
@@ -104,6 +89,21 @@ class GuidanceGrammar(BaseGrammarObject):
|
|
104
89
|
serialized_grammar=self.serialized_grammar,
|
105
90
|
)
|
106
91
|
|
92
|
+
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
93
|
+
ff_tokens = self.ll_matcher.compute_ff_tokens()
|
94
|
+
if ff_tokens:
|
95
|
+
return ff_tokens, ""
|
96
|
+
else:
|
97
|
+
return None
|
98
|
+
|
99
|
+
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
100
|
+
return "", -1
|
101
|
+
|
102
|
+
def jump_and_retokenize(
|
103
|
+
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
104
|
+
):
|
105
|
+
pass
|
106
|
+
|
107
107
|
|
108
108
|
class GuidanceBackend(BaseGrammarBackend):
|
109
109
|
|
@@ -130,12 +130,16 @@ class GuidanceBackend(BaseGrammarBackend):
|
|
130
130
|
return None
|
131
131
|
|
132
132
|
def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]:
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
133
|
+
try:
|
134
|
+
serialized_grammar = LLMatcher.grammar_from_json_schema(
|
135
|
+
key_string,
|
136
|
+
defaults={
|
137
|
+
"whitespace_pattern": self.whitespace_pattern,
|
138
|
+
},
|
139
|
+
)
|
140
|
+
except Exception as e:
|
141
|
+
logger.warning(f"Skip invalid grammar: {key_string=}, {e=}")
|
142
|
+
return None
|
139
143
|
return self._from_serialized(serialized_grammar)
|
140
144
|
|
141
145
|
def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]:
|
@@ -53,6 +53,30 @@ class OutlinesGrammar(BaseGrammarObject):
|
|
53
53
|
def accept_token(self, token: int):
|
54
54
|
self.state = self.guide.get_next_state(self.state, token)
|
55
55
|
|
56
|
+
def allocate_vocab_mask(
|
57
|
+
self, vocab_size: int, batch_size: int, device
|
58
|
+
) -> torch.Tensor:
|
59
|
+
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
|
60
|
+
|
61
|
+
@staticmethod
|
62
|
+
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
63
|
+
return vocab_mask
|
64
|
+
|
65
|
+
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
66
|
+
tokens = torch.tensor(
|
67
|
+
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
|
68
|
+
).to(vocab_mask.device, non_blocking=True)
|
69
|
+
vocab_mask = vocab_mask[idx]
|
70
|
+
vocab_mask.fill_(1)
|
71
|
+
vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
|
72
|
+
|
73
|
+
@staticmethod
|
74
|
+
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
|
75
|
+
logits.masked_fill_(vocab_mask, float("-inf"))
|
76
|
+
|
77
|
+
def copy(self):
|
78
|
+
return OutlinesGrammar(self.guide, self.jump_forward_map)
|
79
|
+
|
56
80
|
def try_jump_forward(self, tokenizer) -> Optional[Tuple]:
|
57
81
|
if not self.jump_forward_map:
|
58
82
|
return None
|
@@ -86,30 +110,6 @@ class OutlinesGrammar(BaseGrammarObject):
|
|
86
110
|
):
|
87
111
|
self.state = next_state
|
88
112
|
|
89
|
-
def allocate_vocab_mask(
|
90
|
-
self, vocab_size: int, batch_size: int, device
|
91
|
-
) -> torch.Tensor:
|
92
|
-
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
|
93
|
-
|
94
|
-
@staticmethod
|
95
|
-
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
96
|
-
return vocab_mask
|
97
|
-
|
98
|
-
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
99
|
-
tokens = torch.tensor(
|
100
|
-
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
|
101
|
-
).to(vocab_mask.device, non_blocking=True)
|
102
|
-
vocab_mask = vocab_mask[idx]
|
103
|
-
vocab_mask.fill_(1)
|
104
|
-
vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
|
105
|
-
|
106
|
-
@staticmethod
|
107
|
-
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
|
108
|
-
logits.masked_fill_(vocab_mask, float("-inf"))
|
109
|
-
|
110
|
-
def copy(self):
|
111
|
-
return OutlinesGrammar(self.guide, self.jump_forward_map)
|
112
|
-
|
113
113
|
|
114
114
|
class OutlinesGrammarBackend(BaseGrammarBackend):
|
115
115
|
def __init__(
|
@@ -169,8 +169,9 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
|
|
169
169
|
key_string,
|
170
170
|
whitespace_pattern=self.whitespace_pattern,
|
171
171
|
)
|
172
|
-
except (NotImplementedError, json.decoder.JSONDecodeError) as e:
|
173
|
-
logger.warning(f"Skip invalid json_schema:
|
172
|
+
except (NotImplementedError, json.decoder.JSONDecodeError, ValueError) as e:
|
173
|
+
logger.warning(f"Skip invalid json_schema: {key_string=}, {e=}")
|
174
|
+
return None
|
174
175
|
return self._compile_regex(regex)
|
175
176
|
|
176
177
|
def dispatch_regex(self, key_string: str):
|