langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +22 -2
- langfun/core/__init__.py +17 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -28
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +69 -2
- langfun/core/component_test.py +54 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +17 -0
- langfun/core/eval/base.py +767 -140
- langfun/core/eval/base_test.py +238 -53
- langfun/core/eval/matching.py +80 -76
- langfun/core/eval/matching_test.py +19 -9
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +37 -28
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +3 -21
- langfun/core/langfunc_test.py +26 -8
- langfun/core/language_model.py +686 -48
- langfun/core/language_model_test.py +681 -44
- langfun/core/llms/__init__.py +100 -12
- langfun/core/llms/anthropic.py +488 -0
- langfun/core/llms/anthropic_test.py +235 -0
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +88 -28
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +39 -26
- langfun/core/llms/fake_test.py +136 -11
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -197
- langfun/core/llms/groq.py +276 -0
- langfun/core/llms/groq_test.py +64 -0
- langfun/core/llms/llama_cpp.py +15 -40
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +436 -226
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +35 -174
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -23
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +15 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +9 -8
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +278 -0
- langfun/core/structured/function_generation_test.py +399 -0
- langfun/core/structured/mapping.py +150 -46
- langfun/core/structured/mapping_test.py +105 -0
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +71 -22
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
- langfun/core/structured/schema.py +208 -99
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +133 -34
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +240 -11
- langfun/core/template_test.py +146 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +14 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -217
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
langfun/core/template.py
CHANGED
@@ -17,7 +17,7 @@ import contextlib
|
|
17
17
|
import dataclasses
|
18
18
|
import functools
|
19
19
|
import inspect
|
20
|
-
from typing import Annotated, Any, Callable, Iterator, Set, Tuple, Type
|
20
|
+
from typing import Annotated, Any, Callable, Iterator, Set, Tuple, Type, Union
|
21
21
|
|
22
22
|
import jinja2
|
23
23
|
from jinja2 import meta as jinja2_meta
|
@@ -38,13 +38,23 @@ NO_TEMPLATE_DOCSTR_SIGN = 'THIS IS NOT A TEMPLATE'
|
|
38
38
|
_TLS_RENDER_STACK = '_template_render_stack'
|
39
39
|
_TLS_RENDER_RESULT_CACHE = '_template_render_result_cache'
|
40
40
|
|
41
|
+
# The prefix for fields or contextual attributes to be treated as additional
|
42
|
+
# metadata for rendered message.
|
43
|
+
_ADDITIONAL_METADATA_PREFIX = 'metadata_'
|
44
|
+
|
41
45
|
|
42
46
|
class Template(
|
43
47
|
natural_language.NaturalLanguageFormattable,
|
44
48
|
component.Component,
|
45
49
|
pg.typing.CustomTyping,
|
50
|
+
pg.views.HtmlTreeView.Extension
|
46
51
|
):
|
47
|
-
"""Langfun string template.
|
52
|
+
"""Langfun string template.
|
53
|
+
|
54
|
+
Langfun uses jinja2 as its template engine. Pleaes check out
|
55
|
+
https://jinja.palletsprojects.com/en/3.1.x/templates/ for detailed
|
56
|
+
explanation on the template language.
|
57
|
+
"""
|
48
58
|
|
49
59
|
template_str: Annotated[
|
50
60
|
str,
|
@@ -97,6 +107,11 @@ class Template(
|
|
97
107
|
# Declare template variables as symbolic attributes.
|
98
108
|
template_vars = Template.resolve_vars(template_str)
|
99
109
|
for var_name in template_vars:
|
110
|
+
if 'DEFAULT' == var_name:
|
111
|
+
raise ValueError(
|
112
|
+
'`{{ DEFAULT }}` cannot be used in pre-configured templates. '
|
113
|
+
f'Encountered: {template_str!r}'
|
114
|
+
)
|
100
115
|
# NOTE(daiyip): This is to avoid warning from accessing
|
101
116
|
# `pg.Object.schema`, which was replaced by `pg.Object.__schema__`.
|
102
117
|
if var_name == 'schema' or not hasattr(cls, var_name):
|
@@ -149,7 +164,7 @@ class Template(
|
|
149
164
|
# TODO(daiyip): Consider to delay template parsing upon usage.
|
150
165
|
unassigned_vars = {}
|
151
166
|
for k in self._variables:
|
152
|
-
if not hasattr(self, k):
|
167
|
+
if k not in ('DEFAULT',) and not hasattr(self, k):
|
153
168
|
unassigned_vars[k] = component.contextual()
|
154
169
|
if unassigned_vars:
|
155
170
|
self.rebind(unassigned_vars, skip_notification=True)
|
@@ -217,6 +232,16 @@ class Template(
|
|
217
232
|
"""Returns the missing variable names."""
|
218
233
|
return self.vars(closure=True, specified=False)
|
219
234
|
|
235
|
+
@classmethod
|
236
|
+
def raw_str(cls, text: str) -> str:
|
237
|
+
"""Returns a template string that preserve the text as original."""
|
238
|
+
return '{% raw %}' + text + '{% endraw %}'
|
239
|
+
|
240
|
+
@classmethod
|
241
|
+
def from_raw_str(cls, text: str) -> 'Template':
|
242
|
+
"""Returns a template that preserve the text as original."""
|
243
|
+
return cls(cls.raw_str(text), clean=False)
|
244
|
+
|
220
245
|
def render(
|
221
246
|
self,
|
222
247
|
*,
|
@@ -303,19 +328,19 @@ class Template(
|
|
303
328
|
with modality.format_modality_as_ref():
|
304
329
|
rendered_text = self._template.render(**inputs)
|
305
330
|
|
331
|
+
# Carry additional metadata.
|
332
|
+
metadata = self.additional_metadata()
|
333
|
+
|
306
334
|
if self.clean:
|
307
335
|
rendered_text = rendered_text.strip()
|
308
336
|
|
309
|
-
|
310
|
-
|
311
|
-
text=rendered_text,
|
312
|
-
metadata={
|
313
|
-
k: pg.Ref(v)
|
314
|
-
for k, v in inputs.items()
|
315
|
-
if not inspect.ismethod(v)
|
316
|
-
},
|
337
|
+
metadata.update(
|
338
|
+
{k: pg.Ref(v) for k, v in inputs.items() if not inspect.ismethod(v)}
|
317
339
|
)
|
318
340
|
|
341
|
+
# Fill the variables for rendering the template as metadata.
|
342
|
+
message = message_cls(text=rendered_text, metadata=metadata)
|
343
|
+
|
319
344
|
# Tag input as rendered message.
|
320
345
|
message.tag(message_lib.Message.TAG_RENDERED)
|
321
346
|
|
@@ -340,6 +365,20 @@ class Template(
|
|
340
365
|
top = pg.object_utils.thread_local_pop(_TLS_RENDER_STACK)
|
341
366
|
assert top is self, (top, self)
|
342
367
|
|
368
|
+
def additional_metadata(self) -> dict[str, Any]:
|
369
|
+
"""Returns additional metadta to be carried in the rendered message."""
|
370
|
+
metadata = {}
|
371
|
+
# Carry metadata from `lf.context`.
|
372
|
+
for k, v in component.all_contextual_values().items():
|
373
|
+
if k.startswith(_ADDITIONAL_METADATA_PREFIX):
|
374
|
+
metadata[k.removeprefix(_ADDITIONAL_METADATA_PREFIX)] = v
|
375
|
+
|
376
|
+
# Carry metadata from fields.
|
377
|
+
for k, v in self.sym_init_args.items():
|
378
|
+
if k.startswith(_ADDITIONAL_METADATA_PREFIX):
|
379
|
+
metadata[k.removeprefix(_ADDITIONAL_METADATA_PREFIX)] = v
|
380
|
+
return metadata
|
381
|
+
|
343
382
|
#
|
344
383
|
# Implements `pg.typing.CustomTyping`.
|
345
384
|
#
|
@@ -380,6 +419,196 @@ class Template(
|
|
380
419
|
# Override __hash__ since __eq__ has changed.
|
381
420
|
return object.__hash__(self)
|
382
421
|
|
422
|
+
#
|
423
|
+
# Special methods.
|
424
|
+
#
|
425
|
+
|
426
|
+
@property
|
427
|
+
def DEFAULT(self) -> 'Template':
|
428
|
+
"""Referring to the default value used for this template.
|
429
|
+
|
430
|
+
This method is intended to be used in template for referring to the default
|
431
|
+
value of current template. There are two scenarios:
|
432
|
+
|
433
|
+
Scenario 1: Use instance-level template_str to override the class default.
|
434
|
+
|
435
|
+
```
|
436
|
+
class Foo(lf.Template):
|
437
|
+
'''Foo template.
|
438
|
+
|
439
|
+
This is {{x}}.
|
440
|
+
'''
|
441
|
+
|
442
|
+
f = Foo(template_str='<h1>{{DEFAULT}}</h1>', x=1)
|
443
|
+
f.render()
|
444
|
+
|
445
|
+
>> <h1>This is 1.</h1>
|
446
|
+
```
|
447
|
+
|
448
|
+
Scenario 2: Use an ad-hoc template to override a predefined field.
|
449
|
+
|
450
|
+
```
|
451
|
+
class Bar(lf.Template):
|
452
|
+
'''Bar template.
|
453
|
+
|
454
|
+
{{preamble}}
|
455
|
+
{{prompt}}
|
456
|
+
'''
|
457
|
+
preamble: lf.Template = lf.Template('You are a chat bot.')
|
458
|
+
prompt: lf.Template = lf.Template('User: hi')
|
459
|
+
|
460
|
+
b = Bar(preamble=lf.Template('<h1>{{DEFAULT}}<h1>'),
|
461
|
+
prompt=lf.Template('<h2>{{DEFAULT}}</h2>')
|
462
|
+
b.render()
|
463
|
+
|
464
|
+
>> <h1>You are a chat bot.<h1>
|
465
|
+
>> <h2>User: hi</h2>
|
466
|
+
```
|
467
|
+
|
468
|
+
Returns:
|
469
|
+
The default (pre-configured) value used for this template.
|
470
|
+
"""
|
471
|
+
base_template = self.__class__.__schema__['template_str'].default_value
|
472
|
+
if base_template == pg.MISSING_VALUE:
|
473
|
+
if not self.sym_path:
|
474
|
+
raise ValueError(
|
475
|
+
f'No DEFAULT template found for {self!r}: '
|
476
|
+
'The template neither has a default `template_str` nor is '
|
477
|
+
'contained under another object.'
|
478
|
+
)
|
479
|
+
key = self.sym_path.key
|
480
|
+
assert self.sym_parent is not None
|
481
|
+
assigned_field = self.sym_parent.sym_attr_field(key)
|
482
|
+
container_cls = self.sym_parent.__class__
|
483
|
+
|
484
|
+
if (
|
485
|
+
assigned_field is None
|
486
|
+
or assigned_field.default_value == pg.MISSING_VALUE
|
487
|
+
):
|
488
|
+
raise ValueError(
|
489
|
+
f'No DEFAULT template found for {self!r}: '
|
490
|
+
f'`{container_cls.__name__}.{key}` '
|
491
|
+
'does not have a default value. '
|
492
|
+
)
|
493
|
+
base_template = assigned_field.default_value
|
494
|
+
if isinstance(base_template, Template):
|
495
|
+
base_template = base_template.template_str
|
496
|
+
if not isinstance(base_template, str):
|
497
|
+
raise ValueError(
|
498
|
+
f'No DEFAULT template found for {self!r}: The default '
|
499
|
+
f'value {base_template!r} of '
|
500
|
+
f'`{container_cls.__name__}.{key}` is not a '
|
501
|
+
'`lf.Template` object or str.'
|
502
|
+
)
|
503
|
+
t = Template(base_template)
|
504
|
+
# NOTE(daiyip): Set the parent of the newly created template to self so
|
505
|
+
# it could access all the contextual variables.
|
506
|
+
t.sym_setparent(self)
|
507
|
+
return t
|
508
|
+
|
509
|
+
@classmethod
|
510
|
+
def from_value(
|
511
|
+
cls,
|
512
|
+
value: Union[str, message_lib.Message, 'Template'],
|
513
|
+
**kwargs
|
514
|
+
) -> 'Template':
|
515
|
+
"""Create a template object from a string or template."""
|
516
|
+
if isinstance(value, cls):
|
517
|
+
return value.clone(override=kwargs) if kwargs else value # pylint: disable=no-value-for-parameter
|
518
|
+
if isinstance(value, str):
|
519
|
+
return cls(template_str=value, **kwargs)
|
520
|
+
if isinstance(value, message_lib.Message):
|
521
|
+
kwargs.update(value.metadata)
|
522
|
+
return cls(template_str=value.text, **kwargs)
|
523
|
+
if isinstance(value, Template):
|
524
|
+
lfun = cls(template_str=value.template_str, **kwargs)
|
525
|
+
# So lfun could acccess all attributes from value.
|
526
|
+
lfun.sym_setparent(value)
|
527
|
+
return lfun
|
528
|
+
return cls(template_str='{{input}}', input=value, **kwargs)
|
529
|
+
|
530
|
+
def _html_tree_view_content(
|
531
|
+
self,
|
532
|
+
*,
|
533
|
+
view: pg.views.HtmlTreeView,
|
534
|
+
root_path: pg.KeyPath | None = None,
|
535
|
+
collapse_level: int | None = None,
|
536
|
+
extra_flags: dict[str, Any] | None = None,
|
537
|
+
debug: bool = False,
|
538
|
+
**kwargs,
|
539
|
+
):
|
540
|
+
extra_flags = extra_flags if extra_flags is not None else {}
|
541
|
+
collapse_template_vars_level: int | None = extra_flags.get(
|
542
|
+
'collapse_template_vars_level', 1
|
543
|
+
)
|
544
|
+
|
545
|
+
def render_template_str():
|
546
|
+
return pg.Html.element(
|
547
|
+
'div',
|
548
|
+
[
|
549
|
+
pg.Html.element('span', [self.template_str])
|
550
|
+
],
|
551
|
+
css_classes=['template-str'],
|
552
|
+
)
|
553
|
+
|
554
|
+
def render_fields():
|
555
|
+
return view.complex_value(
|
556
|
+
{k: v for k, v in self.sym_items()},
|
557
|
+
name='fields',
|
558
|
+
root_path=root_path,
|
559
|
+
parent=self,
|
560
|
+
exclude_keys=['template_str', 'clean'],
|
561
|
+
collapse_level=max(
|
562
|
+
collapse_template_vars_level, collapse_level
|
563
|
+
) if collapse_level is not None else None,
|
564
|
+
extra_flags=extra_flags,
|
565
|
+
debug=debug,
|
566
|
+
**view.get_passthrough_kwargs(
|
567
|
+
remove=['exclude_keys'],
|
568
|
+
**kwargs,
|
569
|
+
)
|
570
|
+
)
|
571
|
+
|
572
|
+
return pg.views.html.controls.TabControl([
|
573
|
+
pg.views.html.controls.Tab(
|
574
|
+
'template_str',
|
575
|
+
render_template_str(),
|
576
|
+
),
|
577
|
+
pg.views.html.controls.Tab(
|
578
|
+
'variables',
|
579
|
+
render_fields(),
|
580
|
+
),
|
581
|
+
], selected=1)
|
582
|
+
|
583
|
+
@classmethod
|
584
|
+
def _html_tree_view_css_styles(cls) -> list[str]:
|
585
|
+
return super()._html_tree_view_css_styles() + [
|
586
|
+
"""
|
587
|
+
/* Langfun Template styles. */
|
588
|
+
.template-str {
|
589
|
+
padding: 10px;
|
590
|
+
margin: 10px 5px 10px 5px;
|
591
|
+
font-style: italic;
|
592
|
+
font-size: 1.1em;
|
593
|
+
white-space: pre-wrap;
|
594
|
+
border: 1px solid #EEE;
|
595
|
+
border-radius: 5px;
|
596
|
+
background-color: #EEE;
|
597
|
+
color: #cc2986;
|
598
|
+
}
|
599
|
+
"""
|
600
|
+
]
|
601
|
+
|
602
|
+
@classmethod
|
603
|
+
@functools.cache
|
604
|
+
def _html_tree_view_config(cls) -> dict[str, Any]:
|
605
|
+
return pg.views.HtmlTreeView.get_kwargs(
|
606
|
+
super()._html_tree_view_config(),
|
607
|
+
dict(
|
608
|
+
css_classes=['lf-template'],
|
609
|
+
)
|
610
|
+
)
|
611
|
+
|
383
612
|
|
384
613
|
# Register converter from str to LangFunc, therefore we can always
|
385
614
|
# pass strs to attributes that accept LangFunc.
|
langfun/core/template_test.py
CHANGED
@@ -13,9 +13,11 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Template test."""
|
15
15
|
import inspect
|
16
|
+
from typing import Any
|
16
17
|
import unittest
|
17
18
|
|
18
19
|
from langfun.core import component
|
20
|
+
from langfun.core import message as message_lib
|
19
21
|
from langfun.core import modality
|
20
22
|
from langfun.core import subscription
|
21
23
|
from langfun.core.template import Template
|
@@ -97,6 +99,21 @@ class BasicTest(unittest.TestCase):
|
|
97
99
|
self.assertEqual(d.z.render(), 'Bye, 1')
|
98
100
|
self.assertEqual(d.p.render(), 'Again Hello, 1')
|
99
101
|
|
102
|
+
def test_raw_text(self):
|
103
|
+
self.assertEqual(
|
104
|
+
Template(
|
105
|
+
'{{a}}' + Template.raw_str('\n{{d}}, {%x%}\n') + '{{b}}',
|
106
|
+
a='hi', b=1
|
107
|
+
).render().text,
|
108
|
+
'hi\n{{d}}, {%x%}\n1'
|
109
|
+
)
|
110
|
+
|
111
|
+
def test_from_raw_str(self):
|
112
|
+
self.assertEqual(
|
113
|
+
Template.from_raw_str('\n{{d}}, {%x%}\n').render().text,
|
114
|
+
'\n{{d}}, {%x%}\n'
|
115
|
+
)
|
116
|
+
|
100
117
|
|
101
118
|
class DefinitionTest(unittest.TestCase):
|
102
119
|
|
@@ -308,9 +325,75 @@ class RenderTest(unittest.TestCase):
|
|
308
325
|
Template(
|
309
326
|
'This is {{ x }} and {{ a }}', x=1, a=CustomModality('foo')
|
310
327
|
).render(),
|
311
|
-
'This is 1 and
|
328
|
+
'This is 1 and <<[[a]]>>',
|
312
329
|
)
|
313
330
|
|
331
|
+
def test_render_with_default(self):
|
332
|
+
|
333
|
+
class Foo(Template):
|
334
|
+
"""Foo.
|
335
|
+
|
336
|
+
This is {{x}}
|
337
|
+
"""
|
338
|
+
|
339
|
+
f = Foo(template_str='!{{DEFAULT}}!', x=1)
|
340
|
+
self.assertEqual(f.DEFAULT.x, 1)
|
341
|
+
self.assertEqual(
|
342
|
+
f.render(), '!This is 1!'
|
343
|
+
)
|
344
|
+
|
345
|
+
class Bar(Template):
|
346
|
+
"""Bar.
|
347
|
+
|
348
|
+
{{preamble}}
|
349
|
+
{{prompt}}
|
350
|
+
"""
|
351
|
+
|
352
|
+
preamble: Template = Template('You are a chat bot.')
|
353
|
+
prompt: Template = Template('User: hi! {{name}}')
|
354
|
+
|
355
|
+
b = Bar(
|
356
|
+
preamble=Template('<h1>{{DEFAULT}}</h1>'),
|
357
|
+
prompt=Template('<h2>{{DEFAULT}}</h2>'),
|
358
|
+
name='Tom',
|
359
|
+
)
|
360
|
+
# Test variable access.
|
361
|
+
self.assertEqual(
|
362
|
+
b.render(),
|
363
|
+
inspect.cleandoc("""
|
364
|
+
<h1>You are a chat bot.</h1>
|
365
|
+
<h2>User: hi! Tom</h2>
|
366
|
+
"""),
|
367
|
+
)
|
368
|
+
|
369
|
+
with self.assertRaisesRegex(ValueError, '`{{ DEFAULT }}` cannot be used'):
|
370
|
+
|
371
|
+
class Baz(Template): # pylint: disable=unused-variable
|
372
|
+
"""Baz.
|
373
|
+
|
374
|
+
{{DEFAULT}}
|
375
|
+
"""
|
376
|
+
|
377
|
+
with self.assertRaisesRegex(
|
378
|
+
ValueError, 'The template neither has a default `template_str` nor'
|
379
|
+
):
|
380
|
+
Template('{{DEFAULT}}').render()
|
381
|
+
|
382
|
+
d = pg.Dict(x=Template('{{DEFAULT}}'))
|
383
|
+
with self.assertRaisesRegex(
|
384
|
+
ValueError, 'does not have a default value'
|
385
|
+
):
|
386
|
+
_ = d.x.DEFAULT
|
387
|
+
|
388
|
+
class Tes(pg.Object):
|
389
|
+
x: str | None = None
|
390
|
+
|
391
|
+
t = Tes(x=Template('{{DEFAULT}}'))
|
392
|
+
with self.assertRaisesRegex(
|
393
|
+
ValueError, 'is not a `lf.Template` object or str'
|
394
|
+
):
|
395
|
+
_ = t.x.DEFAULT
|
396
|
+
|
314
397
|
def test_bad_render(self):
|
315
398
|
with self.assertRaises(ValueError):
|
316
399
|
Template('Hello {{x}}').render(allow_partial=False)
|
@@ -427,6 +510,14 @@ class RenderTest(unittest.TestCase):
|
|
427
510
|
# Test len.
|
428
511
|
self.assert_partial(Template('Hello {{len(x)}}'), 'Hello {{len(x)}}')
|
429
512
|
|
513
|
+
def test_additional_metadata(self):
|
514
|
+
t = Template('hi', metadata_weights=1.0, y=2)
|
515
|
+
self.assertEqual(t.render(), message_lib.UserMessage('hi', weights=1.0))
|
516
|
+
|
517
|
+
t = Template('hi')
|
518
|
+
with component.context(metadata_weights=1.0, y=2):
|
519
|
+
self.assertEqual(t.render(), message_lib.UserMessage('hi', weights=1.0))
|
520
|
+
|
430
521
|
|
431
522
|
class TemplateRenderEventTest(unittest.TestCase):
|
432
523
|
|
@@ -462,5 +553,59 @@ class TemplateRenderEventTest(unittest.TestCase):
|
|
462
553
|
self.assertEqual(render_stacks, [[l]])
|
463
554
|
|
464
555
|
|
556
|
+
class HtmlTest(unittest.TestCase):
|
557
|
+
|
558
|
+
def test_html(self):
|
559
|
+
|
560
|
+
class Foo(Template):
|
561
|
+
"""Template Foo.
|
562
|
+
|
563
|
+
{{x}} + {{y}} = ?
|
564
|
+
"""
|
565
|
+
x: Any
|
566
|
+
y: Any
|
567
|
+
|
568
|
+
class Bar(Template):
|
569
|
+
"""Template Bar.
|
570
|
+
|
571
|
+
{{y}} + {{z}}
|
572
|
+
"""
|
573
|
+
y: Any
|
574
|
+
|
575
|
+
self.assertIn(
|
576
|
+
inspect.cleandoc(
|
577
|
+
"""
|
578
|
+
/* Langfun Template styles. */
|
579
|
+
.template-str {
|
580
|
+
padding: 10px;
|
581
|
+
margin: 10px 5px 10px 5px;
|
582
|
+
font-style: italic;
|
583
|
+
font-size: 1.1em;
|
584
|
+
white-space: pre-wrap;
|
585
|
+
border: 1px solid #EEE;
|
586
|
+
border-radius: 5px;
|
587
|
+
background-color: #EEE;
|
588
|
+
color: #cc2986;
|
589
|
+
}
|
590
|
+
"""
|
591
|
+
),
|
592
|
+
Foo(x=1, y=2).to_html().style_section,
|
593
|
+
)
|
594
|
+
self.assertIn(
|
595
|
+
'template-str',
|
596
|
+
Foo(x=Bar('{{y}} + {{z}}'), y=1).to_html(
|
597
|
+
enable_summary_tooltip=False,
|
598
|
+
).content,
|
599
|
+
)
|
600
|
+
self.assertIn(
|
601
|
+
'template-str',
|
602
|
+
Foo(x=Bar('{{y}} + {{z}}'), y=1).to_html(
|
603
|
+
enable_summary_tooltip=False,
|
604
|
+
collapse_level=0,
|
605
|
+
key_style='label',
|
606
|
+
).content,
|
607
|
+
)
|
608
|
+
|
609
|
+
|
465
610
|
if __name__ == '__main__':
|
466
611
|
unittest.main()
|
@@ -38,6 +38,11 @@ class Conversation(Completion):
|
|
38
38
|
'(Optional) Preamble before beginning the conversation.',
|
39
39
|
] = None
|
40
40
|
|
41
|
+
role: Annotated[
|
42
|
+
str | None,
|
43
|
+
'(Optional) User defined role for the AI response in the conversation.',
|
44
|
+
] = None
|
45
|
+
|
41
46
|
conversation_context: Annotated[
|
42
47
|
lf.LangFunc | None,
|
43
48
|
(
|
@@ -71,6 +76,10 @@ class Conversation(Completion):
|
|
71
76
|
with lf.context(**kwargs):
|
72
77
|
# Call LM based on the prompt generated from `input_message`.
|
73
78
|
lm_response = super().__call__()
|
79
|
+
if self.role is not None:
|
80
|
+
lm_response.rebind(
|
81
|
+
sender=self.role, skip_notification=True, raise_on_no_change=False
|
82
|
+
)
|
74
83
|
|
75
84
|
# Add current turn to memory.
|
76
85
|
self.add(self.input_message, lm_response)
|
@@ -83,6 +83,7 @@ class ConversationTest(unittest.TestCase):
|
|
83
83
|
def test_call(self):
|
84
84
|
c = Conversation(
|
85
85
|
lm=QuestionCounter(),
|
86
|
+
role='Agent',
|
86
87
|
preamble="You are a helpful and joyful AI bot. Now let's chat.",
|
87
88
|
)
|
88
89
|
# First round.
|
@@ -102,7 +103,7 @@ class ConversationTest(unittest.TestCase):
|
|
102
103
|
inspect.cleandoc("""
|
103
104
|
You are a helpful and joyful AI bot. Now let's chat.
|
104
105
|
User: Hello
|
105
|
-
|
106
|
+
Agent: Response 1.
|
106
107
|
User: How are you?
|
107
108
|
"""),
|
108
109
|
)
|
@@ -114,9 +115,9 @@ class ConversationTest(unittest.TestCase):
|
|
114
115
|
inspect.cleandoc("""
|
115
116
|
You are a helpful and joyful AI bot. Now let's chat.
|
116
117
|
User: Hello
|
117
|
-
|
118
|
+
Agent: Response 1.
|
118
119
|
User: How are you?
|
119
|
-
|
120
|
+
Agent: Response 2.
|
120
121
|
User: Okay, bye.
|
121
122
|
"""),
|
122
123
|
)
|
@@ -56,7 +56,13 @@ class SelfPlayTest(unittest.TestCase):
|
|
56
56
|
g = NumberGuess(target_num=10)
|
57
57
|
|
58
58
|
with lf.context(lm=NumberGuesser(guesses=[50, 20, 5, 10])):
|
59
|
-
self.assertEqual(
|
59
|
+
self.assertEqual(
|
60
|
+
g(),
|
61
|
+
lf.AIMessage(
|
62
|
+
'10', score=0.0, logprobs=None, is_cached=False,
|
63
|
+
usage=lf.UsageNotAvailable()
|
64
|
+
)
|
65
|
+
)
|
60
66
|
|
61
67
|
self.assertEqual(g.num_turns, 4)
|
62
68
|
|
@@ -64,7 +70,13 @@ class SelfPlayTest(unittest.TestCase):
|
|
64
70
|
g = NumberGuess(target_num=10, max_turns=10)
|
65
71
|
|
66
72
|
with lf.context(lm=NumberGuesser(guesses=[50, 20, 5, 2, 5, 4])):
|
67
|
-
self.assertEqual(
|
73
|
+
self.assertEqual(
|
74
|
+
g(),
|
75
|
+
lf.AIMessage(
|
76
|
+
'2', score=0.0, logprobs=None, is_cached=False,
|
77
|
+
usage=lf.UsageNotAvailable()
|
78
|
+
)
|
79
|
+
)
|
68
80
|
|
69
81
|
self.assertEqual(g.num_turns, 10)
|
70
82
|
|