eval-framework 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +177 -0
  5. eval_framework/context/eval.py +121 -0
  6. eval_framework/context/local.py +78 -0
  7. eval_framework/evaluation_generator.py +234 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +432 -0
  16. eval_framework/llm/base.py +180 -0
  17. eval_framework/llm/huggingface.py +418 -0
  18. eval_framework/llm/mistral.py +88 -0
  19. eval_framework/llm/models.py +28 -0
  20. eval_framework/llm/openai.py +400 -0
  21. eval_framework/llm/vllm.py +554 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +166 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/aidanbench.py +28 -0
  29. eval_framework/metrics/completion/bleu.py +76 -0
  30. eval_framework/metrics/completion/chrf.py +62 -0
  31. eval_framework/metrics/completion/code_assertion.py +44 -0
  32. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  33. eval_framework/metrics/completion/comet.py +56 -0
  34. eval_framework/metrics/completion/concordance_index.py +38 -0
  35. eval_framework/metrics/completion/csv_format.py +102 -0
  36. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  37. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  38. eval_framework/metrics/completion/f1.py +42 -0
  39. eval_framework/metrics/completion/format_checker.py +56 -0
  40. eval_framework/metrics/completion/grid_difference.py +77 -0
  41. eval_framework/metrics/completion/ifeval.py +73 -0
  42. eval_framework/metrics/completion/json_format.py +179 -0
  43. eval_framework/metrics/completion/language_checker.py +74 -0
  44. eval_framework/metrics/completion/length_control.py +83 -0
  45. eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
  46. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  47. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  48. eval_framework/metrics/completion/repetition.py +88 -0
  49. eval_framework/metrics/completion/rouge_1.py +35 -0
  50. eval_framework/metrics/completion/rouge_2.py +45 -0
  51. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  52. eval_framework/metrics/completion/rouge_l.py +52 -0
  53. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  54. eval_framework/metrics/completion/ter.py +67 -0
  55. eval_framework/metrics/completion/text_counter.py +182 -0
  56. eval_framework/metrics/efficiency/__init__.py +0 -0
  57. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  58. eval_framework/metrics/llm/__init__.py +0 -0
  59. eval_framework/metrics/llm/base.py +34 -0
  60. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  61. eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
  62. eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
  63. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  64. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  65. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  66. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  67. eval_framework/metrics/llm/graders/language.py +56 -0
  68. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  69. eval_framework/metrics/llm/graders/models.py +74 -0
  70. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  71. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  72. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  73. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  74. eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
  75. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  76. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  77. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  78. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  79. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  80. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
  81. eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
  82. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  83. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  84. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  85. eval_framework/metrics/llm/utils.py +20 -0
  86. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  87. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  88. eval_framework/metrics/loglikelihood/base.py +50 -0
  89. eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
  90. eval_framework/metrics/loglikelihood/dcs.py +43 -0
  91. eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
  92. eval_framework/metrics/loglikelihood/ternary.py +42 -0
  93. eval_framework/py.typed +0 -0
  94. eval_framework/response_generator.py +351 -0
  95. eval_framework/result_processors/__init__.py +0 -0
  96. eval_framework/result_processors/base.py +88 -0
  97. eval_framework/result_processors/hf_uploader.py +75 -0
  98. eval_framework/result_processors/result_processor.py +129 -0
  99. eval_framework/result_processors/wandb_uploader.py +137 -0
  100. eval_framework/run.py +369 -0
  101. eval_framework/run_direct.py +42 -0
  102. eval_framework/shared/types.py +227 -0
  103. eval_framework/tasks/__init__.py +6 -0
  104. eval_framework/tasks/base.py +392 -0
  105. eval_framework/tasks/benchmarks/__init__.py +0 -0
  106. eval_framework/tasks/benchmarks/aidanbench.py +211 -0
  107. eval_framework/tasks/benchmarks/arc.py +70 -0
  108. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  109. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  110. eval_framework/tasks/benchmarks/belebele.py +60 -0
  111. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  112. eval_framework/tasks/benchmarks/casehold.py +47 -0
  113. eval_framework/tasks/benchmarks/chembench.py +85 -0
  114. eval_framework/tasks/benchmarks/copa.py +64 -0
  115. eval_framework/tasks/benchmarks/duc.py +91 -0
  116. eval_framework/tasks/benchmarks/flores200.py +133 -0
  117. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  118. eval_framework/tasks/benchmarks/gpqa.py +201 -0
  119. eval_framework/tasks/benchmarks/gsm8k.py +150 -0
  120. eval_framework/tasks/benchmarks/hellaswag.py +69 -0
  121. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  122. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  123. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  124. eval_framework/tasks/benchmarks/include.py +119 -0
  125. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  126. eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
  127. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  128. eval_framework/tasks/benchmarks/mmlu.py +215 -0
  129. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  130. eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
  131. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  132. eval_framework/tasks/benchmarks/openbookqa.py +85 -0
  133. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  134. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  135. eval_framework/tasks/benchmarks/piqa.py +64 -0
  136. eval_framework/tasks/benchmarks/quality.py +56 -0
  137. eval_framework/tasks/benchmarks/sciq.py +110 -0
  138. eval_framework/tasks/benchmarks/sphyr.py +79 -0
  139. eval_framework/tasks/benchmarks/squad.py +211 -0
  140. eval_framework/tasks/benchmarks/struct_eval.py +116 -0
  141. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  142. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  143. eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
  144. eval_framework/tasks/benchmarks/winogender.py +64 -0
  145. eval_framework/tasks/benchmarks/winogrande.py +69 -0
  146. eval_framework/tasks/benchmarks/winox.py +57 -0
  147. eval_framework/tasks/benchmarks/wmt.py +160 -0
  148. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  149. eval_framework/tasks/eval_config.py +136 -0
  150. eval_framework/tasks/perturbation.py +83 -0
  151. eval_framework/tasks/registry.py +186 -0
  152. eval_framework/tasks/task_loader.py +81 -0
  153. eval_framework/tasks/task_names.py +324 -0
  154. eval_framework/tasks/utils.py +584 -0
  155. eval_framework/utils/constants.py +9 -0
  156. eval_framework/utils/file_ops.py +245 -0
  157. eval_framework/utils/generate_task_docs.py +244 -0
  158. eval_framework/utils/helpers.py +32 -0
  159. eval_framework/utils/logging.py +62 -0
  160. eval_framework/utils/packaging.py +52 -0
  161. eval_framework/utils/tqdm_handler.py +14 -0
  162. eval_framework-0.2.7.dist-info/METADATA +548 -0
  163. eval_framework-0.2.7.dist-info/RECORD +170 -0
  164. eval_framework-0.2.7.dist-info/WHEEL +4 -0
  165. eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
  166. template_formatting/README.md +83 -0
  167. template_formatting/__init__.py +0 -0
  168. template_formatting/formatter.py +537 -0
  169. template_formatting/mistral_formatter.py +159 -0
  170. template_formatting/py.typed +0 -0
@@ -0,0 +1,584 @@
1
+ import base64
2
+ import logging
3
+ import os
4
+ import random
5
+ import re
6
+ import string
7
+ from collections.abc import Callable
8
+ from pathlib import Path
9
+ from typing import Any, Literal, NamedTuple
10
+
11
+ import dill
12
+ import numpy as np
13
+ from llm_sandbox import SandboxSession
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ RANDOM_SEED = 42 # hacky way to get around circular import
18
+ redis_warning_printed = False
19
+
20
+
21
+ def raise_errors() -> bool:
22
+ debug = os.environ.get("DEBUG", "FALSE").lower()
23
+ if debug in {"1", "true"}:
24
+ return True
25
+ elif debug in {"0", "false"}:
26
+ return False
27
+ else:
28
+ raise ValueError(f"Invalid value for DEBUG environment variable: {debug}. Use one of 1, 0, true, false.")
29
+
30
+
31
+ def get_n_letters(n: int) -> list[str]:
32
+ return list(string.ascii_uppercase)[: max(0, n)]
33
+
34
+
35
+ def run_python_code(
36
+ code: str,
37
+ image: str | None = None,
38
+ input_files: list[tuple[str, str]] | None = None,
39
+ timeout: int = 60,
40
+ packages: list[str] | None = None,
41
+ ) -> str:
42
+ """
43
+ Run code in a sandboxed environment.
44
+ :param code: The code to run.
45
+ :param image: Docker image to use.
46
+ :param input_files: pairs of host and docker paths, host files will be copied to the docker.
47
+ :param timeout: Timeout in seconds, 0 if no timeout.
48
+ :param packages: List of python packages to install with pip.
49
+ :return: The output of the code.
50
+ """
51
+ with SandboxSession(lang="python", image=image, keep_template=True, commit_container=False) as session:
52
+ for host_file, docker_file in input_files or []:
53
+ session.copy_to_runtime(host_file, docker_file)
54
+
55
+ if timeout > 0: # hack-add timeout from coreutils to the command executed
56
+ session.orig_execute_command = session.execute_command
57
+ session.execute_command = lambda command: session.orig_execute_command(f"timeout {timeout} {command}")
58
+
59
+ return session.run(code, libraries=packages).text.strip()
60
+
61
+
62
+ def unittest_merge_snippets(code: str, test_code: str) -> str:
63
+ # Add unittest.main() if not present (note that without "if" sometimes it just reports
64
+ # "Ran 0 tests" errorneously).
65
+ if "unittest.main(" not in test_code:
66
+ test_code += "\n\nif __name__ == '__main__':\n unittest.main()"
67
+
68
+ # Combine the implementation code and test code
69
+ combined_code = code + "\n\n" + test_code
70
+ return combined_code
71
+
72
+
73
+ class ExecutionResult(NamedTuple):
74
+ """
75
+ A named tuple to store the result of code execution.
76
+
77
+ Attributes:
78
+ success (bool): Indicates if the execution was successful.
79
+ output (str): Contains the output or error messages from the execution.
80
+ """
81
+
82
+ success: bool
83
+ output: str
84
+
85
+
86
+ def execute_python_code_with_tests(
87
+ code: str,
88
+ test_code: str,
89
+ package_mapping: dict[str, str | None],
90
+ merge_code_fn: Callable[[str, str], str],
91
+ image: str | None,
92
+ timeout: int,
93
+ parse_output_fn: Callable[[str], ExecutionResult],
94
+ ) -> ExecutionResult:
95
+ """
96
+ Executes the given code with test cases in a sandboxed environment.
97
+
98
+ :param code: The code to be tested.
99
+ :param test_code: The test cases to run against the code.
100
+ :param package_mapping: Mapping of package names to install commands.
101
+ :param merge_code_fn: function to merge LLM and test code
102
+ :param image: Docker image to use.
103
+ :param timeout: Timeout for the execution in seconds.
104
+ :param parse_otuput_fn: function to parse docker execution output
105
+ :return: An ExecutionResult named tuple with success status and output or errors.
106
+ """
107
+ combined_code = merge_code_fn(code, test_code)
108
+
109
+ packages = get_external_dependencies(combined_code, package_mapping)
110
+
111
+ # Run the combined code in the sandbox
112
+ output = run_python_code(combined_code, image=image, timeout=timeout, packages=packages)
113
+
114
+ # Parse the output to determine success
115
+ return parse_output_fn(output)
116
+
117
+
118
+ class SerializationError(Exception):
119
+ """Base exception for callable serialization errors."""
120
+
121
+ pass
122
+
123
+
124
+ class EncodingError(SerializationError):
125
+ """Raised when encoding a callable fails."""
126
+
127
+ pass
128
+
129
+
130
+ class DecodingError(SerializationError):
131
+ """Raised when decoding a callable fails."""
132
+
133
+ pass
134
+
135
+
136
+ class CallableSerializer:
137
+ @staticmethod
138
+ def encode(fn: Callable[..., Any]) -> str:
139
+ try:
140
+ serialized = dill.dumps(fn)
141
+ return base64.b64encode(serialized).decode("utf-8")
142
+ except Exception as e:
143
+ raise EncodingError(f"Failed to encode callable {fn}: {e}") from e
144
+
145
+ @staticmethod
146
+ def decode(fn_str: str) -> Callable[..., Any]:
147
+ try:
148
+ decoded = base64.b64decode(fn_str.encode("utf-8"))
149
+ return dill.loads(decoded)
150
+ except Exception as e:
151
+ raise DecodingError(f"Failed to decode callable from string: {e}") from e
152
+
153
+
154
+ def _parse_unittest_output(output: str) -> ExecutionResult:
155
+ """Parse the unittest output to determine success and format the result."""
156
+ # Check for unittest success pattern
157
+ if "OK" in output and "FAILED" not in output:
158
+ # Extract the test summary if possible
159
+ match = re.search(r"Ran (\d+) tests? in [\d.]+s", output)
160
+ if match:
161
+ test_count = match.group(1)
162
+ test_output = f"All {test_count} tests completed successfully."
163
+ else:
164
+ test_output = "All tests completed successfully."
165
+
166
+ return ExecutionResult(True, test_output)
167
+
168
+ # Check for unittest failure pattern
169
+ elif "FAILED" in output:
170
+ # Try to extract failure details
171
+ match = re.search(r"FAILED \((.+)\)", output)
172
+ if match:
173
+ failure_details = match.group(1)
174
+ return ExecutionResult(False, f"Tests failed: {failure_details}\n{output}")
175
+ else:
176
+ return ExecutionResult(False, f"Tests failed: {output}")
177
+
178
+ # Check for common error patterns
179
+ elif "AssertionError" in output:
180
+ return ExecutionResult(False, f"Test failed with assertion error: {output}")
181
+ elif "Error:" in output or "Exception:" in output:
182
+ return ExecutionResult(False, f"Error during execution: {output}")
183
+
184
+ # If we can't determine success/failure, return the raw output
185
+ return ExecutionResult(False, f"Could not determine test results, potentially due to timeout. Output: {output}")
186
+
187
+
188
+ def get_external_dependencies(code: str, package_mapping: dict[str, str | None]) -> list[str]:
189
+ """Identify external dependencies in the code."""
190
+ _, packages = extract_imports(code)
191
+
192
+ external_packages = []
193
+ for pkg in packages:
194
+ if pkg in package_mapping and package_mapping[pkg] is not None:
195
+ external_packages.append(package_mapping[pkg])
196
+ return external_packages # type: ignore[return-value]
197
+
198
+
199
+ def extract_imports(code: str) -> tuple[list[str], set[str]]:
200
+ """Extract all import statements and the imported packages from code."""
201
+ # Pattern for 'import x' or 'import x, y, z'
202
+ import_pattern = r"^import\s+([\w\s,.]+)"
203
+
204
+ # Pattern for 'from x import y'
205
+ from_pattern = r"^from\s+([\w.]+)\s+import\s+"
206
+
207
+ imports = []
208
+ packages = set()
209
+
210
+ for line in code.split("\n"):
211
+ line = line.strip()
212
+
213
+ # Skip empty lines
214
+ if not line:
215
+ continue
216
+
217
+ # Handle 'import x' or 'import x, y, z'
218
+ import_match = re.match(import_pattern, line)
219
+ if import_match:
220
+ imports.append(line)
221
+ # Extract all packages from the import statement
222
+ imported_items = import_match.group(1).split(",")
223
+ for item in imported_items:
224
+ # Clean up and get the base package name
225
+ pkg = item.strip().split(".")[0].split(" as ")[0]
226
+ if pkg:
227
+ packages.add(pkg)
228
+ continue
229
+
230
+ # Handle 'from x import y'
231
+ from_match = re.match(from_pattern, line)
232
+ if from_match:
233
+ imports.append(line)
234
+ # Get the base package name
235
+ pkg = from_match.group(1).split(".")[0]
236
+ if pkg:
237
+ packages.add(pkg)
238
+
239
+ return imports, packages
240
+
241
+
242
+ def get_docker_address() -> str:
243
+ # If it's docker-in-docker: the new docker actually started in host, so we need to use the host's IP
244
+ # See https://stackoverflow.com/questions/48546124/what-is-the-linux-equivalent-of-host-docker-internal
245
+ return "172.17.0.1" if Path("/.dockerenv").exists() else "localhost"
246
+
247
+
248
+ class Editor:
249
+ def __init__(self, language: Literal["en", "de"] = "en", seed: int = RANDOM_SEED) -> None:
250
+ self.np_rng = np.random.RandomState(seed)
251
+ self.rng = random.Random(seed)
252
+ if language == "en":
253
+ self.letters = string.ascii_lowercase
254
+ elif language == "de":
255
+ self.letters = string.ascii_lowercase + "ßöäü"
256
+ else:
257
+ raise NotImplementedError
258
+
259
+ @staticmethod
260
+ def _split_sentence(sentence: str) -> tuple[list[str], list[str], bool]:
261
+ words = re.findall(r"\w+", sentence)
262
+ spaces = re.findall(r"[^\w]+", sentence)
263
+ has_leading_space = not words or sentence[: len(words[0])] != words[0]
264
+ return words, spaces, has_leading_space
265
+
266
+ @staticmethod
267
+ def _recombine(words: list[str], spaces: list[str], has_leading_space: bool) -> str:
268
+ if has_leading_space:
269
+ combined_lists = sum([[w, s] for w, s in zip(words, spaces[1:])], [spaces[0]])
270
+ else:
271
+ combined_lists = sum([[w, s] for w, s in zip(words, spaces)], [])
272
+ if len(words) > len(spaces) - (1 if has_leading_space else 0):
273
+ combined_lists.append(words[-1])
274
+ return "".join(combined_lists)
275
+
276
+ @staticmethod
277
+ def _get_word_probs(words: list[str]) -> np.ndarray:
278
+ # We sample words proportional to their length - 1,
279
+ # This means we ignore one-character words such as "I" and "a",
280
+ # because these can't be transposed or split
281
+ lengths = np.array([len(word) - 1 for word in words])
282
+ probs = lengths / np.sum(lengths)
283
+ return probs
284
+
285
+ @staticmethod
286
+ def _transpose(word: str, idx1: int, idx2: int) -> str:
287
+ assert abs(idx2 - idx1) == 1, "idx1 and idx2 are not next to each other"
288
+ if idx1 > idx2:
289
+ idx1, idx2 = idx2, idx1
290
+ return word[:idx1] + word[idx2] + word[idx1] + word[idx2 + 1 :]
291
+
292
+ @staticmethod
293
+ def _delete(word: str, idx: int) -> str:
294
+ return word[:idx] + word[idx + 1 :]
295
+
296
+ @staticmethod
297
+ def _insert(word: str, idx: int, letter: str) -> str:
298
+ assert len(letter) == 1, "`letter` is not a single character"
299
+ return word[:idx] + letter + word[idx:]
300
+
301
+ @staticmethod
302
+ def _change_casing(word: str, idx: int) -> str:
303
+ character = word[idx]
304
+ if character.islower():
305
+ character = character.upper()
306
+ else:
307
+ character = character.lower()
308
+ return word[:idx] + character + word[idx + 1 :]
309
+
310
+ @staticmethod
311
+ def _split_word(word: str, idx: int) -> str:
312
+ return word[:idx] + " " + word[idx:]
313
+
314
+ def _edit_word(self, word: str, num_edits: int) -> str:
315
+ # NB: It could be that two edits cancel each other out
316
+ # but the chance of this is sufficiently small that it doesn't
317
+ # make sense to complicate the code to fix this
318
+ if num_edits == 0:
319
+ return word
320
+
321
+ for _ in range(num_edits):
322
+ # upweighted change casing
323
+ choices = ["insert", "change_casing", "change_casing"]
324
+ if len(word) > 1:
325
+ choices.extend(["transpose", "split_word"])
326
+ if len(word) > 4:
327
+ # use delete more sparingly since it has a big impact
328
+ choices.extend(["delete"])
329
+
330
+ edit_function = self.rng.choice(choices)
331
+ if edit_function == "transpose":
332
+ idx = self.rng.randint(0, len(word) - 2)
333
+ word = self._transpose(word, idx, idx + 1)
334
+ elif edit_function == "delete":
335
+ idx = self.rng.randint(1, len(word) - 2)
336
+ word = self._delete(word, idx)
337
+ elif edit_function == "insert":
338
+ idx = self.rng.randint(0, len(word) - 1)
339
+ letter = self.rng.choice(self.letters)
340
+ word = self._insert(word, idx, letter)
341
+ elif edit_function == "change_casing":
342
+ idx = self.rng.randint(0, len(word) - 1)
343
+ word = self._change_casing(word, idx)
344
+ elif edit_function == "split_word":
345
+ idx = self.rng.randint(1, len(word) - 1)
346
+ word = self._split_word(word, idx)
347
+
348
+ return word
349
+
350
+ def __call__(self, sentence: str, character_edit_change: float, unmodifiable_words: list[str] | None = None) -> str:
351
+ words, spaces, has_leading_space = self._split_sentence(sentence)
352
+
353
+ num_characters = sum(map(len, words))
354
+ num_edits = int(num_characters * character_edit_change)
355
+ if num_edits == 0:
356
+ return sentence
357
+
358
+ probs = self._get_word_probs(words)
359
+ edits_per_word = self.np_rng.multinomial(num_edits, probs)
360
+ unmodifiable_words_set = set([w.lower() for w in unmodifiable_words or []])
361
+ edited_words = []
362
+ for edits, word in zip(edits_per_word, words):
363
+ if word.lower() not in unmodifiable_words_set:
364
+ edited_words.append(self._edit_word(word, int(edits)))
365
+ else:
366
+ edited_words.append(word)
367
+ return self._recombine(edited_words, spaces, has_leading_space)
368
+
369
+
370
+ class HatPaperEditor:
371
+ # Used for Section 4.4 in the HAT paper (https://openreview.net/pdf?id=tU074jg2vS).
372
+
373
+ def __init__(self, seed: int = RANDOM_SEED) -> None:
374
+ self.rng = random.Random(seed)
375
+
376
+ def _get_indices(self, input_text: str, pct: float, unmodifiable_words: list[str] | None = None) -> list[int]:
377
+ indices = [
378
+ i + 1
379
+ for i, c in enumerate(input_text[1:-1])
380
+ if c.isalnum() and input_text[i].isalnum() and input_text[i + 2].isalnum()
381
+ ]
382
+ for word in unmodifiable_words or []:
383
+ for match in re.finditer(r"\b" + word + r"\b", input_text, re.IGNORECASE):
384
+ indices = [i for i in indices if i < match.start(0) or i >= match.end(0)]
385
+ return self.rng.sample(indices, int(len(indices) * pct))
386
+
387
+ def permute_chars_in_string(
388
+ self, input_text: str, permute_pct: float, unmodifiable_words: list[str] | None = None
389
+ ) -> str:
390
+ """
391
+ Randomly permute permute_pct characters in the input string.
392
+
393
+ Only permutes within words (whitespaces and first word chars are preserved).
394
+ """
395
+ chars_to_permute = self._get_indices(input_text, permute_pct, unmodifiable_words)
396
+ permuted_text = list(input_text)
397
+ for char_index in chars_to_permute:
398
+ permuted_text[char_index], permuted_text[char_index + 1] = (
399
+ permuted_text[char_index + 1],
400
+ permuted_text[char_index],
401
+ )
402
+ return "".join(permuted_text)
403
+
404
+ def replace_chars_in_string(
405
+ self, input_text: str, replace_pct: float, unmodifiable_words: list[str] | None = None
406
+ ) -> str:
407
+ """
408
+ Randomly replace replace_pct characters in the input string with replace_char.
409
+
410
+ Only replaces within words (whitespaces and first and last word chars are preserved).
411
+ """
412
+ chars_to_replace = self._get_indices(input_text, replace_pct, unmodifiable_words)
413
+ replaced_text = list(input_text)
414
+ for char_index in chars_to_replace:
415
+ replace_char = chr(self.rng.randint(33, 126)) # ASCII printable characters
416
+ replaced_text[char_index] = replace_char
417
+ return "".join(replaced_text)
418
+
419
+ def delete_chars_in_string(
420
+ self, input_text: str, delete_pct: float, unmodifiable_words: list[str] | None = None
421
+ ) -> str:
422
+ """
423
+ Randomly delete delete_pct characters in the input string.
424
+
425
+ Only deletes within words (whitespaces and first and last word chars are preserved).
426
+ """
427
+ chars_to_delete = self._get_indices(input_text, delete_pct, unmodifiable_words)
428
+ deleted_text = list(input_text)
429
+ for char_index in chars_to_delete:
430
+ deleted_text[char_index] = "" # do not delete list entry since then the length of the list changes
431
+ return "".join(deleted_text)
432
+
433
+ def upper_case_string(self, input_text: str) -> str:
434
+ """
435
+ Upper case all characters in the input string.
436
+ """
437
+ return input_text.upper()
438
+
439
+
440
+ # these are all the packages that occur in the BigCodeBench dataset
441
+ BIG_CODE_BENCH_PACKAGE_MAPPING = {
442
+ # Standard library packages (built-in)
443
+ "array": None,
444
+ "ast": None,
445
+ "base64": None,
446
+ "binascii": None,
447
+ "bisect": None,
448
+ "calendar": None,
449
+ "cgi": None,
450
+ "cmath": None,
451
+ "codecs": None,
452
+ "collections": None,
453
+ "configparser": None,
454
+ "csv": None,
455
+ "ctypes": None,
456
+ "datetime": None,
457
+ "decimal": None,
458
+ "difflib": None,
459
+ "email": None,
460
+ "enum": None,
461
+ "errno": None,
462
+ "fnmatch": None,
463
+ "ftplib": None,
464
+ "functools": None,
465
+ "getpass": None,
466
+ "glob": None,
467
+ "gzip": None,
468
+ "hashlib": None,
469
+ "heapq": None,
470
+ "hmac": None,
471
+ "html": None,
472
+ "http": None,
473
+ "importlib": None,
474
+ "inspect": None,
475
+ "io": None,
476
+ "ipaddress": None,
477
+ "itertools": None,
478
+ "json": None,
479
+ "logging": None,
480
+ "math": None,
481
+ "mimetypes": None,
482
+ "multiprocessing": None,
483
+ "operator": None,
484
+ "os": None,
485
+ "pathlib": None,
486
+ "pickle": None,
487
+ "pkgutil": None,
488
+ "platform": None,
489
+ "queue": None,
490
+ "random": None,
491
+ "re": None,
492
+ "select": None,
493
+ "secrets": None,
494
+ "shlex": None,
495
+ "shutil": None,
496
+ "signal": None,
497
+ "smtplib": None,
498
+ "socket": None,
499
+ "sqlite3": None,
500
+ "ssl": None,
501
+ "statistics": None,
502
+ "string": None,
503
+ "struct": None,
504
+ "subprocess": None,
505
+ "sys": None,
506
+ "tarfile": None,
507
+ "textwrap": None,
508
+ "threading": None,
509
+ "time": None,
510
+ "turtle": None,
511
+ "types": None,
512
+ "typing": None,
513
+ "unicodedata": None,
514
+ "urllib": None,
515
+ "uuid": None,
516
+ "warnings": None,
517
+ "xml": None,
518
+ "zipfile": None,
519
+ "zlib": None,
520
+ "zoneinfo": None,
521
+ # External packages (need pip install)
522
+ "PIL": "pillow",
523
+ "Crypto": "pycryptodome",
524
+ "Levenshtein": "python-Levenshtein",
525
+ "blake3": "blake3",
526
+ "bs4": "beautifulsoup4",
527
+ "chardet": "chardet",
528
+ "cryptography": "cryptography",
529
+ "cv2": "opencv-python",
530
+ "dateutil": "python-dateutil",
531
+ "django": "django",
532
+ "docx": "python-docx",
533
+ "faker": "Faker",
534
+ "flask": "flask",
535
+ "flask_login": "flask-login",
536
+ "flask_mail": "flask-mail",
537
+ "flask_restful": "flask-restful",
538
+ "flask_wtf": "flask-wtf",
539
+ "folium": "folium",
540
+ "gensim": "gensim",
541
+ "geopandas": "geopandas",
542
+ "geopy": "geopy",
543
+ "holidays": "holidays",
544
+ "keras": "keras",
545
+ "librosa": "librosa",
546
+ "lxml": "lxml",
547
+ "matplotlib": "matplotlib",
548
+ "mechanize": "mechanize",
549
+ "mpl_toolkits": "matplotlib",
550
+ "natsort": "natsort",
551
+ "nltk": "nltk",
552
+ "numpy": "numpy",
553
+ "openpyxl": "openpyxl",
554
+ "pandas": "pandas",
555
+ "prettytable": "prettytable",
556
+ "psutil": "psutil",
557
+ "pyquery": "pyquery",
558
+ "pytesseract": "pytesseract",
559
+ "python_http_client": "python-http-client",
560
+ "pytz": "pytz",
561
+ "regex": "regex",
562
+ "requests": "requests",
563
+ "rsa": "rsa",
564
+ "scipy": "scipy",
565
+ "seaborn": "seaborn",
566
+ "sendgrid": "sendgrid",
567
+ "shapely": "shapely",
568
+ "skimage": "scikit-image",
569
+ "sklearn": "scikit-learn",
570
+ "soundfile": "soundfile",
571
+ "statsmodels": "statsmodels",
572
+ "sympy": "sympy",
573
+ "tensorflow": "tensorflow",
574
+ "textblob": "textblob",
575
+ "texttable": "texttable",
576
+ "werkzeug": "werkzeug",
577
+ "wikipedia": "wikipedia",
578
+ "wordcloud": "wordcloud",
579
+ "wordninja": "wordninja",
580
+ "wtforms": "wtforms",
581
+ "xlwt": "xlwt",
582
+ "xmltodict": "xmltodict",
583
+ "yaml": "pyyaml",
584
+ }
@@ -0,0 +1,9 @@
1
+ from pathlib import Path
2
+
3
+ RED = "\033[91m"
4
+ YELLOW = "\033[93m"
5
+ MAGENTA = "\033[1;35;40m"
6
+ RESET = "\033[0m"
7
+ GREEN = "\033[92m"
8
+
9
+ ROOT_DIR = Path(__file__).parents[2]