langfun 0.1.2.dev202504280818__py3-none-any.whl → 0.1.2.dev202504300804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/__init__.py +3 -0
- langfun/core/agentic/action.py +253 -105
- langfun/core/agentic/action_eval.py +10 -3
- langfun/core/agentic/action_test.py +173 -47
- langfun/core/eval/base_test.py +4 -4
- langfun/core/eval/v2/evaluation.py +78 -12
- langfun/core/eval/v2/evaluation_test.py +2 -0
- langfun/core/structured/__init__.py +3 -0
- langfun/core/structured/mapping.py +5 -2
- langfun/core/structured/parsing_test.py +1 -1
- langfun/core/structured/querying.py +205 -18
- langfun/core/structured/querying_test.py +286 -47
- {langfun-0.1.2.dev202504280818.dist-info → langfun-0.1.2.dev202504300804.dist-info}/METADATA +29 -6
- {langfun-0.1.2.dev202504280818.dist-info → langfun-0.1.2.dev202504300804.dist-info}/RECORD +17 -17
- {langfun-0.1.2.dev202504280818.dist-info → langfun-0.1.2.dev202504300804.dist-info}/WHEEL +1 -1
- {langfun-0.1.2.dev202504280818.dist-info → langfun-0.1.2.dev202504300804.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202504280818.dist-info → langfun-0.1.2.dev202504300804.dist-info}/top_level.txt +0 -0
@@ -15,8 +15,9 @@
|
|
15
15
|
|
16
16
|
import contextlib
|
17
17
|
import functools
|
18
|
+
import inspect
|
18
19
|
import time
|
19
|
-
from typing import Annotated, Any, Callable, Iterator, Type, Union
|
20
|
+
from typing import Annotated, Any, Callable, ClassVar, Iterator, Type, Union
|
20
21
|
import uuid
|
21
22
|
|
22
23
|
import langfun.core as lf
|
@@ -26,8 +27,35 @@ import pyglove as pg
|
|
26
27
|
|
27
28
|
|
28
29
|
@lf.use_init_args(['schema', 'default', 'examples'])
|
29
|
-
class
|
30
|
-
"""
|
30
|
+
class LfQuery(mapping.Mapping):
|
31
|
+
"""Base class for different implementations of `lf.query`.
|
32
|
+
|
33
|
+
By subclassing this class, users could create variations of prompts for
|
34
|
+
`lf.query` and associated them with specific protocols and versions.
|
35
|
+
|
36
|
+
For example:
|
37
|
+
|
38
|
+
```
|
39
|
+
class _MyLfQuery(LFQuery):
|
40
|
+
protocol = 'my_format'
|
41
|
+
version = '1.0'
|
42
|
+
|
43
|
+
template_str = inspect.cleandoc(
|
44
|
+
'''
|
45
|
+
...
|
46
|
+
'''
|
47
|
+
)
|
48
|
+
mapping_template = lf.Template(
|
49
|
+
'''
|
50
|
+
...
|
51
|
+
'''
|
52
|
+
)
|
53
|
+
|
54
|
+
lf.query(..., protocol='my_format:1.0')
|
55
|
+
```
|
56
|
+
|
57
|
+
(THIS IS NOT A TEMPLATE)
|
58
|
+
"""
|
31
59
|
|
32
60
|
context_title = 'CONTEXT'
|
33
61
|
input_title = 'INPUT_OBJECT'
|
@@ -37,8 +65,81 @@ class _QueryStructure(mapping.Mapping):
|
|
37
65
|
schema_lib.schema_spec(), 'Required schema for parsing.'
|
38
66
|
]
|
39
67
|
|
68
|
+
# A map from (protocol, version) to the query structure class.
|
69
|
+
# This is used to map different protocols/versions to different templates.
|
70
|
+
# So users can use `lf.query(..., protocol='<protocol>:<version>')` to use
|
71
|
+
# a specific version of the prompt. We use this feature to support variations
|
72
|
+
# of prompts and maintain backward compatibility.
|
73
|
+
_OOP_PROMPT_MAP: ClassVar[
|
74
|
+
dict[
|
75
|
+
str, # protocol.
|
76
|
+
dict[
|
77
|
+
str, # version.
|
78
|
+
Type['LfQuery']
|
79
|
+
]
|
80
|
+
]
|
81
|
+
] = {}
|
82
|
+
|
83
|
+
# This the flag to update default protocol version.
|
84
|
+
_DEFAULT_PROTOCOL_VERSIONS: ClassVar[dict[str, str]] = {
|
85
|
+
'python': '2.0',
|
86
|
+
'json': '1.0',
|
87
|
+
}
|
88
|
+
|
89
|
+
def __init_subclass__(cls) -> Any:
|
90
|
+
super().__init_subclass__()
|
91
|
+
if not inspect.isabstract(cls):
|
92
|
+
protocol = cls.__schema__['protocol'].default_value
|
93
|
+
version_dict = cls._OOP_PROMPT_MAP.get(protocol)
|
94
|
+
if version_dict is None:
|
95
|
+
version_dict = {}
|
96
|
+
cls._OOP_PROMPT_MAP[protocol] = version_dict
|
97
|
+
dest_cls = version_dict.get(cls.version)
|
98
|
+
if dest_cls is not None and dest_cls.__type_name__ != cls.__type_name__:
|
99
|
+
raise ValueError(
|
100
|
+
f'Version {cls.version} is already registered for {dest_cls!r} '
|
101
|
+
f'under protocol {protocol!r}. Please use a different version.'
|
102
|
+
)
|
103
|
+
version_dict[cls.version] = cls
|
104
|
+
|
105
|
+
@classmethod
|
106
|
+
def from_protocol(cls, protocol: str) -> Type['LfQuery']:
|
107
|
+
"""Returns a query structure from the given protocol and version."""
|
108
|
+
if ':' in protocol:
|
109
|
+
protocol, version = protocol.split(':')
|
110
|
+
else:
|
111
|
+
version = cls._DEFAULT_PROTOCOL_VERSIONS.get(protocol)
|
112
|
+
if version is None:
|
113
|
+
version_dict = cls._OOP_PROMPT_MAP.get(protocol)
|
114
|
+
if version_dict is None:
|
115
|
+
raise ValueError(
|
116
|
+
f'Protocol {protocol!r} is not supported. Available protocols: '
|
117
|
+
f'{sorted(cls._OOP_PROMPT_MAP.keys())}.'
|
118
|
+
)
|
119
|
+
elif len(version_dict) == 1:
|
120
|
+
version = list(version_dict.keys())[0]
|
121
|
+
else:
|
122
|
+
raise ValueError(
|
123
|
+
f'Multiple versions found for protocol {protocol!r}, please '
|
124
|
+
f'specify a version with "{protocol}:<version>".'
|
125
|
+
)
|
126
|
+
|
127
|
+
version_dict = cls._OOP_PROMPT_MAP.get(protocol)
|
128
|
+
if version_dict is None:
|
129
|
+
raise ValueError(
|
130
|
+
f'Protocol {protocol!r} is not supported. Available protocols: '
|
131
|
+
f'{sorted(cls._OOP_PROMPT_MAP.keys())}.'
|
132
|
+
)
|
133
|
+
dest_cls = version_dict.get(version)
|
134
|
+
if dest_cls is None:
|
135
|
+
raise ValueError(
|
136
|
+
f'Version {version!r} is not supported for protocol {protocol!r}. '
|
137
|
+
f'Available versions: {sorted(version_dict.keys())}.'
|
138
|
+
)
|
139
|
+
return dest_cls
|
40
140
|
|
41
|
-
|
141
|
+
|
142
|
+
class _LfQueryJsonV1(LfQuery):
|
42
143
|
"""Query a structured value using JSON as the protocol."""
|
43
144
|
|
44
145
|
preamble = """
|
@@ -58,12 +159,13 @@ class _QueryStructureJson(_QueryStructure):
|
|
58
159
|
{"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}}
|
59
160
|
"""
|
60
161
|
|
162
|
+
version = '1.0'
|
61
163
|
protocol = 'json'
|
62
164
|
schema_title = 'SCHEMA'
|
63
165
|
output_title = 'JSON'
|
64
166
|
|
65
167
|
|
66
|
-
class
|
168
|
+
class _LfQueryPythonV1(LfQuery):
|
67
169
|
"""Query a structured value using Python as the protocol."""
|
68
170
|
|
69
171
|
preamble = """
|
@@ -87,20 +189,87 @@ class _QueryStructurePython(_QueryStructure):
|
|
87
189
|
)
|
88
190
|
```
|
89
191
|
"""
|
192
|
+
version = '1.0'
|
90
193
|
protocol = 'python'
|
91
194
|
schema_title = 'OUTPUT_TYPE'
|
92
195
|
output_title = 'OUTPUT_OBJECT'
|
196
|
+
mapping_template = lf.Template(
|
197
|
+
"""
|
198
|
+
{%- if example.context -%}
|
199
|
+
{{ context_title}}:
|
200
|
+
{{ example.context | indent(2, True)}}
|
93
201
|
|
202
|
+
{% endif -%}
|
94
203
|
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
204
|
+
{{ input_title }}:
|
205
|
+
{{ example.input_repr(protocol, compact=False) | indent(2, True) }}
|
206
|
+
|
207
|
+
{% if example.schema -%}
|
208
|
+
{{ schema_title }}:
|
209
|
+
{{ example.schema_repr(protocol) | indent(2, True) }}
|
210
|
+
|
211
|
+
{% endif -%}
|
212
|
+
|
213
|
+
{{ output_title }}:
|
214
|
+
{%- if example.has_output %}
|
215
|
+
{{ example.output_repr(protocol, compact=False) | indent(2, True) }}
|
216
|
+
{% endif -%}
|
217
|
+
"""
|
218
|
+
)
|
219
|
+
|
220
|
+
|
221
|
+
class _LfQueryPythonV2(LfQuery):
|
222
|
+
"""Query a structured value using Python as the protocol."""
|
223
|
+
|
224
|
+
preamble = """
|
225
|
+
Please respond to the last {{ input_title }} with {{ output_title }} only according to {{ schema_title }}.
|
226
|
+
|
227
|
+
{{ input_title }}:
|
228
|
+
1 + 1 =
|
229
|
+
|
230
|
+
{{ schema_title }}:
|
231
|
+
Answer
|
232
|
+
|
233
|
+
```python
|
234
|
+
class Answer:
|
235
|
+
final_answer: int
|
236
|
+
```
|
237
|
+
|
238
|
+
{{ output_title }}:
|
239
|
+
```python
|
240
|
+
output = Answer(
|
241
|
+
final_answer=2
|
242
|
+
)
|
243
|
+
```
|
244
|
+
"""
|
245
|
+
version = '2.0'
|
246
|
+
protocol = 'python'
|
247
|
+
input_title = 'REQUEST'
|
248
|
+
schema_title = 'OUTPUT PYTHON TYPE'
|
249
|
+
output_title = 'OUTPUT PYTHON OBJECT'
|
250
|
+
mapping_template = lf.Template(
|
251
|
+
"""
|
252
|
+
{%- if example.context -%}
|
253
|
+
{{ context_title}}:
|
254
|
+
{{ example.context | indent(2, True)}}
|
255
|
+
|
256
|
+
{% endif -%}
|
257
|
+
|
258
|
+
{{ input_title }}:
|
259
|
+
{{ example.input_repr(protocol, compact=False) | indent(2, True) }}
|
260
|
+
|
261
|
+
{% if example.schema -%}
|
262
|
+
{{ schema_title }}:
|
263
|
+
{{ example.schema_repr(protocol) | indent(2, True) }}
|
264
|
+
|
265
|
+
{% endif -%}
|
266
|
+
|
267
|
+
{{ output_title }}:
|
268
|
+
{%- if example.has_output %}
|
269
|
+
{{ example.output_repr(protocol, compact=False, assign_to_var='output') | indent(2, True) }}
|
270
|
+
{% endif -%}
|
271
|
+
"""
|
272
|
+
)
|
104
273
|
|
105
274
|
|
106
275
|
def query(
|
@@ -116,7 +285,7 @@ def query(
|
|
116
285
|
response_postprocess: Callable[[str], str] | None = None,
|
117
286
|
autofix: int = 0,
|
118
287
|
autofix_lm: lf.LanguageModel | None = None,
|
119
|
-
protocol:
|
288
|
+
protocol: str | None = None,
|
120
289
|
returns_message: bool = False,
|
121
290
|
skip_lm: bool = False,
|
122
291
|
invocation_id: str | None = None,
|
@@ -259,8 +428,14 @@ def query(
|
|
259
428
|
disable auto-fixing. Not supported with the `'json'` protocol.
|
260
429
|
autofix_lm: The LM to use for auto-fixing. Defaults to the `autofix_lm`
|
261
430
|
from `lf.context` or the main `lm`.
|
262
|
-
protocol: Format for schema representation.
|
263
|
-
`'python'
|
431
|
+
protocol: Format for schema representation. Builtin choices are `'json'` or
|
432
|
+
`'python'`, users could extend with their own protocols by subclassing
|
433
|
+
`lf.structured.LfQuery'. Also protocol could be specified with a version
|
434
|
+
in the format of 'protocol:version', e.g., 'python:1.0', so users could
|
435
|
+
use a specific version of the prompt based on the protocol. Please see the
|
436
|
+
documentation of `LfQuery` for more details. If None, the protocol from
|
437
|
+
context manager `lf.query_protocol` will be used, or 'python' if not
|
438
|
+
specified.
|
264
439
|
returns_message: If `True`, returns an `lf.Message` object instead of
|
265
440
|
the final parsed result.
|
266
441
|
skip_lm: If `True`, skips the LLM call and returns the rendered
|
@@ -280,6 +455,9 @@ def query(
|
|
280
455
|
"""
|
281
456
|
# Internal usage logging.
|
282
457
|
|
458
|
+
if protocol is None:
|
459
|
+
protocol = lf.context_value('__query_protocol__', 'python')
|
460
|
+
|
283
461
|
invocation_id = invocation_id or f'query@{uuid.uuid4().hex[-7:]}'
|
284
462
|
# Multiple quries will be issued when `lm` is a list or `num_samples` is
|
285
463
|
# greater than 1.
|
@@ -382,7 +560,7 @@ def query(
|
|
382
560
|
output_message = lf.AIMessage(processed_text, source=output_message)
|
383
561
|
else:
|
384
562
|
# Query with structured output.
|
385
|
-
output_message =
|
563
|
+
output_message = LfQuery.from_protocol(protocol)(
|
386
564
|
input=(
|
387
565
|
query_input.render(lm=lm)
|
388
566
|
if isinstance(query_input, lf.Template)
|
@@ -436,6 +614,15 @@ def query(
|
|
436
614
|
return output_message if returns_message else _result(output_message)
|
437
615
|
|
438
616
|
|
617
|
+
@contextlib.contextmanager
|
618
|
+
def query_protocol(protocol: str) -> Iterator[None]:
|
619
|
+
"""Context manager for setting the query protocol for the scope."""
|
620
|
+
with lf.context(__query_protocol__=protocol):
|
621
|
+
try:
|
622
|
+
yield
|
623
|
+
finally:
|
624
|
+
pass
|
625
|
+
|
439
626
|
#
|
440
627
|
# Helper function for map-reduce style querying.
|
441
628
|
#
|