langfun 0.1.2.dev202510240805__py3-none-any.whl → 0.1.2.dev202510250803__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/core/concurrent_test.py +1 -0
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini_test.py +12 -9
- langfun/core/data/conversion/openai.py +134 -30
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/v2/progress_tracking_test.py +3 -0
- langfun/core/langfunc_test.py +4 -2
- langfun/core/language_model.py +6 -6
- langfun/core/language_model_test.py +9 -3
- langfun/core/llms/__init__.py +2 -1
- langfun/core/llms/cache/base.py +3 -1
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/deepseek.py +1 -1
- langfun/core/llms/groq.py +1 -1
- langfun/core/llms/llama_cpp.py +1 -1
- langfun/core/llms/openai.py +7 -2
- langfun/core/llms/openai_compatible.py +134 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/vertexai.py +2 -2
- langfun/core/message.py +78 -44
- langfun/core/message_test.py +56 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/mime.py +9 -0
- langfun/core/modality.py +104 -27
- langfun/core/modality_test.py +42 -12
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/completion.py +2 -7
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/mapping.py +4 -13
- langfun/core/structured/querying.py +13 -11
- langfun/core/structured/querying_test.py +65 -29
- langfun/core/template.py +39 -13
- langfun/core/template_test.py +83 -17
- langfun/env/event_handlers/metric_writer_test.py +3 -3
- langfun/env/load_balancers_test.py +2 -2
- {langfun-0.1.2.dev202510240805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202510240805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/RECORD +41 -41
- {langfun-0.1.2.dev202510240805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202510240805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202510240805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/top_level.txt +0 -0
langfun/core/modalities/mime.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import base64
|
|
17
17
|
import functools
|
|
18
|
+
import hashlib
|
|
18
19
|
from typing import Annotated, Any, Iterable, Type, Union
|
|
19
20
|
import langfun.core as lf
|
|
20
21
|
# Placeholder for Google-internal internet access import.
|
|
@@ -87,6 +88,14 @@ class Mime(lf.Modality):
|
|
|
87
88
|
"""Returns True if the MIME type is a binary type."""
|
|
88
89
|
return not self.is_text
|
|
89
90
|
|
|
91
|
+
@property
|
|
92
|
+
def hash(self) -> str:
|
|
93
|
+
"""Returns the hash of the MIME content."""
|
|
94
|
+
# Hash the URI to avoid downloading the content.
|
|
95
|
+
if self.uri is not None:
|
|
96
|
+
return hashlib.md5(self.uri.encode()).hexdigest()[:8]
|
|
97
|
+
return super().hash
|
|
98
|
+
|
|
90
99
|
def to_text(self) -> str:
|
|
91
100
|
"""Returns the text content of the MIME type."""
|
|
92
101
|
if not self.is_text:
|
langfun/core/modality.py
CHANGED
|
@@ -14,23 +14,15 @@
|
|
|
14
14
|
"""Interface for modality (e.g. Image, Video, etc.)."""
|
|
15
15
|
|
|
16
16
|
import abc
|
|
17
|
+
import contextlib
|
|
17
18
|
import functools
|
|
18
19
|
import hashlib
|
|
19
|
-
|
|
20
|
+
import re
|
|
21
|
+
from typing import Any, ContextManager, Iterator
|
|
20
22
|
from langfun.core import component
|
|
21
23
|
import pyglove as pg
|
|
22
24
|
|
|
23
25
|
|
|
24
|
-
_TLS_MODALITY_AS_REF = '__format_modality_as_ref__'
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
def format_modality_as_ref(enabled: bool = True) -> ContextManager[None]:
|
|
28
|
-
"""A context manager that formats modality objects as references."""
|
|
29
|
-
return pg.object_utils.thread_local_value_scope(
|
|
30
|
-
_TLS_MODALITY_AS_REF, enabled, False
|
|
31
|
-
)
|
|
32
|
-
|
|
33
|
-
|
|
34
26
|
class Modality(component.Component, pg.views.HtmlTreeView.Extension):
|
|
35
27
|
"""Base class for multimodal object."""
|
|
36
28
|
|
|
@@ -39,15 +31,18 @@ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
|
|
|
39
31
|
|
|
40
32
|
def _on_bound(self):
|
|
41
33
|
super()._on_bound()
|
|
42
|
-
# Invalidate cached hash if modality member is changed.
|
|
34
|
+
# Invalidate cached hash and id if modality member is changed.
|
|
43
35
|
self.__dict__.pop('hash', None)
|
|
36
|
+
self.__dict__.pop('id', None)
|
|
44
37
|
|
|
45
38
|
def format(self, *args, **kwargs) -> str:
|
|
46
|
-
if
|
|
47
|
-
_TLS_MODALITY_AS_REF, False
|
|
48
|
-
):
|
|
39
|
+
if not pg.object_utils.thread_local_get(_TLS_MODALITY_AS_REF, False):
|
|
49
40
|
return super().format(*args, **kwargs)
|
|
50
|
-
|
|
41
|
+
|
|
42
|
+
capture_scope = get_modality_capture_context()
|
|
43
|
+
if capture_scope is not None:
|
|
44
|
+
capture_scope.capture(self)
|
|
45
|
+
return Modality.text_marker(self.id)
|
|
51
46
|
|
|
52
47
|
def __str_kwargs__(self) -> dict[str, Any]:
|
|
53
48
|
# For modality objects, we don't want to use markdown format when they
|
|
@@ -70,14 +65,11 @@ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
|
|
|
70
65
|
"""Returns a marker in the text for this object."""
|
|
71
66
|
return Modality.REF_START + var_name + Modality.REF_END
|
|
72
67
|
|
|
73
|
-
@
|
|
74
|
-
def
|
|
68
|
+
@functools.cached_property
|
|
69
|
+
def id(self) -> str | None:
|
|
75
70
|
"""Returns the referred name of this object in its template."""
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
# Strip the metadata prefix under message.
|
|
79
|
-
path = str(self.sym_path)
|
|
80
|
-
return path[9:] if path.startswith('metadata.') else path
|
|
71
|
+
modality_type = _camel_to_snake(self.__class__.__name__)
|
|
72
|
+
return f'{modality_type}:{self.hash}'
|
|
81
73
|
|
|
82
74
|
@classmethod
|
|
83
75
|
def from_value(cls, value: pg.Symbolic) -> dict[str, 'Modality']:
|
|
@@ -86,7 +78,7 @@ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
|
|
|
86
78
|
def _visit(k, v, p):
|
|
87
79
|
del k, p
|
|
88
80
|
if isinstance(v, Modality):
|
|
89
|
-
modalities[v.
|
|
81
|
+
modalities[v.id] = v
|
|
90
82
|
return pg.TraverseAction.CONTINUE
|
|
91
83
|
return pg.TraverseAction.ENTER
|
|
92
84
|
|
|
@@ -102,7 +94,7 @@ class ModalityRef(pg.Object, pg.typing.CustomTyping):
|
|
|
102
94
|
structure.
|
|
103
95
|
"""
|
|
104
96
|
|
|
105
|
-
|
|
97
|
+
id: str
|
|
106
98
|
|
|
107
99
|
def custom_apply(
|
|
108
100
|
self, path: pg.KeyPath, value_spec: pg.ValueSpec, *args, **kwargs
|
|
@@ -122,12 +114,97 @@ class ModalityRef(pg.Object, pg.typing.CustomTyping):
|
|
|
122
114
|
"""
|
|
123
115
|
|
|
124
116
|
def _placehold(k, v, p):
|
|
125
|
-
del p
|
|
117
|
+
del k, p
|
|
126
118
|
if isinstance(v, Modality):
|
|
127
|
-
return ModalityRef(
|
|
119
|
+
return ModalityRef(id=v.id)
|
|
128
120
|
return v
|
|
129
121
|
return value.clone().rebind(_placehold, raise_on_no_change=False)
|
|
130
122
|
|
|
123
|
+
@classmethod
|
|
124
|
+
def restore(cls, value: pg.Symbolic, modalities: dict[str, Modality]) -> Any:
|
|
125
|
+
"""Returns a copy of value by replacing refs with modality objects."""
|
|
126
|
+
def _restore(k, v, p):
|
|
127
|
+
del k, p
|
|
128
|
+
if isinstance(v, ModalityRef):
|
|
129
|
+
modality_object = modalities.get(v.id)
|
|
130
|
+
if modality_object is None:
|
|
131
|
+
raise ValueError(
|
|
132
|
+
f'Modality {v.id} not found in modalities {modalities.keys()}'
|
|
133
|
+
)
|
|
134
|
+
return modality_object
|
|
135
|
+
return v
|
|
136
|
+
return value.rebind(_restore, raise_on_no_change=False)
|
|
137
|
+
|
|
131
138
|
|
|
132
139
|
class ModalityError(RuntimeError): # pylint: disable=g-bad-exception-name
|
|
133
140
|
"""Exception raised when modality is not supported."""
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
#
|
|
144
|
+
# Context managers to deal with modality objects.
|
|
145
|
+
#
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
_TLS_MODALITY_CAPTURE_SCOPE = '__modality_capture_scope__'
|
|
149
|
+
_TLS_MODALITY_AS_REF = '__format_modality_as_ref__'
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def format_modality_as_ref(enabled: bool = True) -> ContextManager[None]:
|
|
153
|
+
"""A context manager that formats modality objects as references."""
|
|
154
|
+
return pg.object_utils.thread_local_value_scope(
|
|
155
|
+
_TLS_MODALITY_AS_REF, enabled, False
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class _ModalityCaptureContext:
|
|
160
|
+
"""A context to capture modality objects when being rendered."""
|
|
161
|
+
|
|
162
|
+
def __init__(self):
|
|
163
|
+
self._references: dict[str, pg.Ref[Modality]] = {}
|
|
164
|
+
|
|
165
|
+
def capture(self, modality: Modality) -> None:
|
|
166
|
+
"""Captures the modality object."""
|
|
167
|
+
self._references[modality.id] = pg.Ref(modality)
|
|
168
|
+
|
|
169
|
+
@property
|
|
170
|
+
def references(self) -> dict[str, pg.Ref[Modality]]:
|
|
171
|
+
"""Returns the modality references captured in this context."""
|
|
172
|
+
return self._references
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@contextlib.contextmanager
|
|
176
|
+
def capture_rendered_modalities() -> Iterator[dict[str, pg.Ref[Modality]]]:
|
|
177
|
+
"""Capture modality objects whose references is being rendered.
|
|
178
|
+
|
|
179
|
+
Example:
|
|
180
|
+
```
|
|
181
|
+
image = lf.Image.from_url(...)
|
|
182
|
+
with lf.modality.capture_rendered_modalities() as rendered_modalities:
|
|
183
|
+
with lf.modality.format_modality_as_ref():
|
|
184
|
+
print(f'Hello {image}')
|
|
185
|
+
self.assertEqual(rendered_modalities, {'image:<hash>': pg.Ref(image)})
|
|
186
|
+
```
|
|
187
|
+
"""
|
|
188
|
+
context = get_modality_capture_context()
|
|
189
|
+
top_level = context is None
|
|
190
|
+
if top_level:
|
|
191
|
+
context = _ModalityCaptureContext()
|
|
192
|
+
pg.object_utils.thread_local_set(_TLS_MODALITY_CAPTURE_SCOPE, context)
|
|
193
|
+
|
|
194
|
+
try:
|
|
195
|
+
yield context.references # pylint: disable=attribute-error
|
|
196
|
+
finally:
|
|
197
|
+
if top_level:
|
|
198
|
+
pg.object_utils.thread_local_del(_TLS_MODALITY_CAPTURE_SCOPE)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def get_modality_capture_context() -> _ModalityCaptureContext | None:
|
|
202
|
+
"""Returns the current modality capture context."""
|
|
203
|
+
return pg.object_utils.thread_local_get(_TLS_MODALITY_CAPTURE_SCOPE, None)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _camel_to_snake(name: str) -> str:
|
|
207
|
+
"""Converts a camelCase name to snake_case."""
|
|
208
|
+
return re.sub(
|
|
209
|
+
pattern=r'([A-Z]+)', repl=r'_\1', string=name
|
|
210
|
+
).lower().lstrip('_')
|
langfun/core/modality_test.py
CHANGED
|
@@ -29,34 +29,64 @@ class ModalityTest(unittest.TestCase):
|
|
|
29
29
|
|
|
30
30
|
def test_basic(self):
|
|
31
31
|
v = CustomModality('a')
|
|
32
|
-
self.
|
|
32
|
+
self.assertEqual(v.id, 'custom_modality:0cc175b9')
|
|
33
33
|
self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
|
|
34
34
|
self.assertEqual(v.hash, '0cc175b9')
|
|
35
35
|
|
|
36
36
|
_ = pg.Dict(metadata=pg.Dict(x=pg.Dict(metadata=pg.Dict(y=v))))
|
|
37
|
-
self.assertEqual(v.
|
|
37
|
+
self.assertEqual(v.id, 'custom_modality:0cc175b9')
|
|
38
38
|
self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
|
|
39
39
|
with modality.format_modality_as_ref():
|
|
40
|
-
self.assertEqual(str(v), '<<[[
|
|
40
|
+
self.assertEqual(str(v), '<<[[custom_modality:0cc175b9]]>>')
|
|
41
|
+
|
|
42
|
+
def test_capture_rendered_modalities(self):
|
|
43
|
+
x = CustomModality('a')
|
|
44
|
+
y = CustomModality('b')
|
|
45
|
+
z = CustomModality('b')
|
|
46
|
+
|
|
47
|
+
with modality.capture_rendered_modalities() as rendered_modalities:
|
|
48
|
+
with modality.format_modality_as_ref():
|
|
49
|
+
self.assertEqual(
|
|
50
|
+
f'Hello {x} {y} {z}',
|
|
51
|
+
(
|
|
52
|
+
'Hello <<[[custom_modality:0cc175b9]]>> '
|
|
53
|
+
'<<[[custom_modality:92eb5ffe]]>> '
|
|
54
|
+
'<<[[custom_modality:92eb5ffe]]>>'
|
|
55
|
+
)
|
|
56
|
+
)
|
|
57
|
+
self.assertEqual(len(rendered_modalities), 2)
|
|
58
|
+
self.assertIs(rendered_modalities['custom_modality:0cc175b9'].value, x)
|
|
59
|
+
# y and z share the same content will be treated as the same object.
|
|
60
|
+
self.assertIs(rendered_modalities['custom_modality:92eb5ffe'].value, z)
|
|
41
61
|
|
|
42
62
|
|
|
43
63
|
class ModalityRefTest(unittest.TestCase):
|
|
44
64
|
|
|
45
|
-
def
|
|
65
|
+
def test_placehold_and_restore(self):
|
|
46
66
|
class A(pg.Object):
|
|
47
67
|
x: Any
|
|
48
68
|
y: Any
|
|
49
69
|
|
|
50
|
-
|
|
70
|
+
image_a = CustomModality('a')
|
|
71
|
+
image_b = CustomModality('b')
|
|
72
|
+
a = A(x=dict(z=image_a), y=image_b)
|
|
73
|
+
a_placehold = modality.ModalityRef.placehold(a)
|
|
51
74
|
self.assertEqual(
|
|
52
|
-
|
|
53
|
-
A(x=dict(z=modality.ModalityRef(
|
|
75
|
+
a_placehold,
|
|
76
|
+
A(x=dict(z=modality.ModalityRef(image_a.id)),
|
|
77
|
+
y=modality.ModalityRef(image_b.id)),
|
|
78
|
+
)
|
|
79
|
+
a_restore = modality.ModalityRef.restore(
|
|
80
|
+
a_placehold.clone(),
|
|
81
|
+
{image_a.id: image_a, image_b.id: image_b},
|
|
54
82
|
)
|
|
83
|
+
self.assertTrue(pg.eq(a_restore, a))
|
|
55
84
|
self.assertEqual(
|
|
56
85
|
modality.ModalityRef.placehold(a.x),
|
|
57
|
-
|
|
58
|
-
dict(z=modality.ModalityRef('x.z')),
|
|
86
|
+
dict(z=modality.ModalityRef(image_a.id)),
|
|
59
87
|
)
|
|
88
|
+
with self.assertRaisesRegex(ValueError, 'Modality .* not found'):
|
|
89
|
+
modality.ModalityRef.restore(a_placehold, {image_a.id: image_a})
|
|
60
90
|
|
|
61
91
|
def test_from_value(self):
|
|
62
92
|
class A(pg.Object):
|
|
@@ -68,8 +98,8 @@ class ModalityRefTest(unittest.TestCase):
|
|
|
68
98
|
pg.eq(
|
|
69
99
|
modality.Modality.from_value(a),
|
|
70
100
|
{
|
|
71
|
-
'
|
|
72
|
-
'
|
|
101
|
+
'custom_modality:0cc175b9': CustomModality('a'),
|
|
102
|
+
'custom_modality:92eb5ffe': CustomModality('b'),
|
|
73
103
|
},
|
|
74
104
|
)
|
|
75
105
|
)
|
|
@@ -77,7 +107,7 @@ class ModalityRefTest(unittest.TestCase):
|
|
|
77
107
|
pg.eq(
|
|
78
108
|
modality.Modality.from_value(a.x.z),
|
|
79
109
|
{
|
|
80
|
-
'
|
|
110
|
+
'custom_modality:0cc175b9': CustomModality('a'),
|
|
81
111
|
},
|
|
82
112
|
)
|
|
83
113
|
)
|
langfun/core/sampling_test.py
CHANGED
|
@@ -39,8 +39,13 @@ class SamplingTest(unittest.TestCase):
|
|
|
39
39
|
l = LangFunc('Compute {{x}} and {{y}}', x=pg.oneof([1, 2]))
|
|
40
40
|
with component.context(lm=ExcitedEchoer()):
|
|
41
41
|
samples = list(sampling.sweep(l, y=pg.oneof([3, 4])))
|
|
42
|
-
samples = sorted(
|
|
43
|
-
|
|
42
|
+
samples = sorted(
|
|
43
|
+
samples,
|
|
44
|
+
key=lambda x: (
|
|
45
|
+
x[0].__template_input__.x,
|
|
46
|
+
x[0].__template_input__.y
|
|
47
|
+
)
|
|
48
|
+
)
|
|
44
49
|
self.assertEqual(
|
|
45
50
|
samples,
|
|
46
51
|
[
|
|
@@ -57,7 +62,12 @@ class SamplingTest(unittest.TestCase):
|
|
|
57
62
|
samples = list(
|
|
58
63
|
sampling.random_sample(l, y=pg.oneof([2, 4]), num_examples=3, seed=1)
|
|
59
64
|
)
|
|
60
|
-
samples = sorted(
|
|
65
|
+
samples = sorted(
|
|
66
|
+
samples, key=lambda x: (
|
|
67
|
+
x[0].__template_input__.x,
|
|
68
|
+
x[0].__template_input__.y
|
|
69
|
+
)
|
|
70
|
+
)
|
|
61
71
|
|
|
62
72
|
self.assertEqual(
|
|
63
73
|
samples,
|
|
@@ -97,7 +107,13 @@ class SamplingTest(unittest.TestCase):
|
|
|
97
107
|
silence_on_errors=(AttributeError,),
|
|
98
108
|
ignore_examples_with_errors=False))
|
|
99
109
|
|
|
100
|
-
samples = sorted(
|
|
110
|
+
samples = sorted(
|
|
111
|
+
samples,
|
|
112
|
+
key=lambda x: (
|
|
113
|
+
x[0].__template_input__.x,
|
|
114
|
+
x[0].__template_input__.y
|
|
115
|
+
)
|
|
116
|
+
)
|
|
101
117
|
self.assertEqual(
|
|
102
118
|
[x[0] for x in samples],
|
|
103
119
|
[
|
|
@@ -118,13 +118,8 @@ class _CompleteStructure(mapping.Mapping):
|
|
|
118
118
|
def postprocess_result(self, result: Any) -> Any:
|
|
119
119
|
"""Postprocess result."""
|
|
120
120
|
# Try restore modality objects from the input value to output value.
|
|
121
|
-
modalities
|
|
122
|
-
|
|
123
|
-
# Remove the `input` prefix for all entries.
|
|
124
|
-
modalities = pg.object_utils.flatten(
|
|
125
|
-
pg.object_utils.canonicalize(modalities)['input']
|
|
126
|
-
)
|
|
127
|
-
result.rebind(modalities)
|
|
121
|
+
if modalities := self.modalities(self.input):
|
|
122
|
+
result = lf.ModalityRef.restore(result, modalities)
|
|
128
123
|
return result
|
|
129
124
|
|
|
130
125
|
def globals(self):
|
|
@@ -407,22 +407,17 @@ class CompleteStructureTest(unittest.TestCase):
|
|
|
407
407
|
image: modalities.Image
|
|
408
408
|
name: str
|
|
409
409
|
|
|
410
|
+
image_elephant = modalities.Image.from_bytes(b'image_of_elephant')
|
|
411
|
+
image_rabbit = modalities.Image.from_bytes(b'image_of_rabbit')
|
|
410
412
|
input_value = schema_lib.mark_missing(
|
|
411
|
-
Animal.partial(
|
|
412
|
-
modalities.Image.from_bytes(b'image_of_elephant'),
|
|
413
|
-
)
|
|
413
|
+
Animal.partial(image_elephant)
|
|
414
414
|
)
|
|
415
415
|
l = completion._CompleteStructure(
|
|
416
416
|
input=input_value,
|
|
417
417
|
examples=[
|
|
418
418
|
mapping.MappingExample(
|
|
419
|
-
input=Animal.partial(
|
|
420
|
-
|
|
421
|
-
),
|
|
422
|
-
output=Animal(
|
|
423
|
-
modalities.Image.from_bytes(b'image_of_rabbit'),
|
|
424
|
-
'rabbit',
|
|
425
|
-
),
|
|
419
|
+
input=Animal.partial(image_rabbit),
|
|
420
|
+
output=Animal(image_rabbit, 'rabbit'),
|
|
426
421
|
)
|
|
427
422
|
],
|
|
428
423
|
)
|
|
@@ -430,7 +425,7 @@ class CompleteStructureTest(unittest.TestCase):
|
|
|
430
425
|
self.maxDiff = None
|
|
431
426
|
self.assertEqual(
|
|
432
427
|
lm_input.text,
|
|
433
|
-
inspect.cleandoc("""
|
|
428
|
+
inspect.cleandoc(f"""
|
|
434
429
|
Please generate the OUTPUT_OBJECT by completing the MISSING fields from the last INPUT_OBJECT.
|
|
435
430
|
|
|
436
431
|
INSTRUCTIONS:
|
|
@@ -457,22 +452,22 @@ class CompleteStructureTest(unittest.TestCase):
|
|
|
457
452
|
```python
|
|
458
453
|
Animal(
|
|
459
454
|
image=ModalityRef(
|
|
460
|
-
|
|
455
|
+
id='{image_rabbit.id}'
|
|
461
456
|
),
|
|
462
457
|
name=MISSING(str)
|
|
463
458
|
)
|
|
464
459
|
```
|
|
465
460
|
|
|
466
461
|
MODALITY_REFERENCES:
|
|
467
|
-
{
|
|
468
|
-
'
|
|
469
|
-
}
|
|
462
|
+
{{
|
|
463
|
+
'{image_rabbit.id}': <<[[{image_rabbit.id}]]>>
|
|
464
|
+
}}
|
|
470
465
|
|
|
471
466
|
OUTPUT_OBJECT:
|
|
472
467
|
```python
|
|
473
468
|
Animal(
|
|
474
469
|
image=ModalityRef(
|
|
475
|
-
|
|
470
|
+
id='{image_rabbit.id}'
|
|
476
471
|
),
|
|
477
472
|
name='rabbit'
|
|
478
473
|
)
|
|
@@ -483,16 +478,16 @@ class CompleteStructureTest(unittest.TestCase):
|
|
|
483
478
|
```python
|
|
484
479
|
Animal(
|
|
485
480
|
image=ModalityRef(
|
|
486
|
-
|
|
481
|
+
id='{image_elephant.id}'
|
|
487
482
|
),
|
|
488
483
|
name=MISSING(str)
|
|
489
484
|
)
|
|
490
485
|
```
|
|
491
486
|
|
|
492
487
|
MODALITY_REFERENCES:
|
|
493
|
-
{
|
|
494
|
-
'
|
|
495
|
-
}
|
|
488
|
+
{{
|
|
489
|
+
'{image_elephant.id}': <<[[{image_elephant.id}]]>>
|
|
490
|
+
}}
|
|
496
491
|
|
|
497
492
|
OUTPUT_OBJECT:
|
|
498
493
|
"""),
|
|
@@ -500,39 +495,27 @@ class CompleteStructureTest(unittest.TestCase):
|
|
|
500
495
|
self.assertTrue(
|
|
501
496
|
pg.eq(
|
|
502
497
|
{
|
|
503
|
-
'examples': lm_input.
|
|
504
|
-
'input': lm_input.
|
|
498
|
+
'examples': lm_input.__template_input__.examples,
|
|
499
|
+
'input': lm_input.__template_input__.mapping_request.input,
|
|
505
500
|
},
|
|
506
501
|
{
|
|
507
502
|
'examples': [
|
|
508
503
|
mapping.MappingExample(
|
|
509
|
-
input=Animal.partial(
|
|
510
|
-
|
|
511
|
-
b'image_of_rabbit'
|
|
512
|
-
)
|
|
513
|
-
),
|
|
514
|
-
output=Animal.partial(
|
|
515
|
-
image=modalities.Image.from_bytes(
|
|
516
|
-
b'image_of_rabbit'
|
|
517
|
-
),
|
|
518
|
-
name='rabbit',
|
|
519
|
-
),
|
|
504
|
+
input=Animal.partial(image_rabbit),
|
|
505
|
+
output=Animal.partial(image_rabbit, 'rabbit'),
|
|
520
506
|
)
|
|
521
507
|
],
|
|
522
|
-
'input': Animal(
|
|
523
|
-
image=modalities.Image.from_bytes(b'image_of_elephant'),
|
|
524
|
-
name=schema_lib.MISSING,
|
|
525
|
-
),
|
|
508
|
+
'input': Animal(image_elephant, name=schema_lib.MISSING),
|
|
526
509
|
},
|
|
527
510
|
)
|
|
528
511
|
)
|
|
529
512
|
lm_output = l(
|
|
530
513
|
input=input_value,
|
|
531
|
-
lm=fake.StaticResponse(inspect.cleandoc("""
|
|
514
|
+
lm=fake.StaticResponse(inspect.cleandoc(f"""
|
|
532
515
|
```python
|
|
533
516
|
Animal(
|
|
534
517
|
image=ModalityRef(
|
|
535
|
-
|
|
518
|
+
id='{image_elephant.id}'
|
|
536
519
|
),
|
|
537
520
|
name='elephant'
|
|
538
521
|
)
|
|
@@ -542,10 +525,7 @@ class CompleteStructureTest(unittest.TestCase):
|
|
|
542
525
|
self.assertTrue(
|
|
543
526
|
pg.eq(
|
|
544
527
|
lm_output.result,
|
|
545
|
-
Animal(
|
|
546
|
-
image=modalities.Image.from_bytes(b'image_of_elephant'),
|
|
547
|
-
name='elephant',
|
|
548
|
-
),
|
|
528
|
+
Animal(image=image_elephant, name='elephant'),
|
|
549
529
|
)
|
|
550
530
|
)
|
|
551
531
|
|
|
@@ -127,6 +127,8 @@ class MappingExample(lf.NaturalLanguageFormattable,
|
|
|
127
127
|
) -> str:
|
|
128
128
|
if isinstance(value, str):
|
|
129
129
|
return value
|
|
130
|
+
if isinstance(value, lf.Message):
|
|
131
|
+
return str(value)
|
|
130
132
|
if isinstance(value, lf.Modality):
|
|
131
133
|
with lf.modality.format_modality_as_ref():
|
|
132
134
|
return str(value)
|
|
@@ -192,9 +194,7 @@ class MappingExample(lf.NaturalLanguageFormattable,
|
|
|
192
194
|
|
|
193
195
|
def render_value(view, *, value, **kwargs):
|
|
194
196
|
if isinstance(value, lf.Template):
|
|
195
|
-
|
|
196
|
-
# the input.
|
|
197
|
-
value = value.clone().render()
|
|
197
|
+
value = value.render()
|
|
198
198
|
if value is None:
|
|
199
199
|
return None
|
|
200
200
|
return view.render(value, **kwargs)
|
|
@@ -286,12 +286,8 @@ class Mapping(lf.LangFunc):
|
|
|
286
286
|
@property
|
|
287
287
|
def mapping_request(self) -> MappingExample:
|
|
288
288
|
"""Returns a MappingExample as the mapping request."""
|
|
289
|
-
if isinstance(self.input, lf.Message):
|
|
290
|
-
input_value = self.input.text
|
|
291
|
-
else:
|
|
292
|
-
input_value = pg.Ref(self.input)
|
|
293
289
|
return MappingExample(
|
|
294
|
-
input=
|
|
290
|
+
input=pg.Ref(self.input),
|
|
295
291
|
schema=pg.Ref(self.schema),
|
|
296
292
|
context=self.context,
|
|
297
293
|
)
|
|
@@ -402,11 +398,6 @@ class Mapping(lf.LangFunc):
|
|
|
402
398
|
|
|
403
399
|
def transform_input(self, lm_input: lf.Message) -> lf.Message:
|
|
404
400
|
# Find modalities to fill the input message.
|
|
405
|
-
lm_input.metadata.update(
|
|
406
|
-
examples=pg.Ref(self.examples),
|
|
407
|
-
input=pg.Ref(self.input),
|
|
408
|
-
schema=pg.Ref(self.schema) if self.schema is not None else None,
|
|
409
|
-
)
|
|
410
401
|
if isinstance(self.input, lf.Message):
|
|
411
402
|
lm_input.source = self.input
|
|
412
403
|
return lm_input
|
|
@@ -529,24 +529,22 @@ def query(
|
|
|
529
529
|
).render(message_cls=lf.SystemMessage)
|
|
530
530
|
|
|
531
531
|
# Normalize query input.
|
|
532
|
-
if isinstance(prompt,
|
|
532
|
+
if isinstance(prompt, str):
|
|
533
533
|
# Query with structured output.
|
|
534
534
|
prompt_kwargs = kwargs.copy()
|
|
535
535
|
prompt_kwargs.pop('template_str', None)
|
|
536
536
|
query_input = lf.Template.from_value(prompt, **prompt_kwargs)
|
|
537
|
+
elif isinstance(prompt, lf.Message):
|
|
538
|
+
query_input = prompt
|
|
537
539
|
elif isinstance(prompt, lf.Template):
|
|
538
|
-
# Create a copy of the prompt if it has a parent object, so all child
|
|
539
|
-
# modality objects could be referred by path relative to the prompt.
|
|
540
|
-
query_input = prompt.clone() if prompt.sym_parent is not None else prompt
|
|
541
|
-
|
|
542
540
|
# Attach template metadata from kwargs. This is used to pass through fields
|
|
543
541
|
# from kwargs to the rendered message.
|
|
544
|
-
|
|
545
|
-
k: v for k, v in kwargs.items() if k.startswith('metadata_')
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
template_metadata, skip_notification=True, raise_on_no_change=False
|
|
542
|
+
prompt.rebind(
|
|
543
|
+
{k: v for k, v in kwargs.items() if k.startswith('metadata_')},
|
|
544
|
+
skip_notification=True,
|
|
545
|
+
raise_on_no_change=False
|
|
549
546
|
)
|
|
547
|
+
query_input = prompt
|
|
550
548
|
elif pg.MISSING_VALUE == prompt:
|
|
551
549
|
query_input = lf.UserMessage('')
|
|
552
550
|
else:
|
|
@@ -665,7 +663,11 @@ def query(
|
|
|
665
663
|
|
|
666
664
|
if returns_message:
|
|
667
665
|
return output_message
|
|
668
|
-
|
|
666
|
+
if schema not in (None, str):
|
|
667
|
+
return output_message.result
|
|
668
|
+
if returns_message or output_message.referred_modalities:
|
|
669
|
+
return output_message
|
|
670
|
+
return output_message.text
|
|
669
671
|
|
|
670
672
|
|
|
671
673
|
async def aquery(
|