langfun 0.1.2.dev202510230805__py3-none-any.whl → 0.1.2.dev202511270805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/core/__init__.py +2 -0
- langfun/core/agentic/__init__.py +4 -1
- langfun/core/agentic/action.py +447 -29
- langfun/core/agentic/action_eval.py +9 -2
- langfun/core/agentic/action_test.py +149 -21
- langfun/core/async_support.py +32 -3
- langfun/core/coding/python/correction.py +19 -9
- langfun/core/coding/python/execution.py +14 -12
- langfun/core/coding/python/generation.py +21 -16
- langfun/core/coding/python/sandboxing.py +23 -3
- langfun/core/component.py +42 -3
- langfun/core/concurrent.py +70 -6
- langfun/core/concurrent_test.py +1 -0
- langfun/core/console.py +1 -1
- langfun/core/data/conversion/anthropic.py +12 -3
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini.py +9 -2
- langfun/core/data/conversion/gemini_test.py +12 -9
- langfun/core/data/conversion/openai.py +145 -31
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base.py +47 -43
- langfun/core/eval/base_test.py +5 -5
- langfun/core/eval/matching.py +5 -2
- langfun/core/eval/patching.py +3 -3
- langfun/core/eval/scoring.py +4 -3
- langfun/core/eval/v2/__init__.py +1 -0
- langfun/core/eval/v2/checkpointing.py +64 -6
- langfun/core/eval/v2/checkpointing_test.py +9 -2
- langfun/core/eval/v2/eval_test_helper.py +103 -2
- langfun/core/eval/v2/evaluation.py +91 -16
- langfun/core/eval/v2/evaluation_test.py +9 -3
- langfun/core/eval/v2/example.py +50 -40
- langfun/core/eval/v2/example_test.py +16 -8
- langfun/core/eval/v2/experiment.py +74 -8
- langfun/core/eval/v2/experiment_test.py +19 -0
- langfun/core/eval/v2/metric_values.py +31 -3
- langfun/core/eval/v2/metric_values_test.py +32 -0
- langfun/core/eval/v2/metrics.py +157 -44
- langfun/core/eval/v2/metrics_test.py +39 -18
- langfun/core/eval/v2/progress.py +30 -1
- langfun/core/eval/v2/progress_test.py +27 -0
- langfun/core/eval/v2/progress_tracking.py +12 -3
- langfun/core/eval/v2/progress_tracking_test.py +6 -1
- langfun/core/eval/v2/reporting.py +90 -71
- langfun/core/eval/v2/reporting_test.py +24 -6
- langfun/core/eval/v2/runners/__init__.py +30 -0
- langfun/core/eval/v2/{runners.py → runners/base.py} +59 -142
- langfun/core/eval/v2/runners/beam.py +341 -0
- langfun/core/eval/v2/runners/beam_test.py +131 -0
- langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
- langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
- langfun/core/eval/v2/runners/debug.py +40 -0
- langfun/core/eval/v2/runners/debug_test.py +76 -0
- langfun/core/eval/v2/runners/parallel.py +100 -0
- langfun/core/eval/v2/runners/parallel_test.py +95 -0
- langfun/core/eval/v2/runners/sequential.py +47 -0
- langfun/core/eval/v2/runners/sequential_test.py +172 -0
- langfun/core/langfunc.py +45 -130
- langfun/core/langfunc_test.py +7 -5
- langfun/core/language_model.py +141 -21
- langfun/core/language_model_test.py +54 -3
- langfun/core/llms/__init__.py +9 -1
- langfun/core/llms/anthropic.py +157 -2
- langfun/core/llms/azure_openai.py +29 -17
- langfun/core/llms/cache/base.py +25 -3
- langfun/core/llms/cache/in_memory.py +48 -7
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/compositional.py +25 -1
- langfun/core/llms/deepseek.py +30 -2
- langfun/core/llms/fake.py +32 -1
- langfun/core/llms/gemini.py +55 -17
- langfun/core/llms/gemini_test.py +84 -0
- langfun/core/llms/google_genai.py +34 -1
- langfun/core/llms/groq.py +28 -3
- langfun/core/llms/llama_cpp.py +23 -4
- langfun/core/llms/openai.py +36 -3
- langfun/core/llms/openai_compatible.py +148 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/rest.py +12 -1
- langfun/core/llms/vertexai.py +58 -8
- langfun/core/logging.py +1 -1
- langfun/core/mcp/client.py +77 -22
- langfun/core/mcp/client_test.py +8 -35
- langfun/core/mcp/session.py +94 -29
- langfun/core/mcp/session_test.py +54 -0
- langfun/core/mcp/tool.py +151 -22
- langfun/core/mcp/tool_test.py +197 -0
- langfun/core/memory.py +1 -0
- langfun/core/message.py +160 -55
- langfun/core/message_test.py +65 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/audio.py +21 -1
- langfun/core/modalities/image.py +19 -1
- langfun/core/modalities/mime.py +64 -3
- langfun/core/modalities/mime_test.py +11 -0
- langfun/core/modalities/pdf.py +19 -1
- langfun/core/modalities/video.py +21 -1
- langfun/core/modality.py +167 -29
- langfun/core/modality_test.py +42 -12
- langfun/core/natural_language.py +1 -1
- langfun/core/sampling.py +4 -4
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/__init__.py +2 -24
- langfun/core/structured/completion.py +34 -44
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/description.py +54 -50
- langfun/core/structured/function_generation.py +29 -12
- langfun/core/structured/mapping.py +81 -37
- langfun/core/structured/parsing.py +95 -79
- langfun/core/structured/parsing_test.py +0 -3
- langfun/core/structured/querying.py +215 -142
- langfun/core/structured/querying_test.py +65 -29
- langfun/core/structured/schema/__init__.py +49 -0
- langfun/core/structured/schema/base.py +664 -0
- langfun/core/structured/schema/base_test.py +531 -0
- langfun/core/structured/schema/json.py +174 -0
- langfun/core/structured/schema/json_test.py +121 -0
- langfun/core/structured/schema/python.py +316 -0
- langfun/core/structured/schema/python_test.py +410 -0
- langfun/core/structured/schema_generation.py +33 -14
- langfun/core/structured/scoring.py +47 -36
- langfun/core/structured/tokenization.py +26 -11
- langfun/core/subscription.py +2 -2
- langfun/core/template.py +174 -49
- langfun/core/template_test.py +123 -17
- langfun/env/__init__.py +8 -2
- langfun/env/base_environment.py +320 -128
- langfun/env/base_environment_test.py +473 -0
- langfun/env/base_feature.py +92 -15
- langfun/env/base_feature_test.py +228 -0
- langfun/env/base_sandbox.py +84 -361
- langfun/env/base_sandbox_test.py +1235 -0
- langfun/env/event_handlers/__init__.py +1 -1
- langfun/env/event_handlers/chain.py +233 -0
- langfun/env/event_handlers/chain_test.py +253 -0
- langfun/env/event_handlers/event_logger.py +95 -98
- langfun/env/event_handlers/event_logger_test.py +21 -21
- langfun/env/event_handlers/metric_writer.py +225 -140
- langfun/env/event_handlers/metric_writer_test.py +23 -6
- langfun/env/interface.py +854 -40
- langfun/env/interface_test.py +112 -2
- langfun/env/load_balancers_test.py +23 -2
- langfun/env/test_utils.py +126 -84
- {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/METADATA +1 -1
- langfun-0.1.2.dev202511270805.dist-info/RECORD +215 -0
- langfun/core/eval/v2/runners_test.py +0 -343
- langfun/core/structured/schema.py +0 -987
- langfun/core/structured/schema_test.py +0 -982
- langfun/env/base_test.py +0 -1481
- langfun/env/event_handlers/base.py +0 -350
- langfun-0.1.2.dev202510230805.dist-info/RECORD +0 -195
- {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511270805.dist-info}/top_level.txt +0 -0
langfun/core/modalities/video.py
CHANGED
|
@@ -18,7 +18,27 @@ from langfun.core.modalities import mime
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class Video(mime.Mime):
|
|
21
|
-
"""
|
|
21
|
+
"""Represents a video for communicating with language models.
|
|
22
|
+
|
|
23
|
+
`lf.Video` can be initialized from a URI (HTTP/HTTPS URL or local path)
|
|
24
|
+
using `lf.Video.from_uri()` or from raw bytes using `lf.Video.from_bytes()`.
|
|
25
|
+
|
|
26
|
+
**Example:**
|
|
27
|
+
|
|
28
|
+
```python
|
|
29
|
+
import langfun as lf
|
|
30
|
+
|
|
31
|
+
# Load video from path
|
|
32
|
+
video = lf.Video.from_path('/path/to/video.mp4')
|
|
33
|
+
|
|
34
|
+
# Use video in a prompt
|
|
35
|
+
prompt = lf.Template(
|
|
36
|
+
'What is happening in this video? {{video}}', video=video
|
|
37
|
+
)
|
|
38
|
+
response = lf.query(prompt, lm=lf.llms.Gemini25Flash())
|
|
39
|
+
print(response)
|
|
40
|
+
```
|
|
41
|
+
"""
|
|
22
42
|
|
|
23
43
|
MIME_PREFIX = 'video'
|
|
24
44
|
|
langfun/core/modality.py
CHANGED
|
@@ -14,40 +14,63 @@
|
|
|
14
14
|
"""Interface for modality (e.g. Image, Video, etc.)."""
|
|
15
15
|
|
|
16
16
|
import abc
|
|
17
|
+
import contextlib
|
|
17
18
|
import functools
|
|
18
19
|
import hashlib
|
|
19
|
-
|
|
20
|
+
import re
|
|
21
|
+
from typing import Any, ContextManager, Iterator
|
|
20
22
|
from langfun.core import component
|
|
21
23
|
import pyglove as pg
|
|
22
24
|
|
|
23
25
|
|
|
24
|
-
|
|
26
|
+
class Modality(component.Component, pg.views.HtmlTreeView.Extension):
|
|
27
|
+
"""Base class for representing non-text content in prompts.
|
|
25
28
|
|
|
29
|
+
`lf.Modality` is the base class for multimodal objects such as `lf.Image`,
|
|
30
|
+
`lf.Audio`, and `lf.Video`. It allows these non-text inputs to be
|
|
31
|
+
seamlessly embedded within text prompts for processing by multimodal
|
|
32
|
+
language models.
|
|
26
33
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
34
|
+
When a `Modality` object is rendered within an `lf.Template`, it is
|
|
35
|
+
replaced by a text marker (e.g., `<<[[image:b10a8db1]]>>`), and the
|
|
36
|
+
modality object itself is stored in the `referred_modalities` field of
|
|
37
|
+
the resulting `lf.Message`. This allows language models to associate
|
|
38
|
+
the placeholder with its content during processing.
|
|
32
39
|
|
|
40
|
+
**Example:**
|
|
33
41
|
|
|
34
|
-
|
|
35
|
-
|
|
42
|
+
```python
|
|
43
|
+
import langfun as lf
|
|
44
|
+
|
|
45
|
+
image = lf.Image.from_path('/path/to/image.png')
|
|
46
|
+
prompt = lf.Template('What is in this image? {{image}}', image=image)
|
|
47
|
+
|
|
48
|
+
message = prompt.render()
|
|
49
|
+
print(message.text)
|
|
50
|
+
# Output: What is in this image? <<[[image:b10a8db1]]>>
|
|
51
|
+
|
|
52
|
+
print(message.modalities())
|
|
53
|
+
# Output: [<Image object>]
|
|
54
|
+
```
|
|
55
|
+
"""
|
|
36
56
|
|
|
37
57
|
REF_START = '<<[['
|
|
38
58
|
REF_END = ']]>>'
|
|
39
59
|
|
|
40
60
|
def _on_bound(self):
|
|
41
61
|
super()._on_bound()
|
|
42
|
-
# Invalidate cached hash if modality member is changed.
|
|
62
|
+
# Invalidate cached hash and id if modality member is changed.
|
|
43
63
|
self.__dict__.pop('hash', None)
|
|
64
|
+
self.__dict__.pop('id', None)
|
|
44
65
|
|
|
45
66
|
def format(self, *args, **kwargs) -> str:
|
|
46
|
-
if
|
|
47
|
-
_TLS_MODALITY_AS_REF, False
|
|
48
|
-
):
|
|
67
|
+
if not pg.object_utils.thread_local_get(_TLS_MODALITY_AS_REF, False):
|
|
49
68
|
return super().format(*args, **kwargs)
|
|
50
|
-
|
|
69
|
+
|
|
70
|
+
capture_scope = get_modality_capture_context()
|
|
71
|
+
if capture_scope is not None:
|
|
72
|
+
capture_scope.capture(self)
|
|
73
|
+
return Modality.text_marker(self.id)
|
|
51
74
|
|
|
52
75
|
def __str_kwargs__(self) -> dict[str, Any]:
|
|
53
76
|
# For modality objects, we don't want to use markdown format when they
|
|
@@ -70,14 +93,11 @@ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
|
|
|
70
93
|
"""Returns a marker in the text for this object."""
|
|
71
94
|
return Modality.REF_START + var_name + Modality.REF_END
|
|
72
95
|
|
|
73
|
-
@
|
|
74
|
-
def
|
|
96
|
+
@functools.cached_property
|
|
97
|
+
def id(self) -> str | None:
|
|
75
98
|
"""Returns the referred name of this object in its template."""
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
# Strip the metadata prefix under message.
|
|
79
|
-
path = str(self.sym_path)
|
|
80
|
-
return path[9:] if path.startswith('metadata.') else path
|
|
99
|
+
modality_type = _camel_to_snake(self.__class__.__name__)
|
|
100
|
+
return f'{modality_type}:{self.hash}'
|
|
81
101
|
|
|
82
102
|
@classmethod
|
|
83
103
|
def from_value(cls, value: pg.Symbolic) -> dict[str, 'Modality']:
|
|
@@ -86,7 +106,7 @@ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
|
|
|
86
106
|
def _visit(k, v, p):
|
|
87
107
|
del k, p
|
|
88
108
|
if isinstance(v, Modality):
|
|
89
|
-
modalities[v.
|
|
109
|
+
modalities[v.id] = v
|
|
90
110
|
return pg.TraverseAction.CONTINUE
|
|
91
111
|
return pg.TraverseAction.ENTER
|
|
92
112
|
|
|
@@ -95,14 +115,47 @@ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
|
|
|
95
115
|
|
|
96
116
|
|
|
97
117
|
class ModalityRef(pg.Object, pg.typing.CustomTyping):
|
|
98
|
-
"""
|
|
118
|
+
"""Lightweight placeholder for a `lf.Modality` object in a symbolic tree.
|
|
99
119
|
|
|
100
|
-
`ModalityRef`
|
|
101
|
-
|
|
102
|
-
|
|
120
|
+
`ModalityRef` acts as a reference to a `Modality` object (like `lf.Image`
|
|
121
|
+
or `lf.Audio`) within a structured object hierarchy (e.g., a `pg.Object`).
|
|
122
|
+
Instead of embedding potentially large modality data directly, `ModalityRef`
|
|
123
|
+
stores only the ID of the modality object.
|
|
124
|
+
|
|
125
|
+
This is useful in scenarios where structured objects are serialized or
|
|
126
|
+
manipulated, and it's more efficient to refer to modalities by ID rather
|
|
127
|
+
than copying their content. The `lf.ModalityRef.placehold()` class method
|
|
128
|
+
can be used to replace `Modality` instances in a symbolic object with
|
|
129
|
+
`ModalityRef` placeholders, while `lf.ModalityRef.restore()` can reinstate
|
|
130
|
+
the original `Modality` objects using a lookup table.
|
|
131
|
+
|
|
132
|
+
**Example:**
|
|
133
|
+
|
|
134
|
+
```python
|
|
135
|
+
import langfun as lf
|
|
136
|
+
import pyglove as pg
|
|
137
|
+
|
|
138
|
+
class ImagePair(pg.Object):
|
|
139
|
+
image1: lf.Image
|
|
140
|
+
image2: lf.Image
|
|
141
|
+
|
|
142
|
+
pair = ImagePair(
|
|
143
|
+
image1=lf.Image(content=b'abc'), image2=lf.Image(content=b'def')
|
|
144
|
+
)
|
|
145
|
+
modalities = lf.Modality.from_value(pair)
|
|
146
|
+
|
|
147
|
+
# Replace Image objects with ModalityRef placeholders
|
|
148
|
+
pair_with_refs = lf.ModalityRef.placehold(pair)
|
|
149
|
+
print(pair_with_refs.image1)
|
|
150
|
+
# Output: ModalityRef(id='image:d81e5a68')
|
|
151
|
+
|
|
152
|
+
# Restore Image objects from ModalityRef placeholders
|
|
153
|
+
pair_restored = lf.ModalityRef.restore(pair_with_refs, modalities)
|
|
154
|
+
assert pair_restored.image1.content == b'abc'
|
|
155
|
+
```
|
|
103
156
|
"""
|
|
104
157
|
|
|
105
|
-
|
|
158
|
+
id: str
|
|
106
159
|
|
|
107
160
|
def custom_apply(
|
|
108
161
|
self, path: pg.KeyPath, value_spec: pg.ValueSpec, *args, **kwargs
|
|
@@ -122,12 +175,97 @@ class ModalityRef(pg.Object, pg.typing.CustomTyping):
|
|
|
122
175
|
"""
|
|
123
176
|
|
|
124
177
|
def _placehold(k, v, p):
|
|
125
|
-
del p
|
|
178
|
+
del k, p
|
|
126
179
|
if isinstance(v, Modality):
|
|
127
|
-
return ModalityRef(
|
|
180
|
+
return ModalityRef(id=v.id)
|
|
128
181
|
return v
|
|
129
182
|
return value.clone().rebind(_placehold, raise_on_no_change=False)
|
|
130
183
|
|
|
184
|
+
@classmethod
|
|
185
|
+
def restore(cls, value: pg.Symbolic, modalities: dict[str, Modality]) -> Any:
|
|
186
|
+
"""Returns a copy of value by replacing refs with modality objects."""
|
|
187
|
+
def _restore(k, v, p):
|
|
188
|
+
del k, p
|
|
189
|
+
if isinstance(v, ModalityRef):
|
|
190
|
+
modality_object = modalities.get(v.id)
|
|
191
|
+
if modality_object is None:
|
|
192
|
+
raise ValueError(
|
|
193
|
+
f'Modality {v.id} not found in modalities {modalities.keys()}'
|
|
194
|
+
)
|
|
195
|
+
return modality_object
|
|
196
|
+
return v
|
|
197
|
+
return value.rebind(_restore, raise_on_no_change=False)
|
|
198
|
+
|
|
131
199
|
|
|
132
200
|
class ModalityError(RuntimeError): # pylint: disable=g-bad-exception-name
|
|
133
201
|
"""Exception raised when modality is not supported."""
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
#
|
|
205
|
+
# Context managers to deal with modality objects.
|
|
206
|
+
#
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
_TLS_MODALITY_CAPTURE_SCOPE = '__modality_capture_scope__'
|
|
210
|
+
_TLS_MODALITY_AS_REF = '__format_modality_as_ref__'
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def format_modality_as_ref(enabled: bool = True) -> ContextManager[None]:
|
|
214
|
+
"""A context manager that formats modality objects as references."""
|
|
215
|
+
return pg.object_utils.thread_local_value_scope(
|
|
216
|
+
_TLS_MODALITY_AS_REF, enabled, False
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class _ModalityCaptureContext:
|
|
221
|
+
"""A context to capture modality objects when being rendered."""
|
|
222
|
+
|
|
223
|
+
def __init__(self):
|
|
224
|
+
self._references: dict[str, pg.Ref[Modality]] = {}
|
|
225
|
+
|
|
226
|
+
def capture(self, modality: Modality) -> None:
|
|
227
|
+
"""Captures the modality object."""
|
|
228
|
+
self._references[modality.id] = pg.Ref(modality)
|
|
229
|
+
|
|
230
|
+
@property
|
|
231
|
+
def references(self) -> dict[str, pg.Ref[Modality]]:
|
|
232
|
+
"""Returns the modality references captured in this context."""
|
|
233
|
+
return self._references
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
@contextlib.contextmanager
|
|
237
|
+
def capture_rendered_modalities() -> Iterator[dict[str, pg.Ref[Modality]]]:
|
|
238
|
+
"""Capture modality objects whose references is being rendered.
|
|
239
|
+
|
|
240
|
+
Example:
|
|
241
|
+
```
|
|
242
|
+
image = lf.Image.from_url(...)
|
|
243
|
+
with lf.modality.capture_rendered_modalities() as rendered_modalities:
|
|
244
|
+
with lf.modality.format_modality_as_ref():
|
|
245
|
+
print(f'Hello {image}')
|
|
246
|
+
self.assertEqual(rendered_modalities, {'image:<hash>': pg.Ref(image)})
|
|
247
|
+
```
|
|
248
|
+
"""
|
|
249
|
+
context = get_modality_capture_context()
|
|
250
|
+
top_level = context is None
|
|
251
|
+
if top_level:
|
|
252
|
+
context = _ModalityCaptureContext()
|
|
253
|
+
pg.object_utils.thread_local_set(_TLS_MODALITY_CAPTURE_SCOPE, context)
|
|
254
|
+
|
|
255
|
+
try:
|
|
256
|
+
yield context.references # pylint: disable=attribute-error
|
|
257
|
+
finally:
|
|
258
|
+
if top_level:
|
|
259
|
+
pg.object_utils.thread_local_del(_TLS_MODALITY_CAPTURE_SCOPE)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def get_modality_capture_context() -> _ModalityCaptureContext | None:
|
|
263
|
+
"""Returns the current modality capture context."""
|
|
264
|
+
return pg.object_utils.thread_local_get(_TLS_MODALITY_CAPTURE_SCOPE, None)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _camel_to_snake(name: str) -> str:
|
|
268
|
+
"""Converts a camelCase name to snake_case."""
|
|
269
|
+
return re.sub(
|
|
270
|
+
pattern=r'([A-Z]+)', repl=r'_\1', string=name
|
|
271
|
+
).lower().lstrip('_')
|
langfun/core/modality_test.py
CHANGED
|
@@ -29,34 +29,64 @@ class ModalityTest(unittest.TestCase):
|
|
|
29
29
|
|
|
30
30
|
def test_basic(self):
|
|
31
31
|
v = CustomModality('a')
|
|
32
|
-
self.
|
|
32
|
+
self.assertEqual(v.id, 'custom_modality:0cc175b9')
|
|
33
33
|
self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
|
|
34
34
|
self.assertEqual(v.hash, '0cc175b9')
|
|
35
35
|
|
|
36
36
|
_ = pg.Dict(metadata=pg.Dict(x=pg.Dict(metadata=pg.Dict(y=v))))
|
|
37
|
-
self.assertEqual(v.
|
|
37
|
+
self.assertEqual(v.id, 'custom_modality:0cc175b9')
|
|
38
38
|
self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
|
|
39
39
|
with modality.format_modality_as_ref():
|
|
40
|
-
self.assertEqual(str(v), '<<[[
|
|
40
|
+
self.assertEqual(str(v), '<<[[custom_modality:0cc175b9]]>>')
|
|
41
|
+
|
|
42
|
+
def test_capture_rendered_modalities(self):
|
|
43
|
+
x = CustomModality('a')
|
|
44
|
+
y = CustomModality('b')
|
|
45
|
+
z = CustomModality('b')
|
|
46
|
+
|
|
47
|
+
with modality.capture_rendered_modalities() as rendered_modalities:
|
|
48
|
+
with modality.format_modality_as_ref():
|
|
49
|
+
self.assertEqual(
|
|
50
|
+
f'Hello {x} {y} {z}',
|
|
51
|
+
(
|
|
52
|
+
'Hello <<[[custom_modality:0cc175b9]]>> '
|
|
53
|
+
'<<[[custom_modality:92eb5ffe]]>> '
|
|
54
|
+
'<<[[custom_modality:92eb5ffe]]>>'
|
|
55
|
+
)
|
|
56
|
+
)
|
|
57
|
+
self.assertEqual(len(rendered_modalities), 2)
|
|
58
|
+
self.assertIs(rendered_modalities['custom_modality:0cc175b9'].value, x)
|
|
59
|
+
# y and z share the same content will be treated as the same object.
|
|
60
|
+
self.assertIs(rendered_modalities['custom_modality:92eb5ffe'].value, z)
|
|
41
61
|
|
|
42
62
|
|
|
43
63
|
class ModalityRefTest(unittest.TestCase):
|
|
44
64
|
|
|
45
|
-
def
|
|
65
|
+
def test_placehold_and_restore(self):
|
|
46
66
|
class A(pg.Object):
|
|
47
67
|
x: Any
|
|
48
68
|
y: Any
|
|
49
69
|
|
|
50
|
-
|
|
70
|
+
image_a = CustomModality('a')
|
|
71
|
+
image_b = CustomModality('b')
|
|
72
|
+
a = A(x=dict(z=image_a), y=image_b)
|
|
73
|
+
a_placehold = modality.ModalityRef.placehold(a)
|
|
51
74
|
self.assertEqual(
|
|
52
|
-
|
|
53
|
-
A(x=dict(z=modality.ModalityRef(
|
|
75
|
+
a_placehold,
|
|
76
|
+
A(x=dict(z=modality.ModalityRef(image_a.id)),
|
|
77
|
+
y=modality.ModalityRef(image_b.id)),
|
|
78
|
+
)
|
|
79
|
+
a_restore = modality.ModalityRef.restore(
|
|
80
|
+
a_placehold.clone(),
|
|
81
|
+
{image_a.id: image_a, image_b.id: image_b},
|
|
54
82
|
)
|
|
83
|
+
self.assertTrue(pg.eq(a_restore, a))
|
|
55
84
|
self.assertEqual(
|
|
56
85
|
modality.ModalityRef.placehold(a.x),
|
|
57
|
-
|
|
58
|
-
dict(z=modality.ModalityRef('x.z')),
|
|
86
|
+
dict(z=modality.ModalityRef(image_a.id)),
|
|
59
87
|
)
|
|
88
|
+
with self.assertRaisesRegex(ValueError, 'Modality .* not found'):
|
|
89
|
+
modality.ModalityRef.restore(a_placehold, {image_a.id: image_a})
|
|
60
90
|
|
|
61
91
|
def test_from_value(self):
|
|
62
92
|
class A(pg.Object):
|
|
@@ -68,8 +98,8 @@ class ModalityRefTest(unittest.TestCase):
|
|
|
68
98
|
pg.eq(
|
|
69
99
|
modality.Modality.from_value(a),
|
|
70
100
|
{
|
|
71
|
-
'
|
|
72
|
-
'
|
|
101
|
+
'custom_modality:0cc175b9': CustomModality('a'),
|
|
102
|
+
'custom_modality:92eb5ffe': CustomModality('b'),
|
|
73
103
|
},
|
|
74
104
|
)
|
|
75
105
|
)
|
|
@@ -77,7 +107,7 @@ class ModalityRefTest(unittest.TestCase):
|
|
|
77
107
|
pg.eq(
|
|
78
108
|
modality.Modality.from_value(a.x.z),
|
|
79
109
|
{
|
|
80
|
-
'
|
|
110
|
+
'custom_modality:0cc175b9': CustomModality('a'),
|
|
81
111
|
},
|
|
82
112
|
)
|
|
83
113
|
)
|
langfun/core/natural_language.py
CHANGED
|
@@ -11,7 +11,7 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
"""Natural language
|
|
14
|
+
"""Natural language formatting."""
|
|
15
15
|
|
|
16
16
|
import abc
|
|
17
17
|
import pyglove as pg
|
langfun/core/sampling.py
CHANGED
|
@@ -38,10 +38,10 @@ def sweep(
|
|
|
38
38
|
Union[message_lib.Message, BaseException, None], # LM output.
|
|
39
39
|
],
|
|
40
40
|
]:
|
|
41
|
-
"""Sweeps the input/output of
|
|
41
|
+
"""Sweeps the input/output of a LangFunc search space concurrently.
|
|
42
42
|
|
|
43
43
|
Args:
|
|
44
|
-
lfun: An LangFunc object that contains `pg.oneof` as the search space
|
|
44
|
+
lfun: An LangFunc object that contains `pg.oneof` as the search space
|
|
45
45
|
for sampling.
|
|
46
46
|
num_examples: Number of examples to sample.
|
|
47
47
|
max_workers: Max number of concurrent workers to do sampling.
|
|
@@ -84,10 +84,10 @@ def random_sample(
|
|
|
84
84
|
Union[message_lib.Message, BaseException, None], # LM output.
|
|
85
85
|
],
|
|
86
86
|
]:
|
|
87
|
-
"""Random samples the input/output of
|
|
87
|
+
"""Random samples the input/output of a LangFunc search space concurrently.
|
|
88
88
|
|
|
89
89
|
Args:
|
|
90
|
-
lfun: An LangFunc object that contains `pg.oneof` as the search space
|
|
90
|
+
lfun: An LangFunc object that contains `pg.oneof` as the search space
|
|
91
91
|
for sampling.
|
|
92
92
|
num_examples: Number of examples to sample.
|
|
93
93
|
max_workers: Max number of concurrent workers to do sampling.
|
langfun/core/sampling_test.py
CHANGED
|
@@ -39,8 +39,13 @@ class SamplingTest(unittest.TestCase):
|
|
|
39
39
|
l = LangFunc('Compute {{x}} and {{y}}', x=pg.oneof([1, 2]))
|
|
40
40
|
with component.context(lm=ExcitedEchoer()):
|
|
41
41
|
samples = list(sampling.sweep(l, y=pg.oneof([3, 4])))
|
|
42
|
-
samples = sorted(
|
|
43
|
-
|
|
42
|
+
samples = sorted(
|
|
43
|
+
samples,
|
|
44
|
+
key=lambda x: (
|
|
45
|
+
x[0].__template_input__.x,
|
|
46
|
+
x[0].__template_input__.y
|
|
47
|
+
)
|
|
48
|
+
)
|
|
44
49
|
self.assertEqual(
|
|
45
50
|
samples,
|
|
46
51
|
[
|
|
@@ -57,7 +62,12 @@ class SamplingTest(unittest.TestCase):
|
|
|
57
62
|
samples = list(
|
|
58
63
|
sampling.random_sample(l, y=pg.oneof([2, 4]), num_examples=3, seed=1)
|
|
59
64
|
)
|
|
60
|
-
samples = sorted(
|
|
65
|
+
samples = sorted(
|
|
66
|
+
samples, key=lambda x: (
|
|
67
|
+
x[0].__template_input__.x,
|
|
68
|
+
x[0].__template_input__.y
|
|
69
|
+
)
|
|
70
|
+
)
|
|
61
71
|
|
|
62
72
|
self.assertEqual(
|
|
63
73
|
samples,
|
|
@@ -97,7 +107,13 @@ class SamplingTest(unittest.TestCase):
|
|
|
97
107
|
silence_on_errors=(AttributeError,),
|
|
98
108
|
ignore_examples_with_errors=False))
|
|
99
109
|
|
|
100
|
-
samples = sorted(
|
|
110
|
+
samples = sorted(
|
|
111
|
+
samples,
|
|
112
|
+
key=lambda x: (
|
|
113
|
+
x[0].__template_input__.x,
|
|
114
|
+
x[0].__template_input__.y
|
|
115
|
+
)
|
|
116
|
+
)
|
|
101
117
|
self.assertEqual(
|
|
102
118
|
[x[0] for x in samples],
|
|
103
119
|
[
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -16,29 +16,7 @@
|
|
|
16
16
|
# pylint: disable=g-bad-import-order
|
|
17
17
|
# pylint: disable=g-importing-member
|
|
18
18
|
|
|
19
|
-
from langfun.core.structured.schema import
|
|
20
|
-
|
|
21
|
-
from langfun.core.structured.schema import Missing
|
|
22
|
-
from langfun.core.structured.schema import MISSING
|
|
23
|
-
from langfun.core.structured.schema import Unknown
|
|
24
|
-
from langfun.core.structured.schema import UNKNOWN
|
|
25
|
-
|
|
26
|
-
from langfun.core.structured.schema import Schema
|
|
27
|
-
from langfun.core.structured.schema import SchemaProtocol
|
|
28
|
-
from langfun.core.structured.schema import schema_spec
|
|
29
|
-
|
|
30
|
-
from langfun.core.structured.schema import SchemaError
|
|
31
|
-
from langfun.core.structured.schema import JsonError
|
|
32
|
-
|
|
33
|
-
from langfun.core.structured.schema import class_dependencies
|
|
34
|
-
from langfun.core.structured.schema import class_definition
|
|
35
|
-
from langfun.core.structured.schema import class_definitions
|
|
36
|
-
from langfun.core.structured.schema import annotation
|
|
37
|
-
from langfun.core.structured.schema import structure_from_python
|
|
38
|
-
|
|
39
|
-
from langfun.core.structured.schema import schema_repr
|
|
40
|
-
from langfun.core.structured.schema import source_form
|
|
41
|
-
from langfun.core.structured.schema import value_repr
|
|
19
|
+
from langfun.core.structured.schema import *
|
|
42
20
|
|
|
43
21
|
from langfun.core.structured.schema_generation import generate_class
|
|
44
22
|
from langfun.core.structured.schema_generation import classgen_example
|
|
@@ -116,15 +116,10 @@ class _CompleteStructure(mapping.Mapping):
|
|
|
116
116
|
)
|
|
117
117
|
|
|
118
118
|
def postprocess_result(self, result: Any) -> Any:
|
|
119
|
-
"""
|
|
119
|
+
"""Postprocesses result."""
|
|
120
120
|
# Try restore modality objects from the input value to output value.
|
|
121
|
-
modalities
|
|
122
|
-
|
|
123
|
-
# Remove the `input` prefix for all entries.
|
|
124
|
-
modalities = pg.object_utils.flatten(
|
|
125
|
-
pg.object_utils.canonicalize(modalities)['input']
|
|
126
|
-
)
|
|
127
|
-
result.rebind(modalities)
|
|
121
|
+
if modalities := self.modalities(self.input):
|
|
122
|
+
result = lf.ModalityRef.restore(result, modalities)
|
|
128
123
|
return result
|
|
129
124
|
|
|
130
125
|
def globals(self):
|
|
@@ -156,7 +151,7 @@ class _CompleteStructure(mapping.Mapping):
|
|
|
156
151
|
#
|
|
157
152
|
|
|
158
153
|
def has_modality_refs(self, value: Any) -> bool:
|
|
159
|
-
"""Returns
|
|
154
|
+
"""Returns True if the value has modalities."""
|
|
160
155
|
return not isinstance(value, lf.Modality) and pg.contains(
|
|
161
156
|
value, type=lf.Modality
|
|
162
157
|
)
|
|
@@ -186,41 +181,36 @@ def complete(
|
|
|
186
181
|
returns_message: bool = False,
|
|
187
182
|
**kwargs,
|
|
188
183
|
) -> Any:
|
|
189
|
-
"""
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
r = lf.query(prompt, Flight)
|
|
214
|
-
assert isinstance(r, Flight)
|
|
215
|
-
assert r.airline == 'United Airlines'
|
|
216
|
-
assert r.departure_airport_code == 'SFO'
|
|
217
|
-
assert r.duration.hour = 7
|
|
218
|
-
```
|
|
184
|
+
"""Completes a symbolic value by filling its missing fields using an LLM.
|
|
185
|
+
|
|
186
|
+
`lf.complete` is used to fill in missing information in structured
|
|
187
|
+
data. It takes a partially defined `pg.Object` instance where some fields
|
|
188
|
+
are marked as `lf.MISSING`, and uses a language model to infer and
|
|
189
|
+
populate those fields based on the provided values.
|
|
190
|
+
|
|
191
|
+
**Example:**
|
|
192
|
+
|
|
193
|
+
```python
|
|
194
|
+
import langfun as lf
|
|
195
|
+
import pyglove as pg
|
|
196
|
+
|
|
197
|
+
class Country(pg.Object):
|
|
198
|
+
name: str
|
|
199
|
+
capital: str = lf.MISSING
|
|
200
|
+
population: int = lf.MISSING
|
|
201
|
+
|
|
202
|
+
# Filling missing fields of Country(name='France')
|
|
203
|
+
country = lf.complete(Country(name='France'), lm=lf.llms.Gemini25Flash())
|
|
204
|
+
print(country)
|
|
205
|
+
# Output: Country(name='France', capital='Paris', population=67000000)
|
|
206
|
+
```
|
|
219
207
|
|
|
220
208
|
Args:
|
|
221
|
-
input_value: A symbolic value that may contain missing values
|
|
222
|
-
|
|
223
|
-
|
|
209
|
+
input_value: A symbolic value that may contain missing values marked
|
|
210
|
+
by `lf.MISSING`.
|
|
211
|
+
default: The default value to return if parsing fails. If
|
|
212
|
+
`lf.RAISE_IF_HAS_ERROR` is used (default), an error will be raised
|
|
213
|
+
instead.
|
|
224
214
|
lm: The language model to use. If not specified, the language model from
|
|
225
215
|
`lf.context` context manager will be used.
|
|
226
216
|
examples: An optional list of fewshot examples for helping parsing. If None,
|
|
@@ -236,10 +226,10 @@ def complete(
|
|
|
236
226
|
returns_message: If True, returns `lf.Message` as the output, instead of
|
|
237
227
|
returning the structured `message.result`.
|
|
238
228
|
**kwargs: Keyword arguments passed to the
|
|
239
|
-
`lf.structured.
|
|
229
|
+
`lf.structured.Mapping` transform.
|
|
240
230
|
|
|
241
231
|
Returns:
|
|
242
|
-
The
|
|
232
|
+
The input object with missing fields completed by LLM.
|
|
243
233
|
"""
|
|
244
234
|
t = _CompleteStructure(
|
|
245
235
|
input=schema_lib.mark_missing(input_value),
|