langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +20 -2
- langfun/core/__init__.py +16 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -21
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +63 -2
- langfun/core/component_test.py +53 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +16 -1
- langfun/core/eval/base.py +622 -174
- langfun/core/eval/base_test.py +200 -54
- langfun/core/eval/matching.py +63 -76
- langfun/core/eval/matching_test.py +17 -8
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +26 -26
- langfun/core/eval/scoring_test.py +19 -2
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +4 -17
- langfun/core/langfunc_test.py +22 -6
- langfun/core/language_model.py +577 -39
- langfun/core/language_model_test.py +470 -56
- langfun/core/llms/__init__.py +87 -16
- langfun/core/llms/anthropic.py +312 -87
- langfun/core/llms/anthropic_test.py +71 -3
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +53 -2
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +11 -7
- langfun/core/llms/fake_test.py +14 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -202
- langfun/core/llms/groq.py +160 -144
- langfun/core/llms/groq_test.py +31 -137
- langfun/core/llms/llama_cpp.py +15 -42
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +395 -203
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +30 -395
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -26
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +12 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +7 -6
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +60 -27
- langfun/core/structured/function_generation_test.py +72 -2
- langfun/core/structured/mapping.py +97 -47
- langfun/core/structured/mapping_test.py +90 -2
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +53 -9
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
- langfun/core/structured/schema.py +204 -97
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_test.py +130 -29
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +115 -1
- langfun/core/template_test.py +71 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +10 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -238
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
- langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
langfun/core/modalities/video.py
CHANGED
@@ -13,35 +13,18 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Video modality."""
|
15
15
|
|
16
|
-
import
|
17
|
-
from typing import cast
|
16
|
+
import functools
|
18
17
|
from langfun.core.modalities import mime
|
19
18
|
|
20
19
|
|
21
|
-
class Video(mime.
|
22
|
-
"""
|
20
|
+
class Video(mime.Mime):
|
21
|
+
"""Video."""
|
23
22
|
|
24
|
-
|
25
|
-
def video_format(self) -> str:
|
26
|
-
return cast(str, self.mime_type.lstrip('video/'))
|
27
|
-
|
28
|
-
@property
|
29
|
-
def mime_type(self) -> str:
|
30
|
-
# TODO(daiyip): after cl/619658455, LaunchPad binaries cannot import `magic`
|
31
|
-
# correctly. This is to mitigate the issue for major Langfun users who do
|
32
|
-
# not use Video. We shall move this import out once the issue is fixed.
|
33
|
-
import magic # pylint: disable=g-import-not-at-top
|
23
|
+
MIME_PREFIX = 'video'
|
34
24
|
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
return video_mime_type
|
25
|
+
@functools.cached_property
|
26
|
+
def video_format(self) -> str:
|
27
|
+
return self.mime_type.removeprefix(self.MIME_PREFIX + '/')
|
39
28
|
|
40
|
-
def
|
41
|
-
|
42
|
-
return f'<video controls> <source src="{self.uri}"> </video>'
|
43
|
-
video_raw = base64.b64encode(self.to_bytes()).decode()
|
44
|
-
return (
|
45
|
-
'<video controls> <source'
|
46
|
-
f' src="data:video/{self.video_format};base64,{video_raw}"> </video>'
|
47
|
-
)
|
29
|
+
def _mime_control_for(self, uri: str) -> str:
|
30
|
+
return f'<video controls> <source src="{uri}"> </video>'
|
@@ -38,12 +38,12 @@ class VideoContentTest(unittest.TestCase):
|
|
38
38
|
video = video_lib.Video.from_bytes(mp4_bytes)
|
39
39
|
self.assertEqual(video.mime_type, 'video/mp4')
|
40
40
|
self.assertEqual(video.video_format, 'mp4')
|
41
|
-
self.assertIn('data:video/mp4;base64,', video.
|
41
|
+
self.assertIn('data:video/mp4;base64,', video._raw_html())
|
42
42
|
self.assertEqual(video.to_bytes(), mp4_bytes)
|
43
43
|
|
44
44
|
def test_bad_video(self):
|
45
45
|
video = video_lib.Video.from_bytes(b'bad')
|
46
|
-
with self.assertRaisesRegex(ValueError, '
|
46
|
+
with self.assertRaisesRegex(ValueError, 'Expected MIME type'):
|
47
47
|
_ = video.video_format
|
48
48
|
|
49
49
|
|
@@ -56,7 +56,7 @@ class VideoFileTest(unittest.TestCase):
|
|
56
56
|
self.assertEqual(video.video_format, 'mp4')
|
57
57
|
self.assertEqual(video.mime_type, 'video/mp4')
|
58
58
|
self.assertEqual(
|
59
|
-
video.
|
59
|
+
video._raw_html(),
|
60
60
|
'<video controls> <source src="http://mock/web/a.mp4"> </video>',
|
61
61
|
)
|
62
62
|
self.assertEqual(video.to_bytes(), mp4_bytes)
|
langfun/core/modality.py
CHANGED
@@ -14,6 +14,8 @@
|
|
14
14
|
"""Interface for modality (e.g. Image, Video, etc.)."""
|
15
15
|
|
16
16
|
import abc
|
17
|
+
import functools
|
18
|
+
import hashlib
|
17
19
|
from typing import Any, ContextManager
|
18
20
|
from langfun.core import component
|
19
21
|
import pyglove as pg
|
@@ -29,11 +31,16 @@ def format_modality_as_ref(enabled: bool = True) -> ContextManager[None]:
|
|
29
31
|
)
|
30
32
|
|
31
33
|
|
32
|
-
class Modality(component.Component):
|
34
|
+
class Modality(component.Component, pg.views.HtmlTreeView.Extension):
|
33
35
|
"""Base class for multimodal object."""
|
34
36
|
|
35
|
-
REF_START = '
|
36
|
-
REF_END = '
|
37
|
+
REF_START = '<<[['
|
38
|
+
REF_END = ']]>>'
|
39
|
+
|
40
|
+
def _on_bound(self):
|
41
|
+
super()._on_bound()
|
42
|
+
# Invalidate cached hash if modality member is changed.
|
43
|
+
self.__dict__.pop('hash', None)
|
37
44
|
|
38
45
|
def format(self, *args, **kwargs) -> str:
|
39
46
|
if self.referred_name is None or not pg.object_utils.thread_local_get(
|
@@ -42,10 +49,22 @@ class Modality(component.Component):
|
|
42
49
|
return super().format(*args, **kwargs)
|
43
50
|
return Modality.text_marker(self.referred_name)
|
44
51
|
|
52
|
+
def __str_kwargs__(self) -> dict[str, Any]:
|
53
|
+
# For modality objects, we don't want to use markdown format when they
|
54
|
+
# are rendered as parts of the prompt.
|
55
|
+
kwargs = super().__str_kwargs__()
|
56
|
+
kwargs.pop('markdown', None)
|
57
|
+
return kwargs
|
58
|
+
|
45
59
|
@abc.abstractmethod
|
46
60
|
def to_bytes(self) -> bytes:
|
47
61
|
"""Returns content in bytes."""
|
48
62
|
|
63
|
+
@functools.cached_property
|
64
|
+
def hash(self) -> str:
|
65
|
+
"""Returns a 8-byte MD5 hash as the identifier for this modality object."""
|
66
|
+
return hashlib.md5(self.to_bytes()).hexdigest()[:8]
|
67
|
+
|
49
68
|
@classmethod
|
50
69
|
def text_marker(cls, var_name: str) -> str:
|
51
70
|
"""Returns a marker in the text for this object."""
|
@@ -108,3 +127,7 @@ class ModalityRef(pg.Object, pg.typing.CustomTyping):
|
|
108
127
|
return ModalityRef(name=value.sym_path + k)
|
109
128
|
return v
|
110
129
|
return value.clone().rebind(_placehold, raise_on_no_change=False)
|
130
|
+
|
131
|
+
|
132
|
+
class ModalityError(RuntimeError): # pylint: disable=g-bad-exception-name
|
133
|
+
"""Exception raised when modality is not supported."""
|
langfun/core/modality_test.py
CHANGED
@@ -11,7 +11,6 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for modality."""
|
15
14
|
from typing import Any
|
16
15
|
import unittest
|
17
16
|
|
@@ -32,12 +31,13 @@ class ModalityTest(unittest.TestCase):
|
|
32
31
|
v = CustomModality('a')
|
33
32
|
self.assertIsNone(v.referred_name)
|
34
33
|
self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
|
34
|
+
self.assertEqual(v.hash, '0cc175b9')
|
35
35
|
|
36
36
|
_ = pg.Dict(metadata=pg.Dict(x=pg.Dict(metadata=pg.Dict(y=v))))
|
37
37
|
self.assertEqual(v.referred_name, 'x.metadata.y')
|
38
38
|
self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
|
39
39
|
with modality.format_modality_as_ref():
|
40
|
-
self.assertEqual(str(v), '
|
40
|
+
self.assertEqual(str(v), '<<[[x.metadata.y]]>>')
|
41
41
|
|
42
42
|
|
43
43
|
class ModalityRefTest(unittest.TestCase):
|
langfun/core/sampling.py
CHANGED
@@ -28,14 +28,14 @@ def sweep(
|
|
28
28
|
*,
|
29
29
|
max_workers: int = 32,
|
30
30
|
silence_on_errors: Union[
|
31
|
-
Type[
|
31
|
+
Type[BaseException], Tuple[Type[BaseException], ...], None
|
32
32
|
] = None,
|
33
33
|
ignore_examples_with_errors: bool = True,
|
34
34
|
**kwargs,
|
35
35
|
) -> Iterator[
|
36
36
|
Tuple[
|
37
|
-
message_lib.Message |
|
38
|
-
Union[message_lib.Message,
|
37
|
+
message_lib.Message | BaseException, # LM input.
|
38
|
+
Union[message_lib.Message, BaseException, None], # LM output.
|
39
39
|
],
|
40
40
|
]:
|
41
41
|
"""Sweeps the input/output of this LangFunc concurrently.
|
@@ -73,15 +73,15 @@ def random_sample(
|
|
73
73
|
*,
|
74
74
|
max_workers: int = 32,
|
75
75
|
silence_on_errors: Union[
|
76
|
-
Type[
|
76
|
+
Type[BaseException], Tuple[Type[BaseException], ...], None
|
77
77
|
] = None,
|
78
78
|
ignore_examples_with_errors: bool = True,
|
79
79
|
seed: int | None = None,
|
80
80
|
**kwargs,
|
81
81
|
) -> Iterator[
|
82
82
|
Tuple[
|
83
|
-
message_lib.Message |
|
84
|
-
Union[message_lib.Message,
|
83
|
+
message_lib.Message | BaseException, # LM input.
|
84
|
+
Union[message_lib.Message, BaseException, None], # LM output.
|
85
85
|
],
|
86
86
|
]:
|
87
87
|
"""Random samples the input/output of this LangFunc concurrently.
|
@@ -121,14 +121,14 @@ def _concurrent_sample(
|
|
121
121
|
*,
|
122
122
|
max_workers: int = 32,
|
123
123
|
silence_on_errors: Union[
|
124
|
-
Type[
|
124
|
+
Type[BaseException], Tuple[Type[BaseException], ...], None
|
125
125
|
] = None,
|
126
126
|
ignore_examples_with_errors: bool = True,
|
127
127
|
**kwargs,
|
128
128
|
) -> Generator[
|
129
129
|
Tuple[
|
130
|
-
message_lib.Message |
|
131
|
-
Union[message_lib.Message,
|
130
|
+
message_lib.Message | BaseException, # LM input.
|
131
|
+
Union[message_lib.Message, BaseException, None], # LM output.
|
132
132
|
],
|
133
133
|
None,
|
134
134
|
None, # Sender type and return type.
|
@@ -177,6 +177,6 @@ def _concurrent_sample(
|
|
177
177
|
else:
|
178
178
|
lm_input, lm_output = error, error
|
179
179
|
if (not ignore_examples_with_errors
|
180
|
-
or not (isinstance(lm_input,
|
181
|
-
or isinstance(lm_output,
|
180
|
+
or not (isinstance(lm_input, BaseException)
|
181
|
+
or isinstance(lm_output, BaseException))):
|
182
182
|
yield lm_input, lm_output
|
@@ -16,6 +16,8 @@
|
|
16
16
|
# pylint: disable=g-bad-import-order
|
17
17
|
# pylint: disable=g-importing-member
|
18
18
|
|
19
|
+
from langfun.core.structured.schema import include_method_in_prompt
|
20
|
+
|
19
21
|
from langfun.core.structured.schema import Missing
|
20
22
|
from langfun.core.structured.schema import MISSING
|
21
23
|
from langfun.core.structured.schema import Unknown
|
@@ -34,12 +36,6 @@ from langfun.core.structured.schema import class_definitions
|
|
34
36
|
from langfun.core.structured.schema import annotation
|
35
37
|
from langfun.core.structured.schema import structure_from_python
|
36
38
|
|
37
|
-
from langfun.core.structured.schema import SchemaRepr
|
38
|
-
from langfun.core.structured.schema import SchemaJsonRepr
|
39
|
-
from langfun.core.structured.schema import SchemaPythonRepr
|
40
|
-
from langfun.core.structured.schema import ValueRepr
|
41
|
-
from langfun.core.structured.schema import ValueJsonRepr
|
42
|
-
from langfun.core.structured.schema import ValuePythonRepr
|
43
39
|
from langfun.core.structured.schema import schema_repr
|
44
40
|
from langfun.core.structured.schema import source_form
|
45
41
|
from langfun.core.structured.schema import value_repr
|
@@ -54,25 +50,25 @@ from langfun.core.structured.mapping import Mapping
|
|
54
50
|
from langfun.core.structured.mapping import MappingError
|
55
51
|
from langfun.core.structured.mapping import MappingExample
|
56
52
|
|
57
|
-
from langfun.core.structured.parsing import ParseStructure
|
58
|
-
from langfun.core.structured.parsing import ParseStructureJson
|
59
|
-
from langfun.core.structured.parsing import ParseStructurePython
|
60
53
|
from langfun.core.structured.parsing import parse
|
61
54
|
from langfun.core.structured.parsing import call
|
62
55
|
|
63
|
-
from langfun.core.structured.
|
64
|
-
from langfun.core.structured.
|
65
|
-
from langfun.core.structured.
|
66
|
-
from langfun.core.structured.
|
56
|
+
from langfun.core.structured.querying import track_queries
|
57
|
+
from langfun.core.structured.querying import QueryInvocation
|
58
|
+
from langfun.core.structured.querying import query
|
59
|
+
from langfun.core.structured.querying import query_and_reduce
|
67
60
|
|
68
|
-
from langfun.core.structured.
|
69
|
-
from langfun.core.structured.
|
61
|
+
from langfun.core.structured.querying import query_prompt
|
62
|
+
from langfun.core.structured.querying import query_output
|
63
|
+
from langfun.core.structured.querying import query_reward
|
70
64
|
|
71
|
-
from langfun.core.structured.
|
65
|
+
from langfun.core.structured.description import describe
|
72
66
|
from langfun.core.structured.completion import complete
|
73
67
|
|
74
68
|
from langfun.core.structured.scoring import score
|
75
69
|
|
70
|
+
from langfun.core.structured.tokenization import tokenize
|
71
|
+
|
76
72
|
# Expose default examples for structured operations so users could refer to
|
77
73
|
# them.
|
78
74
|
from langfun.core.structured.parsing import default_parse_examples
|
@@ -21,7 +21,7 @@ from langfun.core.structured import schema as schema_lib
|
|
21
21
|
import pyglove as pg
|
22
22
|
|
23
23
|
|
24
|
-
class
|
24
|
+
class _CompleteStructure(mapping.Mapping):
|
25
25
|
"""Complete structure by filling the missing fields."""
|
26
26
|
|
27
27
|
input: Annotated[
|
@@ -30,7 +30,7 @@ class CompleteStructure(mapping.Mapping):
|
|
30
30
|
|
31
31
|
mapping_template = lf.Template("""
|
32
32
|
{{ input_title }}:
|
33
|
-
{{ example.input_repr() | indent(2, True) }}
|
33
|
+
{{ example.input_repr(use_modality_ref=True) | indent(2, True) }}
|
34
34
|
|
35
35
|
{%- if missing_type_dependencies(example.input) %}
|
36
36
|
|
@@ -45,13 +45,16 @@ class CompleteStructure(mapping.Mapping):
|
|
45
45
|
|
46
46
|
{{ output_title }}:
|
47
47
|
{%- if example.has_output %}
|
48
|
-
{{ example.output_repr() | indent(2, True) }}
|
48
|
+
{{ example.output_repr(use_modality_ref=True) | indent(2, True) }}
|
49
49
|
{% endif -%}
|
50
50
|
""")
|
51
51
|
|
52
52
|
input_title = 'INPUT_OBJECT'
|
53
53
|
output_title = 'OUTPUT_OBJECT'
|
54
54
|
schema_title = 'CLASS_DEFINITIONS'
|
55
|
+
modality_refs_title: Annotated[
|
56
|
+
str, 'The section title for modality refs.'
|
57
|
+
] = 'MODALITY_REFERENCES'
|
55
58
|
|
56
59
|
preamble = lf.LangFunc(
|
57
60
|
"""
|
@@ -107,7 +110,9 @@ class CompleteStructure(mapping.Mapping):
|
|
107
110
|
|
108
111
|
def class_defs_repr(self, value: Any) -> str | None:
|
109
112
|
return schema_lib.class_definitions(
|
110
|
-
self.missing_type_dependencies(value),
|
113
|
+
self.missing_type_dependencies(value),
|
114
|
+
markdown=True,
|
115
|
+
allowed_dependencies=set()
|
111
116
|
)
|
112
117
|
|
113
118
|
def postprocess_result(self, result: Any) -> Any:
|
@@ -146,6 +151,28 @@ class CompleteStructure(mapping.Mapping):
|
|
146
151
|
pg.traverse(self.input, _visit)
|
147
152
|
return context
|
148
153
|
|
154
|
+
#
|
155
|
+
# Helper methods for handling modalities.
|
156
|
+
#
|
157
|
+
|
158
|
+
def has_modality_refs(self, value: Any) -> bool:
|
159
|
+
"""Returns true if the value has modalities."""
|
160
|
+
return not isinstance(value, lf.Modality) and pg.contains(
|
161
|
+
value, type=lf.Modality
|
162
|
+
)
|
163
|
+
|
164
|
+
def modalities(self, value: Any) -> dict[str, lf.Modality]:
|
165
|
+
return lf.Modality.from_value(value)
|
166
|
+
|
167
|
+
def modality_refs_repr(self, value: Any) -> str:
|
168
|
+
with lf.modality.format_modality_as_ref(True):
|
169
|
+
return pg.format(
|
170
|
+
self.modalities(value),
|
171
|
+
compact=False,
|
172
|
+
verbose=False,
|
173
|
+
python_format=True,
|
174
|
+
)
|
175
|
+
|
149
176
|
|
150
177
|
def complete(
|
151
178
|
input_value: pg.Symbolic,
|
@@ -214,7 +241,7 @@ def complete(
|
|
214
241
|
Returns:
|
215
242
|
The result based on the schema.
|
216
243
|
"""
|
217
|
-
t =
|
244
|
+
t = _CompleteStructure(
|
218
245
|
input=schema_lib.mark_missing(input_value),
|
219
246
|
default=default,
|
220
247
|
examples=examples,
|
@@ -46,7 +46,7 @@ class TripPlan(pg.Object):
|
|
46
46
|
class CompleteStructureTest(unittest.TestCase):
|
47
47
|
|
48
48
|
def test_render_no_examples(self):
|
49
|
-
l = completion.
|
49
|
+
l = completion._CompleteStructure()
|
50
50
|
input_value = schema_lib.mark_missing(
|
51
51
|
TripPlan.partial(
|
52
52
|
place='San Francisco',
|
@@ -120,7 +120,7 @@ class CompleteStructureTest(unittest.TestCase):
|
|
120
120
|
)
|
121
121
|
|
122
122
|
def test_render_no_class_definitions(self):
|
123
|
-
l = completion.
|
123
|
+
l = completion._CompleteStructure()
|
124
124
|
input_value = schema_lib.mark_missing(
|
125
125
|
TripPlan.partial(
|
126
126
|
place='San Francisco',
|
@@ -200,7 +200,7 @@ class CompleteStructureTest(unittest.TestCase):
|
|
200
200
|
)
|
201
201
|
|
202
202
|
def test_render_with_examples(self):
|
203
|
-
l = completion.
|
203
|
+
l = completion._CompleteStructure()
|
204
204
|
input_value = schema_lib.mark_missing(
|
205
205
|
TripPlan.partial(
|
206
206
|
place='San Francisco',
|
@@ -411,7 +411,7 @@ class CompleteStructureTest(unittest.TestCase):
|
|
411
411
|
modalities.Image.from_bytes(b'image_of_elephant'),
|
412
412
|
)
|
413
413
|
)
|
414
|
-
l = completion.
|
414
|
+
l = completion._CompleteStructure(
|
415
415
|
input=input_value,
|
416
416
|
examples=[
|
417
417
|
mapping.MappingExample(
|
@@ -464,7 +464,7 @@ class CompleteStructureTest(unittest.TestCase):
|
|
464
464
|
|
465
465
|
MODALITY_REFERENCES:
|
466
466
|
{
|
467
|
-
'examples[0].input.image':
|
467
|
+
'examples[0].input.image': <<[[examples[0].input.image]]>>
|
468
468
|
}
|
469
469
|
|
470
470
|
OUTPUT_OBJECT:
|
@@ -490,7 +490,7 @@ class CompleteStructureTest(unittest.TestCase):
|
|
490
490
|
|
491
491
|
MODALITY_REFERENCES:
|
492
492
|
{
|
493
|
-
'input.image':
|
493
|
+
'input.image': <<[[input.image]]>>
|
494
494
|
}
|
495
495
|
|
496
496
|
OUTPUT_OBJECT:
|
@@ -581,6 +581,7 @@ class CompleteStructureTest(unittest.TestCase):
|
|
581
581
|
text='Activity(description="foo")',
|
582
582
|
result=Activity(description='foo'),
|
583
583
|
score=1.0,
|
584
|
+
is_cached=False,
|
584
585
|
logprobs=None,
|
585
586
|
usage=lf.LMSamplingUsage(553, 27, 580),
|
586
587
|
tags=['lm-response', 'lm-output', 'transformed']
|
@@ -22,7 +22,7 @@ import pyglove as pg
|
|
22
22
|
|
23
23
|
|
24
24
|
@pg.use_init_args(['examples'])
|
25
|
-
class
|
25
|
+
class _DescribeStructure(mapping.Mapping):
|
26
26
|
"""Describe a structured value in natural language."""
|
27
27
|
|
28
28
|
input_title = 'PYTHON_OBJECT'
|
@@ -106,7 +106,7 @@ def describe(
|
|
106
106
|
Returns:
|
107
107
|
The parsed result based on the schema.
|
108
108
|
"""
|
109
|
-
return
|
109
|
+
return _DescribeStructure(
|
110
110
|
input=value,
|
111
111
|
context=context,
|
112
112
|
examples=examples or default_describe_examples(),
|
@@ -36,7 +36,7 @@ class Itinerary(pg.Object):
|
|
36
36
|
class DescribeStructureTest(unittest.TestCase):
|
37
37
|
|
38
38
|
def test_render(self):
|
39
|
-
l = description_lib.
|
39
|
+
l = description_lib._DescribeStructure(
|
40
40
|
input=Itinerary(
|
41
41
|
day=1,
|
42
42
|
type='daytime',
|
@@ -137,7 +137,7 @@ class DescribeStructureTest(unittest.TestCase):
|
|
137
137
|
],
|
138
138
|
hotel=None,
|
139
139
|
)
|
140
|
-
l = description_lib.
|
140
|
+
l = description_lib._DescribeStructure(
|
141
141
|
input=value, context='1 day itinerary to SF'
|
142
142
|
)
|
143
143
|
self.assertEqual(
|
@@ -187,7 +187,7 @@ class DescribeStructureTest(unittest.TestCase):
|
|
187
187
|
],
|
188
188
|
hotel=None,
|
189
189
|
)
|
190
|
-
l = description_lib.
|
190
|
+
l = description_lib._DescribeStructure(input=value)
|
191
191
|
self.assertEqual(
|
192
192
|
l.render().text,
|
193
193
|
inspect.cleandoc("""
|
@@ -16,16 +16,16 @@
|
|
16
16
|
import functools
|
17
17
|
import inspect
|
18
18
|
import re
|
19
|
-
from typing import Any, Callable, Optional, Tuple
|
19
|
+
from typing import Any, Callable, Literal, Optional, Tuple
|
20
20
|
|
21
21
|
from langfun.core import language_model
|
22
22
|
from langfun.core import template
|
23
23
|
from langfun.core.coding import python
|
24
|
-
from langfun.core.structured import
|
24
|
+
from langfun.core.structured import querying
|
25
25
|
import pyglove as pg
|
26
26
|
|
27
27
|
|
28
|
-
def unittest_gen(signature, lm, num_retries=
|
28
|
+
def unittest_gen(signature, lm, num_retries=1):
|
29
29
|
"""Generates unit tests for a python function signature."""
|
30
30
|
|
31
31
|
class UnitTest(pg.Object):
|
@@ -39,7 +39,7 @@ def unittest_gen(signature, lm, num_retries=10):
|
|
39
39
|
|
40
40
|
unittest_examples = None
|
41
41
|
for _ in range(num_retries):
|
42
|
-
r =
|
42
|
+
r = querying.query(
|
43
43
|
PythonFunctionSignature(signature=signature),
|
44
44
|
list[UnitTest],
|
45
45
|
lm=lm,
|
@@ -76,12 +76,16 @@ def unittest_with_test_cases(f, unittests):
|
|
76
76
|
|
77
77
|
def _function_gen(
|
78
78
|
func: Callable[..., Any],
|
79
|
+
context: dict[str, Any],
|
79
80
|
signature: str,
|
80
81
|
lm: language_model.LanguageModel,
|
81
|
-
num_retries: int =
|
82
|
+
num_retries: int = 1,
|
82
83
|
unittest: Optional[
|
83
|
-
Callable[[Callable[..., Any]], None]
|
84
|
+
Callable[[Callable[..., Any]], None]
|
85
|
+
| list[Tuple[Any, Any]]
|
86
|
+
| Literal["auto"]
|
84
87
|
] = None,
|
88
|
+
unittest_num_retries: int = 1,
|
85
89
|
):
|
86
90
|
"""Generates a python function with LLM and verify its quality with unit testing."""
|
87
91
|
|
@@ -131,32 +135,43 @@ def _function_gen(
|
|
131
135
|
"""
|
132
136
|
|
133
137
|
unittest_examples = None
|
134
|
-
if unittest
|
135
|
-
unittest_examples = unittest_gen(
|
136
|
-
|
138
|
+
if unittest == "auto":
|
139
|
+
unittest_examples = unittest_gen(
|
140
|
+
signature, lm=lm, num_retries=unittest_num_retries
|
141
|
+
)
|
142
|
+
elif isinstance(unittest, list):
|
137
143
|
unittest_examples = unittest
|
138
144
|
|
145
|
+
last_error = None
|
139
146
|
for _ in range(num_retries):
|
140
147
|
try:
|
141
|
-
source_code =
|
148
|
+
source_code = querying.query(
|
142
149
|
PythonFunctionPrompt(signature=signature), lm=lm
|
143
150
|
)
|
144
|
-
f = python.evaluate(source_code)
|
151
|
+
f = python.evaluate(source_code, global_vars=context)
|
145
152
|
|
146
153
|
# Check whether the sigantures are the same.
|
147
154
|
if inspect.signature(f) != inspect.signature(func):
|
148
|
-
|
155
|
+
raise python.CodeError(
|
156
|
+
code=source_code,
|
157
|
+
cause=TypeError(
|
158
|
+
f"Signature mismatch: Expected: {inspect.signature(func)}, "
|
159
|
+
f"Actual: {inspect.signature(f)}.",
|
160
|
+
),
|
161
|
+
)
|
149
162
|
|
150
163
|
if callable(unittest):
|
151
164
|
unittest(f)
|
152
|
-
|
165
|
+
elif unittest_examples:
|
153
166
|
unittest_with_test_cases(f, unittest_examples)
|
154
167
|
|
155
168
|
return f, source_code
|
156
|
-
except
|
157
|
-
|
158
|
-
|
159
|
-
|
169
|
+
except python.CodeError as e:
|
170
|
+
last_error = e
|
171
|
+
pg.logging.warning(
|
172
|
+
f"Bad code generated: {e}",
|
173
|
+
)
|
174
|
+
raise last_error
|
160
175
|
|
161
176
|
|
162
177
|
def _process_signature(signature):
|
@@ -172,10 +187,13 @@ def _process_signature(signature):
|
|
172
187
|
def function_gen(
|
173
188
|
lm: language_model.LanguageModel,
|
174
189
|
cache_filename: str | None = None,
|
175
|
-
num_retries: int =
|
190
|
+
num_retries: int = 1,
|
176
191
|
unittest: Optional[
|
177
|
-
Callable[[Callable[..., Any]], None]
|
192
|
+
Callable[[Callable[..., Any]], None]
|
193
|
+
| list[Tuple[Any, Any]]
|
194
|
+
| Literal["auto"]
|
178
195
|
] = None,
|
196
|
+
unittest_num_retries: int = 1,
|
179
197
|
):
|
180
198
|
"""A decorator for automating function generation using a language model.
|
181
199
|
|
@@ -192,9 +210,12 @@ def function_gen(
|
|
192
210
|
make to generate a suitable function implementation.
|
193
211
|
unittest: This optional parameter enables the definition of custom unit
|
194
212
|
tests. You can either provide a list of test cases as tuples of inputs
|
195
|
-
and outputs, or a function that throws an error if a test fails
|
196
|
-
|
197
|
-
|
213
|
+
and outputs, or a function that throws an error if a test fails, or let
|
214
|
+
LLM automatically create the unit test cases. If a generated function is
|
215
|
+
and returned, it should pass all the unittests.
|
216
|
+
unittest_num_retries: If unittest is set to "auto", this parameter
|
217
|
+
specifies the number of times the LLM's attempts to generate unit test
|
218
|
+
cases.
|
198
219
|
|
199
220
|
Returns:
|
200
221
|
The implemented function object.
|
@@ -204,6 +225,13 @@ def function_gen(
|
|
204
225
|
setattr(func, "__function__", None)
|
205
226
|
setattr(func, "__source_code__", None)
|
206
227
|
|
228
|
+
# Prepare the globals/locals for the generated code to be evaluated against.
|
229
|
+
callstack = inspect.stack()
|
230
|
+
assert len(callstack) > 1
|
231
|
+
context = dict(callstack[1][0].f_globals)
|
232
|
+
context.update(callstack[1][0].f_locals)
|
233
|
+
context.pop(func.__name__, None)
|
234
|
+
|
207
235
|
@functools.wraps(func)
|
208
236
|
def lm_generated_func(*args, **kwargs):
|
209
237
|
if func.__function__ is not None:
|
@@ -222,15 +250,20 @@ def function_gen(
|
|
222
250
|
|
223
251
|
if signature in cache:
|
224
252
|
func.__source_code__ = cache[signature]
|
225
|
-
func.__function__ = python.evaluate(
|
253
|
+
func.__function__ = python.evaluate(
|
254
|
+
func.__source_code__, global_vars=context
|
255
|
+
)
|
226
256
|
return func.__function__(*args, **kwargs)
|
227
257
|
|
228
258
|
func.__function__, func.__source_code__ = _function_gen(
|
229
|
-
func,
|
259
|
+
func,
|
260
|
+
context,
|
261
|
+
signature,
|
262
|
+
lm,
|
263
|
+
num_retries=num_retries,
|
264
|
+
unittest=unittest,
|
265
|
+
unittest_num_retries=unittest_num_retries,
|
230
266
|
)
|
231
|
-
if func.__function__ is None:
|
232
|
-
raise ValueError(f"Function generation failed. Signature:\n{signature}")
|
233
|
-
|
234
267
|
if cache_filename is not None:
|
235
268
|
cache[signature] = func.__source_code__
|
236
269
|
cache.save(cache_filename)
|