sglang 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sglang/bench_latency.py +6 -4
  2. sglang/bench_serving.py +46 -22
  3. sglang/lang/compiler.py +2 -2
  4. sglang/lang/ir.py +3 -3
  5. sglang/srt/constrained/base_tool_cache.py +1 -1
  6. sglang/srt/constrained/fsm_cache.py +12 -2
  7. sglang/srt/layers/activation.py +33 -0
  8. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  9. sglang/srt/layers/extend_attention.py +6 -1
  10. sglang/srt/layers/layernorm.py +65 -0
  11. sglang/srt/layers/logits_processor.py +5 -0
  12. sglang/srt/layers/pooler.py +50 -0
  13. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  14. sglang/srt/layers/radix_attention.py +2 -2
  15. sglang/srt/managers/detokenizer_manager.py +31 -9
  16. sglang/srt/managers/io_struct.py +63 -0
  17. sglang/srt/managers/policy_scheduler.py +173 -25
  18. sglang/srt/managers/schedule_batch.py +110 -87
  19. sglang/srt/managers/tokenizer_manager.py +193 -111
  20. sglang/srt/managers/tp_worker.py +289 -352
  21. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  22. sglang/srt/mem_cache/chunk_cache.py +43 -20
  23. sglang/srt/mem_cache/memory_pool.py +2 -2
  24. sglang/srt/mem_cache/radix_cache.py +74 -40
  25. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  26. sglang/srt/model_executor/forward_batch_info.py +168 -105
  27. sglang/srt/model_executor/model_runner.py +24 -37
  28. sglang/srt/models/gemma2.py +0 -1
  29. sglang/srt/models/internlm2.py +2 -7
  30. sglang/srt/models/llama2.py +4 -4
  31. sglang/srt/models/llama_embedding.py +88 -0
  32. sglang/srt/models/qwen2_moe.py +0 -11
  33. sglang/srt/openai_api/adapter.py +155 -27
  34. sglang/srt/openai_api/protocol.py +37 -1
  35. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  36. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  37. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  38. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  39. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  40. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  41. sglang/srt/sampling_params.py +31 -4
  42. sglang/srt/server.py +69 -15
  43. sglang/srt/server_args.py +26 -19
  44. sglang/srt/utils.py +31 -13
  45. sglang/test/run_eval.py +10 -1
  46. sglang/test/runners.py +63 -63
  47. sglang/test/simple_eval_humaneval.py +2 -8
  48. sglang/test/simple_eval_mgsm.py +203 -0
  49. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  50. sglang/test/test_layernorm.py +60 -0
  51. sglang/test/test_programs.py +4 -2
  52. sglang/test/test_utils.py +20 -2
  53. sglang/utils.py +0 -1
  54. sglang/version.py +1 -1
  55. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/METADATA +23 -14
  56. sglang-0.2.12.dist-info/RECORD +112 -0
  57. sglang/srt/layers/linear.py +0 -884
  58. sglang/srt/layers/quantization/__init__.py +0 -64
  59. sglang/srt/layers/quantization/fp8.py +0 -677
  60. sglang-0.2.11.dist-info/RECORD +0 -102
  61. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
  62. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
  63. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -6,21 +6,15 @@ Mark Chen and Jerry Tworek and Heewoo Jun and Qiming Yuan and Henrique Ponde de
6
6
  https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/
7
7
  """
8
8
 
9
- import json
10
- import logging
11
- import multiprocessing
12
9
  import random
13
10
  import re
14
- from collections import Counter, defaultdict
15
11
  from concurrent.futures import ThreadPoolExecutor, as_completed
16
- from io import BytesIO
17
- from typing import Any, Dict, List, Tuple
12
+ from typing import Dict, List
18
13
 
19
- import blobfile as bf
20
14
  import tqdm
21
15
 
22
16
  try:
23
- from human_eval.data import HUMAN_EVAL, read_problems
17
+ from human_eval.data import read_problems
24
18
  from human_eval.evaluation import estimate_pass_at_k
25
19
  from human_eval.execution import check_correctness # , unsafe_execute
26
20
  except (ImportError, ModuleNotFoundError):
@@ -0,0 +1,203 @@
1
+ # Adapted from https://github.com/openai/simple-evals/
2
+
3
+ """
4
+ MGSM: Multilingual Grade School Math Benchmark (MGSM) is a benchmark of grade-school math problems.
5
+ Language Models are Multilingual Chain-of-Thought Reasoners
6
+ Freda Shi, Mirac Suzgun, Markus Freitag, Xuezhi Wang, Suraj Srivats, Soroush Vosoughi, Hyung Won Chung, Yi Tay, Sebastian Ruder, Denny Zhou, Dipanjan Das, Jason Wei
7
+ https://arxiv.org/abs/2210.03057 reference: https://github.com/google-research/url-nlp
8
+ """
9
+
10
+ import re
11
+ import urllib
12
+ from typing import Optional
13
+
14
+ from sglang.test import simple_eval_common as common
15
+ from sglang.test.simple_eval_common import (
16
+ HTML_JINJA,
17
+ Eval,
18
+ EvalResult,
19
+ SamplerBase,
20
+ SingleEvalResult,
21
+ )
22
+
23
+ ALL_LANGUAGES = ["bn", "de", "en", "es", "fr", "ja", "ru", "sw", "te", "th", "zh"]
24
+ LATIN_LANGUAGES = ["de", "en", "es", "fr", "sw"]
25
+ NON_LATIN_LANGUAGES = ["bn", "ja", "ru", "te", "th", "zh"]
26
+
27
+ LANG_TO_FPATH = {
28
+ "bn": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_bn.tsv",
29
+ "de": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_de.tsv",
30
+ "en": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_en.tsv",
31
+ "es": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_es.tsv",
32
+ "fr": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_fr.tsv",
33
+ "ja": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_ja.tsv",
34
+ "ru": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_ru.tsv",
35
+ "sw": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_sw.tsv",
36
+ "te": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_te.tsv",
37
+ "th": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_th.tsv",
38
+ "zh": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_zh.tsv",
39
+ }
40
+ LANG_TO_INSTRUCTIONS = {
41
+ "en": """Solve this math problem. Give the reasoning steps before giving the final answer on the last line by itself in the format of "Answer:". Do not add anything other than the integer answer after "Answer:".
42
+
43
+ {input}""",
44
+ "bn": """এই গণিতের সমস্যাটি সমাধান করুন। চূড়ান্ত উত্তর দেওয়ার আগে যুক্তিসম্পন্ন পদক্ষেপ প্রদান করুন। চূড়ান্ত উত্তরটি একক সংখ্যা হিসাবে "উত্তর:" এর পরে শেষ লাইনে দিন। "উত্তর:" এর পরে অন্য কিছু যুক্ত করবেন না।.
45
+
46
+ {input}""",
47
+ "de": """Löse dieses Mathematikproblem. Gib die Schritte zur Begründung an, bevor du die endgültige Antwort in der letzten Zeile alleine im Format "Antwort:" gibst. Füge nichts anderes als die ganzzahlige Antwort nach "Antwort:" hinzu.
48
+
49
+ {input}""",
50
+ "es": """Resuelve este problema matemático. Proporciona los pasos de razonamiento antes de dar la respuesta final en la última línea por sí misma en el formato de "Respuesta:". No añadas nada más que la respuesta entera después de "Respuesta:".
51
+
52
+ {input}""",
53
+ "fr": """Résolvez ce problème de mathématiques. Donnez les étapes de raisonnement avant de fournir la réponse finale sur la dernière ligne elle-même dans le format de "Réponse:". N'ajoutez rien d'autre que la réponse entière après "Réponse:".
54
+
55
+ {input}""",
56
+ "ja": """の数学の問題を解いてください。最終的な答えを出す前に、解答の推論過程を記述してください。そして最後の行には "答え:" の形式で答えを記述し、その後には整数の答え以外何も追加しないでください。
57
+
58
+ {input}""",
59
+ "ru": """Решите эту математическую задачу. Объясните шаги рассуждения перед тем, как дать окончательный ответ в последней строке сам по себе в формате "Ответ:". Не добавляйте ничего, кроме целочисленного ответа после "Ответ:".
60
+
61
+ {input}""",
62
+ "sw": """Suluhisha tatizo hili la hesabu. Toa hatua za mantiki kabla ya kutoa jibu la mwisho kwenye mstari wa mwisho peke yake katika muundo wa "Jibu:". Usiongeze chochote kingine isipokuwa jibu la integer baada ya "Jibu:".
63
+
64
+ {input}""",
65
+ "te": """ఈ గణిత సమస్యను పరిష్కరించండి. చివరి సమాధానాన్ని ఇవ్వదానికి ముందు తర్కాత్మక అదుగులను ఇవ్వండి. చివరి పంక్తిలో మాత్రమే 'సమాధానం:' అనే ఆకారంలో చివరి సమాధానాద్ని ఇవ్వండి సమాధానం: తర్వాత పూర్ణాంక సమాధానానికి తప్పించి ఎదేనా చేర్చవద్దు.
66
+
67
+ {input}""",
68
+ "th": """แก้ปัญหาคณิตศาสตร์นี้ ให้ให้ขั้นตอนการใช้เหตุผลก่อนที่จะให้คำตอบสุดท้ายในบรรทัดสุดท้ายโดยอยู่ในรูปแบบ "คำตอบ:" ไม่ควรเพิ่มอะไรนอกจากคำตอบที่เป็นจำนวนเต็มหลังจาก "คำตอบ:"
69
+
70
+ {input}""",
71
+ "zh": """解决这个数学问题。在最后一行给出答案前,请提供推理步骤。最后一行应该以 "答案: " 的形式独立给出答案。在 "答案:" 后不要添加除整数答案之外的任何内容。
72
+
73
+ {input}""",
74
+ }
75
+
76
+ LANG_TO_ANSWER_PREFIX = {
77
+ "en": "Answer",
78
+ "bn": "উত্তর",
79
+ "de": "Antwort",
80
+ "es": "Respuesta",
81
+ "fr": "Réponse",
82
+ "ja": "答え",
83
+ "ru": "Ответ",
84
+ "sw": "Jibu",
85
+ "te": "సమాధానం",
86
+ "th": "คำตอบ",
87
+ "zh": "答案",
88
+ }
89
+
90
+
91
+ def parse_answer(answer: str, answer_prefix: str) -> str:
92
+ if answer_prefix not in answer:
93
+ return ""
94
+
95
+ answer_text = answer.split(answer_prefix)[-1].strip()
96
+
97
+ # find all the numbers (including decimals) in the string
98
+ numbers = re.findall(r"\d+\.?\d*", answer_text.replace(",", ""))
99
+
100
+ # return the first number (removing trailing decimal point if present),
101
+ # or an empty string if there were no numbers
102
+ return numbers[-1].rstrip(".") if numbers else ""
103
+
104
+
105
+ def score_mgsm(target: str, prediction: str) -> bool:
106
+ if "." in prediction:
107
+ prediction = prediction.rstrip("0").rstrip(".")
108
+
109
+ target = target.replace(",", "")
110
+ prediction = prediction.replace(",", "")
111
+
112
+ return target == prediction
113
+
114
+
115
+ def get_lang_examples(lang: str) -> list[dict[str, str]]:
116
+ fpath = LANG_TO_FPATH[lang]
117
+ examples = []
118
+ with urllib.request.urlopen(fpath) as f:
119
+ for line in f.read().decode("utf-8").splitlines():
120
+ inputs, targets = line.strip().split("\t")
121
+ if "." in targets:
122
+ raise ValueError(f"targets {targets} contains a decimal point.")
123
+ # targets = int(targets.replace(",", ""))
124
+ examples.append({"inputs": inputs, "targets": targets, "lang": lang})
125
+ return examples
126
+
127
+
128
+ def get_all_examples() -> list[dict[str, str]]:
129
+ examples = []
130
+ for lang in ALL_LANGUAGES:
131
+ if lang != "en":
132
+ continue
133
+ examples += get_lang_examples(lang)
134
+ return examples
135
+
136
+
137
+ class MGSMEval(Eval):
138
+ def __init__(
139
+ self,
140
+ num_examples_per_lang: int = 250, # restrict to a subset of the data for debugging
141
+ num_threads: int = 64,
142
+ languages: Optional[list[str]] = ALL_LANGUAGES,
143
+ ):
144
+ if languages is None:
145
+ languages = ALL_LANGUAGES
146
+ else:
147
+ for language in languages:
148
+ if language not in ALL_LANGUAGES:
149
+ raise ValueError(
150
+ f"language {language} is not a valid language. "
151
+ f"It should be one in {ALL_LANGUAGES}"
152
+ )
153
+ self._languages = languages
154
+ self._num_examples_per_lang = num_examples_per_lang
155
+ self._num_threads = num_threads
156
+
157
+ examples = []
158
+ for lang in self._languages:
159
+ lang_examples = get_lang_examples(lang)
160
+ examples.extend(lang_examples[: self._num_examples_per_lang])
161
+ self.examples = examples
162
+
163
+ def __call__(self, sampler: SamplerBase) -> EvalResult:
164
+ def fn(example: dict[str, str]):
165
+ language = example["lang"]
166
+ latin_language = (
167
+ "group_latin" if language in LATIN_LANGUAGES else "group_non_latin"
168
+ )
169
+ correct_answer = example["targets"]
170
+ instructoin = LANG_TO_INSTRUCTIONS[language]
171
+ prompt_messages = [
172
+ sampler._pack_message(
173
+ content=instructoin.format(input=example["inputs"]), role="user"
174
+ )
175
+ ]
176
+ try:
177
+ response_text = sampler(prompt_messages)
178
+ except Exception as e:
179
+ response_text = ""
180
+
181
+ answer_prefix = LANG_TO_ANSWER_PREFIX[language]
182
+ extracted_answer = parse_answer(response_text, answer_prefix)
183
+
184
+ score = score_mgsm(correct_answer, extracted_answer)
185
+ html = common.jinja_env.from_string(HTML_JINJA).render(
186
+ prompt_messages=prompt_messages,
187
+ next_message=dict(content=response_text, role="assistant"),
188
+ score=score,
189
+ correct_answer=correct_answer,
190
+ extracted_answer=extracted_answer,
191
+ )
192
+ convo = prompt_messages + [dict(content=response_text, role="assistant")]
193
+ return SingleEvalResult(
194
+ html=html,
195
+ score=score,
196
+ convo=convo,
197
+ metrics={language: score, latin_language: score},
198
+ )
199
+
200
+ results = common.map_with_progress(
201
+ fn, self.examples, num_threads=self._num_threads
202
+ )
203
+ return common.aggregate_results(results, default_stats=("mean", "std"))
@@ -0,0 +1,337 @@
1
+ import dataclasses
2
+ import enum
3
+ import typing
4
+ import unittest
5
+
6
+ import torch
7
+
8
+ from sglang.srt.sampling.penaltylib.orchestrator import (
9
+ BatchedPenalizerOrchestrator,
10
+ _BatchedPenalizer,
11
+ _BatchLike,
12
+ )
13
+
14
+
15
+ @dataclasses.dataclass
16
+ class MockSamplingParams:
17
+ frequency_penalty: float = 0.0
18
+ min_new_tokens: int = 0
19
+ stop_token_ids: typing.List[int] = None
20
+ presence_penalty: float = 0.0
21
+ repetition_penalty: float = 1.0
22
+
23
+
24
+ @dataclasses.dataclass
25
+ class MockTokenizer:
26
+ eos_token_id: int
27
+
28
+
29
+ @dataclasses.dataclass
30
+ class MockReq:
31
+ origin_input_ids: typing.List[int]
32
+ sampling_params: MockSamplingParams
33
+ tokenizer: MockTokenizer
34
+
35
+
36
+ class StepType(enum.Enum):
37
+ INPUT = "input"
38
+ OUTPUT = "output"
39
+
40
+
41
+ @dataclasses.dataclass
42
+ class Step:
43
+ type: StepType
44
+ token_ids: typing.List[int]
45
+ expected_tensors: typing.Dict[str, torch.Tensor]
46
+ # assume initial logits are all 1
47
+ expected_logits: torch.Tensor
48
+
49
+
50
+ @dataclasses.dataclass
51
+ class Subject:
52
+ sampling_params: MockSamplingParams
53
+ # first step must be input, which will be converted to Req
54
+ steps: typing.List[Step]
55
+ eos_token_id: int = -1
56
+
57
+ def __post_init__(self):
58
+ if self.steps[0].type != StepType.INPUT:
59
+ raise ValueError("First step must be input")
60
+
61
+ # each steps should have the same expected_tensors.keys()
62
+ for i in range(1, len(self.steps)):
63
+ if self.tensor_keys(i) != self.tensor_keys():
64
+ raise ValueError(
65
+ f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}"
66
+ )
67
+
68
+ def tensor_keys(self, i: int = 0) -> typing.Set[str]:
69
+ return set(self.steps[i].expected_tensors.keys())
70
+
71
+ def to_req(self) -> MockReq:
72
+ return MockReq(
73
+ origin_input_ids=self.steps[0].token_ids,
74
+ sampling_params=self.sampling_params,
75
+ tokenizer=MockTokenizer(eos_token_id=self.eos_token_id),
76
+ )
77
+
78
+
79
+ @dataclasses.dataclass
80
+ class Case:
81
+ enabled: bool
82
+ test_subjects: typing.List[Subject]
83
+
84
+ def __post_init__(self):
85
+ # each test_subjects.steps should have the same expected_tensors.keys()
86
+ for i in range(1, len(self.test_subjects)):
87
+ if self.tensor_keys(i) != self.tensor_keys():
88
+ raise ValueError(
89
+ f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}"
90
+ )
91
+
92
+ def tensor_keys(self, i: int = 0) -> typing.List[str]:
93
+ return set(self.test_subjects[i].tensor_keys())
94
+
95
+
96
+ class BaseBatchedPenalizerTest(unittest.TestCase):
97
+ Penalizer: typing.Type[_BatchedPenalizer]
98
+ device = "cuda"
99
+ vocab_size = 5
100
+
101
+ enabled: Subject = None
102
+ disabled: Subject = None
103
+
104
+ def setUp(self):
105
+ if self.__class__ == BaseBatchedPenalizerTest:
106
+ self.skipTest("Base class for penalizer tests")
107
+
108
+ self.create_test_subjects()
109
+ self.create_test_cases()
110
+
111
+ def tensor(self, data, **kwargs) -> torch.Tensor:
112
+ """
113
+ Shortcut to create a tensor with device=self.device.
114
+ """
115
+ return torch.tensor(data, **kwargs, device=self.device)
116
+
117
+ def create_test_subjects(self) -> typing.List[Subject]:
118
+ raise NotImplementedError()
119
+
120
+ def create_test_cases(self):
121
+ self.test_cases = [
122
+ Case(enabled=True, test_subjects=[self.enabled]),
123
+ Case(enabled=False, test_subjects=[self.disabled]),
124
+ Case(enabled=True, test_subjects=[self.enabled, self.disabled]),
125
+ ]
126
+
127
+ def _create_penalizer(
128
+ self, case: Case
129
+ ) -> typing.Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]:
130
+ orchestrator = BatchedPenalizerOrchestrator(
131
+ vocab_size=self.vocab_size,
132
+ batch=_BatchLike(reqs=[subject.to_req() for subject in case.test_subjects]),
133
+ device=self.device,
134
+ Penalizers={self.Penalizer},
135
+ )
136
+
137
+ return orchestrator, orchestrator.penalizers[self.Penalizer]
138
+
139
+ def test_is_required(self):
140
+ for case in self.test_cases:
141
+ with self.subTest(case=case):
142
+ _, penalizer = self._create_penalizer(case)
143
+ self.assertEqual(case.enabled, penalizer.is_required())
144
+
145
+ def test_prepare(self):
146
+ for case in self.test_cases:
147
+ with self.subTest(case=case):
148
+ orchestrator, penalizer = self._create_penalizer(case)
149
+ self.assertEqual(case.enabled, penalizer.is_prepared())
150
+
151
+ if case.enabled:
152
+ for key, tensor in {
153
+ key: torch.cat(
154
+ tensors=[
155
+ subject.steps[0].expected_tensors[key]
156
+ for subject in case.test_subjects
157
+ ],
158
+ )
159
+ for key in case.tensor_keys()
160
+ }.items():
161
+ torch.testing.assert_close(
162
+ actual=getattr(penalizer, key),
163
+ expected=tensor,
164
+ msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
165
+ )
166
+
167
+ actual = orchestrator.apply(
168
+ torch.ones(
169
+ size=(len(case.test_subjects), self.vocab_size),
170
+ dtype=torch.float32,
171
+ device=self.device,
172
+ )
173
+ )
174
+ expected = torch.cat(
175
+ tensors=[
176
+ subject.steps[0].expected_logits
177
+ for subject in case.test_subjects
178
+ ],
179
+ )
180
+ torch.testing.assert_close(
181
+ actual=actual,
182
+ expected=expected,
183
+ msg=f"logits\nactual={actual}\nexpected={expected}",
184
+ )
185
+
186
+ def test_teardown(self):
187
+ for case in self.test_cases:
188
+ with self.subTest(case=case):
189
+ _, penalizer = self._create_penalizer(case)
190
+ penalizer.teardown()
191
+
192
+ for key in case.test_subjects[0].steps[0].expected_tensors.keys():
193
+ self.assertIsNone(getattr(penalizer, key, None))
194
+
195
+ def test_filter(self):
196
+ for case in self.test_cases:
197
+ with self.subTest(case=case):
198
+ orchestrator, penalizer = self._create_penalizer(case)
199
+
200
+ indices_to_keep = [0]
201
+ orchestrator.filter(indices_to_keep=indices_to_keep)
202
+
203
+ filtered_subjects = [case.test_subjects[i] for i in indices_to_keep]
204
+
205
+ if penalizer.is_required():
206
+ self.assertTrue(penalizer.is_prepared())
207
+ for key, tensor in {
208
+ key: torch.cat(
209
+ tensors=[
210
+ subject.steps[0].expected_tensors[key]
211
+ for subject in filtered_subjects
212
+ ],
213
+ )
214
+ for key in case.tensor_keys()
215
+ }.items():
216
+ torch.testing.assert_close(
217
+ actual=getattr(penalizer, key),
218
+ expected=tensor,
219
+ msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
220
+ )
221
+
222
+ actual_logits = orchestrator.apply(
223
+ torch.ones(
224
+ size=(len(filtered_subjects), self.vocab_size),
225
+ dtype=torch.float32,
226
+ device=self.device,
227
+ )
228
+ )
229
+ filtered_expected_logits = torch.cat(
230
+ tensors=[
231
+ subject.steps[0].expected_logits
232
+ for subject in filtered_subjects
233
+ ],
234
+ )
235
+ torch.testing.assert_close(
236
+ actual=actual_logits,
237
+ expected=filtered_expected_logits,
238
+ msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}",
239
+ )
240
+
241
+ def test_merge_enabled_with_disabled(self):
242
+ enabled_test_case = self.test_cases[0]
243
+ disabled_test_case = self.test_cases[1]
244
+
245
+ orchestrator, penalizer = self._create_penalizer(enabled_test_case)
246
+ theirs, _ = self._create_penalizer(disabled_test_case)
247
+
248
+ orchestrator.merge(theirs)
249
+
250
+ for key, tensor in {
251
+ key: torch.cat(
252
+ tensors=[
253
+ enabled_test_case.test_subjects[0].steps[0].expected_tensors[key],
254
+ disabled_test_case.test_subjects[0].steps[0].expected_tensors[key],
255
+ ],
256
+ )
257
+ for key in enabled_test_case.tensor_keys()
258
+ }.items():
259
+ torch.testing.assert_close(
260
+ actual=getattr(penalizer, key),
261
+ expected=tensor,
262
+ msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
263
+ )
264
+
265
+ def test_cumulate_apply_repeat(self):
266
+ for case in self.test_cases:
267
+ with self.subTest(case=case):
268
+ orchestrator, penalizer = self._create_penalizer(case)
269
+
270
+ max_step = max(len(subject.steps) for subject in case.test_subjects)
271
+ for i in range(1, max_step):
272
+ orchestrator.filter(
273
+ indices_to_keep=[
274
+ j
275
+ for j, subject in enumerate(case.test_subjects)
276
+ if i < len(subject.steps)
277
+ ]
278
+ )
279
+
280
+ filtered_subjects = [
281
+ subject
282
+ for subject in case.test_subjects
283
+ if i < len(subject.steps)
284
+ ]
285
+
286
+ inputs: typing.List[typing.List[int]] = []
287
+ outputs: typing.List[typing.List[int]] = []
288
+ for subject in filtered_subjects:
289
+ step = subject.steps[i]
290
+ if step.type == StepType.INPUT:
291
+ inputs.append(step.token_ids)
292
+ outputs.append([])
293
+ else:
294
+ inputs.append([])
295
+ outputs.append(step.token_ids)
296
+
297
+ if any(inputs):
298
+ orchestrator.cumulate_input_tokens(inputs)
299
+
300
+ if any(outputs):
301
+ orchestrator.cumulate_output_tokens(outputs)
302
+
303
+ if penalizer.is_required():
304
+ self.assertTrue(penalizer.is_prepared())
305
+ for key, tensor in {
306
+ key: torch.cat(
307
+ tensors=[
308
+ subject.steps[i].expected_tensors[key]
309
+ for subject in filtered_subjects
310
+ ],
311
+ )
312
+ for key in case.tensor_keys()
313
+ }.items():
314
+ torch.testing.assert_close(
315
+ actual=getattr(penalizer, key),
316
+ expected=tensor,
317
+ msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
318
+ )
319
+
320
+ actual_logits = orchestrator.apply(
321
+ torch.ones(
322
+ size=(len(filtered_subjects), self.vocab_size),
323
+ dtype=torch.float32,
324
+ device=self.device,
325
+ )
326
+ )
327
+ filtered_expected_logits = torch.cat(
328
+ tensors=[
329
+ subject.steps[i].expected_logits
330
+ for subject in filtered_subjects
331
+ ],
332
+ )
333
+ torch.testing.assert_close(
334
+ actual=actual_logits,
335
+ expected=filtered_expected_logits,
336
+ msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}",
337
+ )
@@ -0,0 +1,60 @@
1
+ import itertools
2
+ import unittest
3
+
4
+ import torch
5
+
6
+ from sglang.srt.layers.layernorm import RMSNorm
7
+
8
+
9
+ class TestRMSNorm(unittest.TestCase):
10
+ DTYPES = [torch.half, torch.bfloat16]
11
+ NUM_TOKENS = [7, 83, 4096]
12
+ HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199]
13
+ ADD_RESIDUAL = [False, True]
14
+ SEEDS = [0]
15
+
16
+ @classmethod
17
+ def setUpClass(cls):
18
+ if not torch.cuda.is_available():
19
+ raise unittest.SkipTest("CUDA is not available")
20
+ torch.set_default_device("cuda")
21
+
22
+ def _run_rms_norm_test(self, num_tokens, hidden_size, add_residual, dtype, seed):
23
+ torch.manual_seed(seed)
24
+
25
+ layer = RMSNorm(hidden_size).to(dtype=dtype)
26
+ layer.weight.data.normal_(mean=1.0, std=0.1)
27
+ scale = 1 / (2 * hidden_size)
28
+ x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
29
+ residual = torch.randn_like(x) * scale if add_residual else None
30
+
31
+ with torch.inference_mode():
32
+ ref_out = layer.forward_native(x, residual)
33
+ out = layer(x, residual)
34
+
35
+ if add_residual:
36
+ self.assertTrue(torch.allclose(out[0], ref_out[0], atol=1e-2, rtol=1e-2))
37
+ self.assertTrue(torch.allclose(out[1], ref_out[1], atol=1e-2, rtol=1e-2))
38
+ else:
39
+ self.assertTrue(torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2))
40
+
41
+ def test_rms_norm(self):
42
+ for params in itertools.product(
43
+ self.NUM_TOKENS,
44
+ self.HIDDEN_SIZES,
45
+ self.ADD_RESIDUAL,
46
+ self.DTYPES,
47
+ self.SEEDS,
48
+ ):
49
+ with self.subTest(
50
+ num_tokens=params[0],
51
+ hidden_size=params[1],
52
+ add_residual=params[2],
53
+ dtype=params[3],
54
+ seed=params[4],
55
+ ):
56
+ self._run_rms_norm_test(*params)
57
+
58
+
59
+ if __name__ == "__main__":
60
+ unittest.main(verbosity=2)
@@ -149,7 +149,7 @@ def test_decode_json():
149
149
  assert isinstance(js_obj["population"], int)
150
150
 
151
151
 
152
- def test_expert_answer():
152
+ def test_expert_answer(check_answer=True):
153
153
  @sgl.function
154
154
  def expert_answer(s, question):
155
155
  s += "Question: " + question + "\n"
@@ -167,7 +167,9 @@ def test_expert_answer():
167
167
  )
168
168
 
169
169
  ret = expert_answer.run(question="What is the capital of France?", temperature=0.1)
170
- assert "paris" in ret.text().lower()
170
+
171
+ if check_answer:
172
+ assert "paris" in ret.text().lower(), f"Answer: {ret.text()}"
171
173
 
172
174
 
173
175
  def test_tool_use():