langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. langfun/__init__.py +22 -2
  2. langfun/core/__init__.py +17 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -28
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +69 -2
  18. langfun/core/component_test.py +54 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +17 -0
  24. langfun/core/eval/base.py +767 -140
  25. langfun/core/eval/base_test.py +238 -53
  26. langfun/core/eval/matching.py +80 -76
  27. langfun/core/eval/matching_test.py +19 -9
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +37 -28
  31. langfun/core/eval/scoring_test.py +21 -3
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +3 -21
  55. langfun/core/langfunc_test.py +26 -8
  56. langfun/core/language_model.py +686 -48
  57. langfun/core/language_model_test.py +681 -44
  58. langfun/core/llms/__init__.py +100 -12
  59. langfun/core/llms/anthropic.py +488 -0
  60. langfun/core/llms/anthropic_test.py +235 -0
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +88 -28
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +39 -26
  69. langfun/core/llms/fake_test.py +136 -11
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -197
  74. langfun/core/llms/groq.py +276 -0
  75. langfun/core/llms/groq_test.py +64 -0
  76. langfun/core/llms/llama_cpp.py +15 -40
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +436 -226
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +35 -174
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -23
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +15 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +9 -8
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +278 -0
  112. langfun/core/structured/function_generation_test.py +399 -0
  113. langfun/core/structured/mapping.py +150 -46
  114. langfun/core/structured/mapping_test.py +105 -0
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +71 -22
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
  119. langfun/core/structured/schema.py +208 -99
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_generation_test.py +2 -2
  122. langfun/core/structured/schema_test.py +133 -34
  123. langfun/core/structured/scoring.py +125 -19
  124. langfun/core/structured/scoring_test.py +30 -0
  125. langfun/core/structured/tokenization.py +64 -0
  126. langfun/core/structured/tokenization_test.py +48 -0
  127. langfun/core/template.py +240 -11
  128. langfun/core/template_test.py +146 -1
  129. langfun/core/templates/conversation.py +9 -0
  130. langfun/core/templates/conversation_test.py +4 -3
  131. langfun/core/templates/selfplay_test.py +14 -2
  132. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  133. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  134. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  135. langfun/core/coding/python/errors.py +0 -108
  136. langfun/core/coding/python/errors_test.py +0 -99
  137. langfun/core/coding/python/permissions.py +0 -90
  138. langfun/core/coding/python/permissions_test.py +0 -86
  139. langfun/core/structured/prompting.py +0 -217
  140. langfun/core/text_formatting.py +0 -162
  141. langfun/core/text_formatting_test.py +0 -47
  142. langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
  143. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  144. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  145. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
langfun/core/template.py CHANGED
@@ -17,7 +17,7 @@ import contextlib
17
17
  import dataclasses
18
18
  import functools
19
19
  import inspect
20
- from typing import Annotated, Any, Callable, Iterator, Set, Tuple, Type
20
+ from typing import Annotated, Any, Callable, Iterator, Set, Tuple, Type, Union
21
21
 
22
22
  import jinja2
23
23
  from jinja2 import meta as jinja2_meta
@@ -38,13 +38,23 @@ NO_TEMPLATE_DOCSTR_SIGN = 'THIS IS NOT A TEMPLATE'
38
38
  _TLS_RENDER_STACK = '_template_render_stack'
39
39
  _TLS_RENDER_RESULT_CACHE = '_template_render_result_cache'
40
40
 
41
+ # The prefix for fields or contextual attributes to be treated as additional
42
+ # metadata for rendered message.
43
+ _ADDITIONAL_METADATA_PREFIX = 'metadata_'
44
+
41
45
 
42
46
  class Template(
43
47
  natural_language.NaturalLanguageFormattable,
44
48
  component.Component,
45
49
  pg.typing.CustomTyping,
50
+ pg.views.HtmlTreeView.Extension
46
51
  ):
47
- """Langfun string template."""
52
+ """Langfun string template.
53
+
54
+ Langfun uses jinja2 as its template engine. Pleaes check out
55
+ https://jinja.palletsprojects.com/en/3.1.x/templates/ for detailed
56
+ explanation on the template language.
57
+ """
48
58
 
49
59
  template_str: Annotated[
50
60
  str,
@@ -97,6 +107,11 @@ class Template(
97
107
  # Declare template variables as symbolic attributes.
98
108
  template_vars = Template.resolve_vars(template_str)
99
109
  for var_name in template_vars:
110
+ if 'DEFAULT' == var_name:
111
+ raise ValueError(
112
+ '`{{ DEFAULT }}` cannot be used in pre-configured templates. '
113
+ f'Encountered: {template_str!r}'
114
+ )
100
115
  # NOTE(daiyip): This is to avoid warning from accessing
101
116
  # `pg.Object.schema`, which was replaced by `pg.Object.__schema__`.
102
117
  if var_name == 'schema' or not hasattr(cls, var_name):
@@ -149,7 +164,7 @@ class Template(
149
164
  # TODO(daiyip): Consider to delay template parsing upon usage.
150
165
  unassigned_vars = {}
151
166
  for k in self._variables:
152
- if not hasattr(self, k):
167
+ if k not in ('DEFAULT',) and not hasattr(self, k):
153
168
  unassigned_vars[k] = component.contextual()
154
169
  if unassigned_vars:
155
170
  self.rebind(unassigned_vars, skip_notification=True)
@@ -217,6 +232,16 @@ class Template(
217
232
  """Returns the missing variable names."""
218
233
  return self.vars(closure=True, specified=False)
219
234
 
235
+ @classmethod
236
+ def raw_str(cls, text: str) -> str:
237
+ """Returns a template string that preserve the text as original."""
238
+ return '{% raw %}' + text + '{% endraw %}'
239
+
240
+ @classmethod
241
+ def from_raw_str(cls, text: str) -> 'Template':
242
+ """Returns a template that preserve the text as original."""
243
+ return cls(cls.raw_str(text), clean=False)
244
+
220
245
  def render(
221
246
  self,
222
247
  *,
@@ -303,19 +328,19 @@ class Template(
303
328
  with modality.format_modality_as_ref():
304
329
  rendered_text = self._template.render(**inputs)
305
330
 
331
+ # Carry additional metadata.
332
+ metadata = self.additional_metadata()
333
+
306
334
  if self.clean:
307
335
  rendered_text = rendered_text.strip()
308
336
 
309
- # Fill the variables for rendering the template as metadata.
310
- message = message_cls(
311
- text=rendered_text,
312
- metadata={
313
- k: pg.Ref(v)
314
- for k, v in inputs.items()
315
- if not inspect.ismethod(v)
316
- },
337
+ metadata.update(
338
+ {k: pg.Ref(v) for k, v in inputs.items() if not inspect.ismethod(v)}
317
339
  )
318
340
 
341
+ # Fill the variables for rendering the template as metadata.
342
+ message = message_cls(text=rendered_text, metadata=metadata)
343
+
319
344
  # Tag input as rendered message.
320
345
  message.tag(message_lib.Message.TAG_RENDERED)
321
346
 
@@ -340,6 +365,20 @@ class Template(
340
365
  top = pg.object_utils.thread_local_pop(_TLS_RENDER_STACK)
341
366
  assert top is self, (top, self)
342
367
 
368
+ def additional_metadata(self) -> dict[str, Any]:
369
+ """Returns additional metadta to be carried in the rendered message."""
370
+ metadata = {}
371
+ # Carry metadata from `lf.context`.
372
+ for k, v in component.all_contextual_values().items():
373
+ if k.startswith(_ADDITIONAL_METADATA_PREFIX):
374
+ metadata[k.removeprefix(_ADDITIONAL_METADATA_PREFIX)] = v
375
+
376
+ # Carry metadata from fields.
377
+ for k, v in self.sym_init_args.items():
378
+ if k.startswith(_ADDITIONAL_METADATA_PREFIX):
379
+ metadata[k.removeprefix(_ADDITIONAL_METADATA_PREFIX)] = v
380
+ return metadata
381
+
343
382
  #
344
383
  # Implements `pg.typing.CustomTyping`.
345
384
  #
@@ -380,6 +419,196 @@ class Template(
380
419
  # Override __hash__ since __eq__ has changed.
381
420
  return object.__hash__(self)
382
421
 
422
+ #
423
+ # Special methods.
424
+ #
425
+
426
+ @property
427
+ def DEFAULT(self) -> 'Template':
428
+ """Referring to the default value used for this template.
429
+
430
+ This method is intended to be used in template for referring to the default
431
+ value of current template. There are two scenarios:
432
+
433
+ Scenario 1: Use instance-level template_str to override the class default.
434
+
435
+ ```
436
+ class Foo(lf.Template):
437
+ '''Foo template.
438
+
439
+ This is {{x}}.
440
+ '''
441
+
442
+ f = Foo(template_str='<h1>{{DEFAULT}}</h1>', x=1)
443
+ f.render()
444
+
445
+ >> <h1>This is 1.</h1>
446
+ ```
447
+
448
+ Scenario 2: Use an ad-hoc template to override a predefined field.
449
+
450
+ ```
451
+ class Bar(lf.Template):
452
+ '''Bar template.
453
+
454
+ {{preamble}}
455
+ {{prompt}}
456
+ '''
457
+ preamble: lf.Template = lf.Template('You are a chat bot.')
458
+ prompt: lf.Template = lf.Template('User: hi')
459
+
460
+ b = Bar(preamble=lf.Template('<h1>{{DEFAULT}}<h1>'),
461
+ prompt=lf.Template('<h2>{{DEFAULT}}</h2>')
462
+ b.render()
463
+
464
+ >> <h1>You are a chat bot.<h1>
465
+ >> <h2>User: hi</h2>
466
+ ```
467
+
468
+ Returns:
469
+ The default (pre-configured) value used for this template.
470
+ """
471
+ base_template = self.__class__.__schema__['template_str'].default_value
472
+ if base_template == pg.MISSING_VALUE:
473
+ if not self.sym_path:
474
+ raise ValueError(
475
+ f'No DEFAULT template found for {self!r}: '
476
+ 'The template neither has a default `template_str` nor is '
477
+ 'contained under another object.'
478
+ )
479
+ key = self.sym_path.key
480
+ assert self.sym_parent is not None
481
+ assigned_field = self.sym_parent.sym_attr_field(key)
482
+ container_cls = self.sym_parent.__class__
483
+
484
+ if (
485
+ assigned_field is None
486
+ or assigned_field.default_value == pg.MISSING_VALUE
487
+ ):
488
+ raise ValueError(
489
+ f'No DEFAULT template found for {self!r}: '
490
+ f'`{container_cls.__name__}.{key}` '
491
+ 'does not have a default value. '
492
+ )
493
+ base_template = assigned_field.default_value
494
+ if isinstance(base_template, Template):
495
+ base_template = base_template.template_str
496
+ if not isinstance(base_template, str):
497
+ raise ValueError(
498
+ f'No DEFAULT template found for {self!r}: The default '
499
+ f'value {base_template!r} of '
500
+ f'`{container_cls.__name__}.{key}` is not a '
501
+ '`lf.Template` object or str.'
502
+ )
503
+ t = Template(base_template)
504
+ # NOTE(daiyip): Set the parent of the newly created template to self so
505
+ # it could access all the contextual variables.
506
+ t.sym_setparent(self)
507
+ return t
508
+
509
+ @classmethod
510
+ def from_value(
511
+ cls,
512
+ value: Union[str, message_lib.Message, 'Template'],
513
+ **kwargs
514
+ ) -> 'Template':
515
+ """Create a template object from a string or template."""
516
+ if isinstance(value, cls):
517
+ return value.clone(override=kwargs) if kwargs else value # pylint: disable=no-value-for-parameter
518
+ if isinstance(value, str):
519
+ return cls(template_str=value, **kwargs)
520
+ if isinstance(value, message_lib.Message):
521
+ kwargs.update(value.metadata)
522
+ return cls(template_str=value.text, **kwargs)
523
+ if isinstance(value, Template):
524
+ lfun = cls(template_str=value.template_str, **kwargs)
525
+ # So lfun could acccess all attributes from value.
526
+ lfun.sym_setparent(value)
527
+ return lfun
528
+ return cls(template_str='{{input}}', input=value, **kwargs)
529
+
530
+ def _html_tree_view_content(
531
+ self,
532
+ *,
533
+ view: pg.views.HtmlTreeView,
534
+ root_path: pg.KeyPath | None = None,
535
+ collapse_level: int | None = None,
536
+ extra_flags: dict[str, Any] | None = None,
537
+ debug: bool = False,
538
+ **kwargs,
539
+ ):
540
+ extra_flags = extra_flags if extra_flags is not None else {}
541
+ collapse_template_vars_level: int | None = extra_flags.get(
542
+ 'collapse_template_vars_level', 1
543
+ )
544
+
545
+ def render_template_str():
546
+ return pg.Html.element(
547
+ 'div',
548
+ [
549
+ pg.Html.element('span', [self.template_str])
550
+ ],
551
+ css_classes=['template-str'],
552
+ )
553
+
554
+ def render_fields():
555
+ return view.complex_value(
556
+ {k: v for k, v in self.sym_items()},
557
+ name='fields',
558
+ root_path=root_path,
559
+ parent=self,
560
+ exclude_keys=['template_str', 'clean'],
561
+ collapse_level=max(
562
+ collapse_template_vars_level, collapse_level
563
+ ) if collapse_level is not None else None,
564
+ extra_flags=extra_flags,
565
+ debug=debug,
566
+ **view.get_passthrough_kwargs(
567
+ remove=['exclude_keys'],
568
+ **kwargs,
569
+ )
570
+ )
571
+
572
+ return pg.views.html.controls.TabControl([
573
+ pg.views.html.controls.Tab(
574
+ 'template_str',
575
+ render_template_str(),
576
+ ),
577
+ pg.views.html.controls.Tab(
578
+ 'variables',
579
+ render_fields(),
580
+ ),
581
+ ], selected=1)
582
+
583
+ @classmethod
584
+ def _html_tree_view_css_styles(cls) -> list[str]:
585
+ return super()._html_tree_view_css_styles() + [
586
+ """
587
+ /* Langfun Template styles. */
588
+ .template-str {
589
+ padding: 10px;
590
+ margin: 10px 5px 10px 5px;
591
+ font-style: italic;
592
+ font-size: 1.1em;
593
+ white-space: pre-wrap;
594
+ border: 1px solid #EEE;
595
+ border-radius: 5px;
596
+ background-color: #EEE;
597
+ color: #cc2986;
598
+ }
599
+ """
600
+ ]
601
+
602
+ @classmethod
603
+ @functools.cache
604
+ def _html_tree_view_config(cls) -> dict[str, Any]:
605
+ return pg.views.HtmlTreeView.get_kwargs(
606
+ super()._html_tree_view_config(),
607
+ dict(
608
+ css_classes=['lf-template'],
609
+ )
610
+ )
611
+
383
612
 
384
613
  # Register converter from str to LangFunc, therefore we can always
385
614
  # pass strs to attributes that accept LangFunc.
@@ -13,9 +13,11 @@
13
13
  # limitations under the License.
14
14
  """Template test."""
15
15
  import inspect
16
+ from typing import Any
16
17
  import unittest
17
18
 
18
19
  from langfun.core import component
20
+ from langfun.core import message as message_lib
19
21
  from langfun.core import modality
20
22
  from langfun.core import subscription
21
23
  from langfun.core.template import Template
@@ -97,6 +99,21 @@ class BasicTest(unittest.TestCase):
97
99
  self.assertEqual(d.z.render(), 'Bye, 1')
98
100
  self.assertEqual(d.p.render(), 'Again Hello, 1')
99
101
 
102
+ def test_raw_text(self):
103
+ self.assertEqual(
104
+ Template(
105
+ '{{a}}' + Template.raw_str('\n{{d}}, {%x%}\n') + '{{b}}',
106
+ a='hi', b=1
107
+ ).render().text,
108
+ 'hi\n{{d}}, {%x%}\n1'
109
+ )
110
+
111
+ def test_from_raw_str(self):
112
+ self.assertEqual(
113
+ Template.from_raw_str('\n{{d}}, {%x%}\n').render().text,
114
+ '\n{{d}}, {%x%}\n'
115
+ )
116
+
100
117
 
101
118
  class DefinitionTest(unittest.TestCase):
102
119
 
@@ -308,9 +325,75 @@ class RenderTest(unittest.TestCase):
308
325
  Template(
309
326
  'This is {{ x }} and {{ a }}', x=1, a=CustomModality('foo')
310
327
  ).render(),
311
- 'This is 1 and {{a}}',
328
+ 'This is 1 and <<[[a]]>>',
312
329
  )
313
330
 
331
+ def test_render_with_default(self):
332
+
333
+ class Foo(Template):
334
+ """Foo.
335
+
336
+ This is {{x}}
337
+ """
338
+
339
+ f = Foo(template_str='!{{DEFAULT}}!', x=1)
340
+ self.assertEqual(f.DEFAULT.x, 1)
341
+ self.assertEqual(
342
+ f.render(), '!This is 1!'
343
+ )
344
+
345
+ class Bar(Template):
346
+ """Bar.
347
+
348
+ {{preamble}}
349
+ {{prompt}}
350
+ """
351
+
352
+ preamble: Template = Template('You are a chat bot.')
353
+ prompt: Template = Template('User: hi! {{name}}')
354
+
355
+ b = Bar(
356
+ preamble=Template('<h1>{{DEFAULT}}</h1>'),
357
+ prompt=Template('<h2>{{DEFAULT}}</h2>'),
358
+ name='Tom',
359
+ )
360
+ # Test variable access.
361
+ self.assertEqual(
362
+ b.render(),
363
+ inspect.cleandoc("""
364
+ <h1>You are a chat bot.</h1>
365
+ <h2>User: hi! Tom</h2>
366
+ """),
367
+ )
368
+
369
+ with self.assertRaisesRegex(ValueError, '`{{ DEFAULT }}` cannot be used'):
370
+
371
+ class Baz(Template): # pylint: disable=unused-variable
372
+ """Baz.
373
+
374
+ {{DEFAULT}}
375
+ """
376
+
377
+ with self.assertRaisesRegex(
378
+ ValueError, 'The template neither has a default `template_str` nor'
379
+ ):
380
+ Template('{{DEFAULT}}').render()
381
+
382
+ d = pg.Dict(x=Template('{{DEFAULT}}'))
383
+ with self.assertRaisesRegex(
384
+ ValueError, 'does not have a default value'
385
+ ):
386
+ _ = d.x.DEFAULT
387
+
388
+ class Tes(pg.Object):
389
+ x: str | None = None
390
+
391
+ t = Tes(x=Template('{{DEFAULT}}'))
392
+ with self.assertRaisesRegex(
393
+ ValueError, 'is not a `lf.Template` object or str'
394
+ ):
395
+ _ = t.x.DEFAULT
396
+
314
397
  def test_bad_render(self):
315
398
  with self.assertRaises(ValueError):
316
399
  Template('Hello {{x}}').render(allow_partial=False)
@@ -427,6 +510,14 @@ class RenderTest(unittest.TestCase):
427
510
  # Test len.
428
511
  self.assert_partial(Template('Hello {{len(x)}}'), 'Hello {{len(x)}}')
429
512
 
513
+ def test_additional_metadata(self):
514
+ t = Template('hi', metadata_weights=1.0, y=2)
515
+ self.assertEqual(t.render(), message_lib.UserMessage('hi', weights=1.0))
516
+
517
+ t = Template('hi')
518
+ with component.context(metadata_weights=1.0, y=2):
519
+ self.assertEqual(t.render(), message_lib.UserMessage('hi', weights=1.0))
520
+
430
521
 
431
522
  class TemplateRenderEventTest(unittest.TestCase):
432
523
 
@@ -462,5 +553,59 @@ class TemplateRenderEventTest(unittest.TestCase):
462
553
  self.assertEqual(render_stacks, [[l]])
463
554
 
464
555
 
556
+ class HtmlTest(unittest.TestCase):
557
+
558
+ def test_html(self):
559
+
560
+ class Foo(Template):
561
+ """Template Foo.
562
+
563
+ {{x}} + {{y}} = ?
564
+ """
565
+ x: Any
566
+ y: Any
567
+
568
+ class Bar(Template):
569
+ """Template Bar.
570
+
571
+ {{y}} + {{z}}
572
+ """
573
+ y: Any
574
+
575
+ self.assertIn(
576
+ inspect.cleandoc(
577
+ """
578
+ /* Langfun Template styles. */
579
+ .template-str {
580
+ padding: 10px;
581
+ margin: 10px 5px 10px 5px;
582
+ font-style: italic;
583
+ font-size: 1.1em;
584
+ white-space: pre-wrap;
585
+ border: 1px solid #EEE;
586
+ border-radius: 5px;
587
+ background-color: #EEE;
588
+ color: #cc2986;
589
+ }
590
+ """
591
+ ),
592
+ Foo(x=1, y=2).to_html().style_section,
593
+ )
594
+ self.assertIn(
595
+ 'template-str',
596
+ Foo(x=Bar('{{y}} + {{z}}'), y=1).to_html(
597
+ enable_summary_tooltip=False,
598
+ ).content,
599
+ )
600
+ self.assertIn(
601
+ 'template-str',
602
+ Foo(x=Bar('{{y}} + {{z}}'), y=1).to_html(
603
+ enable_summary_tooltip=False,
604
+ collapse_level=0,
605
+ key_style='label',
606
+ ).content,
607
+ )
608
+
609
+
465
610
  if __name__ == '__main__':
466
611
  unittest.main()
@@ -38,6 +38,11 @@ class Conversation(Completion):
38
38
  '(Optional) Preamble before beginning the conversation.',
39
39
  ] = None
40
40
 
41
+ role: Annotated[
42
+ str | None,
43
+ '(Optional) User defined role for the AI response in the conversation.',
44
+ ] = None
45
+
41
46
  conversation_context: Annotated[
42
47
  lf.LangFunc | None,
43
48
  (
@@ -71,6 +76,10 @@ class Conversation(Completion):
71
76
  with lf.context(**kwargs):
72
77
  # Call LM based on the prompt generated from `input_message`.
73
78
  lm_response = super().__call__()
79
+ if self.role is not None:
80
+ lm_response.rebind(
81
+ sender=self.role, skip_notification=True, raise_on_no_change=False
82
+ )
74
83
 
75
84
  # Add current turn to memory.
76
85
  self.add(self.input_message, lm_response)
@@ -83,6 +83,7 @@ class ConversationTest(unittest.TestCase):
83
83
  def test_call(self):
84
84
  c = Conversation(
85
85
  lm=QuestionCounter(),
86
+ role='Agent',
86
87
  preamble="You are a helpful and joyful AI bot. Now let's chat.",
87
88
  )
88
89
  # First round.
@@ -102,7 +103,7 @@ class ConversationTest(unittest.TestCase):
102
103
  inspect.cleandoc("""
103
104
  You are a helpful and joyful AI bot. Now let's chat.
104
105
  User: Hello
105
- AI: Response 1.
106
+ Agent: Response 1.
106
107
  User: How are you?
107
108
  """),
108
109
  )
@@ -114,9 +115,9 @@ class ConversationTest(unittest.TestCase):
114
115
  inspect.cleandoc("""
115
116
  You are a helpful and joyful AI bot. Now let's chat.
116
117
  User: Hello
117
- AI: Response 1.
118
+ Agent: Response 1.
118
119
  User: How are you?
119
- AI: Response 2.
120
+ Agent: Response 2.
120
121
  User: Okay, bye.
121
122
  """),
122
123
  )
@@ -56,7 +56,13 @@ class SelfPlayTest(unittest.TestCase):
56
56
  g = NumberGuess(target_num=10)
57
57
 
58
58
  with lf.context(lm=NumberGuesser(guesses=[50, 20, 5, 10])):
59
- self.assertEqual(g(), lf.AIMessage('10', score=0.0, logprobs=None))
59
+ self.assertEqual(
60
+ g(),
61
+ lf.AIMessage(
62
+ '10', score=0.0, logprobs=None, is_cached=False,
63
+ usage=lf.UsageNotAvailable()
64
+ )
65
+ )
60
66
 
61
67
  self.assertEqual(g.num_turns, 4)
62
68
 
@@ -64,7 +70,13 @@ class SelfPlayTest(unittest.TestCase):
64
70
  g = NumberGuess(target_num=10, max_turns=10)
65
71
 
66
72
  with lf.context(lm=NumberGuesser(guesses=[50, 20, 5, 2, 5, 4])):
67
- self.assertEqual(g(), lf.AIMessage('2', score=0.0, logprobs=None))
73
+ self.assertEqual(
74
+ g(),
75
+ lf.AIMessage(
76
+ '2', score=0.0, logprobs=None, is_cached=False,
77
+ usage=lf.UsageNotAvailable()
78
+ )
79
+ )
68
80
 
69
81
  self.assertEqual(g.num_turns, 10)
70
82