sglang 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/api.py +7 -1
  2. sglang/bench_latency.py +9 -6
  3. sglang/bench_serving.py +46 -22
  4. sglang/global_config.py +1 -1
  5. sglang/lang/backend/runtime_endpoint.py +60 -49
  6. sglang/lang/compiler.py +2 -2
  7. sglang/lang/interpreter.py +4 -2
  8. sglang/lang/ir.py +16 -7
  9. sglang/srt/constrained/base_tool_cache.py +1 -1
  10. sglang/srt/constrained/fsm_cache.py +12 -2
  11. sglang/srt/constrained/jump_forward.py +13 -2
  12. sglang/srt/layers/activation.py +32 -0
  13. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  14. sglang/srt/layers/extend_attention.py +9 -2
  15. sglang/srt/layers/fused_moe/__init__.py +1 -0
  16. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  17. sglang/srt/layers/fused_moe/layer.py +587 -0
  18. sglang/srt/layers/layernorm.py +65 -0
  19. sglang/srt/layers/logits_processor.py +7 -2
  20. sglang/srt/layers/pooler.py +50 -0
  21. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  22. sglang/srt/layers/radix_attention.py +40 -16
  23. sglang/srt/managers/detokenizer_manager.py +31 -9
  24. sglang/srt/managers/io_struct.py +63 -0
  25. sglang/srt/managers/policy_scheduler.py +173 -25
  26. sglang/srt/managers/schedule_batch.py +115 -97
  27. sglang/srt/managers/tokenizer_manager.py +194 -112
  28. sglang/srt/managers/tp_worker.py +290 -359
  29. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  30. sglang/srt/mem_cache/chunk_cache.py +43 -20
  31. sglang/srt/mem_cache/memory_pool.py +2 -2
  32. sglang/srt/mem_cache/radix_cache.py +74 -40
  33. sglang/srt/model_executor/cuda_graph_runner.py +71 -25
  34. sglang/srt/model_executor/forward_batch_info.py +293 -156
  35. sglang/srt/model_executor/model_runner.py +77 -57
  36. sglang/srt/models/chatglm.py +2 -2
  37. sglang/srt/models/commandr.py +1 -1
  38. sglang/srt/models/deepseek.py +2 -2
  39. sglang/srt/models/deepseek_v2.py +7 -6
  40. sglang/srt/models/gemma.py +1 -1
  41. sglang/srt/models/gemma2.py +11 -6
  42. sglang/srt/models/grok.py +50 -396
  43. sglang/srt/models/internlm2.py +2 -7
  44. sglang/srt/models/llama2.py +4 -4
  45. sglang/srt/models/llama_embedding.py +88 -0
  46. sglang/srt/models/minicpm.py +2 -2
  47. sglang/srt/models/mixtral.py +56 -254
  48. sglang/srt/models/mixtral_quant.py +1 -4
  49. sglang/srt/models/qwen.py +2 -2
  50. sglang/srt/models/qwen2.py +2 -2
  51. sglang/srt/models/qwen2_moe.py +2 -13
  52. sglang/srt/models/stablelm.py +1 -1
  53. sglang/srt/openai_api/adapter.py +187 -48
  54. sglang/srt/openai_api/protocol.py +37 -1
  55. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  56. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  57. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  58. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  59. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  60. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  61. sglang/srt/sampling_params.py +31 -8
  62. sglang/srt/server.py +91 -29
  63. sglang/srt/server_args.py +32 -19
  64. sglang/srt/utils.py +32 -15
  65. sglang/test/run_eval.py +10 -1
  66. sglang/test/runners.py +81 -73
  67. sglang/test/simple_eval_humaneval.py +2 -8
  68. sglang/test/simple_eval_mgsm.py +203 -0
  69. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  70. sglang/test/test_layernorm.py +60 -0
  71. sglang/test/test_programs.py +36 -7
  72. sglang/test/test_utils.py +24 -2
  73. sglang/utils.py +0 -1
  74. sglang/version.py +1 -1
  75. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/METADATA +33 -16
  76. sglang-0.2.13.dist-info/RECORD +112 -0
  77. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
  78. sglang/srt/layers/linear.py +0 -884
  79. sglang/srt/layers/quantization/__init__.py +0 -64
  80. sglang/srt/layers/quantization/fp8.py +0 -677
  81. sglang/srt/model_loader/model_loader.py +0 -292
  82. sglang/srt/model_loader/utils.py +0 -275
  83. sglang-0.2.11.dist-info/RECORD +0 -102
  84. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
  85. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
sglang/test/runners.py CHANGED
@@ -15,6 +15,7 @@ limitations under the License.
15
15
 
16
16
  import json
17
17
  import multiprocessing
18
+ import os
18
19
  from dataclasses import dataclass
19
20
  from typing import List, Union
20
21
 
@@ -23,21 +24,23 @@ import torch.nn.functional as F
23
24
  from transformers import AutoModelForCausalLM, AutoTokenizer
24
25
 
25
26
  from sglang.srt.server import Runtime
27
+ from sglang.srt.utils import is_generation_model
26
28
 
27
29
  DEFAULT_PROMPTS = [
28
- "The capital of France is",
30
+ # the output of gemma-2-2b from SRT is unstable on the commented prompt
31
+ # "The capital of France is",
29
32
  "The capital of the United Kindom is",
30
33
  "Today is a sunny day and I like",
34
+ "AI is a field of computer science focused on",
35
+ "Apple is red. Banana is Yellow. " * 800 + "Apple is",
31
36
  ]
32
37
 
33
- NUM_TOP_LOGPROBS = 5
34
-
38
+ dirpath = os.path.dirname(__file__)
39
+ with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
40
+ long_prompt = f.read()
41
+ DEFAULT_PROMPTS.append(long_prompt)
35
42
 
36
- def is_embedding_model(model_path):
37
- # FIXME incomplete list
38
- if "e5-mistral-7b-instruct" in model_path.lower():
39
- return True
40
- return False
43
+ NUM_TOP_LOGPROBS = 5
41
44
 
42
45
 
43
46
  def get_dtype_str(torch_dtype):
@@ -49,10 +52,11 @@ def get_dtype_str(torch_dtype):
49
52
 
50
53
  @dataclass
51
54
  class ModelOutput:
52
- output_strs: str = None
53
- top_input_logprobs: torch.Tensor = None
54
- top_output_logprobs: torch.Tensor = None
55
- embed_logits: torch.Tensor = None
55
+ output_strs: List[str] = None
56
+ output_ids: List[int] = None
57
+ top_input_logprobs: List[torch.Tensor] = None
58
+ top_output_logprobs: List[torch.Tensor] = None
59
+ embed_logits: List[torch.Tensor] = None
56
60
 
57
61
 
58
62
  class HFRunner:
@@ -60,7 +64,7 @@ class HFRunner:
60
64
  self,
61
65
  model_path,
62
66
  torch_dtype=torch.float16,
63
- is_embedding_model=None,
67
+ is_generation_model=None,
64
68
  ):
65
69
  self.in_queue = multiprocessing.Queue()
66
70
  self.out_queue = multiprocessing.Queue()
@@ -72,13 +76,13 @@ class HFRunner:
72
76
  self.out_queue,
73
77
  model_path,
74
78
  torch_dtype,
75
- is_embedding_model,
79
+ is_generation_model,
76
80
  ),
77
81
  )
78
82
  self.model_proc.start()
79
83
 
80
84
  def start_model_process(
81
- self, in_queue, out_queue, model_path, torch_dtype, is_embedding_model
85
+ self, in_queue, out_queue, model_path, torch_dtype, is_generation_model
82
86
  ):
83
87
  self.tokenizer = AutoTokenizer.from_pretrained(
84
88
  model_path,
@@ -86,12 +90,12 @@ class HFRunner:
86
90
  trust_remote_code=True,
87
91
  )
88
92
 
89
- self.is_embedding_model = (
90
- is_embedding_model(model_path)
91
- if is_embedding_model is None
92
- else is_embedding_model
93
+ self.is_generation_model = (
94
+ is_generation_model(model_path)
95
+ if is_generation_model is None
96
+ else is_generation_model
93
97
  )
94
- if not self.is_embedding_model:
98
+ if self.is_generation_model:
95
99
  self.model = AutoModelForCausalLM.from_pretrained(
96
100
  model_path,
97
101
  torch_dtype=torch_dtype,
@@ -103,13 +107,13 @@ class HFRunner:
103
107
 
104
108
  self.model = SentenceTransformer(
105
109
  model_path,
106
- device="cpu",
107
- ).to(dtype=torch_dtype)
110
+ model_kwargs={"torch_dtype": torch_dtype},
111
+ )
108
112
 
109
113
  while True:
110
114
  prompts, max_new_tokens = in_queue.get()
111
115
  if prompts is not None:
112
- if not self.is_embedding_model:
116
+ if self.is_generation_model:
113
117
  output_strs = []
114
118
  prefill_logprobs = []
115
119
  for p in prompts:
@@ -123,19 +127,19 @@ class HFRunner:
123
127
  output_ids = self.model.generate(
124
128
  input_ids, do_sample=False, max_new_tokens=max_new_tokens
125
129
  )
126
- output_strs.append(self.tokenizer.decode(output_ids[0]))
130
+ output_strs.append(
131
+ self.tokenizer.decode(output_ids[0][len(input_ids[0]) :])
132
+ )
127
133
 
128
134
  logits = self.model.forward(input_ids).logits[0]
129
- logprobs = F.log_softmax(
130
- logits, dim=-1, dtype=torch.float32
131
- ).tolist()
132
- # index_of_max = (lambda nums: nums.index(max(nums)))(logprobs[-1])
133
- # print("index", index_of_max)
134
- logprobs = [
135
- sorted(token_logprobs, reverse=True)[:NUM_TOP_LOGPROBS]
136
- for token_logprobs in logprobs
137
- ]
138
- prefill_logprobs.append(logprobs)
135
+ logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
136
+ logprobs, top_indices = torch.topk(
137
+ logprobs, k=NUM_TOP_LOGPROBS, dim=-1
138
+ )
139
+ # print("index", top_indices)
140
+ prefill_logprobs.append(logprobs.tolist())
141
+ del logits
142
+ del logprobs
139
143
 
140
144
  out_queue.put(
141
145
  ModelOutput(
@@ -144,7 +148,6 @@ class HFRunner:
144
148
  )
145
149
 
146
150
  else:
147
- assert isinstance(prompts, List[str])
148
151
  logits = self.model.encode(prompts).tolist()
149
152
 
150
153
  out_queue.put(ModelOutput(embed_logits=logits))
@@ -152,7 +155,7 @@ class HFRunner:
152
155
  def forward(
153
156
  self,
154
157
  prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
155
- max_new_tokens=64,
158
+ max_new_tokens=8,
156
159
  ):
157
160
  self.in_queue.put((prompts, max_new_tokens))
158
161
  return self.out_queue.get()
@@ -175,59 +178,64 @@ class SRTRunner:
175
178
  model_path,
176
179
  tp_size=1,
177
180
  torch_dtype=torch.float16,
178
- is_embedding_model=None,
181
+ is_generation_model=None,
182
+ port=5157,
179
183
  ):
180
- self.is_embedding_model = (
181
- is_embedding_model(model_path)
182
- if is_embedding_model is None
183
- else is_embedding_model
184
+ self.is_generation_model = (
185
+ is_generation_model(model_path)
186
+ if is_generation_model is None
187
+ else is_generation_model
184
188
  )
185
- if self.is_embedding_model:
186
- raise NotImplementedError()
187
-
188
189
  self.runtime = Runtime(
189
190
  model_path=model_path,
190
191
  tp_size=tp_size,
191
192
  dtype=get_dtype_str(torch_dtype),
193
+ port=port,
194
+ mem_fraction_static=0.7,
192
195
  )
193
196
 
194
197
  def forward(
195
198
  self,
196
199
  prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
197
- max_new_tokens=64,
200
+ max_new_tokens=8,
198
201
  ):
199
- # the return value contains logprobs from prefill
200
- output_strs = []
201
- top_input_logprobs = []
202
- sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
203
- for prompt in prompts:
204
- response = self.runtime.generate(
205
- prompt,
206
- sampling_params=sampling_params,
207
- return_logprob=True,
208
- top_logprobs_num=NUM_TOP_LOGPROBS,
209
- )
210
- response = json.loads(response)
211
- output_strs.append(response["text"])
212
- top_input_logprobs.append(
213
- [
214
- [tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
215
- for x in response["meta_info"]["input_top_logprobs"][1:]
216
- ]
217
- + [
202
+ if self.is_generation_model:
203
+ # the return value contains logprobs from prefill
204
+ output_strs = []
205
+ top_input_logprobs = []
206
+ sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
207
+ for prompt in prompts:
208
+ response = self.runtime.generate(
209
+ prompt,
210
+ sampling_params=sampling_params,
211
+ return_logprob=True,
212
+ top_logprobs_num=NUM_TOP_LOGPROBS,
213
+ )
214
+ response = json.loads(response)
215
+ output_strs.append(response["text"])
216
+ top_input_logprobs.append(
218
217
  [
219
- tup[0]
220
- for tup in response["meta_info"]["output_top_logprobs"][0][
221
- :NUM_TOP_LOGPROBS
218
+ [tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
219
+ for x in response["meta_info"]["input_top_logprobs"][1:]
220
+ ]
221
+ + [
222
+ [
223
+ tup[0]
224
+ for tup in response["meta_info"]["output_top_logprobs"][0][
225
+ :NUM_TOP_LOGPROBS
226
+ ]
222
227
  ]
223
228
  ]
224
- ]
225
- )
226
- # print(response["meta_info"]["output_top_logprobs"][0])
229
+ )
227
230
 
228
- return ModelOutput(
229
- output_strs=output_strs, top_input_logprobs=top_input_logprobs
230
- )
231
+ return ModelOutput(
232
+ output_strs=output_strs, top_input_logprobs=top_input_logprobs
233
+ )
234
+ else:
235
+ response = self.runtime.encode(prompts)
236
+ response = json.loads(response)
237
+ logits = [x["embedding"] for x in response]
238
+ return ModelOutput(embed_logits=logits)
231
239
 
232
240
  def __enter__(self):
233
241
  return self
@@ -6,21 +6,15 @@ Mark Chen and Jerry Tworek and Heewoo Jun and Qiming Yuan and Henrique Ponde de
6
6
  https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/
7
7
  """
8
8
 
9
- import json
10
- import logging
11
- import multiprocessing
12
9
  import random
13
10
  import re
14
- from collections import Counter, defaultdict
15
11
  from concurrent.futures import ThreadPoolExecutor, as_completed
16
- from io import BytesIO
17
- from typing import Any, Dict, List, Tuple
12
+ from typing import Dict, List
18
13
 
19
- import blobfile as bf
20
14
  import tqdm
21
15
 
22
16
  try:
23
- from human_eval.data import HUMAN_EVAL, read_problems
17
+ from human_eval.data import read_problems
24
18
  from human_eval.evaluation import estimate_pass_at_k
25
19
  from human_eval.execution import check_correctness # , unsafe_execute
26
20
  except (ImportError, ModuleNotFoundError):
@@ -0,0 +1,203 @@
1
+ # Adapted from https://github.com/openai/simple-evals/
2
+
3
+ """
4
+ MGSM: Multilingual Grade School Math Benchmark (MGSM) is a benchmark of grade-school math problems.
5
+ Language Models are Multilingual Chain-of-Thought Reasoners
6
+ Freda Shi, Mirac Suzgun, Markus Freitag, Xuezhi Wang, Suraj Srivats, Soroush Vosoughi, Hyung Won Chung, Yi Tay, Sebastian Ruder, Denny Zhou, Dipanjan Das, Jason Wei
7
+ https://arxiv.org/abs/2210.03057 reference: https://github.com/google-research/url-nlp
8
+ """
9
+
10
+ import re
11
+ import urllib
12
+ from typing import Optional
13
+
14
+ from sglang.test import simple_eval_common as common
15
+ from sglang.test.simple_eval_common import (
16
+ HTML_JINJA,
17
+ Eval,
18
+ EvalResult,
19
+ SamplerBase,
20
+ SingleEvalResult,
21
+ )
22
+
23
+ ALL_LANGUAGES = ["bn", "de", "en", "es", "fr", "ja", "ru", "sw", "te", "th", "zh"]
24
+ LATIN_LANGUAGES = ["de", "en", "es", "fr", "sw"]
25
+ NON_LATIN_LANGUAGES = ["bn", "ja", "ru", "te", "th", "zh"]
26
+
27
+ LANG_TO_FPATH = {
28
+ "bn": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_bn.tsv",
29
+ "de": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_de.tsv",
30
+ "en": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_en.tsv",
31
+ "es": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_es.tsv",
32
+ "fr": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_fr.tsv",
33
+ "ja": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_ja.tsv",
34
+ "ru": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_ru.tsv",
35
+ "sw": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_sw.tsv",
36
+ "te": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_te.tsv",
37
+ "th": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_th.tsv",
38
+ "zh": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_zh.tsv",
39
+ }
40
+ LANG_TO_INSTRUCTIONS = {
41
+ "en": """Solve this math problem. Give the reasoning steps before giving the final answer on the last line by itself in the format of "Answer:". Do not add anything other than the integer answer after "Answer:".
42
+
43
+ {input}""",
44
+ "bn": """এই গণিতের সমস্যাটি সমাধান করুন। চূড়ান্ত উত্তর দেওয়ার আগে যুক্তিসম্পন্ন পদক্ষেপ প্রদান করুন। চূড়ান্ত উত্তরটি একক সংখ্যা হিসাবে "উত্তর:" এর পরে শেষ লাইনে দিন। "উত্তর:" এর পরে অন্য কিছু যুক্ত করবেন না।.
45
+
46
+ {input}""",
47
+ "de": """Löse dieses Mathematikproblem. Gib die Schritte zur Begründung an, bevor du die endgültige Antwort in der letzten Zeile alleine im Format "Antwort:" gibst. Füge nichts anderes als die ganzzahlige Antwort nach "Antwort:" hinzu.
48
+
49
+ {input}""",
50
+ "es": """Resuelve este problema matemático. Proporciona los pasos de razonamiento antes de dar la respuesta final en la última línea por sí misma en el formato de "Respuesta:". No añadas nada más que la respuesta entera después de "Respuesta:".
51
+
52
+ {input}""",
53
+ "fr": """Résolvez ce problème de mathématiques. Donnez les étapes de raisonnement avant de fournir la réponse finale sur la dernière ligne elle-même dans le format de "Réponse:". N'ajoutez rien d'autre que la réponse entière après "Réponse:".
54
+
55
+ {input}""",
56
+ "ja": """の数学の問題を解いてください。最終的な答えを出す前に、解答の推論過程を記述してください。そして最後の行には "答え:" の形式で答えを記述し、その後には整数の答え以外何も追加しないでください。
57
+
58
+ {input}""",
59
+ "ru": """Решите эту математическую задачу. Объясните шаги рассуждения перед тем, как дать окончательный ответ в последней строке сам по себе в формате "Ответ:". Не добавляйте ничего, кроме целочисленного ответа после "Ответ:".
60
+
61
+ {input}""",
62
+ "sw": """Suluhisha tatizo hili la hesabu. Toa hatua za mantiki kabla ya kutoa jibu la mwisho kwenye mstari wa mwisho peke yake katika muundo wa "Jibu:". Usiongeze chochote kingine isipokuwa jibu la integer baada ya "Jibu:".
63
+
64
+ {input}""",
65
+ "te": """ఈ గణిత సమస్యను పరిష్కరించండి. చివరి సమాధానాన్ని ఇవ్వదానికి ముందు తర్కాత్మక అదుగులను ఇవ్వండి. చివరి పంక్తిలో మాత్రమే 'సమాధానం:' అనే ఆకారంలో చివరి సమాధానాద్ని ఇవ్వండి సమాధానం: తర్వాత పూర్ణాంక సమాధానానికి తప్పించి ఎదేనా చేర్చవద్దు.
66
+
67
+ {input}""",
68
+ "th": """แก้ปัญหาคณิตศาสตร์นี้ ให้ให้ขั้นตอนการใช้เหตุผลก่อนที่จะให้คำตอบสุดท้ายในบรรทัดสุดท้ายโดยอยู่ในรูปแบบ "คำตอบ:" ไม่ควรเพิ่มอะไรนอกจากคำตอบที่เป็นจำนวนเต็มหลังจาก "คำตอบ:"
69
+
70
+ {input}""",
71
+ "zh": """解决这个数学问题。在最后一行给出答案前,请提供推理步骤。最后一行应该以 "答案: " 的形式独立给出答案。在 "答案:" 后不要添加除整数答案之外的任何内容。
72
+
73
+ {input}""",
74
+ }
75
+
76
+ LANG_TO_ANSWER_PREFIX = {
77
+ "en": "Answer",
78
+ "bn": "উত্তর",
79
+ "de": "Antwort",
80
+ "es": "Respuesta",
81
+ "fr": "Réponse",
82
+ "ja": "答え",
83
+ "ru": "Ответ",
84
+ "sw": "Jibu",
85
+ "te": "సమాధానం",
86
+ "th": "คำตอบ",
87
+ "zh": "答案",
88
+ }
89
+
90
+
91
+ def parse_answer(answer: str, answer_prefix: str) -> str:
92
+ if answer_prefix not in answer:
93
+ return ""
94
+
95
+ answer_text = answer.split(answer_prefix)[-1].strip()
96
+
97
+ # find all the numbers (including decimals) in the string
98
+ numbers = re.findall(r"\d+\.?\d*", answer_text.replace(",", ""))
99
+
100
+ # return the first number (removing trailing decimal point if present),
101
+ # or an empty string if there were no numbers
102
+ return numbers[-1].rstrip(".") if numbers else ""
103
+
104
+
105
+ def score_mgsm(target: str, prediction: str) -> bool:
106
+ if "." in prediction:
107
+ prediction = prediction.rstrip("0").rstrip(".")
108
+
109
+ target = target.replace(",", "")
110
+ prediction = prediction.replace(",", "")
111
+
112
+ return target == prediction
113
+
114
+
115
+ def get_lang_examples(lang: str) -> list[dict[str, str]]:
116
+ fpath = LANG_TO_FPATH[lang]
117
+ examples = []
118
+ with urllib.request.urlopen(fpath) as f:
119
+ for line in f.read().decode("utf-8").splitlines():
120
+ inputs, targets = line.strip().split("\t")
121
+ if "." in targets:
122
+ raise ValueError(f"targets {targets} contains a decimal point.")
123
+ # targets = int(targets.replace(",", ""))
124
+ examples.append({"inputs": inputs, "targets": targets, "lang": lang})
125
+ return examples
126
+
127
+
128
+ def get_all_examples() -> list[dict[str, str]]:
129
+ examples = []
130
+ for lang in ALL_LANGUAGES:
131
+ if lang != "en":
132
+ continue
133
+ examples += get_lang_examples(lang)
134
+ return examples
135
+
136
+
137
+ class MGSMEval(Eval):
138
+ def __init__(
139
+ self,
140
+ num_examples_per_lang: int = 250, # restrict to a subset of the data for debugging
141
+ num_threads: int = 64,
142
+ languages: Optional[list[str]] = ALL_LANGUAGES,
143
+ ):
144
+ if languages is None:
145
+ languages = ALL_LANGUAGES
146
+ else:
147
+ for language in languages:
148
+ if language not in ALL_LANGUAGES:
149
+ raise ValueError(
150
+ f"language {language} is not a valid language. "
151
+ f"It should be one in {ALL_LANGUAGES}"
152
+ )
153
+ self._languages = languages
154
+ self._num_examples_per_lang = num_examples_per_lang
155
+ self._num_threads = num_threads
156
+
157
+ examples = []
158
+ for lang in self._languages:
159
+ lang_examples = get_lang_examples(lang)
160
+ examples.extend(lang_examples[: self._num_examples_per_lang])
161
+ self.examples = examples
162
+
163
+ def __call__(self, sampler: SamplerBase) -> EvalResult:
164
+ def fn(example: dict[str, str]):
165
+ language = example["lang"]
166
+ latin_language = (
167
+ "group_latin" if language in LATIN_LANGUAGES else "group_non_latin"
168
+ )
169
+ correct_answer = example["targets"]
170
+ instructoin = LANG_TO_INSTRUCTIONS[language]
171
+ prompt_messages = [
172
+ sampler._pack_message(
173
+ content=instructoin.format(input=example["inputs"]), role="user"
174
+ )
175
+ ]
176
+ try:
177
+ response_text = sampler(prompt_messages)
178
+ except Exception as e:
179
+ response_text = ""
180
+
181
+ answer_prefix = LANG_TO_ANSWER_PREFIX[language]
182
+ extracted_answer = parse_answer(response_text, answer_prefix)
183
+
184
+ score = score_mgsm(correct_answer, extracted_answer)
185
+ html = common.jinja_env.from_string(HTML_JINJA).render(
186
+ prompt_messages=prompt_messages,
187
+ next_message=dict(content=response_text, role="assistant"),
188
+ score=score,
189
+ correct_answer=correct_answer,
190
+ extracted_answer=extracted_answer,
191
+ )
192
+ convo = prompt_messages + [dict(content=response_text, role="assistant")]
193
+ return SingleEvalResult(
194
+ html=html,
195
+ score=score,
196
+ convo=convo,
197
+ metrics={language: score, latin_language: score},
198
+ )
199
+
200
+ results = common.map_with_progress(
201
+ fn, self.examples, num_threads=self._num_threads
202
+ )
203
+ return common.aggregate_results(results, default_stats=("mean", "std"))