sglang 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. sglang/__init__.py +8 -0
  2. sglang/api.py +10 -2
  3. sglang/bench_latency.py +151 -40
  4. sglang/bench_serving.py +46 -22
  5. sglang/check_env.py +24 -2
  6. sglang/global_config.py +0 -1
  7. sglang/lang/backend/base_backend.py +3 -1
  8. sglang/lang/backend/openai.py +8 -3
  9. sglang/lang/backend/runtime_endpoint.py +46 -29
  10. sglang/lang/choices.py +164 -0
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +6 -13
  13. sglang/lang/ir.py +14 -5
  14. sglang/srt/constrained/base_tool_cache.py +1 -1
  15. sglang/srt/constrained/fsm_cache.py +12 -2
  16. sglang/srt/layers/activation.py +33 -0
  17. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  18. sglang/srt/layers/extend_attention.py +6 -1
  19. sglang/srt/layers/layernorm.py +65 -0
  20. sglang/srt/layers/logits_processor.py +6 -1
  21. sglang/srt/layers/pooler.py +50 -0
  22. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  23. sglang/srt/layers/radix_attention.py +4 -7
  24. sglang/srt/managers/detokenizer_manager.py +31 -9
  25. sglang/srt/managers/io_struct.py +63 -0
  26. sglang/srt/managers/policy_scheduler.py +173 -25
  27. sglang/srt/managers/schedule_batch.py +174 -380
  28. sglang/srt/managers/tokenizer_manager.py +197 -112
  29. sglang/srt/managers/tp_worker.py +299 -364
  30. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  31. sglang/srt/mem_cache/chunk_cache.py +43 -20
  32. sglang/srt/mem_cache/memory_pool.py +10 -15
  33. sglang/srt/mem_cache/radix_cache.py +74 -40
  34. sglang/srt/model_executor/cuda_graph_runner.py +27 -12
  35. sglang/srt/model_executor/forward_batch_info.py +319 -0
  36. sglang/srt/model_executor/model_runner.py +30 -47
  37. sglang/srt/models/chatglm.py +1 -1
  38. sglang/srt/models/commandr.py +1 -1
  39. sglang/srt/models/dbrx.py +1 -1
  40. sglang/srt/models/deepseek.py +1 -1
  41. sglang/srt/models/deepseek_v2.py +1 -1
  42. sglang/srt/models/gemma.py +1 -1
  43. sglang/srt/models/gemma2.py +1 -2
  44. sglang/srt/models/gpt_bigcode.py +1 -1
  45. sglang/srt/models/grok.py +1 -1
  46. sglang/srt/models/internlm2.py +3 -8
  47. sglang/srt/models/llama2.py +5 -5
  48. sglang/srt/models/llama_classification.py +1 -1
  49. sglang/srt/models/llama_embedding.py +88 -0
  50. sglang/srt/models/llava.py +1 -2
  51. sglang/srt/models/llavavid.py +1 -2
  52. sglang/srt/models/minicpm.py +1 -1
  53. sglang/srt/models/mixtral.py +1 -1
  54. sglang/srt/models/mixtral_quant.py +1 -1
  55. sglang/srt/models/qwen.py +1 -1
  56. sglang/srt/models/qwen2.py +1 -1
  57. sglang/srt/models/qwen2_moe.py +1 -12
  58. sglang/srt/models/stablelm.py +1 -1
  59. sglang/srt/openai_api/adapter.py +189 -39
  60. sglang/srt/openai_api/protocol.py +43 -1
  61. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  62. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  63. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  64. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  65. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  66. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  67. sglang/srt/sampling_params.py +31 -4
  68. sglang/srt/server.py +93 -21
  69. sglang/srt/server_args.py +30 -19
  70. sglang/srt/utils.py +31 -13
  71. sglang/test/run_eval.py +10 -1
  72. sglang/test/runners.py +63 -63
  73. sglang/test/simple_eval_humaneval.py +2 -8
  74. sglang/test/simple_eval_mgsm.py +203 -0
  75. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  76. sglang/test/test_layernorm.py +60 -0
  77. sglang/test/test_programs.py +4 -2
  78. sglang/test/test_utils.py +21 -3
  79. sglang/utils.py +0 -1
  80. sglang/version.py +1 -1
  81. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/METADATA +50 -31
  82. sglang-0.2.12.dist-info/RECORD +112 -0
  83. sglang/srt/layers/linear.py +0 -884
  84. sglang/srt/layers/quantization/__init__.py +0 -64
  85. sglang/srt/layers/quantization/fp8.py +0 -677
  86. sglang-0.2.10.dist-info/RECORD +0 -100
  87. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
  88. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
  89. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
sglang/test/runners.py CHANGED
@@ -23,23 +23,19 @@ import torch.nn.functional as F
23
23
  from transformers import AutoModelForCausalLM, AutoTokenizer
24
24
 
25
25
  from sglang.srt.server import Runtime
26
+ from sglang.srt.utils import is_generation_model
26
27
 
27
28
  DEFAULT_PROMPTS = [
28
- "The capital of France is",
29
+ # the output of gemma-2-2b from SRT is unstable on the commented prompt
30
+ # "The capital of France is",
29
31
  "The capital of the United Kindom is",
30
32
  "Today is a sunny day and I like",
33
+ "AI is a field of computer science focused on",
31
34
  ]
32
35
 
33
36
  NUM_TOP_LOGPROBS = 5
34
37
 
35
38
 
36
- def is_embedding_model(model_path):
37
- # FIXME incomplete list
38
- if "e5-mistral-7b-instruct" in model_path.lower():
39
- return True
40
- return False
41
-
42
-
43
39
  def get_dtype_str(torch_dtype):
44
40
  if torch_dtype is torch.float16:
45
41
  return "float16"
@@ -49,10 +45,11 @@ def get_dtype_str(torch_dtype):
49
45
 
50
46
  @dataclass
51
47
  class ModelOutput:
52
- output_strs: str = None
53
- top_input_logprobs: torch.Tensor = None
54
- top_output_logprobs: torch.Tensor = None
55
- embed_logits: torch.Tensor = None
48
+ output_strs: List[str] = None
49
+ output_ids: List[int] = None
50
+ top_input_logprobs: List[torch.Tensor] = None
51
+ top_output_logprobs: List[torch.Tensor] = None
52
+ embed_logits: List[torch.Tensor] = None
56
53
 
57
54
 
58
55
  class HFRunner:
@@ -60,7 +57,7 @@ class HFRunner:
60
57
  self,
61
58
  model_path,
62
59
  torch_dtype=torch.float16,
63
- is_embedding_model=None,
60
+ is_generation_model=None,
64
61
  ):
65
62
  self.in_queue = multiprocessing.Queue()
66
63
  self.out_queue = multiprocessing.Queue()
@@ -72,13 +69,13 @@ class HFRunner:
72
69
  self.out_queue,
73
70
  model_path,
74
71
  torch_dtype,
75
- is_embedding_model,
72
+ is_generation_model,
76
73
  ),
77
74
  )
78
75
  self.model_proc.start()
79
76
 
80
77
  def start_model_process(
81
- self, in_queue, out_queue, model_path, torch_dtype, is_embedding_model
78
+ self, in_queue, out_queue, model_path, torch_dtype, is_generation_model
82
79
  ):
83
80
  self.tokenizer = AutoTokenizer.from_pretrained(
84
81
  model_path,
@@ -86,12 +83,12 @@ class HFRunner:
86
83
  trust_remote_code=True,
87
84
  )
88
85
 
89
- self.is_embedding_model = (
90
- is_embedding_model(model_path)
91
- if is_embedding_model is None
92
- else is_embedding_model
86
+ self.is_generation_model = (
87
+ is_generation_model(model_path)
88
+ if is_generation_model is None
89
+ else is_generation_model
93
90
  )
94
- if not self.is_embedding_model:
91
+ if self.is_generation_model:
95
92
  self.model = AutoModelForCausalLM.from_pretrained(
96
93
  model_path,
97
94
  torch_dtype=torch_dtype,
@@ -103,13 +100,13 @@ class HFRunner:
103
100
 
104
101
  self.model = SentenceTransformer(
105
102
  model_path,
106
- device="cpu",
107
- ).to(dtype=torch_dtype)
103
+ model_kwargs={"torch_dtype": torch_dtype},
104
+ )
108
105
 
109
106
  while True:
110
107
  prompts, max_new_tokens = in_queue.get()
111
108
  if prompts is not None:
112
- if not self.is_embedding_model:
109
+ if self.is_generation_model:
113
110
  output_strs = []
114
111
  prefill_logprobs = []
115
112
  for p in prompts:
@@ -123,7 +120,9 @@ class HFRunner:
123
120
  output_ids = self.model.generate(
124
121
  input_ids, do_sample=False, max_new_tokens=max_new_tokens
125
122
  )
126
- output_strs.append(self.tokenizer.decode(output_ids[0]))
123
+ output_strs.append(
124
+ self.tokenizer.decode(output_ids[0][len(input_ids[0]) :])
125
+ )
127
126
 
128
127
  logits = self.model.forward(input_ids).logits[0]
129
128
  logprobs = F.log_softmax(
@@ -144,7 +143,6 @@ class HFRunner:
144
143
  )
145
144
 
146
145
  else:
147
- assert isinstance(prompts, List[str])
148
146
  logits = self.model.encode(prompts).tolist()
149
147
 
150
148
  out_queue.put(ModelOutput(embed_logits=logits))
@@ -152,7 +150,7 @@ class HFRunner:
152
150
  def forward(
153
151
  self,
154
152
  prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
155
- max_new_tokens=64,
153
+ max_new_tokens=8,
156
154
  ):
157
155
  self.in_queue.put((prompts, max_new_tokens))
158
156
  return self.out_queue.get()
@@ -175,16 +173,13 @@ class SRTRunner:
175
173
  model_path,
176
174
  tp_size=1,
177
175
  torch_dtype=torch.float16,
178
- is_embedding_model=None,
176
+ is_generation_model=None,
179
177
  ):
180
- self.is_embedding_model = (
181
- is_embedding_model(model_path)
182
- if is_embedding_model is None
183
- else is_embedding_model
178
+ self.is_generation_model = (
179
+ is_generation_model(model_path)
180
+ if is_generation_model is None
181
+ else is_generation_model
184
182
  )
185
- if self.is_embedding_model:
186
- raise NotImplementedError()
187
-
188
183
  self.runtime = Runtime(
189
184
  model_path=model_path,
190
185
  tp_size=tp_size,
@@ -194,40 +189,45 @@ class SRTRunner:
194
189
  def forward(
195
190
  self,
196
191
  prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
197
- max_new_tokens=64,
192
+ max_new_tokens=8,
198
193
  ):
199
- # the return value contains logprobs from prefill
200
- output_strs = []
201
- top_input_logprobs = []
202
- sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
203
- for prompt in prompts:
204
- response = self.runtime.generate(
205
- prompt,
206
- sampling_params=sampling_params,
207
- return_logprob=True,
208
- top_logprobs_num=NUM_TOP_LOGPROBS,
209
- )
210
- response = json.loads(response)
211
- output_strs.append(response["text"])
212
- top_input_logprobs.append(
213
- [
214
- [tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
215
- for x in response["meta_info"]["input_top_logprobs"][1:]
216
- ]
217
- + [
194
+ if self.is_generation_model:
195
+ # the return value contains logprobs from prefill
196
+ output_strs = []
197
+ top_input_logprobs = []
198
+ sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
199
+ for prompt in prompts:
200
+ response = self.runtime.generate(
201
+ prompt,
202
+ sampling_params=sampling_params,
203
+ return_logprob=True,
204
+ top_logprobs_num=NUM_TOP_LOGPROBS,
205
+ )
206
+ response = json.loads(response)
207
+ output_strs.append(response["text"])
208
+ top_input_logprobs.append(
218
209
  [
219
- tup[0]
220
- for tup in response["meta_info"]["output_top_logprobs"][0][
221
- :NUM_TOP_LOGPROBS
210
+ [tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
211
+ for x in response["meta_info"]["input_top_logprobs"][1:]
212
+ ]
213
+ + [
214
+ [
215
+ tup[0]
216
+ for tup in response["meta_info"]["output_top_logprobs"][0][
217
+ :NUM_TOP_LOGPROBS
218
+ ]
222
219
  ]
223
220
  ]
224
- ]
225
- )
226
- # print(response["meta_info"]["output_top_logprobs"][0])
221
+ )
227
222
 
228
- return ModelOutput(
229
- output_strs=output_strs, top_input_logprobs=top_input_logprobs
230
- )
223
+ return ModelOutput(
224
+ output_strs=output_strs, top_input_logprobs=top_input_logprobs
225
+ )
226
+ else:
227
+ response = self.runtime.encode(prompts)
228
+ response = json.loads(response)
229
+ logits = [x["embedding"] for x in response]
230
+ return ModelOutput(embed_logits=logits)
231
231
 
232
232
  def __enter__(self):
233
233
  return self
@@ -6,21 +6,15 @@ Mark Chen and Jerry Tworek and Heewoo Jun and Qiming Yuan and Henrique Ponde de
6
6
  https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/
7
7
  """
8
8
 
9
- import json
10
- import logging
11
- import multiprocessing
12
9
  import random
13
10
  import re
14
- from collections import Counter, defaultdict
15
11
  from concurrent.futures import ThreadPoolExecutor, as_completed
16
- from io import BytesIO
17
- from typing import Any, Dict, List, Tuple
12
+ from typing import Dict, List
18
13
 
19
- import blobfile as bf
20
14
  import tqdm
21
15
 
22
16
  try:
23
- from human_eval.data import HUMAN_EVAL, read_problems
17
+ from human_eval.data import read_problems
24
18
  from human_eval.evaluation import estimate_pass_at_k
25
19
  from human_eval.execution import check_correctness # , unsafe_execute
26
20
  except (ImportError, ModuleNotFoundError):
@@ -0,0 +1,203 @@
1
+ # Adapted from https://github.com/openai/simple-evals/
2
+
3
+ """
4
+ MGSM: Multilingual Grade School Math Benchmark (MGSM) is a benchmark of grade-school math problems.
5
+ Language Models are Multilingual Chain-of-Thought Reasoners
6
+ Freda Shi, Mirac Suzgun, Markus Freitag, Xuezhi Wang, Suraj Srivats, Soroush Vosoughi, Hyung Won Chung, Yi Tay, Sebastian Ruder, Denny Zhou, Dipanjan Das, Jason Wei
7
+ https://arxiv.org/abs/2210.03057 reference: https://github.com/google-research/url-nlp
8
+ """
9
+
10
+ import re
11
+ import urllib
12
+ from typing import Optional
13
+
14
+ from sglang.test import simple_eval_common as common
15
+ from sglang.test.simple_eval_common import (
16
+ HTML_JINJA,
17
+ Eval,
18
+ EvalResult,
19
+ SamplerBase,
20
+ SingleEvalResult,
21
+ )
22
+
23
+ ALL_LANGUAGES = ["bn", "de", "en", "es", "fr", "ja", "ru", "sw", "te", "th", "zh"]
24
+ LATIN_LANGUAGES = ["de", "en", "es", "fr", "sw"]
25
+ NON_LATIN_LANGUAGES = ["bn", "ja", "ru", "te", "th", "zh"]
26
+
27
+ LANG_TO_FPATH = {
28
+ "bn": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_bn.tsv",
29
+ "de": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_de.tsv",
30
+ "en": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_en.tsv",
31
+ "es": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_es.tsv",
32
+ "fr": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_fr.tsv",
33
+ "ja": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_ja.tsv",
34
+ "ru": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_ru.tsv",
35
+ "sw": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_sw.tsv",
36
+ "te": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_te.tsv",
37
+ "th": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_th.tsv",
38
+ "zh": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_zh.tsv",
39
+ }
40
+ LANG_TO_INSTRUCTIONS = {
41
+ "en": """Solve this math problem. Give the reasoning steps before giving the final answer on the last line by itself in the format of "Answer:". Do not add anything other than the integer answer after "Answer:".
42
+
43
+ {input}""",
44
+ "bn": """এই গণিতের সমস্যাটি সমাধান করুন। চূড়ান্ত উত্তর দেওয়ার আগে যুক্তিসম্পন্ন পদক্ষেপ প্রদান করুন। চূড়ান্ত উত্তরটি একক সংখ্যা হিসাবে "উত্তর:" এর পরে শেষ লাইনে দিন। "উত্তর:" এর পরে অন্য কিছু যুক্ত করবেন না।.
45
+
46
+ {input}""",
47
+ "de": """Löse dieses Mathematikproblem. Gib die Schritte zur Begründung an, bevor du die endgültige Antwort in der letzten Zeile alleine im Format "Antwort:" gibst. Füge nichts anderes als die ganzzahlige Antwort nach "Antwort:" hinzu.
48
+
49
+ {input}""",
50
+ "es": """Resuelve este problema matemático. Proporciona los pasos de razonamiento antes de dar la respuesta final en la última línea por sí misma en el formato de "Respuesta:". No añadas nada más que la respuesta entera después de "Respuesta:".
51
+
52
+ {input}""",
53
+ "fr": """Résolvez ce problème de mathématiques. Donnez les étapes de raisonnement avant de fournir la réponse finale sur la dernière ligne elle-même dans le format de "Réponse:". N'ajoutez rien d'autre que la réponse entière après "Réponse:".
54
+
55
+ {input}""",
56
+ "ja": """の数学の問題を解いてください。最終的な答えを出す前に、解答の推論過程を記述してください。そして最後の行には "答え:" の形式で答えを記述し、その後には整数の答え以外何も追加しないでください。
57
+
58
+ {input}""",
59
+ "ru": """Решите эту математическую задачу. Объясните шаги рассуждения перед тем, как дать окончательный ответ в последней строке сам по себе в формате "Ответ:". Не добавляйте ничего, кроме целочисленного ответа после "Ответ:".
60
+
61
+ {input}""",
62
+ "sw": """Suluhisha tatizo hili la hesabu. Toa hatua za mantiki kabla ya kutoa jibu la mwisho kwenye mstari wa mwisho peke yake katika muundo wa "Jibu:". Usiongeze chochote kingine isipokuwa jibu la integer baada ya "Jibu:".
63
+
64
+ {input}""",
65
+ "te": """ఈ గణిత సమస్యను పరిష్కరించండి. చివరి సమాధానాన్ని ఇవ్వదానికి ముందు తర్కాత్మక అదుగులను ఇవ్వండి. చివరి పంక్తిలో మాత్రమే 'సమాధానం:' అనే ఆకారంలో చివరి సమాధానాద్ని ఇవ్వండి సమాధానం: తర్వాత పూర్ణాంక సమాధానానికి తప్పించి ఎదేనా చేర్చవద్దు.
66
+
67
+ {input}""",
68
+ "th": """แก้ปัญหาคณิตศาสตร์นี้ ให้ให้ขั้นตอนการใช้เหตุผลก่อนที่จะให้คำตอบสุดท้ายในบรรทัดสุดท้ายโดยอยู่ในรูปแบบ "คำตอบ:" ไม่ควรเพิ่มอะไรนอกจากคำตอบที่เป็นจำนวนเต็มหลังจาก "คำตอบ:"
69
+
70
+ {input}""",
71
+ "zh": """解决这个数学问题。在最后一行给出答案前,请提供推理步骤。最后一行应该以 "答案: " 的形式独立给出答案。在 "答案:" 后不要添加除整数答案之外的任何内容。
72
+
73
+ {input}""",
74
+ }
75
+
76
+ LANG_TO_ANSWER_PREFIX = {
77
+ "en": "Answer",
78
+ "bn": "উত্তর",
79
+ "de": "Antwort",
80
+ "es": "Respuesta",
81
+ "fr": "Réponse",
82
+ "ja": "答え",
83
+ "ru": "Ответ",
84
+ "sw": "Jibu",
85
+ "te": "సమాధానం",
86
+ "th": "คำตอบ",
87
+ "zh": "答案",
88
+ }
89
+
90
+
91
+ def parse_answer(answer: str, answer_prefix: str) -> str:
92
+ if answer_prefix not in answer:
93
+ return ""
94
+
95
+ answer_text = answer.split(answer_prefix)[-1].strip()
96
+
97
+ # find all the numbers (including decimals) in the string
98
+ numbers = re.findall(r"\d+\.?\d*", answer_text.replace(",", ""))
99
+
100
+ # return the first number (removing trailing decimal point if present),
101
+ # or an empty string if there were no numbers
102
+ return numbers[-1].rstrip(".") if numbers else ""
103
+
104
+
105
+ def score_mgsm(target: str, prediction: str) -> bool:
106
+ if "." in prediction:
107
+ prediction = prediction.rstrip("0").rstrip(".")
108
+
109
+ target = target.replace(",", "")
110
+ prediction = prediction.replace(",", "")
111
+
112
+ return target == prediction
113
+
114
+
115
+ def get_lang_examples(lang: str) -> list[dict[str, str]]:
116
+ fpath = LANG_TO_FPATH[lang]
117
+ examples = []
118
+ with urllib.request.urlopen(fpath) as f:
119
+ for line in f.read().decode("utf-8").splitlines():
120
+ inputs, targets = line.strip().split("\t")
121
+ if "." in targets:
122
+ raise ValueError(f"targets {targets} contains a decimal point.")
123
+ # targets = int(targets.replace(",", ""))
124
+ examples.append({"inputs": inputs, "targets": targets, "lang": lang})
125
+ return examples
126
+
127
+
128
+ def get_all_examples() -> list[dict[str, str]]:
129
+ examples = []
130
+ for lang in ALL_LANGUAGES:
131
+ if lang != "en":
132
+ continue
133
+ examples += get_lang_examples(lang)
134
+ return examples
135
+
136
+
137
+ class MGSMEval(Eval):
138
+ def __init__(
139
+ self,
140
+ num_examples_per_lang: int = 250, # restrict to a subset of the data for debugging
141
+ num_threads: int = 64,
142
+ languages: Optional[list[str]] = ALL_LANGUAGES,
143
+ ):
144
+ if languages is None:
145
+ languages = ALL_LANGUAGES
146
+ else:
147
+ for language in languages:
148
+ if language not in ALL_LANGUAGES:
149
+ raise ValueError(
150
+ f"language {language} is not a valid language. "
151
+ f"It should be one in {ALL_LANGUAGES}"
152
+ )
153
+ self._languages = languages
154
+ self._num_examples_per_lang = num_examples_per_lang
155
+ self._num_threads = num_threads
156
+
157
+ examples = []
158
+ for lang in self._languages:
159
+ lang_examples = get_lang_examples(lang)
160
+ examples.extend(lang_examples[: self._num_examples_per_lang])
161
+ self.examples = examples
162
+
163
+ def __call__(self, sampler: SamplerBase) -> EvalResult:
164
+ def fn(example: dict[str, str]):
165
+ language = example["lang"]
166
+ latin_language = (
167
+ "group_latin" if language in LATIN_LANGUAGES else "group_non_latin"
168
+ )
169
+ correct_answer = example["targets"]
170
+ instructoin = LANG_TO_INSTRUCTIONS[language]
171
+ prompt_messages = [
172
+ sampler._pack_message(
173
+ content=instructoin.format(input=example["inputs"]), role="user"
174
+ )
175
+ ]
176
+ try:
177
+ response_text = sampler(prompt_messages)
178
+ except Exception as e:
179
+ response_text = ""
180
+
181
+ answer_prefix = LANG_TO_ANSWER_PREFIX[language]
182
+ extracted_answer = parse_answer(response_text, answer_prefix)
183
+
184
+ score = score_mgsm(correct_answer, extracted_answer)
185
+ html = common.jinja_env.from_string(HTML_JINJA).render(
186
+ prompt_messages=prompt_messages,
187
+ next_message=dict(content=response_text, role="assistant"),
188
+ score=score,
189
+ correct_answer=correct_answer,
190
+ extracted_answer=extracted_answer,
191
+ )
192
+ convo = prompt_messages + [dict(content=response_text, role="assistant")]
193
+ return SingleEvalResult(
194
+ html=html,
195
+ score=score,
196
+ convo=convo,
197
+ metrics={language: score, latin_language: score},
198
+ )
199
+
200
+ results = common.map_with_progress(
201
+ fn, self.examples, num_threads=self._num_threads
202
+ )
203
+ return common.aggregate_results(results, default_stats=("mean", "std"))