langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. langfun/__init__.py +20 -2
  2. langfun/core/__init__.py +16 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -21
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +63 -2
  18. langfun/core/component_test.py +53 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +16 -1
  24. langfun/core/eval/base.py +622 -174
  25. langfun/core/eval/base_test.py +200 -54
  26. langfun/core/eval/matching.py +63 -76
  27. langfun/core/eval/matching_test.py +17 -8
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +26 -26
  31. langfun/core/eval/scoring_test.py +19 -2
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +4 -17
  55. langfun/core/langfunc_test.py +22 -6
  56. langfun/core/language_model.py +577 -39
  57. langfun/core/language_model_test.py +470 -56
  58. langfun/core/llms/__init__.py +87 -16
  59. langfun/core/llms/anthropic.py +312 -87
  60. langfun/core/llms/anthropic_test.py +71 -3
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +53 -2
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +11 -7
  69. langfun/core/llms/fake_test.py +14 -0
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -202
  74. langfun/core/llms/groq.py +160 -144
  75. langfun/core/llms/groq_test.py +31 -137
  76. langfun/core/llms/llama_cpp.py +15 -42
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +395 -203
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +30 -395
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -26
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +12 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +7 -6
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +60 -27
  112. langfun/core/structured/function_generation_test.py +72 -2
  113. langfun/core/structured/mapping.py +97 -47
  114. langfun/core/structured/mapping_test.py +90 -2
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +53 -9
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
  119. langfun/core/structured/schema.py +204 -97
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_test.py +130 -29
  122. langfun/core/structured/scoring.py +125 -19
  123. langfun/core/structured/scoring_test.py +30 -0
  124. langfun/core/structured/tokenization.py +64 -0
  125. langfun/core/structured/tokenization_test.py +48 -0
  126. langfun/core/template.py +115 -1
  127. langfun/core/template_test.py +71 -1
  128. langfun/core/templates/conversation.py +9 -0
  129. langfun/core/templates/conversation_test.py +4 -3
  130. langfun/core/templates/selfplay_test.py +10 -2
  131. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  132. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  133. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  134. langfun/core/coding/python/errors.py +0 -108
  135. langfun/core/coding/python/errors_test.py +0 -99
  136. langfun/core/coding/python/permissions.py +0 -90
  137. langfun/core/coding/python/permissions_test.py +0 -86
  138. langfun/core/structured/prompting.py +0 -238
  139. langfun/core/text_formatting.py +0 -162
  140. langfun/core/text_formatting_test.py +0 -47
  141. langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
  142. langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
  143. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  144. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -13,35 +13,18 @@
13
13
  # limitations under the License.
14
14
  """Video modality."""
15
15
 
16
- import base64
17
- from typing import cast
16
+ import functools
18
17
  from langfun.core.modalities import mime
19
18
 
20
19
 
21
- class Video(mime.MimeType):
22
- """Base class for Video."""
20
+ class Video(mime.Mime):
21
+ """Video."""
23
22
 
24
- @property
25
- def video_format(self) -> str:
26
- return cast(str, self.mime_type.lstrip('video/'))
27
-
28
- @property
29
- def mime_type(self) -> str:
30
- # TODO(daiyip): after cl/619658455, LaunchPad binaries cannot import `magic`
31
- # correctly. This is to mitigate the issue for major Langfun users who do
32
- # not use Video. We shall move this import out once the issue is fixed.
33
- import magic # pylint: disable=g-import-not-at-top
23
+ MIME_PREFIX = 'video'
34
24
 
35
- video_mime_type = magic.from_buffer(self.to_bytes(), mime=True)
36
- if 'video/' not in video_mime_type:
37
- raise ValueError(f'Not a video: {video_mime_type!r}.')
38
- return video_mime_type
25
+ @functools.cached_property
26
+ def video_format(self) -> str:
27
+ return self.mime_type.removeprefix(self.MIME_PREFIX + '/')
39
28
 
40
- def _repr_html_(self) -> str:
41
- if self.uri and self.uri.lower().startswith(('http:', 'https:', 'ftp:')):
42
- return f'<video controls> <source src="{self.uri}"> </video>'
43
- video_raw = base64.b64encode(self.to_bytes()).decode()
44
- return (
45
- '<video controls> <source'
46
- f' src="data:video/{self.video_format};base64,{video_raw}"> </video>'
47
- )
29
+ def _mime_control_for(self, uri: str) -> str:
30
+ return f'<video controls> <source src="{uri}"> </video>'
@@ -38,12 +38,12 @@ class VideoContentTest(unittest.TestCase):
38
38
  video = video_lib.Video.from_bytes(mp4_bytes)
39
39
  self.assertEqual(video.mime_type, 'video/mp4')
40
40
  self.assertEqual(video.video_format, 'mp4')
41
- self.assertIn('data:video/mp4;base64,', video._repr_html_())
41
+ self.assertIn('data:video/mp4;base64,', video._raw_html())
42
42
  self.assertEqual(video.to_bytes(), mp4_bytes)
43
43
 
44
44
  def test_bad_video(self):
45
45
  video = video_lib.Video.from_bytes(b'bad')
46
- with self.assertRaisesRegex(ValueError, 'Not a video'):
46
+ with self.assertRaisesRegex(ValueError, 'Expected MIME type'):
47
47
  _ = video.video_format
48
48
 
49
49
 
@@ -56,7 +56,7 @@ class VideoFileTest(unittest.TestCase):
56
56
  self.assertEqual(video.video_format, 'mp4')
57
57
  self.assertEqual(video.mime_type, 'video/mp4')
58
58
  self.assertEqual(
59
- video._repr_html_(),
59
+ video._raw_html(),
60
60
  '<video controls> <source src="http://mock/web/a.mp4"> </video>',
61
61
  )
62
62
  self.assertEqual(video.to_bytes(), mp4_bytes)
langfun/core/modality.py CHANGED
@@ -14,6 +14,8 @@
14
14
  """Interface for modality (e.g. Image, Video, etc.)."""
15
15
 
16
16
  import abc
17
+ import functools
18
+ import hashlib
17
19
  from typing import Any, ContextManager
18
20
  from langfun.core import component
19
21
  import pyglove as pg
@@ -29,11 +31,16 @@ def format_modality_as_ref(enabled: bool = True) -> ContextManager[None]:
29
31
  )
30
32
 
31
33
 
32
- class Modality(component.Component):
34
+ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
33
35
  """Base class for multimodal object."""
34
36
 
35
- REF_START = '{{'
36
- REF_END = '}}'
37
+ REF_START = '<<[['
38
+ REF_END = ']]>>'
39
+
40
+ def _on_bound(self):
41
+ super()._on_bound()
42
+ # Invalidate cached hash if modality member is changed.
43
+ self.__dict__.pop('hash', None)
37
44
 
38
45
  def format(self, *args, **kwargs) -> str:
39
46
  if self.referred_name is None or not pg.object_utils.thread_local_get(
@@ -42,10 +49,22 @@ class Modality(component.Component):
42
49
  return super().format(*args, **kwargs)
43
50
  return Modality.text_marker(self.referred_name)
44
51
 
52
+ def __str_kwargs__(self) -> dict[str, Any]:
53
+ # For modality objects, we don't want to use markdown format when they
54
+ # are rendered as parts of the prompt.
55
+ kwargs = super().__str_kwargs__()
56
+ kwargs.pop('markdown', None)
57
+ return kwargs
58
+
45
59
  @abc.abstractmethod
46
60
  def to_bytes(self) -> bytes:
47
61
  """Returns content in bytes."""
48
62
 
63
+ @functools.cached_property
64
+ def hash(self) -> str:
65
+ """Returns a 8-byte MD5 hash as the identifier for this modality object."""
66
+ return hashlib.md5(self.to_bytes()).hexdigest()[:8]
67
+
49
68
  @classmethod
50
69
  def text_marker(cls, var_name: str) -> str:
51
70
  """Returns a marker in the text for this object."""
@@ -108,3 +127,7 @@ class ModalityRef(pg.Object, pg.typing.CustomTyping):
108
127
  return ModalityRef(name=value.sym_path + k)
109
128
  return v
110
129
  return value.clone().rebind(_placehold, raise_on_no_change=False)
130
+
131
+
132
+ class ModalityError(RuntimeError): # pylint: disable=g-bad-exception-name
133
+ """Exception raised when modality is not supported."""
@@ -11,7 +11,6 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Tests for modality."""
15
14
  from typing import Any
16
15
  import unittest
17
16
 
@@ -32,12 +31,13 @@ class ModalityTest(unittest.TestCase):
32
31
  v = CustomModality('a')
33
32
  self.assertIsNone(v.referred_name)
34
33
  self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
34
+ self.assertEqual(v.hash, '0cc175b9')
35
35
 
36
36
  _ = pg.Dict(metadata=pg.Dict(x=pg.Dict(metadata=pg.Dict(y=v))))
37
37
  self.assertEqual(v.referred_name, 'x.metadata.y')
38
38
  self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
39
39
  with modality.format_modality_as_ref():
40
- self.assertEqual(str(v), '{{x.metadata.y}}')
40
+ self.assertEqual(str(v), '<<[[x.metadata.y]]>>')
41
41
 
42
42
 
43
43
  class ModalityRefTest(unittest.TestCase):
langfun/core/sampling.py CHANGED
@@ -28,14 +28,14 @@ def sweep(
28
28
  *,
29
29
  max_workers: int = 32,
30
30
  silence_on_errors: Union[
31
- Type[Exception], Tuple[Type[Exception]], None
31
+ Type[BaseException], Tuple[Type[BaseException], ...], None
32
32
  ] = None,
33
33
  ignore_examples_with_errors: bool = True,
34
34
  **kwargs,
35
35
  ) -> Iterator[
36
36
  Tuple[
37
- message_lib.Message | Exception, # LM input.
38
- Union[message_lib.Message, Exception, None], # LM output.
37
+ message_lib.Message | BaseException, # LM input.
38
+ Union[message_lib.Message, BaseException, None], # LM output.
39
39
  ],
40
40
  ]:
41
41
  """Sweeps the input/output of this LangFunc concurrently.
@@ -73,15 +73,15 @@ def random_sample(
73
73
  *,
74
74
  max_workers: int = 32,
75
75
  silence_on_errors: Union[
76
- Type[Exception], Tuple[Type[Exception]], None
76
+ Type[BaseException], Tuple[Type[BaseException], ...], None
77
77
  ] = None,
78
78
  ignore_examples_with_errors: bool = True,
79
79
  seed: int | None = None,
80
80
  **kwargs,
81
81
  ) -> Iterator[
82
82
  Tuple[
83
- message_lib.Message | Exception, # LM input.
84
- Union[message_lib.Message, Exception, None], # LM output.
83
+ message_lib.Message | BaseException, # LM input.
84
+ Union[message_lib.Message, BaseException, None], # LM output.
85
85
  ],
86
86
  ]:
87
87
  """Random samples the input/output of this LangFunc concurrently.
@@ -121,14 +121,14 @@ def _concurrent_sample(
121
121
  *,
122
122
  max_workers: int = 32,
123
123
  silence_on_errors: Union[
124
- Type[Exception], Tuple[Type[Exception]], None
124
+ Type[BaseException], Tuple[Type[BaseException], ...], None
125
125
  ] = None,
126
126
  ignore_examples_with_errors: bool = True,
127
127
  **kwargs,
128
128
  ) -> Generator[
129
129
  Tuple[
130
- message_lib.Message | Exception, # LM input.
131
- Union[message_lib.Message, Exception, None], # LM output.
130
+ message_lib.Message | BaseException, # LM input.
131
+ Union[message_lib.Message, BaseException, None], # LM output.
132
132
  ],
133
133
  None,
134
134
  None, # Sender type and return type.
@@ -177,6 +177,6 @@ def _concurrent_sample(
177
177
  else:
178
178
  lm_input, lm_output = error, error
179
179
  if (not ignore_examples_with_errors
180
- or not (isinstance(lm_input, Exception)
181
- or isinstance(lm_output, Exception))):
180
+ or not (isinstance(lm_input, BaseException)
181
+ or isinstance(lm_output, BaseException))):
182
182
  yield lm_input, lm_output
@@ -16,6 +16,8 @@
16
16
  # pylint: disable=g-bad-import-order
17
17
  # pylint: disable=g-importing-member
18
18
 
19
+ from langfun.core.structured.schema import include_method_in_prompt
20
+
19
21
  from langfun.core.structured.schema import Missing
20
22
  from langfun.core.structured.schema import MISSING
21
23
  from langfun.core.structured.schema import Unknown
@@ -34,12 +36,6 @@ from langfun.core.structured.schema import class_definitions
34
36
  from langfun.core.structured.schema import annotation
35
37
  from langfun.core.structured.schema import structure_from_python
36
38
 
37
- from langfun.core.structured.schema import SchemaRepr
38
- from langfun.core.structured.schema import SchemaJsonRepr
39
- from langfun.core.structured.schema import SchemaPythonRepr
40
- from langfun.core.structured.schema import ValueRepr
41
- from langfun.core.structured.schema import ValueJsonRepr
42
- from langfun.core.structured.schema import ValuePythonRepr
43
39
  from langfun.core.structured.schema import schema_repr
44
40
  from langfun.core.structured.schema import source_form
45
41
  from langfun.core.structured.schema import value_repr
@@ -54,25 +50,25 @@ from langfun.core.structured.mapping import Mapping
54
50
  from langfun.core.structured.mapping import MappingError
55
51
  from langfun.core.structured.mapping import MappingExample
56
52
 
57
- from langfun.core.structured.parsing import ParseStructure
58
- from langfun.core.structured.parsing import ParseStructureJson
59
- from langfun.core.structured.parsing import ParseStructurePython
60
53
  from langfun.core.structured.parsing import parse
61
54
  from langfun.core.structured.parsing import call
62
55
 
63
- from langfun.core.structured.prompting import QueryStructure
64
- from langfun.core.structured.prompting import QueryStructureJson
65
- from langfun.core.structured.prompting import QueryStructurePython
66
- from langfun.core.structured.prompting import query
56
+ from langfun.core.structured.querying import track_queries
57
+ from langfun.core.structured.querying import QueryInvocation
58
+ from langfun.core.structured.querying import query
59
+ from langfun.core.structured.querying import query_and_reduce
67
60
 
68
- from langfun.core.structured.description import DescribeStructure
69
- from langfun.core.structured.description import describe
61
+ from langfun.core.structured.querying import query_prompt
62
+ from langfun.core.structured.querying import query_output
63
+ from langfun.core.structured.querying import query_reward
70
64
 
71
- from langfun.core.structured.completion import CompleteStructure
65
+ from langfun.core.structured.description import describe
72
66
  from langfun.core.structured.completion import complete
73
67
 
74
68
  from langfun.core.structured.scoring import score
75
69
 
70
+ from langfun.core.structured.tokenization import tokenize
71
+
76
72
  # Expose default examples for structured operations so users could refer to
77
73
  # them.
78
74
  from langfun.core.structured.parsing import default_parse_examples
@@ -21,7 +21,7 @@ from langfun.core.structured import schema as schema_lib
21
21
  import pyglove as pg
22
22
 
23
23
 
24
- class CompleteStructure(mapping.Mapping):
24
+ class _CompleteStructure(mapping.Mapping):
25
25
  """Complete structure by filling the missing fields."""
26
26
 
27
27
  input: Annotated[
@@ -30,7 +30,7 @@ class CompleteStructure(mapping.Mapping):
30
30
 
31
31
  mapping_template = lf.Template("""
32
32
  {{ input_title }}:
33
- {{ example.input_repr() | indent(2, True) }}
33
+ {{ example.input_repr(use_modality_ref=True) | indent(2, True) }}
34
34
 
35
35
  {%- if missing_type_dependencies(example.input) %}
36
36
 
@@ -45,13 +45,16 @@ class CompleteStructure(mapping.Mapping):
45
45
 
46
46
  {{ output_title }}:
47
47
  {%- if example.has_output %}
48
- {{ example.output_repr() | indent(2, True) }}
48
+ {{ example.output_repr(use_modality_ref=True) | indent(2, True) }}
49
49
  {% endif -%}
50
50
  """)
51
51
 
52
52
  input_title = 'INPUT_OBJECT'
53
53
  output_title = 'OUTPUT_OBJECT'
54
54
  schema_title = 'CLASS_DEFINITIONS'
55
+ modality_refs_title: Annotated[
56
+ str, 'The section title for modality refs.'
57
+ ] = 'MODALITY_REFERENCES'
55
58
 
56
59
  preamble = lf.LangFunc(
57
60
  """
@@ -107,7 +110,9 @@ class CompleteStructure(mapping.Mapping):
107
110
 
108
111
  def class_defs_repr(self, value: Any) -> str | None:
109
112
  return schema_lib.class_definitions(
110
- self.missing_type_dependencies(value), markdown=True
113
+ self.missing_type_dependencies(value),
114
+ markdown=True,
115
+ allowed_dependencies=set()
111
116
  )
112
117
 
113
118
  def postprocess_result(self, result: Any) -> Any:
@@ -146,6 +151,28 @@ class CompleteStructure(mapping.Mapping):
146
151
  pg.traverse(self.input, _visit)
147
152
  return context
148
153
 
154
+ #
155
+ # Helper methods for handling modalities.
156
+ #
157
+
158
+ def has_modality_refs(self, value: Any) -> bool:
159
+ """Returns true if the value has modalities."""
160
+ return not isinstance(value, lf.Modality) and pg.contains(
161
+ value, type=lf.Modality
162
+ )
163
+
164
+ def modalities(self, value: Any) -> dict[str, lf.Modality]:
165
+ return lf.Modality.from_value(value)
166
+
167
+ def modality_refs_repr(self, value: Any) -> str:
168
+ with lf.modality.format_modality_as_ref(True):
169
+ return pg.format(
170
+ self.modalities(value),
171
+ compact=False,
172
+ verbose=False,
173
+ python_format=True,
174
+ )
175
+
149
176
 
150
177
  def complete(
151
178
  input_value: pg.Symbolic,
@@ -214,7 +241,7 @@ def complete(
214
241
  Returns:
215
242
  The result based on the schema.
216
243
  """
217
- t = CompleteStructure(
244
+ t = _CompleteStructure(
218
245
  input=schema_lib.mark_missing(input_value),
219
246
  default=default,
220
247
  examples=examples,
@@ -46,7 +46,7 @@ class TripPlan(pg.Object):
46
46
  class CompleteStructureTest(unittest.TestCase):
47
47
 
48
48
  def test_render_no_examples(self):
49
- l = completion.CompleteStructure()
49
+ l = completion._CompleteStructure()
50
50
  input_value = schema_lib.mark_missing(
51
51
  TripPlan.partial(
52
52
  place='San Francisco',
@@ -120,7 +120,7 @@ class CompleteStructureTest(unittest.TestCase):
120
120
  )
121
121
 
122
122
  def test_render_no_class_definitions(self):
123
- l = completion.CompleteStructure()
123
+ l = completion._CompleteStructure()
124
124
  input_value = schema_lib.mark_missing(
125
125
  TripPlan.partial(
126
126
  place='San Francisco',
@@ -200,7 +200,7 @@ class CompleteStructureTest(unittest.TestCase):
200
200
  )
201
201
 
202
202
  def test_render_with_examples(self):
203
- l = completion.CompleteStructure()
203
+ l = completion._CompleteStructure()
204
204
  input_value = schema_lib.mark_missing(
205
205
  TripPlan.partial(
206
206
  place='San Francisco',
@@ -411,7 +411,7 @@ class CompleteStructureTest(unittest.TestCase):
411
411
  modalities.Image.from_bytes(b'image_of_elephant'),
412
412
  )
413
413
  )
414
- l = completion.CompleteStructure(
414
+ l = completion._CompleteStructure(
415
415
  input=input_value,
416
416
  examples=[
417
417
  mapping.MappingExample(
@@ -464,7 +464,7 @@ class CompleteStructureTest(unittest.TestCase):
464
464
 
465
465
  MODALITY_REFERENCES:
466
466
  {
467
- 'examples[0].input.image': {{examples[0].input.image}}
467
+ 'examples[0].input.image': <<[[examples[0].input.image]]>>
468
468
  }
469
469
 
470
470
  OUTPUT_OBJECT:
@@ -490,7 +490,7 @@ class CompleteStructureTest(unittest.TestCase):
490
490
 
491
491
  MODALITY_REFERENCES:
492
492
  {
493
- 'input.image': {{input.image}}
493
+ 'input.image': <<[[input.image]]>>
494
494
  }
495
495
 
496
496
  OUTPUT_OBJECT:
@@ -581,6 +581,7 @@ class CompleteStructureTest(unittest.TestCase):
581
581
  text='Activity(description="foo")',
582
582
  result=Activity(description='foo'),
583
583
  score=1.0,
584
+ is_cached=False,
584
585
  logprobs=None,
585
586
  usage=lf.LMSamplingUsage(553, 27, 580),
586
587
  tags=['lm-response', 'lm-output', 'transformed']
@@ -22,7 +22,7 @@ import pyglove as pg
22
22
 
23
23
 
24
24
  @pg.use_init_args(['examples'])
25
- class DescribeStructure(mapping.Mapping):
25
+ class _DescribeStructure(mapping.Mapping):
26
26
  """Describe a structured value in natural language."""
27
27
 
28
28
  input_title = 'PYTHON_OBJECT'
@@ -106,7 +106,7 @@ def describe(
106
106
  Returns:
107
107
  The parsed result based on the schema.
108
108
  """
109
- return DescribeStructure(
109
+ return _DescribeStructure(
110
110
  input=value,
111
111
  context=context,
112
112
  examples=examples or default_describe_examples(),
@@ -36,7 +36,7 @@ class Itinerary(pg.Object):
36
36
  class DescribeStructureTest(unittest.TestCase):
37
37
 
38
38
  def test_render(self):
39
- l = description_lib.DescribeStructure(
39
+ l = description_lib._DescribeStructure(
40
40
  input=Itinerary(
41
41
  day=1,
42
42
  type='daytime',
@@ -137,7 +137,7 @@ class DescribeStructureTest(unittest.TestCase):
137
137
  ],
138
138
  hotel=None,
139
139
  )
140
- l = description_lib.DescribeStructure(
140
+ l = description_lib._DescribeStructure(
141
141
  input=value, context='1 day itinerary to SF'
142
142
  )
143
143
  self.assertEqual(
@@ -187,7 +187,7 @@ class DescribeStructureTest(unittest.TestCase):
187
187
  ],
188
188
  hotel=None,
189
189
  )
190
- l = description_lib.DescribeStructure(input=value)
190
+ l = description_lib._DescribeStructure(input=value)
191
191
  self.assertEqual(
192
192
  l.render().text,
193
193
  inspect.cleandoc("""
@@ -16,16 +16,16 @@
16
16
  import functools
17
17
  import inspect
18
18
  import re
19
- from typing import Any, Callable, Optional, Tuple
19
+ from typing import Any, Callable, Literal, Optional, Tuple
20
20
 
21
21
  from langfun.core import language_model
22
22
  from langfun.core import template
23
23
  from langfun.core.coding import python
24
- from langfun.core.structured import prompting
24
+ from langfun.core.structured import querying
25
25
  import pyglove as pg
26
26
 
27
27
 
28
- def unittest_gen(signature, lm, num_retries=10):
28
+ def unittest_gen(signature, lm, num_retries=1):
29
29
  """Generates unit tests for a python function signature."""
30
30
 
31
31
  class UnitTest(pg.Object):
@@ -39,7 +39,7 @@ def unittest_gen(signature, lm, num_retries=10):
39
39
 
40
40
  unittest_examples = None
41
41
  for _ in range(num_retries):
42
- r = prompting.query(
42
+ r = querying.query(
43
43
  PythonFunctionSignature(signature=signature),
44
44
  list[UnitTest],
45
45
  lm=lm,
@@ -76,12 +76,16 @@ def unittest_with_test_cases(f, unittests):
76
76
 
77
77
  def _function_gen(
78
78
  func: Callable[..., Any],
79
+ context: dict[str, Any],
79
80
  signature: str,
80
81
  lm: language_model.LanguageModel,
81
- num_retries: int = 10,
82
+ num_retries: int = 1,
82
83
  unittest: Optional[
83
- Callable[[Callable[..., Any]], None] | list[Tuple[Any, Any]]
84
+ Callable[[Callable[..., Any]], None]
85
+ | list[Tuple[Any, Any]]
86
+ | Literal["auto"]
84
87
  ] = None,
88
+ unittest_num_retries: int = 1,
85
89
  ):
86
90
  """Generates a python function with LLM and verify its quality with unit testing."""
87
91
 
@@ -131,32 +135,43 @@ def _function_gen(
131
135
  """
132
136
 
133
137
  unittest_examples = None
134
- if unittest is None:
135
- unittest_examples = unittest_gen(signature, lm=lm)
136
- elif not callable(unittest):
138
+ if unittest == "auto":
139
+ unittest_examples = unittest_gen(
140
+ signature, lm=lm, num_retries=unittest_num_retries
141
+ )
142
+ elif isinstance(unittest, list):
137
143
  unittest_examples = unittest
138
144
 
145
+ last_error = None
139
146
  for _ in range(num_retries):
140
147
  try:
141
- source_code = prompting.query(
148
+ source_code = querying.query(
142
149
  PythonFunctionPrompt(signature=signature), lm=lm
143
150
  )
144
- f = python.evaluate(source_code)
151
+ f = python.evaluate(source_code, global_vars=context)
145
152
 
146
153
  # Check whether the sigantures are the same.
147
154
  if inspect.signature(f) != inspect.signature(func):
148
- continue
155
+ raise python.CodeError(
156
+ code=source_code,
157
+ cause=TypeError(
158
+ f"Signature mismatch: Expected: {inspect.signature(func)}, "
159
+ f"Actual: {inspect.signature(f)}.",
160
+ ),
161
+ )
149
162
 
150
163
  if callable(unittest):
151
164
  unittest(f)
152
- else:
165
+ elif unittest_examples:
153
166
  unittest_with_test_cases(f, unittest_examples)
154
167
 
155
168
  return f, source_code
156
- except Exception: # pylint: disable=broad-exception-caught
157
- pass
158
-
159
- return None, None
169
+ except python.CodeError as e:
170
+ last_error = e
171
+ pg.logging.warning(
172
+ f"Bad code generated: {e}",
173
+ )
174
+ raise last_error
160
175
 
161
176
 
162
177
  def _process_signature(signature):
@@ -172,10 +187,13 @@ def _process_signature(signature):
172
187
  def function_gen(
173
188
  lm: language_model.LanguageModel,
174
189
  cache_filename: str | None = None,
175
- num_retries: int = 10,
190
+ num_retries: int = 1,
176
191
  unittest: Optional[
177
- Callable[[Callable[..., Any]], None] | list[Tuple[Any, Any]]
192
+ Callable[[Callable[..., Any]], None]
193
+ | list[Tuple[Any, Any]]
194
+ | Literal["auto"]
178
195
  ] = None,
196
+ unittest_num_retries: int = 1,
179
197
  ):
180
198
  """A decorator for automating function generation using a language model.
181
199
 
@@ -192,9 +210,12 @@ def function_gen(
192
210
  make to generate a suitable function implementation.
193
211
  unittest: This optional parameter enables the definition of custom unit
194
212
  tests. You can either provide a list of test cases as tuples of inputs
195
- and outputs, or a function that throws an error if a test fails. If left
196
- as None (the default setting), the LLM will automatically create the
197
- unit test cases.
213
+ and outputs, or a function that throws an error if a test fails, or let
214
+ LLM automatically create the unit test cases. If a generated function is
215
+ and returned, it should pass all the unittests.
216
+ unittest_num_retries: If unittest is set to "auto", this parameter
217
+ specifies the number of times the LLM's attempts to generate unit test
218
+ cases.
198
219
 
199
220
  Returns:
200
221
  The implemented function object.
@@ -204,6 +225,13 @@ def function_gen(
204
225
  setattr(func, "__function__", None)
205
226
  setattr(func, "__source_code__", None)
206
227
 
228
+ # Prepare the globals/locals for the generated code to be evaluated against.
229
+ callstack = inspect.stack()
230
+ assert len(callstack) > 1
231
+ context = dict(callstack[1][0].f_globals)
232
+ context.update(callstack[1][0].f_locals)
233
+ context.pop(func.__name__, None)
234
+
207
235
  @functools.wraps(func)
208
236
  def lm_generated_func(*args, **kwargs):
209
237
  if func.__function__ is not None:
@@ -222,15 +250,20 @@ def function_gen(
222
250
 
223
251
  if signature in cache:
224
252
  func.__source_code__ = cache[signature]
225
- func.__function__ = python.evaluate(func.__source_code__)
253
+ func.__function__ = python.evaluate(
254
+ func.__source_code__, global_vars=context
255
+ )
226
256
  return func.__function__(*args, **kwargs)
227
257
 
228
258
  func.__function__, func.__source_code__ = _function_gen(
229
- func, signature, lm, num_retries=num_retries, unittest=unittest
259
+ func,
260
+ context,
261
+ signature,
262
+ lm,
263
+ num_retries=num_retries,
264
+ unittest=unittest,
265
+ unittest_num_retries=unittest_num_retries,
230
266
  )
231
- if func.__function__ is None:
232
- raise ValueError(f"Function generation failed. Signature:\n{signature}")
233
-
234
267
  if cache_filename is not None:
235
268
  cache[signature] = func.__source_code__
236
269
  cache.save(cache_filename)