langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +22 -2
- langfun/core/__init__.py +17 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -28
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +69 -2
- langfun/core/component_test.py +54 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +17 -0
- langfun/core/eval/base.py +767 -140
- langfun/core/eval/base_test.py +238 -53
- langfun/core/eval/matching.py +80 -76
- langfun/core/eval/matching_test.py +19 -9
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +37 -28
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +3 -21
- langfun/core/langfunc_test.py +26 -8
- langfun/core/language_model.py +686 -48
- langfun/core/language_model_test.py +681 -44
- langfun/core/llms/__init__.py +100 -12
- langfun/core/llms/anthropic.py +488 -0
- langfun/core/llms/anthropic_test.py +235 -0
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +88 -28
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +39 -26
- langfun/core/llms/fake_test.py +136 -11
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -197
- langfun/core/llms/groq.py +276 -0
- langfun/core/llms/groq_test.py +64 -0
- langfun/core/llms/llama_cpp.py +15 -40
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +436 -226
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +35 -174
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -23
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +15 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +9 -8
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +278 -0
- langfun/core/structured/function_generation_test.py +399 -0
- langfun/core/structured/mapping.py +150 -46
- langfun/core/structured/mapping_test.py +105 -0
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +71 -22
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
- langfun/core/structured/schema.py +208 -99
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +133 -34
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +240 -11
- langfun/core/template_test.py +146 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +14 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -217
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
langfun/core/modalities/video.py
CHANGED
@@ -13,32 +13,18 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Video modality."""
|
15
15
|
|
16
|
-
import
|
17
|
-
from typing import cast
|
18
|
-
|
16
|
+
import functools
|
19
17
|
from langfun.core.modalities import mime
|
20
|
-
import magic
|
21
18
|
|
22
19
|
|
23
|
-
class Video(mime.
|
24
|
-
"""
|
20
|
+
class Video(mime.Mime):
|
21
|
+
"""Video."""
|
25
22
|
|
26
|
-
|
27
|
-
def video_format(self) -> str:
|
28
|
-
return cast(str, self.mime_type.lstrip('video/'))
|
23
|
+
MIME_PREFIX = 'video'
|
29
24
|
|
30
|
-
@
|
31
|
-
def
|
32
|
-
|
33
|
-
if 'video/' not in video_mime_type:
|
34
|
-
raise ValueError(f'Not a video: {video_mime_type!r}.')
|
35
|
-
return video_mime_type
|
25
|
+
@functools.cached_property
|
26
|
+
def video_format(self) -> str:
|
27
|
+
return self.mime_type.removeprefix(self.MIME_PREFIX + '/')
|
36
28
|
|
37
|
-
def
|
38
|
-
|
39
|
-
return f'<video controls> <source src="{self.uri}"> </video>'
|
40
|
-
video_raw = base64.b64encode(self.to_bytes()).decode()
|
41
|
-
return (
|
42
|
-
'<video controls> <source'
|
43
|
-
f' src="data:video/{self.video_format};base64,{video_raw}"> </video>'
|
44
|
-
)
|
29
|
+
def _mime_control_for(self, uri: str) -> str:
|
30
|
+
return f'<video controls> <source src="{uri}"> </video>'
|
@@ -38,12 +38,12 @@ class VideoContentTest(unittest.TestCase):
|
|
38
38
|
video = video_lib.Video.from_bytes(mp4_bytes)
|
39
39
|
self.assertEqual(video.mime_type, 'video/mp4')
|
40
40
|
self.assertEqual(video.video_format, 'mp4')
|
41
|
-
self.assertIn('data:video/mp4;base64,', video.
|
41
|
+
self.assertIn('data:video/mp4;base64,', video._raw_html())
|
42
42
|
self.assertEqual(video.to_bytes(), mp4_bytes)
|
43
43
|
|
44
44
|
def test_bad_video(self):
|
45
45
|
video = video_lib.Video.from_bytes(b'bad')
|
46
|
-
with self.assertRaisesRegex(ValueError, '
|
46
|
+
with self.assertRaisesRegex(ValueError, 'Expected MIME type'):
|
47
47
|
_ = video.video_format
|
48
48
|
|
49
49
|
|
@@ -56,7 +56,7 @@ class VideoFileTest(unittest.TestCase):
|
|
56
56
|
self.assertEqual(video.video_format, 'mp4')
|
57
57
|
self.assertEqual(video.mime_type, 'video/mp4')
|
58
58
|
self.assertEqual(
|
59
|
-
video.
|
59
|
+
video._raw_html(),
|
60
60
|
'<video controls> <source src="http://mock/web/a.mp4"> </video>',
|
61
61
|
)
|
62
62
|
self.assertEqual(video.to_bytes(), mp4_bytes)
|
langfun/core/modality.py
CHANGED
@@ -14,6 +14,8 @@
|
|
14
14
|
"""Interface for modality (e.g. Image, Video, etc.)."""
|
15
15
|
|
16
16
|
import abc
|
17
|
+
import functools
|
18
|
+
import hashlib
|
17
19
|
from typing import Any, ContextManager
|
18
20
|
from langfun.core import component
|
19
21
|
import pyglove as pg
|
@@ -29,11 +31,16 @@ def format_modality_as_ref(enabled: bool = True) -> ContextManager[None]:
|
|
29
31
|
)
|
30
32
|
|
31
33
|
|
32
|
-
class Modality(component.Component):
|
34
|
+
class Modality(component.Component, pg.views.HtmlTreeView.Extension):
|
33
35
|
"""Base class for multimodal object."""
|
34
36
|
|
35
|
-
REF_START = '
|
36
|
-
REF_END = '
|
37
|
+
REF_START = '<<[['
|
38
|
+
REF_END = ']]>>'
|
39
|
+
|
40
|
+
def _on_bound(self):
|
41
|
+
super()._on_bound()
|
42
|
+
# Invalidate cached hash if modality member is changed.
|
43
|
+
self.__dict__.pop('hash', None)
|
37
44
|
|
38
45
|
def format(self, *args, **kwargs) -> str:
|
39
46
|
if self.referred_name is None or not pg.object_utils.thread_local_get(
|
@@ -42,10 +49,22 @@ class Modality(component.Component):
|
|
42
49
|
return super().format(*args, **kwargs)
|
43
50
|
return Modality.text_marker(self.referred_name)
|
44
51
|
|
52
|
+
def __str_kwargs__(self) -> dict[str, Any]:
|
53
|
+
# For modality objects, we don't want to use markdown format when they
|
54
|
+
# are rendered as parts of the prompt.
|
55
|
+
kwargs = super().__str_kwargs__()
|
56
|
+
kwargs.pop('markdown', None)
|
57
|
+
return kwargs
|
58
|
+
|
45
59
|
@abc.abstractmethod
|
46
60
|
def to_bytes(self) -> bytes:
|
47
61
|
"""Returns content in bytes."""
|
48
62
|
|
63
|
+
@functools.cached_property
|
64
|
+
def hash(self) -> str:
|
65
|
+
"""Returns a 8-byte MD5 hash as the identifier for this modality object."""
|
66
|
+
return hashlib.md5(self.to_bytes()).hexdigest()[:8]
|
67
|
+
|
49
68
|
@classmethod
|
50
69
|
def text_marker(cls, var_name: str) -> str:
|
51
70
|
"""Returns a marker in the text for this object."""
|
@@ -108,3 +127,7 @@ class ModalityRef(pg.Object, pg.typing.CustomTyping):
|
|
108
127
|
return ModalityRef(name=value.sym_path + k)
|
109
128
|
return v
|
110
129
|
return value.clone().rebind(_placehold, raise_on_no_change=False)
|
130
|
+
|
131
|
+
|
132
|
+
class ModalityError(RuntimeError): # pylint: disable=g-bad-exception-name
|
133
|
+
"""Exception raised when modality is not supported."""
|
langfun/core/modality_test.py
CHANGED
@@ -11,7 +11,6 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for modality."""
|
15
14
|
from typing import Any
|
16
15
|
import unittest
|
17
16
|
|
@@ -32,12 +31,13 @@ class ModalityTest(unittest.TestCase):
|
|
32
31
|
v = CustomModality('a')
|
33
32
|
self.assertIsNone(v.referred_name)
|
34
33
|
self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
|
34
|
+
self.assertEqual(v.hash, '0cc175b9')
|
35
35
|
|
36
36
|
_ = pg.Dict(metadata=pg.Dict(x=pg.Dict(metadata=pg.Dict(y=v))))
|
37
37
|
self.assertEqual(v.referred_name, 'x.metadata.y')
|
38
38
|
self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
|
39
39
|
with modality.format_modality_as_ref():
|
40
|
-
self.assertEqual(str(v), '
|
40
|
+
self.assertEqual(str(v), '<<[[x.metadata.y]]>>')
|
41
41
|
|
42
42
|
|
43
43
|
class ModalityRefTest(unittest.TestCase):
|
langfun/core/sampling.py
CHANGED
@@ -28,14 +28,14 @@ def sweep(
|
|
28
28
|
*,
|
29
29
|
max_workers: int = 32,
|
30
30
|
silence_on_errors: Union[
|
31
|
-
Type[
|
31
|
+
Type[BaseException], Tuple[Type[BaseException], ...], None
|
32
32
|
] = None,
|
33
33
|
ignore_examples_with_errors: bool = True,
|
34
34
|
**kwargs,
|
35
35
|
) -> Iterator[
|
36
36
|
Tuple[
|
37
|
-
message_lib.Message |
|
38
|
-
Union[message_lib.Message,
|
37
|
+
message_lib.Message | BaseException, # LM input.
|
38
|
+
Union[message_lib.Message, BaseException, None], # LM output.
|
39
39
|
],
|
40
40
|
]:
|
41
41
|
"""Sweeps the input/output of this LangFunc concurrently.
|
@@ -73,15 +73,15 @@ def random_sample(
|
|
73
73
|
*,
|
74
74
|
max_workers: int = 32,
|
75
75
|
silence_on_errors: Union[
|
76
|
-
Type[
|
76
|
+
Type[BaseException], Tuple[Type[BaseException], ...], None
|
77
77
|
] = None,
|
78
78
|
ignore_examples_with_errors: bool = True,
|
79
79
|
seed: int | None = None,
|
80
80
|
**kwargs,
|
81
81
|
) -> Iterator[
|
82
82
|
Tuple[
|
83
|
-
message_lib.Message |
|
84
|
-
Union[message_lib.Message,
|
83
|
+
message_lib.Message | BaseException, # LM input.
|
84
|
+
Union[message_lib.Message, BaseException, None], # LM output.
|
85
85
|
],
|
86
86
|
]:
|
87
87
|
"""Random samples the input/output of this LangFunc concurrently.
|
@@ -121,14 +121,14 @@ def _concurrent_sample(
|
|
121
121
|
*,
|
122
122
|
max_workers: int = 32,
|
123
123
|
silence_on_errors: Union[
|
124
|
-
Type[
|
124
|
+
Type[BaseException], Tuple[Type[BaseException], ...], None
|
125
125
|
] = None,
|
126
126
|
ignore_examples_with_errors: bool = True,
|
127
127
|
**kwargs,
|
128
128
|
) -> Generator[
|
129
129
|
Tuple[
|
130
|
-
message_lib.Message |
|
131
|
-
Union[message_lib.Message,
|
130
|
+
message_lib.Message | BaseException, # LM input.
|
131
|
+
Union[message_lib.Message, BaseException, None], # LM output.
|
132
132
|
],
|
133
133
|
None,
|
134
134
|
None, # Sender type and return type.
|
@@ -177,6 +177,6 @@ def _concurrent_sample(
|
|
177
177
|
else:
|
178
178
|
lm_input, lm_output = error, error
|
179
179
|
if (not ignore_examples_with_errors
|
180
|
-
or not (isinstance(lm_input,
|
181
|
-
or isinstance(lm_output,
|
180
|
+
or not (isinstance(lm_input, BaseException)
|
181
|
+
or isinstance(lm_output, BaseException))):
|
182
182
|
yield lm_input, lm_output
|
@@ -16,6 +16,8 @@
|
|
16
16
|
# pylint: disable=g-bad-import-order
|
17
17
|
# pylint: disable=g-importing-member
|
18
18
|
|
19
|
+
from langfun.core.structured.schema import include_method_in_prompt
|
20
|
+
|
19
21
|
from langfun.core.structured.schema import Missing
|
20
22
|
from langfun.core.structured.schema import MISSING
|
21
23
|
from langfun.core.structured.schema import Unknown
|
@@ -34,12 +36,6 @@ from langfun.core.structured.schema import class_definitions
|
|
34
36
|
from langfun.core.structured.schema import annotation
|
35
37
|
from langfun.core.structured.schema import structure_from_python
|
36
38
|
|
37
|
-
from langfun.core.structured.schema import SchemaRepr
|
38
|
-
from langfun.core.structured.schema import SchemaJsonRepr
|
39
|
-
from langfun.core.structured.schema import SchemaPythonRepr
|
40
|
-
from langfun.core.structured.schema import ValueRepr
|
41
|
-
from langfun.core.structured.schema import ValueJsonRepr
|
42
|
-
from langfun.core.structured.schema import ValuePythonRepr
|
43
39
|
from langfun.core.structured.schema import schema_repr
|
44
40
|
from langfun.core.structured.schema import source_form
|
45
41
|
from langfun.core.structured.schema import value_repr
|
@@ -48,28 +44,31 @@ from langfun.core.structured.schema_generation import generate_class
|
|
48
44
|
from langfun.core.structured.schema_generation import classgen_example
|
49
45
|
from langfun.core.structured.schema_generation import default_classgen_examples
|
50
46
|
|
47
|
+
from langfun.core.structured.function_generation import function_gen
|
48
|
+
|
51
49
|
from langfun.core.structured.mapping import Mapping
|
50
|
+
from langfun.core.structured.mapping import MappingError
|
52
51
|
from langfun.core.structured.mapping import MappingExample
|
53
52
|
|
54
|
-
from langfun.core.structured.parsing import ParseStructure
|
55
|
-
from langfun.core.structured.parsing import ParseStructureJson
|
56
|
-
from langfun.core.structured.parsing import ParseStructurePython
|
57
53
|
from langfun.core.structured.parsing import parse
|
58
54
|
from langfun.core.structured.parsing import call
|
59
55
|
|
60
|
-
from langfun.core.structured.
|
61
|
-
from langfun.core.structured.
|
62
|
-
from langfun.core.structured.
|
63
|
-
from langfun.core.structured.
|
56
|
+
from langfun.core.structured.querying import track_queries
|
57
|
+
from langfun.core.structured.querying import QueryInvocation
|
58
|
+
from langfun.core.structured.querying import query
|
59
|
+
from langfun.core.structured.querying import query_and_reduce
|
64
60
|
|
65
|
-
from langfun.core.structured.
|
66
|
-
from langfun.core.structured.
|
61
|
+
from langfun.core.structured.querying import query_prompt
|
62
|
+
from langfun.core.structured.querying import query_output
|
63
|
+
from langfun.core.structured.querying import query_reward
|
67
64
|
|
68
|
-
from langfun.core.structured.
|
65
|
+
from langfun.core.structured.description import describe
|
69
66
|
from langfun.core.structured.completion import complete
|
70
67
|
|
71
68
|
from langfun.core.structured.scoring import score
|
72
69
|
|
70
|
+
from langfun.core.structured.tokenization import tokenize
|
71
|
+
|
73
72
|
# Expose default examples for structured operations so users could refer to
|
74
73
|
# them.
|
75
74
|
from langfun.core.structured.parsing import default_parse_examples
|
@@ -21,7 +21,7 @@ from langfun.core.structured import schema as schema_lib
|
|
21
21
|
import pyglove as pg
|
22
22
|
|
23
23
|
|
24
|
-
class
|
24
|
+
class _CompleteStructure(mapping.Mapping):
|
25
25
|
"""Complete structure by filling the missing fields."""
|
26
26
|
|
27
27
|
input: Annotated[
|
@@ -30,7 +30,7 @@ class CompleteStructure(mapping.Mapping):
|
|
30
30
|
|
31
31
|
mapping_template = lf.Template("""
|
32
32
|
{{ input_title }}:
|
33
|
-
{{ example.input_repr() | indent(2, True) }}
|
33
|
+
{{ example.input_repr(use_modality_ref=True) | indent(2, True) }}
|
34
34
|
|
35
35
|
{%- if missing_type_dependencies(example.input) %}
|
36
36
|
|
@@ -45,13 +45,16 @@ class CompleteStructure(mapping.Mapping):
|
|
45
45
|
|
46
46
|
{{ output_title }}:
|
47
47
|
{%- if example.has_output %}
|
48
|
-
{{ example.output_repr() | indent(2, True) }}
|
48
|
+
{{ example.output_repr(use_modality_ref=True) | indent(2, True) }}
|
49
49
|
{% endif -%}
|
50
50
|
""")
|
51
51
|
|
52
52
|
input_title = 'INPUT_OBJECT'
|
53
53
|
output_title = 'OUTPUT_OBJECT'
|
54
54
|
schema_title = 'CLASS_DEFINITIONS'
|
55
|
+
modality_refs_title: Annotated[
|
56
|
+
str, 'The section title for modality refs.'
|
57
|
+
] = 'MODALITY_REFERENCES'
|
55
58
|
|
56
59
|
preamble = lf.LangFunc(
|
57
60
|
"""
|
@@ -107,7 +110,9 @@ class CompleteStructure(mapping.Mapping):
|
|
107
110
|
|
108
111
|
def class_defs_repr(self, value: Any) -> str | None:
|
109
112
|
return schema_lib.class_definitions(
|
110
|
-
self.missing_type_dependencies(value),
|
113
|
+
self.missing_type_dependencies(value),
|
114
|
+
markdown=True,
|
115
|
+
allowed_dependencies=set()
|
111
116
|
)
|
112
117
|
|
113
118
|
def postprocess_result(self, result: Any) -> Any:
|
@@ -146,6 +151,28 @@ class CompleteStructure(mapping.Mapping):
|
|
146
151
|
pg.traverse(self.input, _visit)
|
147
152
|
return context
|
148
153
|
|
154
|
+
#
|
155
|
+
# Helper methods for handling modalities.
|
156
|
+
#
|
157
|
+
|
158
|
+
def has_modality_refs(self, value: Any) -> bool:
|
159
|
+
"""Returns true if the value has modalities."""
|
160
|
+
return not isinstance(value, lf.Modality) and pg.contains(
|
161
|
+
value, type=lf.Modality
|
162
|
+
)
|
163
|
+
|
164
|
+
def modalities(self, value: Any) -> dict[str, lf.Modality]:
|
165
|
+
return lf.Modality.from_value(value)
|
166
|
+
|
167
|
+
def modality_refs_repr(self, value: Any) -> str:
|
168
|
+
with lf.modality.format_modality_as_ref(True):
|
169
|
+
return pg.format(
|
170
|
+
self.modalities(value),
|
171
|
+
compact=False,
|
172
|
+
verbose=False,
|
173
|
+
python_format=True,
|
174
|
+
)
|
175
|
+
|
149
176
|
|
150
177
|
def complete(
|
151
178
|
input_value: pg.Symbolic,
|
@@ -214,7 +241,7 @@ def complete(
|
|
214
241
|
Returns:
|
215
242
|
The result based on the schema.
|
216
243
|
"""
|
217
|
-
t =
|
244
|
+
t = _CompleteStructure(
|
218
245
|
input=schema_lib.mark_missing(input_value),
|
219
246
|
default=default,
|
220
247
|
examples=examples,
|
@@ -17,7 +17,6 @@ import inspect
|
|
17
17
|
import unittest
|
18
18
|
|
19
19
|
import langfun.core as lf
|
20
|
-
from langfun.core import coding
|
21
20
|
from langfun.core import modalities
|
22
21
|
from langfun.core.llms import fake
|
23
22
|
from langfun.core.structured import completion
|
@@ -47,7 +46,7 @@ class TripPlan(pg.Object):
|
|
47
46
|
class CompleteStructureTest(unittest.TestCase):
|
48
47
|
|
49
48
|
def test_render_no_examples(self):
|
50
|
-
l = completion.
|
49
|
+
l = completion._CompleteStructure()
|
51
50
|
input_value = schema_lib.mark_missing(
|
52
51
|
TripPlan.partial(
|
53
52
|
place='San Francisco',
|
@@ -121,7 +120,7 @@ class CompleteStructureTest(unittest.TestCase):
|
|
121
120
|
)
|
122
121
|
|
123
122
|
def test_render_no_class_definitions(self):
|
124
|
-
l = completion.
|
123
|
+
l = completion._CompleteStructure()
|
125
124
|
input_value = schema_lib.mark_missing(
|
126
125
|
TripPlan.partial(
|
127
126
|
place='San Francisco',
|
@@ -201,7 +200,7 @@ class CompleteStructureTest(unittest.TestCase):
|
|
201
200
|
)
|
202
201
|
|
203
202
|
def test_render_with_examples(self):
|
204
|
-
l = completion.
|
203
|
+
l = completion._CompleteStructure()
|
205
204
|
input_value = schema_lib.mark_missing(
|
206
205
|
TripPlan.partial(
|
207
206
|
place='San Francisco',
|
@@ -412,7 +411,7 @@ class CompleteStructureTest(unittest.TestCase):
|
|
412
411
|
modalities.Image.from_bytes(b'image_of_elephant'),
|
413
412
|
)
|
414
413
|
)
|
415
|
-
l = completion.
|
414
|
+
l = completion._CompleteStructure(
|
416
415
|
input=input_value,
|
417
416
|
examples=[
|
418
417
|
mapping.MappingExample(
|
@@ -465,7 +464,7 @@ class CompleteStructureTest(unittest.TestCase):
|
|
465
464
|
|
466
465
|
MODALITY_REFERENCES:
|
467
466
|
{
|
468
|
-
'examples[0].input.image':
|
467
|
+
'examples[0].input.image': <<[[examples[0].input.image]]>>
|
469
468
|
}
|
470
469
|
|
471
470
|
OUTPUT_OBJECT:
|
@@ -491,7 +490,7 @@ class CompleteStructureTest(unittest.TestCase):
|
|
491
490
|
|
492
491
|
MODALITY_REFERENCES:
|
493
492
|
{
|
494
|
-
'input.image':
|
493
|
+
'input.image': <<[[input.image]]>>
|
495
494
|
}
|
496
495
|
|
497
496
|
OUTPUT_OBJECT:
|
@@ -582,7 +581,9 @@ class CompleteStructureTest(unittest.TestCase):
|
|
582
581
|
text='Activity(description="foo")',
|
583
582
|
result=Activity(description='foo'),
|
584
583
|
score=1.0,
|
584
|
+
is_cached=False,
|
585
585
|
logprobs=None,
|
586
|
+
usage=lf.LMSamplingUsage(553, 27, 580),
|
586
587
|
tags=['lm-response', 'lm-output', 'transformed']
|
587
588
|
)
|
588
589
|
)
|
@@ -607,7 +608,7 @@ class CompleteStructureTest(unittest.TestCase):
|
|
607
608
|
override_attrs=True,
|
608
609
|
):
|
609
610
|
with self.assertRaisesRegex(
|
610
|
-
|
611
|
+
mapping.MappingError,
|
611
612
|
'Expect .* but encountered .*',
|
612
613
|
):
|
613
614
|
completion.complete(Activity.partial(), autofix=0)
|
@@ -22,7 +22,7 @@ import pyglove as pg
|
|
22
22
|
|
23
23
|
|
24
24
|
@pg.use_init_args(['examples'])
|
25
|
-
class
|
25
|
+
class _DescribeStructure(mapping.Mapping):
|
26
26
|
"""Describe a structured value in natural language."""
|
27
27
|
|
28
28
|
input_title = 'PYTHON_OBJECT'
|
@@ -106,7 +106,7 @@ def describe(
|
|
106
106
|
Returns:
|
107
107
|
The parsed result based on the schema.
|
108
108
|
"""
|
109
|
-
return
|
109
|
+
return _DescribeStructure(
|
110
110
|
input=value,
|
111
111
|
context=context,
|
112
112
|
examples=examples or default_describe_examples(),
|
@@ -36,7 +36,7 @@ class Itinerary(pg.Object):
|
|
36
36
|
class DescribeStructureTest(unittest.TestCase):
|
37
37
|
|
38
38
|
def test_render(self):
|
39
|
-
l = description_lib.
|
39
|
+
l = description_lib._DescribeStructure(
|
40
40
|
input=Itinerary(
|
41
41
|
day=1,
|
42
42
|
type='daytime',
|
@@ -137,7 +137,7 @@ class DescribeStructureTest(unittest.TestCase):
|
|
137
137
|
],
|
138
138
|
hotel=None,
|
139
139
|
)
|
140
|
-
l = description_lib.
|
140
|
+
l = description_lib._DescribeStructure(
|
141
141
|
input=value, context='1 day itinerary to SF'
|
142
142
|
)
|
143
143
|
self.assertEqual(
|
@@ -187,7 +187,7 @@ class DescribeStructureTest(unittest.TestCase):
|
|
187
187
|
],
|
188
188
|
hotel=None,
|
189
189
|
)
|
190
|
-
l = description_lib.
|
190
|
+
l = description_lib._DescribeStructure(input=value)
|
191
191
|
self.assertEqual(
|
192
192
|
l.render().text,
|
193
193
|
inspect.cleandoc("""
|