langfun 0.0.2.dev20240330__py3-none-any.whl → 0.0.2.dev20240429__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. langfun/__init__.py +2 -0
  2. langfun/core/__init__.py +1 -0
  3. langfun/core/coding/python/correction.py +0 -7
  4. langfun/core/component.py +6 -0
  5. langfun/core/component_test.py +1 -0
  6. langfun/core/eval/__init__.py +2 -0
  7. langfun/core/eval/base.py +202 -23
  8. langfun/core/eval/base_test.py +49 -10
  9. langfun/core/eval/matching.py +26 -9
  10. langfun/core/eval/matching_test.py +2 -1
  11. langfun/core/eval/scoring.py +15 -6
  12. langfun/core/eval/scoring_test.py +2 -1
  13. langfun/core/langfunc.py +0 -5
  14. langfun/core/langfunc_test.py +6 -4
  15. langfun/core/language_model.py +124 -24
  16. langfun/core/language_model_test.py +249 -26
  17. langfun/core/llms/__init__.py +19 -2
  18. langfun/core/llms/anthropic.py +263 -0
  19. langfun/core/llms/anthropic_test.py +167 -0
  20. langfun/core/llms/cache/in_memory_test.py +37 -28
  21. langfun/core/llms/fake.py +31 -22
  22. langfun/core/llms/fake_test.py +122 -11
  23. langfun/core/llms/google_genai_test.py +8 -3
  24. langfun/core/llms/groq.py +260 -0
  25. langfun/core/llms/groq_test.py +170 -0
  26. langfun/core/llms/llama_cpp.py +3 -1
  27. langfun/core/llms/openai.py +97 -79
  28. langfun/core/llms/openai_test.py +285 -59
  29. langfun/core/modalities/video.py +5 -2
  30. langfun/core/structured/__init__.py +3 -0
  31. langfun/core/structured/completion_test.py +2 -2
  32. langfun/core/structured/function_generation.py +245 -0
  33. langfun/core/structured/function_generation_test.py +329 -0
  34. langfun/core/structured/mapping.py +56 -2
  35. langfun/core/structured/mapping_test.py +17 -0
  36. langfun/core/structured/parsing_test.py +18 -13
  37. langfun/core/structured/prompting.py +27 -6
  38. langfun/core/structured/prompting_test.py +79 -12
  39. langfun/core/structured/schema.py +4 -2
  40. langfun/core/structured/schema_generation_test.py +2 -2
  41. langfun/core/structured/schema_test.py +4 -6
  42. langfun/core/template.py +125 -10
  43. langfun/core/template_test.py +75 -0
  44. langfun/core/templates/selfplay_test.py +6 -2
  45. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/METADATA +3 -2
  46. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/RECORD +49 -43
  47. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/LICENSE +0 -0
  48. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/WHEEL +0 -0
  49. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/top_level.txt +0 -0
@@ -17,11 +17,9 @@ import inspect
17
17
  import unittest
18
18
 
19
19
  import langfun.core as lf
20
- from langfun.core import coding
21
20
  from langfun.core.llms import fake
22
21
  from langfun.core.structured import mapping
23
22
  from langfun.core.structured import parsing
24
- from langfun.core.structured import schema as schema_lib
25
23
  import pyglove as pg
26
24
 
27
25
 
@@ -255,7 +253,7 @@ class ParseStructurePythonTest(unittest.TestCase):
255
253
  override_attrs=True,
256
254
  ):
257
255
  with self.assertRaisesRegex(
258
- coding.CodeError,
256
+ mapping.MappingError,
259
257
  'name .* is not defined',
260
258
  ):
261
259
  parsing.parse('three', int)
@@ -280,13 +278,15 @@ class ParseStructurePythonTest(unittest.TestCase):
280
278
  ),
281
279
  1,
282
280
  )
281
+ r = parsing.parse(
282
+ 'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm,
283
+ returns_message=True
284
+ )
283
285
  self.assertEqual(
284
- parsing.parse(
285
- 'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm,
286
- returns_message=True
287
- ),
286
+ r,
288
287
  lf.AIMessage(
289
288
  '1', score=1.0, result=1, logprobs=None,
289
+ usage=lf.LMSamplingUsage(652, 1, 653),
290
290
  tags=['lm-response', 'lm-output', 'transformed']
291
291
  ),
292
292
  )
@@ -544,7 +544,7 @@ class ParseStructureJsonTest(unittest.TestCase):
544
544
  override_attrs=True,
545
545
  ):
546
546
  with self.assertRaisesRegex(
547
- schema_lib.JsonError,
547
+ mapping.MappingError,
548
548
  'No JSON dict in the output',
549
549
  ):
550
550
  parsing.parse('three', int, protocol='json')
@@ -634,13 +634,18 @@ class CallTest(unittest.TestCase):
634
634
  )
635
635
 
636
636
  def test_call_with_returning_message(self):
637
+ r = parsing.call(
638
+ 'Compute 1 + 2', int, lm=fake.StaticSequence(['three', '3']),
639
+ returns_message=True
640
+ )
637
641
  self.assertEqual(
638
- parsing.call(
639
- 'Compute 1 + 2', int, lm=fake.StaticSequence(['three', '3']),
640
- returns_message=True
641
- ),
642
+ r,
642
643
  lf.AIMessage(
643
- '3', result=3, score=1.0, logprobs=None,
644
+ '3',
645
+ result=3,
646
+ score=1.0,
647
+ logprobs=None,
648
+ usage=lf.LMSamplingUsage(315, 1, 316),
644
649
  tags=['lm-response', 'lm-output', 'transformed']
645
650
  ),
646
651
  )
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  """Symbolic query."""
15
15
 
16
- from typing import Any, Type, Union
16
+ from typing import Any, Callable, Type, Union
17
17
 
18
18
  import langfun.core as lf
19
19
  from langfun.core.structured import mapping
@@ -78,7 +78,9 @@ class QueryStructurePython(QueryStructure):
78
78
 
79
79
  {{ output_title }}:
80
80
  ```python
81
- Answer(final_answer=2)
81
+ Answer(
82
+ final_answer=2
83
+ )
82
84
  ```
83
85
  """
84
86
  protocol = 'python'
@@ -107,6 +109,7 @@ def query(
107
109
  lm: lf.LanguageModel | None = None,
108
110
  examples: list[mapping.MappingExample] | None = None,
109
111
  cache_seed: int | None = 0,
112
+ response_postprocess: Callable[[str], str] | None = None,
110
113
  autofix: int = 0,
111
114
  autofix_lm: lf.LanguageModel | None = None,
112
115
  protocol: schema_lib.SchemaProtocol = 'python',
@@ -159,6 +162,9 @@ def query(
159
162
  cache_seed: Seed for computing cache key. The cache key is determined by a
160
163
  tuple of (lm, prompt, cache seed). If None, cache will be disabled for
161
164
  the query even cache is configured by the LM.
165
+ response_postprocess: An optional callable object to process the raw LM
166
+ response before parsing it into the final output object. If None, the
167
+ raw LM response will not be processed.
162
168
  autofix: Number of attempts to auto fix the generated code. If 0, autofix is
163
169
  disabled. Auto-fix is not supported for 'json' protocol.
164
170
  autofix_lm: The language model to use for autofix. If not specified, the
@@ -170,8 +176,11 @@ def query(
170
176
  returning the structured `message.result`.
171
177
  skip_lm: If True, returns the rendered prompt as a UserMessage object.
172
178
  otherwise return the LLM response based on the rendered prompt.
173
- **kwargs: Keyword arguments passed to the
174
- `lf.structured.NaturalLanguageToStructureed` transform.
179
+ **kwargs: Keyword arguments passed to render the prompt or configure the
180
+ `lf.structured.Mapping` class. Notable kwargs are:
181
+ - template_str: Change the root template for query.
182
+ - preamble: Change the preamble for query.
183
+ - mapping_template: Change the template for each mapping examle.
175
184
 
176
185
  Returns:
177
186
  The result based on the schema.
@@ -188,13 +197,24 @@ def query(
188
197
  output = lf.LangFunc.from_value(prompt, **kwargs)(
189
198
  lm=lm, cache_seed=cache_seed, skip_lm=skip_lm
190
199
  )
200
+ if response_postprocess:
201
+ processed_text = response_postprocess(output.text)
202
+ if processed_text != output.text:
203
+ output = lf.AIMessage(processed_text, source=output)
191
204
  return output if returns_message else output.text
192
205
 
193
206
  # Query with structured output.
207
+ prompt_kwargs = kwargs.copy()
208
+
209
+ # NOTE(daiyip): when `template_str` is passed in, it's intended to modify the
210
+ # QueryStructure template string. Therefore, we pop out the argument for
211
+ # prompt rendering.
212
+ prompt_kwargs.pop('template_str', None)
213
+
194
214
  if isinstance(prompt, str):
195
- prompt = lf.Template(prompt, **kwargs)
215
+ prompt = lf.Template(prompt, **prompt_kwargs)
196
216
  elif isinstance(prompt, lf.Template):
197
- prompt = prompt.rebind(**kwargs)
217
+ prompt = prompt.rebind(**prompt_kwargs, raise_on_no_change=False)
198
218
 
199
219
  if isinstance(prompt, lf.Template):
200
220
  prompt = prompt.render(lm=lm)
@@ -206,6 +226,7 @@ def query(
206
226
  schema=schema,
207
227
  default=default,
208
228
  examples=examples,
229
+ response_postprocess=response_postprocess,
209
230
  autofix=autofix if protocol == 'python' else 0,
210
231
  **kwargs,
211
232
  )(
@@ -17,12 +17,10 @@ import inspect
17
17
  import unittest
18
18
 
19
19
  import langfun.core as lf
20
- from langfun.core import coding
21
20
  from langfun.core import modalities
22
21
  from langfun.core.llms import fake
23
22
  from langfun.core.structured import mapping
24
23
  from langfun.core.structured import prompting
25
- from langfun.core.structured import schema as schema_lib
26
24
  import pyglove as pg
27
25
 
28
26
 
@@ -77,6 +75,7 @@ class QueryTest(unittest.TestCase):
77
75
  result=1,
78
76
  score=1.0,
79
77
  logprobs=None,
78
+ usage=lf.LMSamplingUsage(323, 1, 324),
80
79
  tags=['lm-response', 'lm-output', 'transformed'],
81
80
  ),
82
81
  )
@@ -116,12 +115,59 @@ class QueryTest(unittest.TestCase):
116
115
  y=2,
117
116
  lm=lm.clone(),
118
117
  expected_snippet=(
119
- 'Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT'
120
- ' according to OUTPUT_TYPE.\n\nINPUT_OBJECT:\n 1 + 1'
121
- ' =\n\nOUTPUT_TYPE:\n Answer\n\n ```python\n class Answer:\n '
122
- ' final_answer: int\n ```\n\nOUTPUT_OBJECT:\n ```python\n '
123
- ' Answer(final_answer=2)\n ```\n\nINPUT_OBJECT:\n What is 1 +'
124
- ' 2?\n\nOUTPUT_TYPE:\n int\n\nOUTPUT_OBJECT:'
118
+ 'Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT '
119
+ 'according to OUTPUT_TYPE.\n\n'
120
+ 'INPUT_OBJECT:\n 1 + 1 =\n\n'
121
+ 'OUTPUT_TYPE:\n'
122
+ ' Answer\n\n'
123
+ ' ```python\n'
124
+ ' class Answer:\n'
125
+ ' final_answer: int\n'
126
+ ' ```\n\n'
127
+ 'OUTPUT_OBJECT:\n'
128
+ ' ```python\n'
129
+ ' Answer(\n'
130
+ ' final_answer=2\n'
131
+ ' )\n'
132
+ ' ```\n\n'
133
+ 'INPUT_OBJECT:\n'
134
+ ' What is 1 + 2?\n\n'
135
+ 'OUTPUT_TYPE:\n'
136
+ ' int\n\n'
137
+ 'OUTPUT_OBJECT:'
138
+ ),
139
+ )
140
+
141
+ def test_str_to_structure_render_custom_template(self):
142
+ lm = fake.StaticResponse('1')
143
+ self.assert_render(
144
+ 'What is {{x}} + {{y}}?',
145
+ int,
146
+ x=1,
147
+ y=2,
148
+ lm=lm.clone(),
149
+ template_str='!!{{ DEFAULT }}!!',
150
+ expected_snippet=(
151
+ '!!Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT '
152
+ 'according to OUTPUT_TYPE.\n\n'
153
+ 'INPUT_OBJECT:\n 1 + 1 =\n\n'
154
+ 'OUTPUT_TYPE:\n'
155
+ ' Answer\n\n'
156
+ ' ```python\n'
157
+ ' class Answer:\n'
158
+ ' final_answer: int\n'
159
+ ' ```\n\n'
160
+ 'OUTPUT_OBJECT:\n'
161
+ ' ```python\n'
162
+ ' Answer(\n'
163
+ ' final_answer=2\n'
164
+ ' )\n'
165
+ ' ```\n\n'
166
+ 'INPUT_OBJECT:\n'
167
+ ' What is 1 + 2?\n\n'
168
+ 'OUTPUT_TYPE:\n'
169
+ ' int\n\n'
170
+ 'OUTPUT_OBJECT:!!'
125
171
  ),
126
172
  )
127
173
 
@@ -264,7 +310,9 @@ class QueryStructurePythonTest(unittest.TestCase):
264
310
 
265
311
  OUTPUT_OBJECT:
266
312
  ```python
267
- Answer(final_answer=2)
313
+ Answer(
314
+ final_answer=2
315
+ )
268
316
  ```
269
317
 
270
318
  INPUT_OBJECT:
@@ -308,7 +356,9 @@ class QueryStructurePythonTest(unittest.TestCase):
308
356
 
309
357
  OUTPUT_OBJECT:
310
358
  ```python
311
- Answer(final_answer=2)
359
+ Answer(
360
+ final_answer=2
361
+ )
312
362
  ```
313
363
 
314
364
  INPUT_OBJECT:
@@ -420,7 +470,7 @@ class QueryStructurePythonTest(unittest.TestCase):
420
470
  override_attrs=True,
421
471
  ):
422
472
  with self.assertRaisesRegex(
423
- coding.CodeError,
473
+ mapping.MappingError,
424
474
  'name .* is not defined',
425
475
  ):
426
476
  prompting.query('Compute 1 + 2', int)
@@ -436,6 +486,23 @@ class QueryStructurePythonTest(unittest.TestCase):
436
486
  ])
437
487
  self.assertEqual(prompting.query('what is 1 + 0', int, lm=lm, autofix=3), 1)
438
488
 
489
+ def test_response_postprocess(self):
490
+ with lf.context(
491
+ lm=fake.StaticResponse('<!-- some comment-->\n3'),
492
+ override_attrs=True,
493
+ ):
494
+ self.assertEqual(
495
+ prompting.query(
496
+ 'Compute 1 + 2', response_postprocess=lambda x: x.split('\n')[1]),
497
+ '3'
498
+ )
499
+ self.assertEqual(
500
+ prompting.query(
501
+ 'Compute 1 + 2', int,
502
+ response_postprocess=lambda x: x.split('\n')[1]),
503
+ 3
504
+ )
505
+
439
506
 
440
507
  class QueryStructureJsonTest(unittest.TestCase):
441
508
 
@@ -641,7 +708,7 @@ class QueryStructureJsonTest(unittest.TestCase):
641
708
  override_attrs=True,
642
709
  ):
643
710
  with self.assertRaisesRegex(
644
- schema_lib.JsonError,
711
+ mapping.MappingError,
645
712
  'No JSON dict in the output',
646
713
  ):
647
714
  prompting.query('Compute 1 + 2', int, protocol='json')
@@ -386,10 +386,12 @@ def class_definition(
386
386
  if schema.fields:
387
387
  for key, field in schema.items():
388
388
  if not isinstance(key, pg.typing.ConstStrKey):
389
- raise TypeError(
389
+ pg.logging.warning(
390
390
  'Variable-length keyword arguments is not supported in '
391
- f'structured parsing or query. Encountered: {field}'
391
+ f'structured parsing or query. Encountered: {cls}, Schema: {schema}'
392
392
  )
393
+ continue
394
+
393
395
  # Write field doc string as comments before the field definition.
394
396
  if field.description:
395
397
  for line in field.description.split('\n'):
@@ -14,8 +14,8 @@
14
14
  import inspect
15
15
  import unittest
16
16
 
17
- import langfun.core.coding as lf_coding
18
17
  from langfun.core.llms import fake
18
+ from langfun.core.structured import mapping
19
19
  from langfun.core.structured import schema_generation
20
20
 
21
21
 
@@ -92,7 +92,7 @@ class GenerateClassTest(unittest.TestCase):
92
92
  )
93
93
  self.assertIs(cls.__name__, 'B')
94
94
 
95
- with self.assertRaises(lf_coding.CodeError):
95
+ with self.assertRaises(mapping.MappingError):
96
96
  schema_generation.generate_class(
97
97
  'Foo',
98
98
  'Generate a Foo class with a field pointing to another class A',
@@ -192,9 +192,9 @@ class SchemaTest(unittest.TestCase):
192
192
  self.assertEqual(schema.parse('{"result": 1}'), 1)
193
193
  schema = schema_lib.Schema(dict[str, int])
194
194
  self.assertEqual(
195
- schema.parse(
196
- '{"result": {"_type": "Unknown", "x": 1}}}', force_dict=True),
197
- dict(x=1))
195
+ schema.parse('{"result": {"x": 1}}}'),
196
+ dict(x=1)
197
+ )
198
198
  with self.assertRaisesRegex(
199
199
  schema_lib.SchemaError, 'Expect .* but encountered .*'):
200
200
  schema.parse('{"result": "def"}')
@@ -459,9 +459,7 @@ class SchemaPythonReprTest(unittest.TestCase):
459
459
  x: str
460
460
  __kwargs__: typing.Any
461
461
 
462
- with self.assertRaisesRegex(
463
- TypeError, 'Variable-length keyword arguments is not supported'):
464
- schema_lib.class_definition(C)
462
+ self.assertEqual(schema_lib.class_definition(C), 'class C:\n x: str\n')
465
463
 
466
464
  def test_repr(self):
467
465
  class Foo(pg.Object):
langfun/core/template.py CHANGED
@@ -38,13 +38,22 @@ NO_TEMPLATE_DOCSTR_SIGN = 'THIS IS NOT A TEMPLATE'
38
38
  _TLS_RENDER_STACK = '_template_render_stack'
39
39
  _TLS_RENDER_RESULT_CACHE = '_template_render_result_cache'
40
40
 
41
+ # The prefix for fields or contextual attributes to be treated as additional
42
+ # metadata for rendered message.
43
+ _ADDITIONAL_METADATA_PREFIX = 'metadata_'
44
+
41
45
 
42
46
  class Template(
43
47
  natural_language.NaturalLanguageFormattable,
44
48
  component.Component,
45
49
  pg.typing.CustomTyping,
46
50
  ):
47
- """Langfun string template."""
51
+ """Langfun string template.
52
+
53
+ Langfun uses jinja2 as its template engine. Pleaes check out
54
+ https://jinja.palletsprojects.com/en/3.1.x/templates/ for detailed
55
+ explanation on the template language.
56
+ """
48
57
 
49
58
  template_str: Annotated[
50
59
  str,
@@ -97,6 +106,11 @@ class Template(
97
106
  # Declare template variables as symbolic attributes.
98
107
  template_vars = Template.resolve_vars(template_str)
99
108
  for var_name in template_vars:
109
+ if 'DEFAULT' == var_name:
110
+ raise ValueError(
111
+ '`{{ DEFAULT }}` cannot be used in pre-configured templates. '
112
+ f'Encountered: {template_str!r}'
113
+ )
100
114
  # NOTE(daiyip): This is to avoid warning from accessing
101
115
  # `pg.Object.schema`, which was replaced by `pg.Object.__schema__`.
102
116
  if var_name == 'schema' or not hasattr(cls, var_name):
@@ -149,7 +163,7 @@ class Template(
149
163
  # TODO(daiyip): Consider to delay template parsing upon usage.
150
164
  unassigned_vars = {}
151
165
  for k in self._variables:
152
- if not hasattr(self, k):
166
+ if k not in ('DEFAULT',) and not hasattr(self, k):
153
167
  unassigned_vars[k] = component.contextual()
154
168
  if unassigned_vars:
155
169
  self.rebind(unassigned_vars, skip_notification=True)
@@ -303,19 +317,19 @@ class Template(
303
317
  with modality.format_modality_as_ref():
304
318
  rendered_text = self._template.render(**inputs)
305
319
 
320
+ # Carry additional metadata.
321
+ metadata = self.additional_metadata()
322
+
306
323
  if self.clean:
307
324
  rendered_text = rendered_text.strip()
308
325
 
309
- # Fill the variables for rendering the template as metadata.
310
- message = message_cls(
311
- text=rendered_text,
312
- metadata={
313
- k: pg.Ref(v)
314
- for k, v in inputs.items()
315
- if not inspect.ismethod(v)
316
- },
326
+ metadata.update(
327
+ {k: pg.Ref(v) for k, v in inputs.items() if not inspect.ismethod(v)}
317
328
  )
318
329
 
330
+ # Fill the variables for rendering the template as metadata.
331
+ message = message_cls(text=rendered_text, metadata=metadata)
332
+
319
333
  # Tag input as rendered message.
320
334
  message.tag(message_lib.Message.TAG_RENDERED)
321
335
 
@@ -340,6 +354,20 @@ class Template(
340
354
  top = pg.object_utils.thread_local_pop(_TLS_RENDER_STACK)
341
355
  assert top is self, (top, self)
342
356
 
357
+ def additional_metadata(self) -> dict[str, Any]:
358
+ """Returns additional metadta to be carried in the rendered message."""
359
+ metadata = {}
360
+ # Carry metadata from `lf.context`.
361
+ for k, v in component.all_contextual_values().items():
362
+ if k.startswith(_ADDITIONAL_METADATA_PREFIX):
363
+ metadata[k.removeprefix(_ADDITIONAL_METADATA_PREFIX)] = v
364
+
365
+ # Carry metadata from fields.
366
+ for k, v in self.sym_init_args.items():
367
+ if k.startswith(_ADDITIONAL_METADATA_PREFIX):
368
+ metadata[k.removeprefix(_ADDITIONAL_METADATA_PREFIX)] = v
369
+ return metadata
370
+
343
371
  #
344
372
  # Implements `pg.typing.CustomTyping`.
345
373
  #
@@ -380,6 +408,93 @@ class Template(
380
408
  # Override __hash__ since __eq__ has changed.
381
409
  return object.__hash__(self)
382
410
 
411
+ #
412
+ # Special methods.
413
+ #
414
+
415
+ @property
416
+ def DEFAULT(self) -> 'Template':
417
+ """Referring to the default value used for this template.
418
+
419
+ This method is intended to be used in template for referring to the default
420
+ value of current template. There are two scenarios:
421
+
422
+ Scenario 1: Use instance-level template_str to override the class default.
423
+
424
+ ```
425
+ class Foo(lf.Template):
426
+ '''Foo template.
427
+
428
+ This is {{x}}.
429
+ '''
430
+
431
+ f = Foo(template_str='<h1>{{DEFAULT}}</h1>', x=1)
432
+ f.render()
433
+
434
+ >> <h1>This is 1.</h1>
435
+ ```
436
+
437
+ Scenario 2: Use an ad-hoc template to override a predefined field.
438
+
439
+ ```
440
+ class Bar(lf.Template):
441
+ '''Bar template.
442
+
443
+ {{preamble}}
444
+ {{prompt}}
445
+ '''
446
+ preamble: lf.Template = lf.Template('You are a chat bot.')
447
+ prompt: lf.Template = lf.Template('User: hi')
448
+
449
+ b = Bar(preamble=lf.Template('<h1>{{DEFAULT}}<h1>'),
450
+ prompt=lf.Template('<h2>{{DEFAULT}}</h2>')
451
+ b.render()
452
+
453
+ >> <h1>You are a chat bot.<h1>
454
+ >> <h2>User: hi</h2>
455
+ ```
456
+
457
+ Returns:
458
+ The default (pre-configured) value used for this template.
459
+ """
460
+ base_template = self.__class__.__schema__['template_str'].default_value
461
+ if base_template == pg.MISSING_VALUE:
462
+ if not self.sym_path:
463
+ raise ValueError(
464
+ f'No DEFAULT template found for {self!r}: '
465
+ 'The template neither has a default `template_str` nor is '
466
+ 'contained under another object.'
467
+ )
468
+ key = self.sym_path.key
469
+ assert self.sym_parent is not None
470
+ assigned_field = self.sym_parent.sym_attr_field(key)
471
+ container_cls = self.sym_parent.__class__
472
+
473
+ if (
474
+ assigned_field is None
475
+ or assigned_field.default_value == pg.MISSING_VALUE
476
+ ):
477
+ raise ValueError(
478
+ f'No DEFAULT template found for {self!r}: '
479
+ f'`{container_cls.__name__}.{key}` '
480
+ 'does not have a default value. '
481
+ )
482
+ base_template = assigned_field.default_value
483
+ if isinstance(base_template, Template):
484
+ base_template = base_template.template_str
485
+ if not isinstance(base_template, str):
486
+ raise ValueError(
487
+ f'No DEFAULT template found for {self!r}: The default '
488
+ f'value {base_template!r} of '
489
+ f'`{container_cls.__name__}.{key}` is not a '
490
+ '`lf.Template` object or str.'
491
+ )
492
+ t = Template(base_template)
493
+ # NOTE(daiyip): Set the parent of the newly created template to self so
494
+ # it could access all the contextual variables.
495
+ t.sym_setparent(self)
496
+ return t
497
+
383
498
 
384
499
  # Register converter from str to LangFunc, therefore we can always
385
500
  # pass strs to attributes that accept LangFunc.
@@ -16,6 +16,7 @@ import inspect
16
16
  import unittest
17
17
 
18
18
  from langfun.core import component
19
+ from langfun.core import message as message_lib
19
20
  from langfun.core import modality
20
21
  from langfun.core import subscription
21
22
  from langfun.core.template import Template
@@ -311,6 +312,72 @@ class RenderTest(unittest.TestCase):
311
312
  'This is 1 and {{a}}',
312
313
  )
313
314
 
315
+ def test_render_with_default(self):
316
+
317
+ class Foo(Template):
318
+ """Foo.
319
+
320
+ This is {{x}}
321
+ """
322
+
323
+ f = Foo(template_str='!{{DEFAULT}}!', x=1)
324
+ self.assertEqual(f.DEFAULT.x, 1)
325
+ self.assertEqual(
326
+ f.render(), '!This is 1!'
327
+ )
328
+
329
+ class Bar(Template):
330
+ """Bar.
331
+
332
+ {{preamble}}
333
+ {{prompt}}
334
+ """
335
+
336
+ preamble: Template = Template('You are a chat bot.')
337
+ prompt: Template = Template('User: hi! {{name}}')
338
+
339
+ b = Bar(
340
+ preamble=Template('<h1>{{DEFAULT}}</h1>'),
341
+ prompt=Template('<h2>{{DEFAULT}}</h2>'),
342
+ name='Tom',
343
+ )
344
+ # Test variable access.
345
+ self.assertEqual(
346
+ b.render(),
347
+ inspect.cleandoc("""
348
+ <h1>You are a chat bot.</h1>
349
+ <h2>User: hi! Tom</h2>
350
+ """),
351
+ )
352
+
353
+ with self.assertRaisesRegex(ValueError, '`{{ DEFAULT }}` cannot be used'):
354
+
355
+ class Baz(Template): # pylint: disable=unused-variable
356
+ """Baz.
357
+
358
+ {{DEFAULT}}
359
+ """
360
+
361
+ with self.assertRaisesRegex(
362
+ ValueError, 'The template neither has a default `template_str` nor'
363
+ ):
364
+ Template('{{DEFAULT}}').render()
365
+
366
+ d = pg.Dict(x=Template('{{DEFAULT}}'))
367
+ with self.assertRaisesRegex(
368
+ ValueError, 'does not have a default value'
369
+ ):
370
+ _ = d.x.DEFAULT
371
+
372
+ class Tes(pg.Object):
373
+ x: str | None = None
374
+
375
+ t = Tes(x=Template('{{DEFAULT}}'))
376
+ with self.assertRaisesRegex(
377
+ ValueError, 'is not a `lf.Template` object or str'
378
+ ):
379
+ _ = t.x.DEFAULT
380
+
314
381
  def test_bad_render(self):
315
382
  with self.assertRaises(ValueError):
316
383
  Template('Hello {{x}}').render(allow_partial=False)
@@ -427,6 +494,14 @@ class RenderTest(unittest.TestCase):
427
494
  # Test len.
428
495
  self.assert_partial(Template('Hello {{len(x)}}'), 'Hello {{len(x)}}')
429
496
 
497
+ def test_additional_metadata(self):
498
+ t = Template('hi', metadata_weights=1.0, y=2)
499
+ self.assertEqual(t.render(), message_lib.UserMessage('hi', weights=1.0))
500
+
501
+ t = Template('hi')
502
+ with component.context(metadata_weights=1.0, y=2):
503
+ self.assertEqual(t.render(), message_lib.UserMessage('hi', weights=1.0))
504
+
430
505
 
431
506
  class TemplateRenderEventTest(unittest.TestCase):
432
507
 
@@ -56,7 +56,9 @@ class SelfPlayTest(unittest.TestCase):
56
56
  g = NumberGuess(target_num=10)
57
57
 
58
58
  with lf.context(lm=NumberGuesser(guesses=[50, 20, 5, 10])):
59
- self.assertEqual(g(), lf.AIMessage('10', score=0.0, logprobs=None))
59
+ self.assertEqual(
60
+ g(), lf.AIMessage('10', score=0.0, logprobs=None, usage=None)
61
+ )
60
62
 
61
63
  self.assertEqual(g.num_turns, 4)
62
64
 
@@ -64,7 +66,9 @@ class SelfPlayTest(unittest.TestCase):
64
66
  g = NumberGuess(target_num=10, max_turns=10)
65
67
 
66
68
  with lf.context(lm=NumberGuesser(guesses=[50, 20, 5, 2, 5, 4])):
67
- self.assertEqual(g(), lf.AIMessage('2', score=0.0, logprobs=None))
69
+ self.assertEqual(
70
+ g(), lf.AIMessage('2', score=0.0, logprobs=None, usage=None)
71
+ )
68
72
 
69
73
  self.assertEqual(g.num_turns, 10)
70
74
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.0.2.dev20240330
3
+ Version: 0.0.2.dev20240429
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -21,10 +21,11 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
21
21
  Classifier: Topic :: Software Development :: Libraries
22
22
  Description-Content-Type: text/markdown
23
23
  License-File: LICENSE
24
+ Requires-Dist: absl-py >=1.0.0
24
25
  Requires-Dist: google-generativeai >=0.3.2
25
26
  Requires-Dist: jinja2 >=3.1.2
26
27
  Requires-Dist: openai ==0.27.2
27
- Requires-Dist: pyglove >=0.4.5.dev20240323
28
+ Requires-Dist: pyglove >=0.4.5.dev20240423
28
29
  Requires-Dist: python-magic >=0.4.27
29
30
  Requires-Dist: requests >=2.31.0
30
31
  Requires-Dist: termcolor ==1.1.0