langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. langfun/__init__.py +22 -2
  2. langfun/core/__init__.py +17 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -28
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +69 -2
  18. langfun/core/component_test.py +54 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +17 -0
  24. langfun/core/eval/base.py +767 -140
  25. langfun/core/eval/base_test.py +238 -53
  26. langfun/core/eval/matching.py +80 -76
  27. langfun/core/eval/matching_test.py +19 -9
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +37 -28
  31. langfun/core/eval/scoring_test.py +21 -3
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +3 -21
  55. langfun/core/langfunc_test.py +26 -8
  56. langfun/core/language_model.py +686 -48
  57. langfun/core/language_model_test.py +681 -44
  58. langfun/core/llms/__init__.py +100 -12
  59. langfun/core/llms/anthropic.py +488 -0
  60. langfun/core/llms/anthropic_test.py +235 -0
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +88 -28
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +39 -26
  69. langfun/core/llms/fake_test.py +136 -11
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -197
  74. langfun/core/llms/groq.py +276 -0
  75. langfun/core/llms/groq_test.py +64 -0
  76. langfun/core/llms/llama_cpp.py +15 -40
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +436 -226
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +35 -174
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -23
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +15 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +9 -8
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +278 -0
  112. langfun/core/structured/function_generation_test.py +399 -0
  113. langfun/core/structured/mapping.py +150 -46
  114. langfun/core/structured/mapping_test.py +105 -0
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +71 -22
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
  119. langfun/core/structured/schema.py +208 -99
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_generation_test.py +2 -2
  122. langfun/core/structured/schema_test.py +133 -34
  123. langfun/core/structured/scoring.py +125 -19
  124. langfun/core/structured/scoring_test.py +30 -0
  125. langfun/core/structured/tokenization.py +64 -0
  126. langfun/core/structured/tokenization_test.py +48 -0
  127. langfun/core/template.py +240 -11
  128. langfun/core/template_test.py +146 -1
  129. langfun/core/templates/conversation.py +9 -0
  130. langfun/core/templates/conversation_test.py +4 -3
  131. langfun/core/templates/selfplay_test.py +14 -2
  132. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  133. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  134. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  135. langfun/core/coding/python/errors.py +0 -108
  136. langfun/core/coding/python/errors_test.py +0 -99
  137. langfun/core/coding/python/permissions.py +0 -90
  138. langfun/core/coding/python/permissions_test.py +0 -86
  139. langfun/core/structured/prompting.py +0 -217
  140. langfun/core/text_formatting.py +0 -162
  141. langfun/core/text_formatting_test.py +0 -47
  142. langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
  143. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  144. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  145. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,278 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """LLM-based function generation."""
15
+
16
+ import functools
17
+ import inspect
18
+ import re
19
+ from typing import Any, Callable, Literal, Optional, Tuple
20
+
21
+ from langfun.core import language_model
22
+ from langfun.core import template
23
+ from langfun.core.coding import python
24
+ from langfun.core.structured import querying
25
+ import pyglove as pg
26
+
27
+
28
+ def unittest_gen(signature, lm, num_retries=1):
29
+ """Generates unit tests for a python function signature."""
30
+
31
+ class UnitTest(pg.Object):
32
+ """A valid unit test for a python function."""
33
+
34
+ input: dict[str, Any]
35
+ expected_output: Any
36
+
37
+ class PythonFunctionSignature(pg.Object):
38
+ signature: str
39
+
40
+ unittest_examples = None
41
+ for _ in range(num_retries):
42
+ r = querying.query(
43
+ PythonFunctionSignature(signature=signature),
44
+ list[UnitTest],
45
+ lm=lm,
46
+ default=None,
47
+ )
48
+ if isinstance(r, list) and r:
49
+ unittest_examples = []
50
+ for unit_test in r:
51
+ unittest_examples.append((unit_test.input, unit_test.expected_output))
52
+ break
53
+
54
+ return unittest_examples
55
+
56
+
57
+ def unittest_with_test_cases(f, unittests):
58
+ """Applies unit tests to a python function to be tested."""
59
+ if not unittests:
60
+ raise ValueError(f"No unit tests provided: {unittests}")
61
+
62
+ for unit_test in unittests:
63
+ inputs = unit_test[0]
64
+ if isinstance(inputs, dict):
65
+ actual = f(**inputs)
66
+ elif isinstance(inputs, tuple):
67
+ actual = f(*inputs)
68
+ else:
69
+ actual = f(inputs)
70
+
71
+ expected = unit_test[1]
72
+ assert (
73
+ actual == expected
74
+ ), f"Test FAILED: Inputs: {inputs}, Expected: {expected}, Actual: {actual}"
75
+
76
+
77
+ def _function_gen(
78
+ func: Callable[..., Any],
79
+ context: dict[str, Any],
80
+ signature: str,
81
+ lm: language_model.LanguageModel,
82
+ num_retries: int = 1,
83
+ unittest: Optional[
84
+ Callable[[Callable[..., Any]], None]
85
+ | list[Tuple[Any, Any]]
86
+ | Literal["auto"]
87
+ ] = None,
88
+ unittest_num_retries: int = 1,
89
+ ):
90
+ """Generates a python function with LLM and verify its quality with unit testing."""
91
+
92
+ class PythonFunctionPrompt(template.Template):
93
+ r"""A template for a python function generation.
94
+
95
+ Please reply to the last PYTHON_FUNCTION_SIGNATURE with a self-sufficient,
96
+ error-free, and efficiently coded PYTHON_FUNCTION, crafted to the standards
97
+ of a world-class programmer.
98
+
99
+ PYTHON_FUNCTION_SIGNATURE:
100
+ ```python
101
+ def calculate_area_circle(radius: float) -> float:
102
+ \"\"\"Calculates the area of a circle given its radius.
103
+
104
+ Args:
105
+ radius: The radius of the circle.
106
+
107
+ Returns:
108
+ The area of the circle.
109
+ \"\"\"
110
+ ```
111
+
112
+ PYTHON_FUNCTION:
113
+ ```python
114
+ def calculate_area_circle(radius: float) -> float:
115
+ \"\"\"Calculates the area of a circle given its radius.
116
+
117
+ Args:
118
+ radius: The radius of the circle.
119
+
120
+ Returns:
121
+ The area of the circle.
122
+ \"\"\"
123
+ import math
124
+
125
+ area = math.pi * radius**2
126
+ return area
127
+ ```
128
+
129
+ PYTHON_FUNCTION_SIGNATURE:
130
+ ```python
131
+ {{signature}}
132
+ ```
133
+
134
+ PYTHON_FUNCTION:
135
+ """
136
+
137
+ unittest_examples = None
138
+ if unittest == "auto":
139
+ unittest_examples = unittest_gen(
140
+ signature, lm=lm, num_retries=unittest_num_retries
141
+ )
142
+ elif isinstance(unittest, list):
143
+ unittest_examples = unittest
144
+
145
+ last_error = None
146
+ for _ in range(num_retries):
147
+ try:
148
+ source_code = querying.query(
149
+ PythonFunctionPrompt(signature=signature), lm=lm
150
+ )
151
+ f = python.evaluate(source_code, global_vars=context)
152
+
153
+ # Check whether the sigantures are the same.
154
+ if inspect.signature(f) != inspect.signature(func):
155
+ raise python.CodeError(
156
+ code=source_code,
157
+ cause=TypeError(
158
+ f"Signature mismatch: Expected: {inspect.signature(func)}, "
159
+ f"Actual: {inspect.signature(f)}.",
160
+ ),
161
+ )
162
+
163
+ if callable(unittest):
164
+ unittest(f)
165
+ elif unittest_examples:
166
+ unittest_with_test_cases(f, unittest_examples)
167
+
168
+ return f, source_code
169
+ except python.CodeError as e:
170
+ last_error = e
171
+ pg.logging.warning(
172
+ f"Bad code generated: {e}",
173
+ )
174
+ raise last_error
175
+
176
+
177
+ def _process_signature(signature):
178
+ # Remove the decorator.
179
+ pattern = r"^\@.*function_gen.*$"
180
+ signature = re.sub(pattern, "", signature, flags=re.MULTILINE)
181
+ # Remove the possible 'pass' in an empty function.
182
+ pattern = r"^\s*pass\s*$"
183
+ signature = re.sub(pattern, "", signature, flags=re.MULTILINE)
184
+ return signature.strip()
185
+
186
+
187
+ def function_gen(
188
+ lm: language_model.LanguageModel,
189
+ cache_filename: str | None = None,
190
+ num_retries: int = 1,
191
+ unittest: Optional[
192
+ Callable[[Callable[..., Any]], None]
193
+ | list[Tuple[Any, Any]]
194
+ | Literal["auto"]
195
+ ] = None,
196
+ unittest_num_retries: int = 1,
197
+ ):
198
+ """A decorator for automating function generation using a language model.
199
+
200
+ This decorator should be applied to functions that are not yet implemented. It
201
+ facilitates the implementation via the specified LLM, ensuring
202
+ quality through unit tests.
203
+
204
+ Args:
205
+ lm (lf.LanguageModel): The language model used for generating function
206
+ implementations.
207
+ cache_filename (str | None): Optional. The path of the file where
208
+ generated function implementations are loaded from or saved to.
209
+ num_retries (int): Maximum number of attempts the language model should
210
+ make to generate a suitable function implementation.
211
+ unittest: This optional parameter enables the definition of custom unit
212
+ tests. You can either provide a list of test cases as tuples of inputs
213
+ and outputs, or a function that throws an error if a test fails, or let
214
+ LLM automatically create the unit test cases. If a generated function is
215
+ and returned, it should pass all the unittests.
216
+ unittest_num_retries: If unittest is set to "auto", this parameter
217
+ specifies the number of times the LLM's attempts to generate unit test
218
+ cases.
219
+
220
+ Returns:
221
+ The implemented function object.
222
+ """
223
+
224
+ def _decorate(func):
225
+ setattr(func, "__function__", None)
226
+ setattr(func, "__source_code__", None)
227
+
228
+ # Prepare the globals/locals for the generated code to be evaluated against.
229
+ callstack = inspect.stack()
230
+ assert len(callstack) > 1
231
+ context = dict(callstack[1][0].f_globals)
232
+ context.update(callstack[1][0].f_locals)
233
+ context.pop(func.__name__, None)
234
+
235
+ @functools.wraps(func)
236
+ def lm_generated_func(*args, **kwargs):
237
+ if func.__function__ is not None:
238
+ return func.__function__(*args, **kwargs)
239
+
240
+ signature = _process_signature(inspect.getsource(func))
241
+ cache = pg.Dict()
242
+ if cache_filename is not None:
243
+ try:
244
+ cache = pg.load(cache_filename)
245
+ except FileNotFoundError:
246
+ pg.logging.warning(
247
+ "Creating a new cache as cache file '%s' does not exist.",
248
+ cache_filename,
249
+ )
250
+
251
+ if signature in cache:
252
+ func.__source_code__ = cache[signature]
253
+ func.__function__ = python.evaluate(
254
+ func.__source_code__, global_vars=context
255
+ )
256
+ return func.__function__(*args, **kwargs)
257
+
258
+ func.__function__, func.__source_code__ = _function_gen(
259
+ func,
260
+ context,
261
+ signature,
262
+ lm,
263
+ num_retries=num_retries,
264
+ unittest=unittest,
265
+ unittest_num_retries=unittest_num_retries,
266
+ )
267
+ if cache_filename is not None:
268
+ cache[signature] = func.__source_code__
269
+ cache.save(cache_filename)
270
+ return func.__function__(*args, **kwargs)
271
+
272
+ lm_generated_func.__name__ = func.__name__
273
+ lm_generated_func.__qualname__ = func.__qualname__
274
+ lm_generated_func.__module__ = func.__module__
275
+ lm_generated_func.source = lambda: func.__source_code__
276
+ return lm_generated_func
277
+
278
+ return _decorate
@@ -0,0 +1,399 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import inspect
15
+ import os
16
+ import tempfile
17
+ import unittest
18
+ from langfun.core.llms import fake
19
+ from langfun.core.structured import function_generation
20
+ import pyglove as pg
21
+
22
+
23
+ class FunctionGenerationTest(unittest.TestCase):
24
+
25
+ def test_generate_function(self):
26
+ function_gen_lm_response = inspect.cleandoc("""
27
+ def linear_search(items, target):
28
+ \"\"\"
29
+ Performs a linear search on a list to find a target value.
30
+
31
+ Args:
32
+ items (list): The list to search within.
33
+ target: The value to search for.
34
+
35
+ Returns:
36
+ int: The index of the target value if found, otherwise -1.
37
+ \"\"\"
38
+ for i, item in enumerate(items):
39
+ if item == target:
40
+ return i
41
+ return -1
42
+ """)
43
+ unittest_lm_response = inspect.cleandoc("""
44
+ ```python
45
+ [
46
+ UnitTest(
47
+ input={
48
+ 'items': [1, 2, 3, 4, 5],
49
+ 'target': 3
50
+ },
51
+ expected_output=2
52
+ ),
53
+ UnitTest(
54
+ input={
55
+ 'items': [1, 2, 3, 4, 5],
56
+ 'target': 6
57
+ },
58
+ expected_output=-1
59
+ )
60
+ ]
61
+ ```
62
+ """)
63
+
64
+ lm = fake.StaticSequence([unittest_lm_response, function_gen_lm_response])
65
+
66
+ @function_generation.function_gen(lm=lm, unittest='auto')
67
+ def linear_search(items, target): # pylint: disable=unused-argument
68
+ """Performs a linear search on a list to find a target value.
69
+
70
+ Args:
71
+ items (list): The list to search within.
72
+ target: The value to search for.
73
+
74
+ Returns:
75
+ int: The index of the target value if found, otherwise -1.
76
+ """
77
+
78
+ self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
79
+ self.assertEqual(linear_search.source(), function_gen_lm_response)
80
+
81
+ def test_generate_function_without_unittest(self):
82
+ function_gen_lm_response = inspect.cleandoc("""
83
+ def linear_search(items, target):
84
+ \"\"\"
85
+ Performs a linear search on a list to find a target value.
86
+
87
+ Args:
88
+ items (list): The list to search within.
89
+ target: The value to search for.
90
+
91
+ Returns:
92
+ int: The index of the target value if found, otherwise -1.
93
+ \"\"\"
94
+ for i, item in enumerate(items):
95
+ if item == target:
96
+ return i
97
+ return -1
98
+ """)
99
+
100
+ lm = fake.StaticSequence([function_gen_lm_response])
101
+
102
+ @function_generation.function_gen(lm=lm)
103
+ def linear_search(items, target): # pylint: disable=unused-argument
104
+ """Performs a linear search on a list to find a target value.
105
+
106
+ Args:
107
+ items (list): The list to search within.
108
+ target: The value to search for.
109
+
110
+ Returns:
111
+ int: The index of the target value if found, otherwise -1.
112
+ """
113
+
114
+ self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
115
+ self.assertEqual(linear_search.source(), function_gen_lm_response)
116
+
117
+ def test_custom_unittest_examples(self):
118
+ function_gen_lm_response = inspect.cleandoc("""
119
+ ```python
120
+ def linear_search(items, target):
121
+ \"\"\"
122
+ Performs a linear search on a list to find a target value.
123
+
124
+ Args:
125
+ items (list): The list to search within.
126
+ target: The value to search for.
127
+
128
+ Returns:
129
+ int: The index of the target value if found, otherwise -1.
130
+ \"\"\"
131
+ for i, item in enumerate(items):
132
+ if item == target:
133
+ return i
134
+ return -1
135
+ ```
136
+ """)
137
+
138
+ lm = fake.StaticSequence([function_gen_lm_response])
139
+
140
+ custom_unittest = [(([1, 2, 3, 4, 5], 3), 2)]
141
+
142
+ @function_generation.function_gen(lm=lm, unittest=custom_unittest)
143
+ def linear_search(items, target): # pylint: disable=unused-argument
144
+ """Performs a linear search on a list to find a target value.
145
+
146
+ Args:
147
+ items (list): The list to search within.
148
+ target: The value to search for.
149
+
150
+ Returns:
151
+ int: The index of the target value if found, otherwise -1.
152
+ """
153
+
154
+ self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
155
+
156
+ def test_custom_unittest_fn(self):
157
+ function_gen_lm_response = inspect.cleandoc("""
158
+ ```python
159
+ def linear_search(items, target):
160
+ \"\"\"
161
+ Performs a linear search on a list to find a target value.
162
+
163
+ Args:
164
+ items (list): The list to search within.
165
+ target: The value to search for.
166
+
167
+ Returns:
168
+ int: The index of the target value if found, otherwise -1.
169
+ \"\"\"
170
+ for i, item in enumerate(items):
171
+ if item == target:
172
+ return i
173
+ return -1
174
+ ```
175
+ """)
176
+
177
+ lm = fake.StaticSequence([function_gen_lm_response])
178
+
179
+ def _unittest_fn(func):
180
+ assert func([1, 2, 3, 4, 5], 3) == 2
181
+ assert func([1, 2, 3, 4, 5], 6) == -1
182
+
183
+ custom_unittest = _unittest_fn
184
+
185
+ @function_generation.function_gen(lm=lm, unittest=custom_unittest)
186
+ def linear_search(items, target): # pylint: disable=unused-argument
187
+ """Performs a linear search on a list to find a target value.
188
+
189
+ Args:
190
+ items (list): The list to search within.
191
+ target: The value to search for.
192
+
193
+ Returns:
194
+ int: The index of the target value if found, otherwise -1.
195
+ """
196
+
197
+ self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
198
+
199
+ def test_load_function_from_cache_file(self):
200
+ lm = fake.StaticSequence([])
201
+
202
+ def _unittest_fn(func):
203
+ assert func([1, 2, 3, 4, 5], 3) == 2
204
+ assert func([1, 2, 3, 4, 5], 6) == -1
205
+
206
+ cache_file_dir = tempfile.gettempdir()
207
+ cache_file = os.path.join(cache_file_dir, 'cache_file.json')
208
+
209
+ cache_key = """@function_generation.function_gen(
210
+ lm=lm,
211
+ unittest=_unittest_fn,
212
+ cache_filename=cache_file,
213
+ )
214
+ def linear_search(items, target): # pylint: disable=unused-argument
215
+ \"\"\"Performs a linear search on a list to find a target value.
216
+
217
+ Args:
218
+ items (list): The list to search within.
219
+ target: The value to search for.
220
+
221
+ Returns:
222
+ int: The index of the target value if found, otherwise -1.
223
+ \"\"\""""
224
+ cache_value = """
225
+ ```python
226
+ def linear_search(items, target):
227
+ \"\"\"
228
+ Performs a linear search on a list to find a target value.
229
+
230
+ Args:
231
+ items (list): The list to search within.
232
+ target: The value to search for.
233
+
234
+ Returns:
235
+ int: The index of the target value if found, otherwise -1.
236
+ \"\"\"
237
+ for i, item in enumerate(items):
238
+ if item == target:
239
+ return i
240
+ return -1
241
+ ```
242
+ """
243
+ cache = pg.Dict()
244
+ cache[cache_key] = cache_value
245
+ cache.save(cache_file)
246
+
247
+ @function_generation.function_gen(
248
+ lm=lm,
249
+ unittest=_unittest_fn,
250
+ cache_filename=cache_file,
251
+ )
252
+ def linear_search(items, target): # pylint: disable=unused-argument
253
+ """Performs a linear search on a list to find a target value.
254
+
255
+ Args:
256
+ items (list): The list to search within.
257
+ target: The value to search for.
258
+
259
+ Returns:
260
+ int: The index of the target value if found, otherwise -1.
261
+ """
262
+
263
+ self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
264
+ self.assertEqual(linear_search(['a', 'b', 'c'], 'd'), -1)
265
+
266
+ def test_empty_cache_file(self):
267
+ function_gen_lm_response = inspect.cleandoc("""
268
+ ```python
269
+ def linear_search(items, target):
270
+ \"\"\"
271
+ Performs a linear search on a list to find a target value.
272
+
273
+ Args:
274
+ items (list): The list to search within.
275
+ target: The value to search for.
276
+
277
+ Returns:
278
+ int: The index of the target value if found, otherwise -1.
279
+ \"\"\"
280
+ for i, item in enumerate(items):
281
+ if item == target:
282
+ return i
283
+ return -1
284
+ ```
285
+ """)
286
+
287
+ lm = fake.StaticSequence([function_gen_lm_response])
288
+
289
+ def _unittest_fn(func):
290
+ assert func([1, 2, 3, 4, 5], 3) == 2
291
+ assert func([1, 2, 3, 4, 5], 6) == -1
292
+
293
+ cache_file_dir = tempfile.gettempdir()
294
+ cache_file = os.path.join(cache_file_dir, 'cache_file.json')
295
+
296
+ @function_generation.function_gen(
297
+ lm=lm,
298
+ unittest=_unittest_fn,
299
+ cache_filename=cache_file,
300
+ )
301
+ def linear_search(items, target): # pylint: disable=unused-argument
302
+ """Performs a linear search on a list to find a target value.
303
+
304
+ Args:
305
+ items (list): The list to search within.
306
+ target: The value to search for.
307
+
308
+ Returns:
309
+ int: The index of the target value if found, otherwise -1.
310
+ """
311
+
312
+ self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
313
+
314
+ def test_context_passthrough(self):
315
+
316
+ class Number(pg.Object):
317
+ value: int
318
+
319
+ function_gen_lm_response = inspect.cleandoc("""
320
+ ```python
321
+ def add(a: Number, b: Number) -> Number:
322
+ \"\"\"Adds two numbers together.\"\"\"
323
+ return Number(a.value + b.value)
324
+ ```
325
+ """)
326
+
327
+ lm = fake.StaticSequence(
328
+ [function_gen_lm_response]
329
+ )
330
+
331
+ def _unittest_fn(func):
332
+ assert func(Number(1), Number(2)) == Number(3)
333
+
334
+ custom_unittest = _unittest_fn
335
+
336
+ @function_generation.function_gen(
337
+ lm=lm, unittest=custom_unittest, num_retries=1
338
+ )
339
+ def add(a: Number, b: Number) -> Number: # pylint: disable=unused-argument
340
+ """Adds two numbers together."""
341
+
342
+ self.assertEqual(add(Number(2), Number(3)), Number(5))
343
+
344
+ def test_siganture_check(self):
345
+ incorrect_signature_lm_response = inspect.cleandoc("""
346
+ ```python
347
+ def dummy():
348
+ pass
349
+ ```
350
+ """)
351
+ function_gen_lm_response = inspect.cleandoc("""
352
+ ```python
353
+ def linear_search(items, target):
354
+ \"\"\"
355
+ Performs a linear search on a list to find a target value.
356
+
357
+ Args:
358
+ items (list): The list to search within.
359
+ target: The value to search for.
360
+
361
+ Returns:
362
+ int: The index of the target value if found, otherwise -1.
363
+ \"\"\"
364
+ for i, item in enumerate(items):
365
+ if item == target:
366
+ return i
367
+ return -1
368
+ ```
369
+ """)
370
+
371
+ lm = fake.StaticSequence(
372
+ [incorrect_signature_lm_response, function_gen_lm_response]
373
+ )
374
+
375
+ def _unittest_fn(func):
376
+ assert func([1, 2, 3, 4, 5], 3) == 2
377
+ assert func([1, 2, 3, 4, 5], 6) == -1
378
+
379
+ custom_unittest = _unittest_fn
380
+
381
+ @function_generation.function_gen(
382
+ lm=lm, unittest=custom_unittest, num_retries=2
383
+ )
384
+ def linear_search(items, target): # pylint: disable=unused-argument
385
+ """Performs a linear search on a list to find a target value.
386
+
387
+ Args:
388
+ items (list): The list to search within.
389
+ target: The value to search for.
390
+
391
+ Returns:
392
+ int: The index of the target value if found, otherwise -1.
393
+ """
394
+
395
+ self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
396
+
397
+
398
+ if __name__ == '__main__':
399
+ unittest.main()