crfm-helm 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crfm-helm might be problematic. Click here for more details.

Files changed (268) hide show
  1. {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/METADATA +74 -53
  2. {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/RECORD +262 -182
  3. {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/WHEEL +1 -1
  4. helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +3 -3
  5. helm/benchmark/adaptation/adapters/test_adapter.py +4 -4
  6. helm/benchmark/annotation/air_bench_annotator.py +2 -2
  7. helm/benchmark/annotation/bigcodebench_annotator.py +3 -3
  8. helm/benchmark/annotation/bird_sql_annotator.py +2 -2
  9. helm/benchmark/annotation/chw_care_plan_annotator.py +7 -12
  10. helm/benchmark/annotation/ehr_sql_annotator.py +2 -2
  11. helm/benchmark/annotation/helpdesk_call_summarization_annotator.py +7 -7
  12. helm/benchmark/annotation/live_qa_annotator.py +1 -1
  13. helm/benchmark/annotation/mimic_bhc_annotator.py +100 -0
  14. helm/benchmark/annotation/model_as_judge.py +12 -16
  15. helm/benchmark/annotation/omni_math_annotator.py +13 -14
  16. helm/benchmark/annotation/wildbench_annotator.py +9 -9
  17. helm/benchmark/executor.py +11 -12
  18. helm/benchmark/metrics/aci_bench_metrics.py +9 -29
  19. helm/benchmark/metrics/bias_word_lists.py +1 -1
  20. helm/benchmark/metrics/chw_care_plan_metrics.py +10 -30
  21. helm/benchmark/metrics/classification_metrics.py +3 -3
  22. helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
  23. helm/benchmark/metrics/codeinsights_code_efficiency_metrics.py +186 -0
  24. helm/benchmark/metrics/codeinsights_code_evaluation_metrics.py +477 -0
  25. helm/benchmark/metrics/codeinsights_correct_code_metrics.py +366 -0
  26. helm/benchmark/metrics/codeinsights_edge_case_metrics.py +92 -0
  27. helm/benchmark/metrics/codeinsights_metric_specs.py +51 -0
  28. helm/benchmark/metrics/comet_metric.py +1 -1
  29. helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +2 -2
  30. helm/benchmark/metrics/copyright_metrics.py +1 -1
  31. helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +1 -1
  32. helm/benchmark/metrics/dischargeme_metrics.py +9 -29
  33. helm/benchmark/metrics/efficiency_metrics.py +3 -3
  34. helm/benchmark/metrics/evaluate_reference_metrics.py +1 -1
  35. helm/benchmark/metrics/gpt4_audio_refusal_metrics.py +145 -0
  36. helm/benchmark/metrics/ifeval_metrics.py +2 -2
  37. helm/benchmark/metrics/image_generation/clip_score_metrics.py +13 -2
  38. helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +1 -1
  39. helm/benchmark/metrics/kpi_edgar_metrics.py +121 -0
  40. helm/benchmark/metrics/llm_jury_metrics.py +46 -0
  41. helm/benchmark/metrics/lmkt_metric_specs.py +12 -0
  42. helm/benchmark/metrics/lmkt_metrics.py +47 -0
  43. helm/benchmark/metrics/med_dialog_metrics.py +9 -29
  44. helm/benchmark/metrics/medalign_metrics.py +9 -29
  45. helm/benchmark/metrics/medi_qa_metrics.py +9 -29
  46. helm/benchmark/metrics/medication_qa_metrics.py +10 -30
  47. helm/benchmark/metrics/melt_bias_metric.py +234 -0
  48. helm/benchmark/metrics/melt_bias_word_lists.py +1367 -0
  49. helm/benchmark/metrics/melt_metric_specs.py +43 -0
  50. helm/benchmark/metrics/melt_toxicity_metric.py +107 -0
  51. helm/benchmark/metrics/mental_health_metrics.py +9 -29
  52. helm/benchmark/metrics/metric_service.py +11 -11
  53. helm/benchmark/metrics/mimic_bhc_metrics.py +14 -0
  54. helm/benchmark/metrics/mimic_rrs_metrics.py +9 -29
  55. helm/benchmark/metrics/mtsamples_procedures_metrics.py +9 -29
  56. helm/benchmark/metrics/mtsamples_replicate_metrics.py +9 -29
  57. helm/benchmark/metrics/openai_mrcr_metrics.py +52 -0
  58. helm/benchmark/metrics/ruler_qa_metrics.py +34 -0
  59. helm/benchmark/metrics/starr_patient_instructions_metrics.py +9 -29
  60. helm/benchmark/metrics/summac/model_summac.py +2 -3
  61. helm/benchmark/metrics/summarization_metrics.py +2 -1
  62. helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +2 -2
  63. helm/benchmark/metrics/toxicity_metrics.py +2 -2
  64. helm/benchmark/metrics/unitxt_metrics.py +3 -4
  65. helm/benchmark/metrics/vision_language/emd_utils.py +4 -2
  66. helm/benchmark/metrics/vision_language/image_utils.py +2 -2
  67. helm/benchmark/model_deployment_registry.py +16 -26
  68. helm/benchmark/presentation/contamination.py +3 -3
  69. helm/benchmark/presentation/create_plots.py +43 -13
  70. helm/benchmark/presentation/run_display.py +13 -0
  71. helm/benchmark/presentation/schema.py +7 -1
  72. helm/benchmark/presentation/summarize.py +84 -61
  73. helm/benchmark/presentation/test_create_plots.py +4 -1
  74. helm/benchmark/reeval_run.py +3 -4
  75. helm/benchmark/reeval_runner.py +3 -3
  76. helm/benchmark/run.py +84 -73
  77. helm/benchmark/run_expander.py +12 -1
  78. helm/benchmark/run_spec_factory.py +7 -6
  79. helm/benchmark/run_specs/arabic_run_specs.py +73 -0
  80. helm/benchmark/run_specs/audio_run_specs.py +52 -8
  81. helm/benchmark/run_specs/bluex_run_specs.py +40 -0
  82. helm/benchmark/run_specs/classic_run_specs.py +0 -53
  83. helm/benchmark/run_specs/codeinsights_run_specs.py +192 -0
  84. helm/benchmark/run_specs/enterprise_run_specs.py +20 -0
  85. helm/benchmark/run_specs/experimental_run_specs.py +31 -1
  86. helm/benchmark/run_specs/healthqa_br_run_specs.py +40 -0
  87. helm/benchmark/run_specs/heim_run_specs.py +3 -1
  88. helm/benchmark/run_specs/lmkt_run_specs.py +144 -0
  89. helm/benchmark/run_specs/long_context_run_specs.py +114 -15
  90. helm/benchmark/run_specs/medhelm_run_specs.py +146 -41
  91. helm/benchmark/run_specs/melt_run_specs.py +783 -0
  92. helm/benchmark/run_specs/multilingual_run_specs.py +50 -0
  93. helm/benchmark/run_specs/speech_disorder_audio_run_specs.py +163 -0
  94. helm/benchmark/run_specs/vlm_run_specs.py +28 -0
  95. helm/benchmark/runner.py +5 -5
  96. helm/benchmark/scenarios/aci_bench_scenario.py +7 -1
  97. helm/benchmark/scenarios/alghafa_scenario.py +126 -0
  98. helm/benchmark/scenarios/arabic_mmlu_scenario.py +78 -0
  99. helm/benchmark/scenarios/aratrust_scenario.py +76 -0
  100. helm/benchmark/scenarios/audio_language/air_bench_chat_scenario.py +3 -1
  101. helm/benchmark/scenarios/audio_language/air_bench_foundation_scenario.py +5 -5
  102. helm/benchmark/scenarios/audio_language/casual_conversations2_scenario.py +1 -1
  103. helm/benchmark/scenarios/audio_language/corebench_scenario.py +77 -0
  104. helm/benchmark/scenarios/audio_language/mustard_scenario.py +1 -1
  105. helm/benchmark/scenarios/audio_language/ultra_suite_asr_classification_scenario.py +104 -0
  106. helm/benchmark/scenarios/audio_language/ultra_suite_asr_transcription_scenario.py +99 -0
  107. helm/benchmark/scenarios/audio_language/ultra_suite_classification_scenario.py +118 -0
  108. helm/benchmark/scenarios/audio_language/ultra_suite_disorder_breakdown_scenario.py +86 -0
  109. helm/benchmark/scenarios/audio_language/ultra_suite_disorder_symptoms_scenario.py +117 -0
  110. helm/benchmark/scenarios/audio_language/vocal_sound_scenario.py +15 -1
  111. helm/benchmark/scenarios/audio_language/voxceleb2_scenario.py +1 -2
  112. helm/benchmark/scenarios/autobencher_capabilities_scenario.py +2 -2
  113. helm/benchmark/scenarios/bluex_scenario.py +66 -0
  114. helm/benchmark/scenarios/chw_care_plan_scenario.py +14 -13
  115. helm/benchmark/scenarios/clear_scenario.py +11 -7
  116. helm/benchmark/scenarios/cleva_scenario.py +1 -1
  117. helm/benchmark/scenarios/codeinsights_code_efficiency_scenario.py +197 -0
  118. helm/benchmark/scenarios/codeinsights_correct_code_scenario.py +78 -0
  119. helm/benchmark/scenarios/codeinsights_edge_case_scenario.py +192 -0
  120. helm/benchmark/scenarios/codeinsights_student_coding_scenario.py +162 -0
  121. helm/benchmark/scenarios/codeinsights_student_mistake_scenario.py +188 -0
  122. helm/benchmark/scenarios/dischargeme_scenario.py +36 -21
  123. helm/benchmark/scenarios/ehr_sql_scenario.py +7 -1
  124. helm/benchmark/scenarios/ehrshot_scenario.py +28 -55
  125. helm/benchmark/scenarios/exams_multilingual_scenario.py +115 -0
  126. helm/benchmark/scenarios/grammar.py +2 -2
  127. helm/benchmark/scenarios/headqa_scenario.py +6 -1
  128. helm/benchmark/scenarios/healthqa_br_scenario.py +80 -0
  129. helm/benchmark/scenarios/infinite_bench_en_mc_scenario.py +90 -0
  130. helm/benchmark/scenarios/infinite_bench_en_qa_scenario.py +85 -0
  131. helm/benchmark/scenarios/{infinite_bench_sum_scenario.py → infinite_bench_en_sum_scenario.py} +10 -13
  132. helm/benchmark/scenarios/kpi_edgar_scenario.py +151 -0
  133. helm/benchmark/scenarios/lmkt_scenarios.py +288 -0
  134. helm/benchmark/scenarios/math_scenario.py +21 -20
  135. helm/benchmark/scenarios/med_dialog_scenario.py +6 -1
  136. helm/benchmark/scenarios/medalign_scenario.py +9 -3
  137. helm/benchmark/scenarios/medalign_scenario_helper.py +27 -130
  138. helm/benchmark/scenarios/medbullets_scenario.py +7 -2
  139. helm/benchmark/scenarios/medcalc_bench_scenario.py +4 -2
  140. helm/benchmark/scenarios/medec_scenario.py +6 -1
  141. helm/benchmark/scenarios/medhallu_scenario.py +7 -1
  142. helm/benchmark/scenarios/medi_qa_scenario.py +10 -4
  143. helm/benchmark/scenarios/medication_qa_scenario.py +7 -1
  144. helm/benchmark/scenarios/melt_ir_scenario.py +171 -0
  145. helm/benchmark/scenarios/melt_knowledge_scenario.py +246 -0
  146. helm/benchmark/scenarios/melt_lm_scenarios.py +252 -0
  147. helm/benchmark/scenarios/melt_scenarios.py +793 -0
  148. helm/benchmark/scenarios/melt_srn_scenario.py +342 -0
  149. helm/benchmark/scenarios/melt_synthetic_reasoning_scenario.py +222 -0
  150. helm/benchmark/scenarios/melt_translation_scenario.py +152 -0
  151. helm/benchmark/scenarios/mental_health_scenario.py +16 -5
  152. helm/benchmark/scenarios/mimic_bhc_scenario.py +13 -8
  153. helm/benchmark/scenarios/mimic_rrs_scenario.py +17 -8
  154. helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +14 -8
  155. helm/benchmark/scenarios/mmlu_pro_scenario.py +1 -1
  156. helm/benchmark/scenarios/mmmlu_scenario.py +85 -0
  157. helm/benchmark/scenarios/mtsamples_procedures_scenario.py +5 -2
  158. helm/benchmark/scenarios/mtsamples_replicate_scenario.py +3 -2
  159. helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +11 -5
  160. helm/benchmark/scenarios/openai_mrcr_scenario.py +79 -0
  161. helm/benchmark/scenarios/pubmed_qa_scenario.py +6 -1
  162. helm/benchmark/scenarios/race_based_med_scenario.py +18 -8
  163. helm/benchmark/scenarios/ruler_qa_scenario_helper.py +2 -2
  164. helm/benchmark/scenarios/ruler_qa_scenarios.py +2 -2
  165. helm/benchmark/scenarios/seahelm_scenario.py +2 -2
  166. helm/benchmark/scenarios/shc_bmt_scenario.py +12 -6
  167. helm/benchmark/scenarios/shc_cdi_scenario.py +11 -6
  168. helm/benchmark/scenarios/shc_conf_scenario.py +12 -6
  169. helm/benchmark/scenarios/shc_ent_scenario.py +11 -6
  170. helm/benchmark/scenarios/shc_gip_scenario.py +13 -5
  171. helm/benchmark/scenarios/shc_privacy_scenario.py +78 -0
  172. helm/benchmark/scenarios/shc_proxy_scenario.py +76 -0
  173. helm/benchmark/scenarios/shc_ptbm_scenario.py +12 -7
  174. helm/benchmark/scenarios/shc_sei_scenario.py +12 -7
  175. helm/benchmark/scenarios/shc_sequoia_scenario.py +13 -5
  176. helm/benchmark/scenarios/starr_patient_instructions_scenario.py +15 -8
  177. helm/benchmark/scenarios/test_alghafa_scenario.py +29 -0
  178. helm/benchmark/scenarios/test_aratrust_scenario.py +21 -0
  179. helm/benchmark/scenarios/test_bluex_scenario.py +59 -0
  180. helm/benchmark/scenarios/test_exams_multilingual_scenario.py +29 -0
  181. helm/benchmark/scenarios/test_healtha_br_scenario.py +57 -0
  182. helm/benchmark/scenarios/test_infinite_bench_en_qa_scenario.py +18 -0
  183. helm/benchmark/scenarios/test_infinite_bench_en_sum_scenario.py +31 -0
  184. helm/benchmark/scenarios/truthful_qa_scenario.py +2 -1
  185. helm/benchmark/scenarios/vision_language/msr_vtt_scenario.py +75 -0
  186. helm/benchmark/server.py +2 -1
  187. helm/benchmark/slurm_jobs.py +1 -2
  188. helm/benchmark/slurm_runner.py +8 -1
  189. helm/benchmark/static/schema_arabic.yaml +228 -0
  190. helm/benchmark/static/schema_audio.yaml +60 -49
  191. helm/benchmark/static/schema_classic.yaml +0 -17
  192. helm/benchmark/static/schema_enterprise.yaml +21 -0
  193. helm/benchmark/static/schema_long_context.yaml +81 -20
  194. helm/benchmark/static/schema_medhelm.yaml +272 -213
  195. helm/benchmark/static/schema_melt.yaml +1257 -0
  196. helm/benchmark/static/schema_slphelm.yaml +162 -0
  197. helm/benchmark/static/schema_vhelm.yaml +26 -26
  198. helm/benchmark/static/schema_video.yaml +219 -0
  199. helm/benchmark/static_build/assets/index-b9779128.css +1 -0
  200. helm/benchmark/static_build/assets/index-e439d5e1.js +10 -0
  201. helm/benchmark/static_build/assets/medhelm-overview-eac29843.png +0 -0
  202. helm/benchmark/static_build/assets/{tremor-9cefc3c5.js → tremor-38a10867.js} +1 -1
  203. helm/benchmark/static_build/index.html +4 -4
  204. helm/benchmark/window_services/encoder_decoder_window_service.py +3 -3
  205. helm/benchmark/window_services/image_generation/clip_window_service.py +1 -3
  206. helm/benchmark/window_services/test_utils.py +3 -4
  207. helm/benchmark/window_services/tokenizer_service.py +7 -8
  208. helm/clients/anthropic_client.py +69 -29
  209. helm/clients/audio_language/diva_llama_client.py +4 -2
  210. helm/clients/audio_language/qwen2_5_omni_client.py +209 -0
  211. helm/clients/audio_language/qwen2_audiolm_client.py +8 -6
  212. helm/clients/audio_language/qwen_audiolm_client.py +4 -2
  213. helm/clients/audio_language/test.py +62 -0
  214. helm/clients/bedrock_client.py +3 -1
  215. helm/clients/client.py +7 -7
  216. helm/clients/grok_client.py +36 -0
  217. helm/clients/huggingface_client.py +42 -3
  218. helm/clients/huggingface_pipeline_client.py +138 -0
  219. helm/clients/image_generation/dalle_mini/model/configuration.py +1 -1
  220. helm/clients/image_generation/dalle_mini/model/modeling.py +1 -1
  221. helm/clients/image_generation/dalle_mini/model/processor.py +1 -1
  222. helm/clients/image_generation/dalle_mini/model/tokenizer.py +1 -1
  223. helm/clients/openai_client.py +102 -55
  224. helm/clients/openai_responses_client.py +176 -0
  225. helm/clients/palmyra_client.py +2 -5
  226. helm/clients/reka_client.py +2 -2
  227. helm/clients/test_huggingface_client.py +3 -3
  228. helm/clients/together_client.py +31 -6
  229. helm/clients/vertexai_client.py +17 -9
  230. helm/clients/vision_language/huggingface_vision2seq_client.py +6 -4
  231. helm/clients/vision_language/huggingface_vlm_client.py +2 -2
  232. helm/clients/vision_language/idefics_client.py +6 -2
  233. helm/clients/vision_language/paligemma_client.py +2 -2
  234. helm/clients/vision_language/qwen2_vlm_client.py +66 -53
  235. helm/clients/vision_language/qwen_vlm_client.py +7 -5
  236. helm/clients/vllm_client.py +43 -7
  237. helm/clients/vllm_granite_thinking_client.py +56 -0
  238. helm/clients/writer_client.py +102 -0
  239. helm/common/context.py +80 -0
  240. helm/common/credentials_utils.py +5 -5
  241. helm/common/critique_request.py +0 -1
  242. helm/common/general.py +9 -2
  243. helm/common/hierarchical_logger.py +104 -12
  244. helm/common/local_context.py +140 -0
  245. helm/common/object_spec.py +23 -8
  246. helm/common/remote_context.py +61 -0
  247. helm/common/request.py +8 -0
  248. helm/common/test_logging.py +94 -0
  249. helm/config/model_deployments.yaml +995 -45
  250. helm/config/model_metadata.yaml +780 -59
  251. helm/config/tokenizer_configs.yaml +224 -3
  252. helm/proxy/cli.py +4 -2
  253. helm/proxy/critique/mechanical_turk_utils.py +1 -1
  254. helm/proxy/retry.py +5 -0
  255. helm/proxy/services/server_service.py +21 -85
  256. helm/tokenizers/grok_tokenizer.py +55 -0
  257. helm/tokenizers/huggingface_tokenizer.py +1 -1
  258. helm/tokenizers/test_grok_tokenizer.py +33 -0
  259. helm/benchmark/metrics/numeracy_metrics.py +0 -72
  260. helm/benchmark/metrics/test_numeracy_metrics.py +0 -95
  261. helm/benchmark/scenarios/numeracy_scenario.py +0 -793
  262. helm/benchmark/scenarios/test_infinite_bench_sum_scenario.py +0 -46
  263. helm/benchmark/static_build/assets/index-262903c1.js +0 -10
  264. helm/benchmark/static_build/assets/index-42060d71.css +0 -1
  265. {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/entry_points.txt +0 -0
  266. {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/licenses/LICENSE +0 -0
  267. {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/top_level.txt +0 -0
  268. /helm/benchmark/static_build/assets/{medhelm-overview-3ddfcd65.png → medhelm-v1-overview-3ddfcd65.png} +0 -0
@@ -1,793 +0,0 @@
1
- # flake8: noqa
2
- from collections import defaultdict
3
- from dataclasses import dataclass, field
4
- from itertools import combinations_with_replacement, product
5
- import math
6
- from math import comb
7
- import numpy as np
8
- import numpy.typing as npt
9
- import random
10
- from typing import List, Optional, Tuple, Dict
11
-
12
- from helm.benchmark.adaptation.adapters.adapter_factory import ADAPT_GENERATION
13
- from helm.benchmark.adaptation.adapter_spec import AdapterSpec
14
- from helm.benchmark.window_services.tokenizer_service import TokenizerService
15
- from helm.common.authentication import Authentication
16
- from helm.common.optional_dependencies import handle_module_not_found_error
17
- from helm.proxy.services.server_service import ServerService
18
- from helm.benchmark.scenarios.scenario import (
19
- Scenario,
20
- Instance,
21
- Reference,
22
- TRAIN_SPLIT,
23
- TEST_SPLIT,
24
- CORRECT_TAG,
25
- Input,
26
- Output,
27
- )
28
-
29
- try:
30
- import sympy
31
- from sympy import Symbol, Poly, diff
32
- from sympy.parsing.sympy_parser import standard_transformations, implicit_multiplication_application
33
- except ModuleNotFoundError as e:
34
- handle_module_not_found_error(e, ["scenarios"])
35
-
36
-
37
- # TODO: we shouldn't create an Adapter and TokenizerService in a scenario
38
- # The Adapter and Scenarios should be completely decoupled.
39
- # https://github.com/stanford-crfm/benchmarking/issues/569
40
- def get_test_tokenizer_service() -> TokenizerService:
41
- # Pointed to the default local path set in run.py (--local-path)
42
- return TokenizerService(ServerService(base_path="prod_env", root_mode=True), Authentication("test"))
43
-
44
-
45
- SOLUTION_TAG: str = "solution"
46
- CLASS_TAG: str = "class"
47
- Range = List[Tuple[int, int]]
48
-
49
- SYMPY_TRANSFORMATIONS = standard_transformations + (implicit_multiplication_application,)
50
-
51
-
52
- def generate_terms(degree: int, num_variables: int) -> List[List[int]]:
53
- """Lists out multisets corresponding to all possible terms up to degree `degree` and `num_variables` variables."""
54
- return sum(
55
- [
56
- list(map(lambda _: list(_), combinations_with_replacement(range(num_variables), d)))
57
- for d in reversed(range(degree + 1))
58
- ],
59
- [],
60
- )
61
-
62
-
63
- def get_powers(terms: List[List[int]]) -> List[List[Tuple[int, int]]]:
64
- return list(map(lambda _: list(zip(*np.unique(_, return_counts=True))), terms))
65
-
66
-
67
- def sympy_power_to_power(power: Tuple[int, ...]) -> List[Tuple[int, int]]:
68
- return [(idx, exp) for idx, exp in enumerate(power) if exp]
69
-
70
-
71
- def stringify_terms(terms: List[List[int]], variable_names: List[str] = list("xyz")) -> List[str]:
72
- """Formatting utility for multisets."""
73
-
74
- def stringify_power(index: int, degree: int) -> str:
75
- """Helper formatting utility for powers."""
76
- var = variable_names[index]
77
- if degree == 0:
78
- return ""
79
- if degree == 1:
80
- return var
81
- return f"{var}^{degree}"
82
-
83
- powers = get_powers(terms)
84
- return list(map(lambda _: "".join([stringify_power(*el) for el in _]), powers))
85
-
86
-
87
- @dataclass
88
- class Polynomial:
89
- """A simple polynomial class over the integers that supports evaluation and pretty-printing."""
90
-
91
- degree: int
92
- num_variables: int
93
- coeffs: npt.NDArray[np.int64]
94
- terms: List[List[int]] = field(init=False)
95
-
96
- def __post_init__(self):
97
- self.terms = generate_terms(self.degree, self.num_variables)
98
-
99
- def eval(self, vals: List[int]):
100
- return np.dot(self.coeffs, np.array(list(map(lambda _: np.prod(np.array(vals).__getitem__(_)), self.terms))))
101
-
102
- def __str__(self):
103
- def stringify_monomial(coeff: int, term: str) -> Optional[str]:
104
- if coeff == 0:
105
- return None
106
- if coeff == 1:
107
- return term or str(coeff)
108
- if coeff == -1:
109
- return f"-{term}" if term else "-1"
110
- return f"{coeff}{term}"
111
-
112
- monomials = [stringify_monomial(c, x) for c, x in zip(self.coeffs, stringify_terms(self.terms))]
113
- present_monomials: List[str] = [m for m in monomials if m]
114
- return " + ".join(present_monomials).replace(" + -", " - ")
115
-
116
- @classmethod
117
- def from_string(cls, expr_str: str, degree: int, num_variables: int):
118
- expr = sympy.parse_expr(expr_str.replace("^", "**"), transformations=SYMPY_TRANSFORMATIONS)
119
- poly = Poly(expr, list(sorted(expr.free_symbols, key=lambda _: _.name)))
120
- return sympy_poly_to_poly(poly, degree, num_variables)
121
-
122
-
123
- def sympy_poly_to_poly(poly: Poly, degree: int, num_variables: int) -> Polynomial:
124
- terms = poly.terms()
125
- all_terms = generate_terms(degree, num_variables)
126
- all_powers = get_powers(all_terms)
127
- coeffs_dict = defaultdict(int, {tuple(sympy_power_to_power(power)): coeff for power, coeff in terms})
128
- coeffs = [coeffs_dict[tuple(_)] for _ in all_powers]
129
- return Polynomial(degree=degree, num_variables=num_variables, coeffs=np.array(coeffs))
130
-
131
-
132
- def generate_polynomial(
133
- degree: int,
134
- num_variables: int,
135
- range_coeffs: Range, # inclusive
136
- seed: Optional[int] = None,
137
- strict_degree=True,
138
- strict_variables=True,
139
- strict_constant=True,
140
- ) -> Polynomial:
141
- """Sample the coefficients (A, B, ...) of the polynomial equation y = ... + A x + B.
142
- A generic method used by the function class-specific methods below.
143
-
144
- Args:
145
- strict_degree (bool): if True, require `rel` to have degree strictly equal to `degree`
146
- strict_variables (bool): if True, require `rel` to use exactly `num_variables`
147
- strict_constant (bool): if True, require the constant (ie. term of degree 0) to be non-zero
148
- Returns:
149
- `rel` (Polynomial)
150
- """
151
- MAX_ATTEMPTS = 100
152
- if seed is not None:
153
- random.seed(seed)
154
- np.random.seed(seed)
155
- count = 0
156
- terms = generate_terms(degree, num_variables)
157
- while count < MAX_ATTEMPTS:
158
- done = True
159
- coeffs = [random.randint(r[0], r[1]) for r in range_coeffs]
160
- if strict_constant and coeffs[-1] == 0:
161
- done = False
162
- if strict_degree and not sum(coeffs[: comb(degree + num_variables - 1, num_variables - 1)]):
163
- done = False
164
- if strict_variables:
165
- for idx in range(num_variables):
166
- vals = np.zeros(num_variables)
167
- vals[idx] = 1
168
- res = np.dot(coeffs[:-1], np.array(list(map(lambda _: np.prod(vals.__getitem__(_)), terms[:-1]))))
169
- if not res:
170
- done = False
171
- break
172
- if done:
173
- break
174
- count += 1
175
- if count >= MAX_ATTEMPTS:
176
- raise ValueError(
177
- "Failed to sample valid polynomial equation within "
178
- + f"{MAX_ATTEMPTS} attempts from ranges {str(range_coeffs)}."
179
- )
180
- return Polynomial(degree=degree, num_variables=num_variables, coeffs=np.array(coeffs))
181
-
182
-
183
- def generate_linear(range_coeffs: Range) -> Polynomial:
184
- return generate_polynomial(
185
- degree=1,
186
- num_variables=1,
187
- range_coeffs=range_coeffs,
188
- strict_degree=True,
189
- strict_variables=True,
190
- strict_constant=True,
191
- )
192
-
193
-
194
- def generate_parabola(range_coeffs: Range) -> Polynomial:
195
- return generate_polynomial(
196
- degree=2,
197
- num_variables=1,
198
- range_coeffs=range_coeffs,
199
- strict_degree=True,
200
- strict_variables=True,
201
- strict_constant=True,
202
- )
203
-
204
-
205
- def generate_plane(range_coeffs: Range) -> Polynomial:
206
- return generate_polynomial(
207
- degree=1,
208
- num_variables=2,
209
- range_coeffs=range_coeffs,
210
- strict_degree=True,
211
- strict_variables=True,
212
- strict_constant=True,
213
- )
214
-
215
-
216
- def generate_paraboloid(range_coeffs: Range) -> Polynomial:
217
- return generate_polynomial(
218
- degree=2,
219
- num_variables=2,
220
- range_coeffs=range_coeffs,
221
- strict_degree=True,
222
- strict_variables=True,
223
- strict_constant=True,
224
- )
225
-
226
-
227
- def generate_rotated_translated_paraboloid(range_coeffs: Range) -> Polynomial:
228
- """Unused."""
229
- do_sample = True
230
- while do_sample:
231
- coeffs_0 = generate_plane(range_coeffs).coeffs
232
- coeffs_1 = generate_plane(range_coeffs).coeffs
233
- mat = np.array(
234
- [
235
- coeffs_0,
236
- coeffs_1,
237
- ]
238
- )
239
- if np.linalg.matrix_rank(mat) == 2:
240
- do_sample = False
241
- x = Symbol("x")
242
- y = Symbol("y")
243
- xprime = coeffs_0[0] * x + coeffs_0[1] * y + coeffs_0[2]
244
- yprime = coeffs_1[0] * x + coeffs_1[1] * y + coeffs_1[2]
245
- expr = xprime**2 + yprime**2
246
- poly = Poly(expr, [x, y])
247
- return sympy_poly_to_poly(poly, 2, 2)
248
-
249
-
250
- def distance_linear(point: List[int], rel_str: str):
251
- """
252
- Returns the minimum distance from the given point to the relation given by `rel_str` which has the form:
253
- A x - y + B = 0
254
- """
255
- relation_type = "linear"
256
- degree: int = RELTYPE_INFO[relation_type].degree
257
- num_variables: int = RELTYPE_INFO[relation_type].num_variables
258
- rel = Polynomial.from_string(rel_str.split(" = ")[-1], degree, num_variables)
259
- A = rel.coeffs[0]
260
- B = -1
261
- C = rel.coeffs[1]
262
- x, y = point
263
- return float(abs((A * x + B * y + C)) / (math.sqrt(A**2 + B**2)))
264
-
265
-
266
- def distance_parabola(point: List[int], rel_str: str, TOL: float = 1e-10):
267
- """
268
- Returns the minimum distance from the given point to the relation given by `rel_str` which has the form:
269
- y = A x^2 + B x + C
270
- """
271
- rel_str = rel_str.split(" = ")[-1]
272
- expr = sympy.parse_expr(rel_str.replace("^", "**"), transformations=SYMPY_TRANSFORMATIONS)
273
- poly = sympy.Poly(expr, list(expr.free_symbols))
274
- x = list(expr.free_symbols)[0]
275
- x0, y0 = point
276
- dist = (x - x0) ** 2 + (poly - y0) ** 2
277
- deriv = sympy.diff(dist, x)
278
- try:
279
- sols = sympy.solve(deriv, x)
280
- except ZeroDivisionError:
281
- # This shouldn't happen, but has happened for a prior implementation of
282
- # `distance_paraboloid`, so catch it conservatively:
283
- print("Failed to compute minimum distance.")
284
- # pdb.set_trace()
285
- return float(0.0)
286
- dist_vals = list(map(lambda _: sympy.N(dist.eval(_)), sols))
287
- try:
288
- dist_val = min([sympy.re(_) for _ in dist_vals if abs(sympy.im(_)) < TOL and sympy.re(_) >= 0])
289
- except ValueError:
290
- # A real solution should exist, but if not (eg. numerical error exceeds TOL):
291
- print("Failed to compute minimum distance.")
292
- # pdb.set_trace()
293
- return float(0.0)
294
- return np.sqrt(float(dist_val))
295
-
296
-
297
- def distance_plane(point: List[int], rel_str: str):
298
- """
299
- Returns the minimum distance from the given point to the relation given by `rel_str` which has the form:
300
- A x + B y - z + C = 0
301
- """
302
- relation_type = "plane"
303
- degree: int = RELTYPE_INFO[relation_type].degree
304
- num_variables: int = RELTYPE_INFO[relation_type].num_variables
305
- rel = Polynomial.from_string(rel_str.split(" = ")[-1], degree, num_variables)
306
- A = rel.coeffs[0]
307
- B = rel.coeffs[1]
308
- C = -1
309
- D = rel.coeffs[2]
310
- x, y, z = point
311
- d = abs((A * x + B * y + C * z + D))
312
- e = math.sqrt(A**2 + B**2 + C**2)
313
- return float(d / e)
314
-
315
-
316
- def distance_paraboloid(point: List[int], rel_str: str, TOL: float = 1e-10):
317
- """
318
- Returns the minimum distance from the given point to the relation given by `rel_str` which has the form:
319
- z = A x^2 + B x y + C y^2 + D x + E y + F
320
- Uses method of Lagrange multipliers.
321
- """
322
- rel_str = rel_str.split(" = ")[-1]
323
- expr = sympy.parse_expr(rel_str.replace("^", "**"), transformations=SYMPY_TRANSFORMATIONS)
324
- x, y = list(expr.free_symbols)
325
- if x.name == "y":
326
- x, y = y, x
327
- z = Symbol("z")
328
- x0, y0, z0 = point
329
- f = (x - x0) ** 2 + (y - y0) ** 2 + (z - z0) ** 2
330
- g = z - expr
331
- if abs(g.subs([(x, x0), (y, y0), (z, z0)])) < TOL:
332
- return float(0.0)
333
- λ = Symbol("λ")
334
- # The code below is meant to be equivalent to
335
- # `sols = sympy.solve([eq_x, eq_y, eq_z, g], [x, y, z, λ])`
336
- # but sympy.solve was failing to find any solution on many inputs
337
- # as well as not finding some solutions
338
- # so this breaks it down for the special case of `f - λ g` which is at most quadratic.
339
-
340
- # Set up the equations from method of Lagrange multipliers
341
- eq_x = diff(f, x) - λ * diff(g, x)
342
- eq_y = diff(f, y) - λ * diff(g, y)
343
- eq_z = diff(f, z) - λ * diff(g, z)
344
- # Solve for each variable individually
345
- has_xy = y in eq_x.free_symbols # has xy term
346
- if has_xy:
347
- sols_x = sympy.solve(eq_x, [x, y, λ])
348
- sols_y = sympy.solve(eq_y, [x, y, λ])
349
- sols_z = sympy.solve(eq_z, [z, λ])
350
- else:
351
- sols_x = sympy.solve(eq_x, [x, λ])
352
- sols_y = sympy.solve(eq_y, [y, λ])
353
- sols_z = sympy.solve(eq_z, [z, λ])
354
- try:
355
- # Put the solutions together
356
-
357
- # Extract x,y,z resp. from tuples
358
- sols_lst_xyz = [[_[0] for _ in lst] for lst in [sols_x, sols_y, sols_z]]
359
-
360
- # Extract solutions for λ from tuples
361
- sols_lst_λλλ = [[_[-1] for _ in lst] for lst in [sols_x, sols_y, sols_z]]
362
-
363
- # Get list of possible solution tuples and corresponding solutions for λ
364
- sols_xyz = list(product(*sols_lst_xyz))
365
- vals_λ = list(product(*sols_lst_λλλ))
366
-
367
- sols = []
368
- # Try each possible combined solution for x, y, z, λ
369
- for sol_xyz, val_λs in zip(sols_xyz, vals_λ):
370
- val_λs = tuple(set(filter(lambda _: not _.is_symbol, val_λs))) # get distinct values for λ if there are any
371
- if len(val_λs) > 1: # there can be at most one distinct value for λ
372
- continue
373
- val_λ = val_λs[0] if val_λs else λ
374
- sol_x, sol_y, sol_z = sol_xyz
375
- if not val_λ.is_symbol:
376
- # Substitute in values of λ
377
- sol_x = sol_x.subs(λ, val_λ)
378
- sol_y = sol_y.subs(λ, val_λ)
379
- sol_z = sol_z.subs(λ, val_λ)
380
- g_λ = g.subs(λ, val_λ)
381
- else:
382
- g_λ = g
383
-
384
- # Substitute in solutions for x, y, z
385
- if has_xy:
386
- g_λ = g_λ.subs([(x, sol_x), (z, sol_z)])
387
- sol_ys = sympy.solve(sol_x - sol_y, y)
388
- for sol_y in sol_ys:
389
- g_λy = g_λ.subs(y, sol_y)
390
- sol_xy = sol_x.subs(y, sol_y)
391
- syms = list(g_λy.free_symbols)
392
- if len(syms) > 1: # underdetermined system
393
- continue
394
- sym = syms[0]
395
- vals = [sympy.N(_) for _ in sympy.solveset(g_λy, sym)]
396
- sols.extend([(sol_xy.subs(sym, _), sol_y.subs(sym, _), sol_z.subs(sym, _)) for _ in vals])
397
- else:
398
- g_λ = g_λ.subs([(x, sol_x), (y, sol_y), (z, sol_z)])
399
- syms = list(g_λ.free_symbols)
400
- if len(syms) > 1: # underdetermined system
401
- continue
402
- # Solve for remaining variable
403
- sym = syms[0]
404
- vals = [sympy.N(_) for _ in sympy.solveset(g_λ, sym)]
405
- sols.extend([(sol_x.subs(sym, _), sol_y.subs(sym, _), sol_z.subs(sym, _)) for _ in vals])
406
- except ZeroDivisionError:
407
- # This shouldn't happen, but has happened for a prior implementation of
408
- # `distance_paraboloid`, so catch it conservatively:
409
- print("Failed to compute minimum distance.")
410
- # pdb.set_trace()
411
- return float(0.0)
412
- poly_f = sympy.Poly(f, [x, y, z])
413
- # Evaluate f on found solutions
414
- try:
415
- dist_vals = list(map(lambda _: sympy.N(poly_f.eval(_)), sols))
416
- except sympy.polys.polyerrors.UnificationFailed:
417
- # Forgot to substitute all variables in some expression.
418
- # This shouldn't happen, but has happened for a prior implementation of
419
- # `distance_paraboloid`, so catch it conservatively:
420
- print("sympy error: Unification failed.")
421
- # pdb.set_trace()
422
- return float(0.0)
423
- # Get the minimum nonnegative real value
424
- try:
425
- dist_val = min([sympy.re(_) for _ in dist_vals if abs(sympy.im(_)) < TOL and sympy.re(_) >= 0])
426
- except ValueError:
427
- # A real solution should exist, but if not (eg. numerical error exceeds TOL):
428
- print("Failed to compute minimum distance.")
429
- print([eq_x, eq_y, eq_z, g])
430
- print(sols)
431
- # pdb.set_trace()
432
- return float(0.0)
433
- return np.sqrt(float(dist_val))
434
-
435
-
436
- def select_ranges(
437
- num_train: int, num_test: int, dim: int, overlap: bool = True, nonnegative_only: bool = False
438
- ) -> Tuple[Range, Range]:
439
- """
440
- Choose disjoint intervals from which to sample points, where
441
- the test points lie within a region bounded by the region
442
- that the train points are sampled from.
443
- """
444
- choices: npt.NDArray[np.int64] = np.array([0, 1, 2, 5, 10, 20, 50, 100, 200])
445
-
446
- def select_index(lst: npt.NDArray[np.int64], val: int) -> int:
447
- return list((lst - val) >= 0).index(True)
448
-
449
- def construct_range(index: int, dim: int) -> List[Tuple[int, int]]:
450
- if nonnegative_only:
451
- return [(0, choices[index]) for _ in range(dim)]
452
- return [(-choices[index], choices[index]) for _ in range(dim)]
453
-
454
- if nonnegative_only:
455
- num_points = (choices + 1) ** dim # list of ints
456
- else:
457
- num_points = (2 * choices + 1) ** dim # list of ints
458
-
459
- if overlap:
460
- train_index = test_index = select_index(num_points, num_train + num_test)
461
- else:
462
- test_index = select_index(num_points, num_test)
463
- train_index = select_index(num_points - num_points[test_index], num_train)
464
-
465
- test_range = construct_range(test_index, dim)
466
- train_range = construct_range(train_index, dim)
467
- return (train_range, test_range)
468
-
469
-
470
- @dataclass(frozen=True)
471
- class RelationTypeInfo:
472
- name: str
473
- degree: int
474
- num_variables: int
475
- range: Range
476
- example_coeffs: npt.NDArray[np.int64]
477
-
478
-
479
- RELTYPE_INFO: Dict[str, RelationTypeInfo] = {
480
- "linear": RelationTypeInfo(
481
- name="linear", degree=1, num_variables=1, range=[(1, 5), (1, 5)], example_coeffs=np.array([2, 5])
482
- ), # 2x + 5
483
- "parabola": RelationTypeInfo(
484
- # parabolas with axis of symmetry to the left of the origin
485
- name="parabola",
486
- degree=2,
487
- num_variables=1,
488
- range=[(1, 2), (0, 2), (1, 5)],
489
- example_coeffs=np.array([1, 0, 2]),
490
- ), # x^2 + 2
491
- "plane": RelationTypeInfo(
492
- name="plane", degree=1, num_variables=2, range=[(1, 5), (1, 5), (1, 5)], example_coeffs=np.array([2, 1, 5])
493
- ), # 2x + y + 5
494
- "paraboloid": RelationTypeInfo(
495
- # axis-aligned elliptic paraboloids only, ie. of the form z = A x^2 + B y^2 + C
496
- name="paraboloid",
497
- degree=2,
498
- num_variables=2,
499
- range=[(1, 2), (0, 1), (1, 2), (0, 0), (0, 0), (1, 5)],
500
- example_coeffs=np.array([2, 0, 1, 0, 0, 2]),
501
- ), # 2x^2 + y^2 + 2
502
- }
503
-
504
-
505
- # MODE_INFO = { # Testing purposes
506
- # "example": {"num_function_train": 1, "num_function_test": 1, "num_train": 10, "num_test": 1,},
507
- # "standard": {"num_function_train": 1, "num_function_test": 1, "num_train": 10, "num_test": 1,},
508
- # "function": {"num_function_train": 2, "num_function_test": 2, "num_train": 2, "num_test": 1,},
509
- # }
510
-
511
-
512
- MODE_INFO = {
513
- "example": {
514
- "num_function_train": 1,
515
- "num_function_test": 1,
516
- "num_train": 100,
517
- "num_test": 100,
518
- },
519
- "standard": {
520
- "num_function_train": 1,
521
- "num_function_test": 1,
522
- "num_train": 100,
523
- "num_test": 100,
524
- },
525
- "function": {
526
- "num_function_train": 1000,
527
- "num_function_test": 1000, # don't bother excluding from train set
528
- "num_train": 100,
529
- "num_test": 1,
530
- },
531
- }
532
-
533
-
534
- def get_var(dim: int, variable_names=list("xyz")):
535
- return variable_names[dim - 1]
536
-
537
-
538
- def get_dataset_header(
539
- dim: int, variable_names: List[str] = list("xyz"), delimiter: str = ", ", output_prefix: str = ", "
540
- ):
541
- return delimiter.join(variable_names[: dim - 1]) + output_prefix + variable_names[dim - 1]
542
-
543
-
544
- def get_numeracy_adapter_spec(
545
- max_train_instances: int, max_eval_instances: int, dim: int, delimiter: str = ", ", **kwargs
546
- ) -> AdapterSpec:
547
- return AdapterSpec(
548
- **{
549
- **{
550
- "method": ADAPT_GENERATION,
551
- "instructions": get_dataset_header(dim, delimiter=delimiter, output_prefix=", "),
552
- "max_train_instances": max_train_instances,
553
- "max_eval_instances": max_eval_instances,
554
- "num_outputs": 1,
555
- "num_train_trials": 1,
556
- "model_deployment": "openai/davinci",
557
- "temperature": 0,
558
- "stop_sequences": ["\n"],
559
- "max_tokens": 20,
560
- "input_prefix": "",
561
- "output_prefix": ", ",
562
- "instance_prefix": "\n",
563
- },
564
- **kwargs,
565
- }
566
- ) # enable override
567
-
568
-
569
- class NumeracyScenario(Scenario):
570
- """
571
- A task that asks the model to induce an unknown polynomial at a point given a set of function evaluations.
572
- Unlike pre-existing tasks testing arithmetic, this task attempts to test a deeper notion of numeracy
573
- which the model cannot rely purely on rote memorization of standard tables of arithmetic operations
574
- in order to succeed on and which intuitively occurs as a implicit subroutine in broader contexts.
575
-
576
- Decomposes into 4 function classes:
577
- - linear (1 degree, 1 variable)
578
- - parabola (2 degrees, 2 variables)
579
- - plane (1 degree, 2 variables)
580
- - (elliptic) paraboloid (2 degrees, 2 variables)
581
-
582
- with coefficients drawn from restricted ranges
583
- (see dict `RELTYPE_INFO`), and
584
- where {parabola, paraboloid} have nonnegative domains,
585
- ie. the right ray of the x-axis or upper-right
586
- quadrant of the plane resp. so that the model cannot
587
- rely on symmetry.
588
-
589
- and independently 2 + 1 modes:
590
- - standard
591
- - A single dataset corresponding to the same polynomial.
592
- Evaluate on different points.
593
- - function
594
- - Multiple datasets, where each dataset instance corresponds to
595
- an independently sampled polynomial belonging to the same class.
596
- Evaluate on different (dataset, point) pairs.
597
- and
598
- - example
599
- - A single dataset corresponding to the same fixed representative for each class.
600
-
601
- If `overlap` is `True`:
602
- Train and test datapoints are drawn from the same rectilinear region
603
- centered at the origin (see function `select_ranges`),
604
- making sure to exclude the training set from the test set.
605
- Otherwise:
606
- Train datapoints are drawn from a rectilinear border region while
607
- test datapoints are drawn from a disjoint rectilinear interior region,
608
- centered at the origin (see function `select_ranges`).
609
-
610
- Example prompt for `relation_type=parabola,mode=function` with `num_function_train=num_function_test=num_train=2`:
611
- x,y
612
- 1,4
613
- -1,2
614
- 0,2
615
-
616
- x,y
617
- -1,0
618
- 1,20
619
- 0,8
620
-
621
- x,y
622
- -1,7
623
- 1,11
624
- 0,
625
- """
626
-
627
- name = "numeracy"
628
- description = "polynomial induction"
629
- tags: List[str] = []
630
- RELTYPES: List[str] = ["linear", "parabola", "plane", "paraboloid"]
631
- MODES: List[str] = ["example", "standard", "function"]
632
- delimiter: str = ", "
633
-
634
- def __init__(
635
- self,
636
- relation_type: str = "linear",
637
- mode: str = "function",
638
- seed: Optional[int] = None,
639
- overlap: bool = True, # whether the in-context and eval points are drawn from the same region
640
- sort_vals: bool = False, # whether to sort the in-context examples
641
- ):
642
- super().__init__()
643
- assert relation_type in NumeracyScenario.RELTYPES
644
- assert mode in NumeracyScenario.MODES
645
- self.random_seed = seed
646
-
647
- self.relation_type = relation_type
648
- self.mode = mode
649
- self.delimiter = NumeracyScenario.delimiter
650
- self.seed = seed
651
- self.overlap = overlap
652
- self.sort_vals = sort_vals
653
-
654
- self.degree: int = RELTYPE_INFO[relation_type].degree
655
- self.num_variables: int = RELTYPE_INFO[relation_type].num_variables
656
- self.range_coeffs = RELTYPE_INFO[relation_type].range
657
- self.dim = self.num_variables + 1
658
-
659
- self.num_function_train = MODE_INFO[mode]["num_function_train"]
660
- self.num_function_test = MODE_INFO[mode]["num_function_test"]
661
- self.num_train = MODE_INFO[mode]["num_train"]
662
- self.num_test = MODE_INFO[mode]["num_test"]
663
-
664
- def get_instances(self, output_path: str) -> List[Instance]:
665
- assert self.random_seed is not None
666
- random.seed(self.random_seed)
667
- np.random.seed(self.random_seed)
668
-
669
- train_range, test_range = select_ranges(
670
- num_train=100,
671
- num_test=100,
672
- dim=self.num_variables, # not a typo
673
- overlap=self.overlap,
674
- nonnegative_only=self.relation_type in ["parabola", "paraboloid"],
675
- )
676
- # train_range = test_range:
677
- # -------------------------
678
- # linear: [(-100, 100)]
679
- # parabola: [(0, 200)]
680
- # plane: [(-10, 10), (-10, 10)]
681
- # paraboloid: [(0, 20), (0, 20)]
682
-
683
- test_vals = list(product(*[range(r[0], r[1] + 1) for r in test_range]))
684
- if self.overlap:
685
- train_vals = test_vals
686
- else:
687
- train_vals = list(set(product(*[range(r[0], r[1] + 1) for r in train_range])) - set(test_vals))
688
- if self.sort_vals:
689
- train_vals = list(sorted(train_vals))
690
- if self.num_variables == 2:
691
- test_vals = list(filter(lambda _: _[0] <= _[1], test_vals))
692
- train_vals = list(filter(lambda _: _[0] <= _[1], train_vals))
693
-
694
- def generate_datapoint(rel: Polynomial, vals: List[int]) -> Tuple[List[str], str]:
695
- y = rel.eval(vals)
696
- return list(map(str, vals)), str(y)
697
-
698
- def generate_datapoint_instances_for_split(rel, idxs, eval_vals, split):
699
- instances = []
700
- for idx in idxs:
701
- vals = eval_vals[idx]
702
- str_vals, y = generate_datapoint(rel, vals)
703
- input = self.delimiter.join(str_vals)
704
- output = y
705
- var = get_var(self.dim)
706
- solution = f"{var} = {rel}"
707
- references = [
708
- Reference(Output(text=output), tags=[CORRECT_TAG]),
709
- Reference(Output(text=solution), tags=[SOLUTION_TAG]),
710
- Reference(Output(text=self.relation_type), tags=[CLASS_TAG]),
711
- ]
712
- instance = Instance(Input(text=input), references=references, split=split)
713
- instances.append(instance)
714
- return instances
715
-
716
- def generate_datapoint_instances(rel: Polynomial):
717
- train_idxs = list(np.random.choice(range(len(train_vals)), self.num_train, replace=False))
718
- if self.sort_vals:
719
- train_idxs = list(sorted(train_idxs))
720
- if self.overlap:
721
- all_test_idxs = list(set(range(len(test_vals))) - set(train_idxs))
722
- else:
723
- all_test_idxs = list(range(len(test_vals)))
724
- test_idxs = np.random.choice(all_test_idxs, self.num_test, replace=False)
725
-
726
- train_instances = generate_datapoint_instances_for_split(rel, train_idxs, train_vals, TRAIN_SPLIT)
727
- test_instances = generate_datapoint_instances_for_split(rel, test_idxs, test_vals, TEST_SPLIT)
728
- instances = train_instances + test_instances
729
- return instances
730
-
731
- def generate_dataset():
732
- generate_func = globals()[f"generate_{self.relation_type}"]
733
- rel = generate_func(self.range_coeffs)
734
- instances = generate_datapoint_instances(rel)
735
- return instances
736
-
737
- def generate_datasets(num_instances: int, split: str):
738
- # TODO: construct_prompt is no longer part of adapter, and this function needs to be rewritten
739
- # https://github.com/stanford-crfm/benchmarking/issues/569
740
- return []
741
- # spec = get_numeracy_adapter_spec(self.num_train, self.num_test, self.dim, self.delimiter)
742
- # service = get_test_tokenizer_service()
743
- # adapter = Adapter(spec, service)
744
- # outer_spec = get_numeracy_adapter_spec(
745
- # self.num_train,
746
- # self.num_test,
747
- # self.dim,
748
- # instructions="",
749
- # instance_prefix="\n\n",
750
- # delimiter=self.delimiter,
751
- # )
752
- # outer_adapter = Adapter(outer_spec, service)
753
- # instances = []
754
- # for idx in range(num_instances):
755
- # datapoint_instances = generate_dataset()
756
- # train_instances = datapoint_instances[: self.num_train]
757
- # eval_instances = datapoint_instances[self.num_train :]
758
- # dataset_instances = []
759
- # for idx in range(self.num_test):
760
- # eval_instance = eval_instances[idx]
761
- # input = adapter.construct_prompt(
762
- # train_instances, eval_instance, include_output=False, reference_index=None
763
- # ).text
764
- # input = input[: -len(spec.output_prefix.rstrip())] # strip output_prefix
765
- # references = eval_instance.references
766
- # dataset_instance = Instance(input=input, references=references, split=split) # split doesn't matter
767
- # dataset_instances.append(dataset_instance)
768
-
769
- # input = outer_adapter.construct_prompt(
770
- # dataset_instances[:-1], dataset_instances[-1], include_output=False, reference_index=None
771
- # ).text
772
- # input = input[: -len(spec.output_prefix.rstrip())] # strip output_prefix
773
- # references = dataset_instances[-1].references
774
- # instance = Instance(input=input, references=references, split=split)
775
- # instances.append(instance)
776
-
777
- # return instances
778
-
779
- def generate_instances():
780
- generate_func = globals()[f"generate_{self.relation_type}"]
781
- if self.mode == "example":
782
- coeffs = RELTYPE_INFO[self.relation_type].example_coeffs
783
- rel = Polynomial(self.degree, self.num_variables, coeffs)
784
- return generate_datapoint_instances(rel)
785
- if self.mode == "standard":
786
- rel = generate_func(self.range_coeffs)
787
- return generate_datapoint_instances(rel)
788
- if self.mode == "function":
789
- return generate_datasets(self.num_function_train, TRAIN_SPLIT) + generate_datasets(
790
- self.num_function_test, TEST_SPLIT
791
- )
792
-
793
- return generate_instances()