langfun 0.0.2.dev20240330__py3-none-any.whl → 0.0.2.dev20240429__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +2 -0
- langfun/core/__init__.py +1 -0
- langfun/core/coding/python/correction.py +0 -7
- langfun/core/component.py +6 -0
- langfun/core/component_test.py +1 -0
- langfun/core/eval/__init__.py +2 -0
- langfun/core/eval/base.py +202 -23
- langfun/core/eval/base_test.py +49 -10
- langfun/core/eval/matching.py +26 -9
- langfun/core/eval/matching_test.py +2 -1
- langfun/core/eval/scoring.py +15 -6
- langfun/core/eval/scoring_test.py +2 -1
- langfun/core/langfunc.py +0 -5
- langfun/core/langfunc_test.py +6 -4
- langfun/core/language_model.py +124 -24
- langfun/core/language_model_test.py +249 -26
- langfun/core/llms/__init__.py +19 -2
- langfun/core/llms/anthropic.py +263 -0
- langfun/core/llms/anthropic_test.py +167 -0
- langfun/core/llms/cache/in_memory_test.py +37 -28
- langfun/core/llms/fake.py +31 -22
- langfun/core/llms/fake_test.py +122 -11
- langfun/core/llms/google_genai_test.py +8 -3
- langfun/core/llms/groq.py +260 -0
- langfun/core/llms/groq_test.py +170 -0
- langfun/core/llms/llama_cpp.py +3 -1
- langfun/core/llms/openai.py +97 -79
- langfun/core/llms/openai_test.py +285 -59
- langfun/core/modalities/video.py +5 -2
- langfun/core/structured/__init__.py +3 -0
- langfun/core/structured/completion_test.py +2 -2
- langfun/core/structured/function_generation.py +245 -0
- langfun/core/structured/function_generation_test.py +329 -0
- langfun/core/structured/mapping.py +56 -2
- langfun/core/structured/mapping_test.py +17 -0
- langfun/core/structured/parsing_test.py +18 -13
- langfun/core/structured/prompting.py +27 -6
- langfun/core/structured/prompting_test.py +79 -12
- langfun/core/structured/schema.py +4 -2
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +4 -6
- langfun/core/template.py +125 -10
- langfun/core/template_test.py +75 -0
- langfun/core/templates/selfplay_test.py +6 -2
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/METADATA +3 -2
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/RECORD +49 -43
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/top_level.txt +0 -0
@@ -17,11 +17,9 @@ import inspect
|
|
17
17
|
import unittest
|
18
18
|
|
19
19
|
import langfun.core as lf
|
20
|
-
from langfun.core import coding
|
21
20
|
from langfun.core.llms import fake
|
22
21
|
from langfun.core.structured import mapping
|
23
22
|
from langfun.core.structured import parsing
|
24
|
-
from langfun.core.structured import schema as schema_lib
|
25
23
|
import pyglove as pg
|
26
24
|
|
27
25
|
|
@@ -255,7 +253,7 @@ class ParseStructurePythonTest(unittest.TestCase):
|
|
255
253
|
override_attrs=True,
|
256
254
|
):
|
257
255
|
with self.assertRaisesRegex(
|
258
|
-
|
256
|
+
mapping.MappingError,
|
259
257
|
'name .* is not defined',
|
260
258
|
):
|
261
259
|
parsing.parse('three', int)
|
@@ -280,13 +278,15 @@ class ParseStructurePythonTest(unittest.TestCase):
|
|
280
278
|
),
|
281
279
|
1,
|
282
280
|
)
|
281
|
+
r = parsing.parse(
|
282
|
+
'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm,
|
283
|
+
returns_message=True
|
284
|
+
)
|
283
285
|
self.assertEqual(
|
284
|
-
|
285
|
-
'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm,
|
286
|
-
returns_message=True
|
287
|
-
),
|
286
|
+
r,
|
288
287
|
lf.AIMessage(
|
289
288
|
'1', score=1.0, result=1, logprobs=None,
|
289
|
+
usage=lf.LMSamplingUsage(652, 1, 653),
|
290
290
|
tags=['lm-response', 'lm-output', 'transformed']
|
291
291
|
),
|
292
292
|
)
|
@@ -544,7 +544,7 @@ class ParseStructureJsonTest(unittest.TestCase):
|
|
544
544
|
override_attrs=True,
|
545
545
|
):
|
546
546
|
with self.assertRaisesRegex(
|
547
|
-
|
547
|
+
mapping.MappingError,
|
548
548
|
'No JSON dict in the output',
|
549
549
|
):
|
550
550
|
parsing.parse('three', int, protocol='json')
|
@@ -634,13 +634,18 @@ class CallTest(unittest.TestCase):
|
|
634
634
|
)
|
635
635
|
|
636
636
|
def test_call_with_returning_message(self):
|
637
|
+
r = parsing.call(
|
638
|
+
'Compute 1 + 2', int, lm=fake.StaticSequence(['three', '3']),
|
639
|
+
returns_message=True
|
640
|
+
)
|
637
641
|
self.assertEqual(
|
638
|
-
|
639
|
-
'Compute 1 + 2', int, lm=fake.StaticSequence(['three', '3']),
|
640
|
-
returns_message=True
|
641
|
-
),
|
642
|
+
r,
|
642
643
|
lf.AIMessage(
|
643
|
-
'3',
|
644
|
+
'3',
|
645
|
+
result=3,
|
646
|
+
score=1.0,
|
647
|
+
logprobs=None,
|
648
|
+
usage=lf.LMSamplingUsage(315, 1, 316),
|
644
649
|
tags=['lm-response', 'lm-output', 'transformed']
|
645
650
|
),
|
646
651
|
)
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Symbolic query."""
|
15
15
|
|
16
|
-
from typing import Any, Type, Union
|
16
|
+
from typing import Any, Callable, Type, Union
|
17
17
|
|
18
18
|
import langfun.core as lf
|
19
19
|
from langfun.core.structured import mapping
|
@@ -78,7 +78,9 @@ class QueryStructurePython(QueryStructure):
|
|
78
78
|
|
79
79
|
{{ output_title }}:
|
80
80
|
```python
|
81
|
-
Answer(
|
81
|
+
Answer(
|
82
|
+
final_answer=2
|
83
|
+
)
|
82
84
|
```
|
83
85
|
"""
|
84
86
|
protocol = 'python'
|
@@ -107,6 +109,7 @@ def query(
|
|
107
109
|
lm: lf.LanguageModel | None = None,
|
108
110
|
examples: list[mapping.MappingExample] | None = None,
|
109
111
|
cache_seed: int | None = 0,
|
112
|
+
response_postprocess: Callable[[str], str] | None = None,
|
110
113
|
autofix: int = 0,
|
111
114
|
autofix_lm: lf.LanguageModel | None = None,
|
112
115
|
protocol: schema_lib.SchemaProtocol = 'python',
|
@@ -159,6 +162,9 @@ def query(
|
|
159
162
|
cache_seed: Seed for computing cache key. The cache key is determined by a
|
160
163
|
tuple of (lm, prompt, cache seed). If None, cache will be disabled for
|
161
164
|
the query even cache is configured by the LM.
|
165
|
+
response_postprocess: An optional callable object to process the raw LM
|
166
|
+
response before parsing it into the final output object. If None, the
|
167
|
+
raw LM response will not be processed.
|
162
168
|
autofix: Number of attempts to auto fix the generated code. If 0, autofix is
|
163
169
|
disabled. Auto-fix is not supported for 'json' protocol.
|
164
170
|
autofix_lm: The language model to use for autofix. If not specified, the
|
@@ -170,8 +176,11 @@ def query(
|
|
170
176
|
returning the structured `message.result`.
|
171
177
|
skip_lm: If True, returns the rendered prompt as a UserMessage object.
|
172
178
|
otherwise return the LLM response based on the rendered prompt.
|
173
|
-
**kwargs: Keyword arguments passed to the
|
174
|
-
`lf.structured.
|
179
|
+
**kwargs: Keyword arguments passed to render the prompt or configure the
|
180
|
+
`lf.structured.Mapping` class. Notable kwargs are:
|
181
|
+
- template_str: Change the root template for query.
|
182
|
+
- preamble: Change the preamble for query.
|
183
|
+
- mapping_template: Change the template for each mapping examle.
|
175
184
|
|
176
185
|
Returns:
|
177
186
|
The result based on the schema.
|
@@ -188,13 +197,24 @@ def query(
|
|
188
197
|
output = lf.LangFunc.from_value(prompt, **kwargs)(
|
189
198
|
lm=lm, cache_seed=cache_seed, skip_lm=skip_lm
|
190
199
|
)
|
200
|
+
if response_postprocess:
|
201
|
+
processed_text = response_postprocess(output.text)
|
202
|
+
if processed_text != output.text:
|
203
|
+
output = lf.AIMessage(processed_text, source=output)
|
191
204
|
return output if returns_message else output.text
|
192
205
|
|
193
206
|
# Query with structured output.
|
207
|
+
prompt_kwargs = kwargs.copy()
|
208
|
+
|
209
|
+
# NOTE(daiyip): when `template_str` is passed in, it's intended to modify the
|
210
|
+
# QueryStructure template string. Therefore, we pop out the argument for
|
211
|
+
# prompt rendering.
|
212
|
+
prompt_kwargs.pop('template_str', None)
|
213
|
+
|
194
214
|
if isinstance(prompt, str):
|
195
|
-
prompt = lf.Template(prompt, **
|
215
|
+
prompt = lf.Template(prompt, **prompt_kwargs)
|
196
216
|
elif isinstance(prompt, lf.Template):
|
197
|
-
prompt = prompt.rebind(**
|
217
|
+
prompt = prompt.rebind(**prompt_kwargs, raise_on_no_change=False)
|
198
218
|
|
199
219
|
if isinstance(prompt, lf.Template):
|
200
220
|
prompt = prompt.render(lm=lm)
|
@@ -206,6 +226,7 @@ def query(
|
|
206
226
|
schema=schema,
|
207
227
|
default=default,
|
208
228
|
examples=examples,
|
229
|
+
response_postprocess=response_postprocess,
|
209
230
|
autofix=autofix if protocol == 'python' else 0,
|
210
231
|
**kwargs,
|
211
232
|
)(
|
@@ -17,12 +17,10 @@ import inspect
|
|
17
17
|
import unittest
|
18
18
|
|
19
19
|
import langfun.core as lf
|
20
|
-
from langfun.core import coding
|
21
20
|
from langfun.core import modalities
|
22
21
|
from langfun.core.llms import fake
|
23
22
|
from langfun.core.structured import mapping
|
24
23
|
from langfun.core.structured import prompting
|
25
|
-
from langfun.core.structured import schema as schema_lib
|
26
24
|
import pyglove as pg
|
27
25
|
|
28
26
|
|
@@ -77,6 +75,7 @@ class QueryTest(unittest.TestCase):
|
|
77
75
|
result=1,
|
78
76
|
score=1.0,
|
79
77
|
logprobs=None,
|
78
|
+
usage=lf.LMSamplingUsage(323, 1, 324),
|
80
79
|
tags=['lm-response', 'lm-output', 'transformed'],
|
81
80
|
),
|
82
81
|
)
|
@@ -116,12 +115,59 @@ class QueryTest(unittest.TestCase):
|
|
116
115
|
y=2,
|
117
116
|
lm=lm.clone(),
|
118
117
|
expected_snippet=(
|
119
|
-
'Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT'
|
120
|
-
'
|
121
|
-
'
|
122
|
-
'
|
123
|
-
'
|
124
|
-
'
|
118
|
+
'Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT '
|
119
|
+
'according to OUTPUT_TYPE.\n\n'
|
120
|
+
'INPUT_OBJECT:\n 1 + 1 =\n\n'
|
121
|
+
'OUTPUT_TYPE:\n'
|
122
|
+
' Answer\n\n'
|
123
|
+
' ```python\n'
|
124
|
+
' class Answer:\n'
|
125
|
+
' final_answer: int\n'
|
126
|
+
' ```\n\n'
|
127
|
+
'OUTPUT_OBJECT:\n'
|
128
|
+
' ```python\n'
|
129
|
+
' Answer(\n'
|
130
|
+
' final_answer=2\n'
|
131
|
+
' )\n'
|
132
|
+
' ```\n\n'
|
133
|
+
'INPUT_OBJECT:\n'
|
134
|
+
' What is 1 + 2?\n\n'
|
135
|
+
'OUTPUT_TYPE:\n'
|
136
|
+
' int\n\n'
|
137
|
+
'OUTPUT_OBJECT:'
|
138
|
+
),
|
139
|
+
)
|
140
|
+
|
141
|
+
def test_str_to_structure_render_custom_template(self):
|
142
|
+
lm = fake.StaticResponse('1')
|
143
|
+
self.assert_render(
|
144
|
+
'What is {{x}} + {{y}}?',
|
145
|
+
int,
|
146
|
+
x=1,
|
147
|
+
y=2,
|
148
|
+
lm=lm.clone(),
|
149
|
+
template_str='!!{{ DEFAULT }}!!',
|
150
|
+
expected_snippet=(
|
151
|
+
'!!Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT '
|
152
|
+
'according to OUTPUT_TYPE.\n\n'
|
153
|
+
'INPUT_OBJECT:\n 1 + 1 =\n\n'
|
154
|
+
'OUTPUT_TYPE:\n'
|
155
|
+
' Answer\n\n'
|
156
|
+
' ```python\n'
|
157
|
+
' class Answer:\n'
|
158
|
+
' final_answer: int\n'
|
159
|
+
' ```\n\n'
|
160
|
+
'OUTPUT_OBJECT:\n'
|
161
|
+
' ```python\n'
|
162
|
+
' Answer(\n'
|
163
|
+
' final_answer=2\n'
|
164
|
+
' )\n'
|
165
|
+
' ```\n\n'
|
166
|
+
'INPUT_OBJECT:\n'
|
167
|
+
' What is 1 + 2?\n\n'
|
168
|
+
'OUTPUT_TYPE:\n'
|
169
|
+
' int\n\n'
|
170
|
+
'OUTPUT_OBJECT:!!'
|
125
171
|
),
|
126
172
|
)
|
127
173
|
|
@@ -264,7 +310,9 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
264
310
|
|
265
311
|
OUTPUT_OBJECT:
|
266
312
|
```python
|
267
|
-
Answer(
|
313
|
+
Answer(
|
314
|
+
final_answer=2
|
315
|
+
)
|
268
316
|
```
|
269
317
|
|
270
318
|
INPUT_OBJECT:
|
@@ -308,7 +356,9 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
308
356
|
|
309
357
|
OUTPUT_OBJECT:
|
310
358
|
```python
|
311
|
-
Answer(
|
359
|
+
Answer(
|
360
|
+
final_answer=2
|
361
|
+
)
|
312
362
|
```
|
313
363
|
|
314
364
|
INPUT_OBJECT:
|
@@ -420,7 +470,7 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
420
470
|
override_attrs=True,
|
421
471
|
):
|
422
472
|
with self.assertRaisesRegex(
|
423
|
-
|
473
|
+
mapping.MappingError,
|
424
474
|
'name .* is not defined',
|
425
475
|
):
|
426
476
|
prompting.query('Compute 1 + 2', int)
|
@@ -436,6 +486,23 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
436
486
|
])
|
437
487
|
self.assertEqual(prompting.query('what is 1 + 0', int, lm=lm, autofix=3), 1)
|
438
488
|
|
489
|
+
def test_response_postprocess(self):
|
490
|
+
with lf.context(
|
491
|
+
lm=fake.StaticResponse('<!-- some comment-->\n3'),
|
492
|
+
override_attrs=True,
|
493
|
+
):
|
494
|
+
self.assertEqual(
|
495
|
+
prompting.query(
|
496
|
+
'Compute 1 + 2', response_postprocess=lambda x: x.split('\n')[1]),
|
497
|
+
'3'
|
498
|
+
)
|
499
|
+
self.assertEqual(
|
500
|
+
prompting.query(
|
501
|
+
'Compute 1 + 2', int,
|
502
|
+
response_postprocess=lambda x: x.split('\n')[1]),
|
503
|
+
3
|
504
|
+
)
|
505
|
+
|
439
506
|
|
440
507
|
class QueryStructureJsonTest(unittest.TestCase):
|
441
508
|
|
@@ -641,7 +708,7 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
641
708
|
override_attrs=True,
|
642
709
|
):
|
643
710
|
with self.assertRaisesRegex(
|
644
|
-
|
711
|
+
mapping.MappingError,
|
645
712
|
'No JSON dict in the output',
|
646
713
|
):
|
647
714
|
prompting.query('Compute 1 + 2', int, protocol='json')
|
@@ -386,10 +386,12 @@ def class_definition(
|
|
386
386
|
if schema.fields:
|
387
387
|
for key, field in schema.items():
|
388
388
|
if not isinstance(key, pg.typing.ConstStrKey):
|
389
|
-
|
389
|
+
pg.logging.warning(
|
390
390
|
'Variable-length keyword arguments is not supported in '
|
391
|
-
f'structured parsing or query. Encountered: {
|
391
|
+
f'structured parsing or query. Encountered: {cls}, Schema: {schema}'
|
392
392
|
)
|
393
|
+
continue
|
394
|
+
|
393
395
|
# Write field doc string as comments before the field definition.
|
394
396
|
if field.description:
|
395
397
|
for line in field.description.split('\n'):
|
@@ -14,8 +14,8 @@
|
|
14
14
|
import inspect
|
15
15
|
import unittest
|
16
16
|
|
17
|
-
import langfun.core.coding as lf_coding
|
18
17
|
from langfun.core.llms import fake
|
18
|
+
from langfun.core.structured import mapping
|
19
19
|
from langfun.core.structured import schema_generation
|
20
20
|
|
21
21
|
|
@@ -92,7 +92,7 @@ class GenerateClassTest(unittest.TestCase):
|
|
92
92
|
)
|
93
93
|
self.assertIs(cls.__name__, 'B')
|
94
94
|
|
95
|
-
with self.assertRaises(
|
95
|
+
with self.assertRaises(mapping.MappingError):
|
96
96
|
schema_generation.generate_class(
|
97
97
|
'Foo',
|
98
98
|
'Generate a Foo class with a field pointing to another class A',
|
@@ -192,9 +192,9 @@ class SchemaTest(unittest.TestCase):
|
|
192
192
|
self.assertEqual(schema.parse('{"result": 1}'), 1)
|
193
193
|
schema = schema_lib.Schema(dict[str, int])
|
194
194
|
self.assertEqual(
|
195
|
-
schema.parse(
|
196
|
-
|
197
|
-
|
195
|
+
schema.parse('{"result": {"x": 1}}}'),
|
196
|
+
dict(x=1)
|
197
|
+
)
|
198
198
|
with self.assertRaisesRegex(
|
199
199
|
schema_lib.SchemaError, 'Expect .* but encountered .*'):
|
200
200
|
schema.parse('{"result": "def"}')
|
@@ -459,9 +459,7 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
459
459
|
x: str
|
460
460
|
__kwargs__: typing.Any
|
461
461
|
|
462
|
-
|
463
|
-
TypeError, 'Variable-length keyword arguments is not supported'):
|
464
|
-
schema_lib.class_definition(C)
|
462
|
+
self.assertEqual(schema_lib.class_definition(C), 'class C:\n x: str\n')
|
465
463
|
|
466
464
|
def test_repr(self):
|
467
465
|
class Foo(pg.Object):
|
langfun/core/template.py
CHANGED
@@ -38,13 +38,22 @@ NO_TEMPLATE_DOCSTR_SIGN = 'THIS IS NOT A TEMPLATE'
|
|
38
38
|
_TLS_RENDER_STACK = '_template_render_stack'
|
39
39
|
_TLS_RENDER_RESULT_CACHE = '_template_render_result_cache'
|
40
40
|
|
41
|
+
# The prefix for fields or contextual attributes to be treated as additional
|
42
|
+
# metadata for rendered message.
|
43
|
+
_ADDITIONAL_METADATA_PREFIX = 'metadata_'
|
44
|
+
|
41
45
|
|
42
46
|
class Template(
|
43
47
|
natural_language.NaturalLanguageFormattable,
|
44
48
|
component.Component,
|
45
49
|
pg.typing.CustomTyping,
|
46
50
|
):
|
47
|
-
"""Langfun string template.
|
51
|
+
"""Langfun string template.
|
52
|
+
|
53
|
+
Langfun uses jinja2 as its template engine. Pleaes check out
|
54
|
+
https://jinja.palletsprojects.com/en/3.1.x/templates/ for detailed
|
55
|
+
explanation on the template language.
|
56
|
+
"""
|
48
57
|
|
49
58
|
template_str: Annotated[
|
50
59
|
str,
|
@@ -97,6 +106,11 @@ class Template(
|
|
97
106
|
# Declare template variables as symbolic attributes.
|
98
107
|
template_vars = Template.resolve_vars(template_str)
|
99
108
|
for var_name in template_vars:
|
109
|
+
if 'DEFAULT' == var_name:
|
110
|
+
raise ValueError(
|
111
|
+
'`{{ DEFAULT }}` cannot be used in pre-configured templates. '
|
112
|
+
f'Encountered: {template_str!r}'
|
113
|
+
)
|
100
114
|
# NOTE(daiyip): This is to avoid warning from accessing
|
101
115
|
# `pg.Object.schema`, which was replaced by `pg.Object.__schema__`.
|
102
116
|
if var_name == 'schema' or not hasattr(cls, var_name):
|
@@ -149,7 +163,7 @@ class Template(
|
|
149
163
|
# TODO(daiyip): Consider to delay template parsing upon usage.
|
150
164
|
unassigned_vars = {}
|
151
165
|
for k in self._variables:
|
152
|
-
if not hasattr(self, k):
|
166
|
+
if k not in ('DEFAULT',) and not hasattr(self, k):
|
153
167
|
unassigned_vars[k] = component.contextual()
|
154
168
|
if unassigned_vars:
|
155
169
|
self.rebind(unassigned_vars, skip_notification=True)
|
@@ -303,19 +317,19 @@ class Template(
|
|
303
317
|
with modality.format_modality_as_ref():
|
304
318
|
rendered_text = self._template.render(**inputs)
|
305
319
|
|
320
|
+
# Carry additional metadata.
|
321
|
+
metadata = self.additional_metadata()
|
322
|
+
|
306
323
|
if self.clean:
|
307
324
|
rendered_text = rendered_text.strip()
|
308
325
|
|
309
|
-
|
310
|
-
|
311
|
-
text=rendered_text,
|
312
|
-
metadata={
|
313
|
-
k: pg.Ref(v)
|
314
|
-
for k, v in inputs.items()
|
315
|
-
if not inspect.ismethod(v)
|
316
|
-
},
|
326
|
+
metadata.update(
|
327
|
+
{k: pg.Ref(v) for k, v in inputs.items() if not inspect.ismethod(v)}
|
317
328
|
)
|
318
329
|
|
330
|
+
# Fill the variables for rendering the template as metadata.
|
331
|
+
message = message_cls(text=rendered_text, metadata=metadata)
|
332
|
+
|
319
333
|
# Tag input as rendered message.
|
320
334
|
message.tag(message_lib.Message.TAG_RENDERED)
|
321
335
|
|
@@ -340,6 +354,20 @@ class Template(
|
|
340
354
|
top = pg.object_utils.thread_local_pop(_TLS_RENDER_STACK)
|
341
355
|
assert top is self, (top, self)
|
342
356
|
|
357
|
+
def additional_metadata(self) -> dict[str, Any]:
|
358
|
+
"""Returns additional metadta to be carried in the rendered message."""
|
359
|
+
metadata = {}
|
360
|
+
# Carry metadata from `lf.context`.
|
361
|
+
for k, v in component.all_contextual_values().items():
|
362
|
+
if k.startswith(_ADDITIONAL_METADATA_PREFIX):
|
363
|
+
metadata[k.removeprefix(_ADDITIONAL_METADATA_PREFIX)] = v
|
364
|
+
|
365
|
+
# Carry metadata from fields.
|
366
|
+
for k, v in self.sym_init_args.items():
|
367
|
+
if k.startswith(_ADDITIONAL_METADATA_PREFIX):
|
368
|
+
metadata[k.removeprefix(_ADDITIONAL_METADATA_PREFIX)] = v
|
369
|
+
return metadata
|
370
|
+
|
343
371
|
#
|
344
372
|
# Implements `pg.typing.CustomTyping`.
|
345
373
|
#
|
@@ -380,6 +408,93 @@ class Template(
|
|
380
408
|
# Override __hash__ since __eq__ has changed.
|
381
409
|
return object.__hash__(self)
|
382
410
|
|
411
|
+
#
|
412
|
+
# Special methods.
|
413
|
+
#
|
414
|
+
|
415
|
+
@property
|
416
|
+
def DEFAULT(self) -> 'Template':
|
417
|
+
"""Referring to the default value used for this template.
|
418
|
+
|
419
|
+
This method is intended to be used in template for referring to the default
|
420
|
+
value of current template. There are two scenarios:
|
421
|
+
|
422
|
+
Scenario 1: Use instance-level template_str to override the class default.
|
423
|
+
|
424
|
+
```
|
425
|
+
class Foo(lf.Template):
|
426
|
+
'''Foo template.
|
427
|
+
|
428
|
+
This is {{x}}.
|
429
|
+
'''
|
430
|
+
|
431
|
+
f = Foo(template_str='<h1>{{DEFAULT}}</h1>', x=1)
|
432
|
+
f.render()
|
433
|
+
|
434
|
+
>> <h1>This is 1.</h1>
|
435
|
+
```
|
436
|
+
|
437
|
+
Scenario 2: Use an ad-hoc template to override a predefined field.
|
438
|
+
|
439
|
+
```
|
440
|
+
class Bar(lf.Template):
|
441
|
+
'''Bar template.
|
442
|
+
|
443
|
+
{{preamble}}
|
444
|
+
{{prompt}}
|
445
|
+
'''
|
446
|
+
preamble: lf.Template = lf.Template('You are a chat bot.')
|
447
|
+
prompt: lf.Template = lf.Template('User: hi')
|
448
|
+
|
449
|
+
b = Bar(preamble=lf.Template('<h1>{{DEFAULT}}<h1>'),
|
450
|
+
prompt=lf.Template('<h2>{{DEFAULT}}</h2>')
|
451
|
+
b.render()
|
452
|
+
|
453
|
+
>> <h1>You are a chat bot.<h1>
|
454
|
+
>> <h2>User: hi</h2>
|
455
|
+
```
|
456
|
+
|
457
|
+
Returns:
|
458
|
+
The default (pre-configured) value used for this template.
|
459
|
+
"""
|
460
|
+
base_template = self.__class__.__schema__['template_str'].default_value
|
461
|
+
if base_template == pg.MISSING_VALUE:
|
462
|
+
if not self.sym_path:
|
463
|
+
raise ValueError(
|
464
|
+
f'No DEFAULT template found for {self!r}: '
|
465
|
+
'The template neither has a default `template_str` nor is '
|
466
|
+
'contained under another object.'
|
467
|
+
)
|
468
|
+
key = self.sym_path.key
|
469
|
+
assert self.sym_parent is not None
|
470
|
+
assigned_field = self.sym_parent.sym_attr_field(key)
|
471
|
+
container_cls = self.sym_parent.__class__
|
472
|
+
|
473
|
+
if (
|
474
|
+
assigned_field is None
|
475
|
+
or assigned_field.default_value == pg.MISSING_VALUE
|
476
|
+
):
|
477
|
+
raise ValueError(
|
478
|
+
f'No DEFAULT template found for {self!r}: '
|
479
|
+
f'`{container_cls.__name__}.{key}` '
|
480
|
+
'does not have a default value. '
|
481
|
+
)
|
482
|
+
base_template = assigned_field.default_value
|
483
|
+
if isinstance(base_template, Template):
|
484
|
+
base_template = base_template.template_str
|
485
|
+
if not isinstance(base_template, str):
|
486
|
+
raise ValueError(
|
487
|
+
f'No DEFAULT template found for {self!r}: The default '
|
488
|
+
f'value {base_template!r} of '
|
489
|
+
f'`{container_cls.__name__}.{key}` is not a '
|
490
|
+
'`lf.Template` object or str.'
|
491
|
+
)
|
492
|
+
t = Template(base_template)
|
493
|
+
# NOTE(daiyip): Set the parent of the newly created template to self so
|
494
|
+
# it could access all the contextual variables.
|
495
|
+
t.sym_setparent(self)
|
496
|
+
return t
|
497
|
+
|
383
498
|
|
384
499
|
# Register converter from str to LangFunc, therefore we can always
|
385
500
|
# pass strs to attributes that accept LangFunc.
|
langfun/core/template_test.py
CHANGED
@@ -16,6 +16,7 @@ import inspect
|
|
16
16
|
import unittest
|
17
17
|
|
18
18
|
from langfun.core import component
|
19
|
+
from langfun.core import message as message_lib
|
19
20
|
from langfun.core import modality
|
20
21
|
from langfun.core import subscription
|
21
22
|
from langfun.core.template import Template
|
@@ -311,6 +312,72 @@ class RenderTest(unittest.TestCase):
|
|
311
312
|
'This is 1 and {{a}}',
|
312
313
|
)
|
313
314
|
|
315
|
+
def test_render_with_default(self):
|
316
|
+
|
317
|
+
class Foo(Template):
|
318
|
+
"""Foo.
|
319
|
+
|
320
|
+
This is {{x}}
|
321
|
+
"""
|
322
|
+
|
323
|
+
f = Foo(template_str='!{{DEFAULT}}!', x=1)
|
324
|
+
self.assertEqual(f.DEFAULT.x, 1)
|
325
|
+
self.assertEqual(
|
326
|
+
f.render(), '!This is 1!'
|
327
|
+
)
|
328
|
+
|
329
|
+
class Bar(Template):
|
330
|
+
"""Bar.
|
331
|
+
|
332
|
+
{{preamble}}
|
333
|
+
{{prompt}}
|
334
|
+
"""
|
335
|
+
|
336
|
+
preamble: Template = Template('You are a chat bot.')
|
337
|
+
prompt: Template = Template('User: hi! {{name}}')
|
338
|
+
|
339
|
+
b = Bar(
|
340
|
+
preamble=Template('<h1>{{DEFAULT}}</h1>'),
|
341
|
+
prompt=Template('<h2>{{DEFAULT}}</h2>'),
|
342
|
+
name='Tom',
|
343
|
+
)
|
344
|
+
# Test variable access.
|
345
|
+
self.assertEqual(
|
346
|
+
b.render(),
|
347
|
+
inspect.cleandoc("""
|
348
|
+
<h1>You are a chat bot.</h1>
|
349
|
+
<h2>User: hi! Tom</h2>
|
350
|
+
"""),
|
351
|
+
)
|
352
|
+
|
353
|
+
with self.assertRaisesRegex(ValueError, '`{{ DEFAULT }}` cannot be used'):
|
354
|
+
|
355
|
+
class Baz(Template): # pylint: disable=unused-variable
|
356
|
+
"""Baz.
|
357
|
+
|
358
|
+
{{DEFAULT}}
|
359
|
+
"""
|
360
|
+
|
361
|
+
with self.assertRaisesRegex(
|
362
|
+
ValueError, 'The template neither has a default `template_str` nor'
|
363
|
+
):
|
364
|
+
Template('{{DEFAULT}}').render()
|
365
|
+
|
366
|
+
d = pg.Dict(x=Template('{{DEFAULT}}'))
|
367
|
+
with self.assertRaisesRegex(
|
368
|
+
ValueError, 'does not have a default value'
|
369
|
+
):
|
370
|
+
_ = d.x.DEFAULT
|
371
|
+
|
372
|
+
class Tes(pg.Object):
|
373
|
+
x: str | None = None
|
374
|
+
|
375
|
+
t = Tes(x=Template('{{DEFAULT}}'))
|
376
|
+
with self.assertRaisesRegex(
|
377
|
+
ValueError, 'is not a `lf.Template` object or str'
|
378
|
+
):
|
379
|
+
_ = t.x.DEFAULT
|
380
|
+
|
314
381
|
def test_bad_render(self):
|
315
382
|
with self.assertRaises(ValueError):
|
316
383
|
Template('Hello {{x}}').render(allow_partial=False)
|
@@ -427,6 +494,14 @@ class RenderTest(unittest.TestCase):
|
|
427
494
|
# Test len.
|
428
495
|
self.assert_partial(Template('Hello {{len(x)}}'), 'Hello {{len(x)}}')
|
429
496
|
|
497
|
+
def test_additional_metadata(self):
|
498
|
+
t = Template('hi', metadata_weights=1.0, y=2)
|
499
|
+
self.assertEqual(t.render(), message_lib.UserMessage('hi', weights=1.0))
|
500
|
+
|
501
|
+
t = Template('hi')
|
502
|
+
with component.context(metadata_weights=1.0, y=2):
|
503
|
+
self.assertEqual(t.render(), message_lib.UserMessage('hi', weights=1.0))
|
504
|
+
|
430
505
|
|
431
506
|
class TemplateRenderEventTest(unittest.TestCase):
|
432
507
|
|
@@ -56,7 +56,9 @@ class SelfPlayTest(unittest.TestCase):
|
|
56
56
|
g = NumberGuess(target_num=10)
|
57
57
|
|
58
58
|
with lf.context(lm=NumberGuesser(guesses=[50, 20, 5, 10])):
|
59
|
-
self.assertEqual(
|
59
|
+
self.assertEqual(
|
60
|
+
g(), lf.AIMessage('10', score=0.0, logprobs=None, usage=None)
|
61
|
+
)
|
60
62
|
|
61
63
|
self.assertEqual(g.num_turns, 4)
|
62
64
|
|
@@ -64,7 +66,9 @@ class SelfPlayTest(unittest.TestCase):
|
|
64
66
|
g = NumberGuess(target_num=10, max_turns=10)
|
65
67
|
|
66
68
|
with lf.context(lm=NumberGuesser(guesses=[50, 20, 5, 2, 5, 4])):
|
67
|
-
self.assertEqual(
|
69
|
+
self.assertEqual(
|
70
|
+
g(), lf.AIMessage('2', score=0.0, logprobs=None, usage=None)
|
71
|
+
)
|
68
72
|
|
69
73
|
self.assertEqual(g.num_turns, 10)
|
70
74
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: langfun
|
3
|
-
Version: 0.0.2.
|
3
|
+
Version: 0.0.2.dev20240429
|
4
4
|
Summary: Langfun: Language as Functions.
|
5
5
|
Home-page: https://github.com/google/langfun
|
6
6
|
Author: Langfun Authors
|
@@ -21,10 +21,11 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
21
21
|
Classifier: Topic :: Software Development :: Libraries
|
22
22
|
Description-Content-Type: text/markdown
|
23
23
|
License-File: LICENSE
|
24
|
+
Requires-Dist: absl-py >=1.0.0
|
24
25
|
Requires-Dist: google-generativeai >=0.3.2
|
25
26
|
Requires-Dist: jinja2 >=3.1.2
|
26
27
|
Requires-Dist: openai ==0.27.2
|
27
|
-
Requires-Dist: pyglove >=0.4.5.
|
28
|
+
Requires-Dist: pyglove >=0.4.5.dev20240423
|
28
29
|
Requires-Dist: python-magic >=0.4.27
|
29
30
|
Requires-Dist: requests >=2.31.0
|
30
31
|
Requires-Dist: termcolor ==1.1.0
|