langfun 0.1.2.dev202510230805__py3-none-any.whl → 0.1.2.dev202510250803__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/core/concurrent_test.py +1 -0
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini_test.py +12 -9
- langfun/core/data/conversion/openai.py +134 -30
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base_test.py +4 -4
- langfun/core/eval/v2/progress_tracking_test.py +3 -0
- langfun/core/langfunc_test.py +6 -4
- langfun/core/language_model.py +15 -6
- langfun/core/language_model_test.py +9 -3
- langfun/core/llms/__init__.py +7 -1
- langfun/core/llms/anthropic.py +130 -0
- langfun/core/llms/cache/base.py +3 -1
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/deepseek.py +1 -1
- langfun/core/llms/gemini.py +2 -5
- langfun/core/llms/groq.py +1 -1
- langfun/core/llms/llama_cpp.py +1 -1
- langfun/core/llms/openai.py +7 -2
- langfun/core/llms/openai_compatible.py +136 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/vertexai.py +12 -2
- langfun/core/message.py +78 -44
- langfun/core/message_test.py +56 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/mime.py +9 -0
- langfun/core/modality.py +104 -27
- langfun/core/modality_test.py +42 -12
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/completion.py +2 -7
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/mapping.py +4 -13
- langfun/core/structured/querying.py +13 -11
- langfun/core/structured/querying_test.py +65 -29
- langfun/core/template.py +39 -13
- langfun/core/template_test.py +83 -17
- langfun/env/event_handlers/metric_writer_test.py +3 -3
- langfun/env/load_balancers_test.py +2 -2
- {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/RECORD +44 -44
- {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/top_level.txt +0 -0
langfun/core/modality_test.py
CHANGED
|
@@ -29,34 +29,64 @@ class ModalityTest(unittest.TestCase):
|
|
|
29
29
|
|
|
30
30
|
def test_basic(self):
|
|
31
31
|
v = CustomModality('a')
|
|
32
|
-
self.
|
|
32
|
+
self.assertEqual(v.id, 'custom_modality:0cc175b9')
|
|
33
33
|
self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
|
|
34
34
|
self.assertEqual(v.hash, '0cc175b9')
|
|
35
35
|
|
|
36
36
|
_ = pg.Dict(metadata=pg.Dict(x=pg.Dict(metadata=pg.Dict(y=v))))
|
|
37
|
-
self.assertEqual(v.
|
|
37
|
+
self.assertEqual(v.id, 'custom_modality:0cc175b9')
|
|
38
38
|
self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
|
|
39
39
|
with modality.format_modality_as_ref():
|
|
40
|
-
self.assertEqual(str(v), '<<[[
|
|
40
|
+
self.assertEqual(str(v), '<<[[custom_modality:0cc175b9]]>>')
|
|
41
|
+
|
|
42
|
+
def test_capture_rendered_modalities(self):
|
|
43
|
+
x = CustomModality('a')
|
|
44
|
+
y = CustomModality('b')
|
|
45
|
+
z = CustomModality('b')
|
|
46
|
+
|
|
47
|
+
with modality.capture_rendered_modalities() as rendered_modalities:
|
|
48
|
+
with modality.format_modality_as_ref():
|
|
49
|
+
self.assertEqual(
|
|
50
|
+
f'Hello {x} {y} {z}',
|
|
51
|
+
(
|
|
52
|
+
'Hello <<[[custom_modality:0cc175b9]]>> '
|
|
53
|
+
'<<[[custom_modality:92eb5ffe]]>> '
|
|
54
|
+
'<<[[custom_modality:92eb5ffe]]>>'
|
|
55
|
+
)
|
|
56
|
+
)
|
|
57
|
+
self.assertEqual(len(rendered_modalities), 2)
|
|
58
|
+
self.assertIs(rendered_modalities['custom_modality:0cc175b9'].value, x)
|
|
59
|
+
# y and z share the same content will be treated as the same object.
|
|
60
|
+
self.assertIs(rendered_modalities['custom_modality:92eb5ffe'].value, z)
|
|
41
61
|
|
|
42
62
|
|
|
43
63
|
class ModalityRefTest(unittest.TestCase):
|
|
44
64
|
|
|
45
|
-
def
|
|
65
|
+
def test_placehold_and_restore(self):
|
|
46
66
|
class A(pg.Object):
|
|
47
67
|
x: Any
|
|
48
68
|
y: Any
|
|
49
69
|
|
|
50
|
-
|
|
70
|
+
image_a = CustomModality('a')
|
|
71
|
+
image_b = CustomModality('b')
|
|
72
|
+
a = A(x=dict(z=image_a), y=image_b)
|
|
73
|
+
a_placehold = modality.ModalityRef.placehold(a)
|
|
51
74
|
self.assertEqual(
|
|
52
|
-
|
|
53
|
-
A(x=dict(z=modality.ModalityRef(
|
|
75
|
+
a_placehold,
|
|
76
|
+
A(x=dict(z=modality.ModalityRef(image_a.id)),
|
|
77
|
+
y=modality.ModalityRef(image_b.id)),
|
|
78
|
+
)
|
|
79
|
+
a_restore = modality.ModalityRef.restore(
|
|
80
|
+
a_placehold.clone(),
|
|
81
|
+
{image_a.id: image_a, image_b.id: image_b},
|
|
54
82
|
)
|
|
83
|
+
self.assertTrue(pg.eq(a_restore, a))
|
|
55
84
|
self.assertEqual(
|
|
56
85
|
modality.ModalityRef.placehold(a.x),
|
|
57
|
-
|
|
58
|
-
dict(z=modality.ModalityRef('x.z')),
|
|
86
|
+
dict(z=modality.ModalityRef(image_a.id)),
|
|
59
87
|
)
|
|
88
|
+
with self.assertRaisesRegex(ValueError, 'Modality .* not found'):
|
|
89
|
+
modality.ModalityRef.restore(a_placehold, {image_a.id: image_a})
|
|
60
90
|
|
|
61
91
|
def test_from_value(self):
|
|
62
92
|
class A(pg.Object):
|
|
@@ -68,8 +98,8 @@ class ModalityRefTest(unittest.TestCase):
|
|
|
68
98
|
pg.eq(
|
|
69
99
|
modality.Modality.from_value(a),
|
|
70
100
|
{
|
|
71
|
-
'
|
|
72
|
-
'
|
|
101
|
+
'custom_modality:0cc175b9': CustomModality('a'),
|
|
102
|
+
'custom_modality:92eb5ffe': CustomModality('b'),
|
|
73
103
|
},
|
|
74
104
|
)
|
|
75
105
|
)
|
|
@@ -77,7 +107,7 @@ class ModalityRefTest(unittest.TestCase):
|
|
|
77
107
|
pg.eq(
|
|
78
108
|
modality.Modality.from_value(a.x.z),
|
|
79
109
|
{
|
|
80
|
-
'
|
|
110
|
+
'custom_modality:0cc175b9': CustomModality('a'),
|
|
81
111
|
},
|
|
82
112
|
)
|
|
83
113
|
)
|
langfun/core/sampling_test.py
CHANGED
|
@@ -39,8 +39,13 @@ class SamplingTest(unittest.TestCase):
|
|
|
39
39
|
l = LangFunc('Compute {{x}} and {{y}}', x=pg.oneof([1, 2]))
|
|
40
40
|
with component.context(lm=ExcitedEchoer()):
|
|
41
41
|
samples = list(sampling.sweep(l, y=pg.oneof([3, 4])))
|
|
42
|
-
samples = sorted(
|
|
43
|
-
|
|
42
|
+
samples = sorted(
|
|
43
|
+
samples,
|
|
44
|
+
key=lambda x: (
|
|
45
|
+
x[0].__template_input__.x,
|
|
46
|
+
x[0].__template_input__.y
|
|
47
|
+
)
|
|
48
|
+
)
|
|
44
49
|
self.assertEqual(
|
|
45
50
|
samples,
|
|
46
51
|
[
|
|
@@ -57,7 +62,12 @@ class SamplingTest(unittest.TestCase):
|
|
|
57
62
|
samples = list(
|
|
58
63
|
sampling.random_sample(l, y=pg.oneof([2, 4]), num_examples=3, seed=1)
|
|
59
64
|
)
|
|
60
|
-
samples = sorted(
|
|
65
|
+
samples = sorted(
|
|
66
|
+
samples, key=lambda x: (
|
|
67
|
+
x[0].__template_input__.x,
|
|
68
|
+
x[0].__template_input__.y
|
|
69
|
+
)
|
|
70
|
+
)
|
|
61
71
|
|
|
62
72
|
self.assertEqual(
|
|
63
73
|
samples,
|
|
@@ -97,7 +107,13 @@ class SamplingTest(unittest.TestCase):
|
|
|
97
107
|
silence_on_errors=(AttributeError,),
|
|
98
108
|
ignore_examples_with_errors=False))
|
|
99
109
|
|
|
100
|
-
samples = sorted(
|
|
110
|
+
samples = sorted(
|
|
111
|
+
samples,
|
|
112
|
+
key=lambda x: (
|
|
113
|
+
x[0].__template_input__.x,
|
|
114
|
+
x[0].__template_input__.y
|
|
115
|
+
)
|
|
116
|
+
)
|
|
101
117
|
self.assertEqual(
|
|
102
118
|
[x[0] for x in samples],
|
|
103
119
|
[
|
|
@@ -118,13 +118,8 @@ class _CompleteStructure(mapping.Mapping):
|
|
|
118
118
|
def postprocess_result(self, result: Any) -> Any:
|
|
119
119
|
"""Postprocess result."""
|
|
120
120
|
# Try restore modality objects from the input value to output value.
|
|
121
|
-
modalities
|
|
122
|
-
|
|
123
|
-
# Remove the `input` prefix for all entries.
|
|
124
|
-
modalities = pg.object_utils.flatten(
|
|
125
|
-
pg.object_utils.canonicalize(modalities)['input']
|
|
126
|
-
)
|
|
127
|
-
result.rebind(modalities)
|
|
121
|
+
if modalities := self.modalities(self.input):
|
|
122
|
+
result = lf.ModalityRef.restore(result, modalities)
|
|
128
123
|
return result
|
|
129
124
|
|
|
130
125
|
def globals(self):
|
|
@@ -407,22 +407,17 @@ class CompleteStructureTest(unittest.TestCase):
|
|
|
407
407
|
image: modalities.Image
|
|
408
408
|
name: str
|
|
409
409
|
|
|
410
|
+
image_elephant = modalities.Image.from_bytes(b'image_of_elephant')
|
|
411
|
+
image_rabbit = modalities.Image.from_bytes(b'image_of_rabbit')
|
|
410
412
|
input_value = schema_lib.mark_missing(
|
|
411
|
-
Animal.partial(
|
|
412
|
-
modalities.Image.from_bytes(b'image_of_elephant'),
|
|
413
|
-
)
|
|
413
|
+
Animal.partial(image_elephant)
|
|
414
414
|
)
|
|
415
415
|
l = completion._CompleteStructure(
|
|
416
416
|
input=input_value,
|
|
417
417
|
examples=[
|
|
418
418
|
mapping.MappingExample(
|
|
419
|
-
input=Animal.partial(
|
|
420
|
-
|
|
421
|
-
),
|
|
422
|
-
output=Animal(
|
|
423
|
-
modalities.Image.from_bytes(b'image_of_rabbit'),
|
|
424
|
-
'rabbit',
|
|
425
|
-
),
|
|
419
|
+
input=Animal.partial(image_rabbit),
|
|
420
|
+
output=Animal(image_rabbit, 'rabbit'),
|
|
426
421
|
)
|
|
427
422
|
],
|
|
428
423
|
)
|
|
@@ -430,7 +425,7 @@ class CompleteStructureTest(unittest.TestCase):
|
|
|
430
425
|
self.maxDiff = None
|
|
431
426
|
self.assertEqual(
|
|
432
427
|
lm_input.text,
|
|
433
|
-
inspect.cleandoc("""
|
|
428
|
+
inspect.cleandoc(f"""
|
|
434
429
|
Please generate the OUTPUT_OBJECT by completing the MISSING fields from the last INPUT_OBJECT.
|
|
435
430
|
|
|
436
431
|
INSTRUCTIONS:
|
|
@@ -457,22 +452,22 @@ class CompleteStructureTest(unittest.TestCase):
|
|
|
457
452
|
```python
|
|
458
453
|
Animal(
|
|
459
454
|
image=ModalityRef(
|
|
460
|
-
|
|
455
|
+
id='{image_rabbit.id}'
|
|
461
456
|
),
|
|
462
457
|
name=MISSING(str)
|
|
463
458
|
)
|
|
464
459
|
```
|
|
465
460
|
|
|
466
461
|
MODALITY_REFERENCES:
|
|
467
|
-
{
|
|
468
|
-
'
|
|
469
|
-
}
|
|
462
|
+
{{
|
|
463
|
+
'{image_rabbit.id}': <<[[{image_rabbit.id}]]>>
|
|
464
|
+
}}
|
|
470
465
|
|
|
471
466
|
OUTPUT_OBJECT:
|
|
472
467
|
```python
|
|
473
468
|
Animal(
|
|
474
469
|
image=ModalityRef(
|
|
475
|
-
|
|
470
|
+
id='{image_rabbit.id}'
|
|
476
471
|
),
|
|
477
472
|
name='rabbit'
|
|
478
473
|
)
|
|
@@ -483,16 +478,16 @@ class CompleteStructureTest(unittest.TestCase):
|
|
|
483
478
|
```python
|
|
484
479
|
Animal(
|
|
485
480
|
image=ModalityRef(
|
|
486
|
-
|
|
481
|
+
id='{image_elephant.id}'
|
|
487
482
|
),
|
|
488
483
|
name=MISSING(str)
|
|
489
484
|
)
|
|
490
485
|
```
|
|
491
486
|
|
|
492
487
|
MODALITY_REFERENCES:
|
|
493
|
-
{
|
|
494
|
-
'
|
|
495
|
-
}
|
|
488
|
+
{{
|
|
489
|
+
'{image_elephant.id}': <<[[{image_elephant.id}]]>>
|
|
490
|
+
}}
|
|
496
491
|
|
|
497
492
|
OUTPUT_OBJECT:
|
|
498
493
|
"""),
|
|
@@ -500,39 +495,27 @@ class CompleteStructureTest(unittest.TestCase):
|
|
|
500
495
|
self.assertTrue(
|
|
501
496
|
pg.eq(
|
|
502
497
|
{
|
|
503
|
-
'examples': lm_input.
|
|
504
|
-
'input': lm_input.
|
|
498
|
+
'examples': lm_input.__template_input__.examples,
|
|
499
|
+
'input': lm_input.__template_input__.mapping_request.input,
|
|
505
500
|
},
|
|
506
501
|
{
|
|
507
502
|
'examples': [
|
|
508
503
|
mapping.MappingExample(
|
|
509
|
-
input=Animal.partial(
|
|
510
|
-
|
|
511
|
-
b'image_of_rabbit'
|
|
512
|
-
)
|
|
513
|
-
),
|
|
514
|
-
output=Animal.partial(
|
|
515
|
-
image=modalities.Image.from_bytes(
|
|
516
|
-
b'image_of_rabbit'
|
|
517
|
-
),
|
|
518
|
-
name='rabbit',
|
|
519
|
-
),
|
|
504
|
+
input=Animal.partial(image_rabbit),
|
|
505
|
+
output=Animal.partial(image_rabbit, 'rabbit'),
|
|
520
506
|
)
|
|
521
507
|
],
|
|
522
|
-
'input': Animal(
|
|
523
|
-
image=modalities.Image.from_bytes(b'image_of_elephant'),
|
|
524
|
-
name=schema_lib.MISSING,
|
|
525
|
-
),
|
|
508
|
+
'input': Animal(image_elephant, name=schema_lib.MISSING),
|
|
526
509
|
},
|
|
527
510
|
)
|
|
528
511
|
)
|
|
529
512
|
lm_output = l(
|
|
530
513
|
input=input_value,
|
|
531
|
-
lm=fake.StaticResponse(inspect.cleandoc("""
|
|
514
|
+
lm=fake.StaticResponse(inspect.cleandoc(f"""
|
|
532
515
|
```python
|
|
533
516
|
Animal(
|
|
534
517
|
image=ModalityRef(
|
|
535
|
-
|
|
518
|
+
id='{image_elephant.id}'
|
|
536
519
|
),
|
|
537
520
|
name='elephant'
|
|
538
521
|
)
|
|
@@ -542,10 +525,7 @@ class CompleteStructureTest(unittest.TestCase):
|
|
|
542
525
|
self.assertTrue(
|
|
543
526
|
pg.eq(
|
|
544
527
|
lm_output.result,
|
|
545
|
-
Animal(
|
|
546
|
-
image=modalities.Image.from_bytes(b'image_of_elephant'),
|
|
547
|
-
name='elephant',
|
|
548
|
-
),
|
|
528
|
+
Animal(image=image_elephant, name='elephant'),
|
|
549
529
|
)
|
|
550
530
|
)
|
|
551
531
|
|
|
@@ -127,6 +127,8 @@ class MappingExample(lf.NaturalLanguageFormattable,
|
|
|
127
127
|
) -> str:
|
|
128
128
|
if isinstance(value, str):
|
|
129
129
|
return value
|
|
130
|
+
if isinstance(value, lf.Message):
|
|
131
|
+
return str(value)
|
|
130
132
|
if isinstance(value, lf.Modality):
|
|
131
133
|
with lf.modality.format_modality_as_ref():
|
|
132
134
|
return str(value)
|
|
@@ -192,9 +194,7 @@ class MappingExample(lf.NaturalLanguageFormattable,
|
|
|
192
194
|
|
|
193
195
|
def render_value(view, *, value, **kwargs):
|
|
194
196
|
if isinstance(value, lf.Template):
|
|
195
|
-
|
|
196
|
-
# the input.
|
|
197
|
-
value = value.clone().render()
|
|
197
|
+
value = value.render()
|
|
198
198
|
if value is None:
|
|
199
199
|
return None
|
|
200
200
|
return view.render(value, **kwargs)
|
|
@@ -286,12 +286,8 @@ class Mapping(lf.LangFunc):
|
|
|
286
286
|
@property
|
|
287
287
|
def mapping_request(self) -> MappingExample:
|
|
288
288
|
"""Returns a MappingExample as the mapping request."""
|
|
289
|
-
if isinstance(self.input, lf.Message):
|
|
290
|
-
input_value = self.input.text
|
|
291
|
-
else:
|
|
292
|
-
input_value = pg.Ref(self.input)
|
|
293
289
|
return MappingExample(
|
|
294
|
-
input=
|
|
290
|
+
input=pg.Ref(self.input),
|
|
295
291
|
schema=pg.Ref(self.schema),
|
|
296
292
|
context=self.context,
|
|
297
293
|
)
|
|
@@ -402,11 +398,6 @@ class Mapping(lf.LangFunc):
|
|
|
402
398
|
|
|
403
399
|
def transform_input(self, lm_input: lf.Message) -> lf.Message:
|
|
404
400
|
# Find modalities to fill the input message.
|
|
405
|
-
lm_input.metadata.update(
|
|
406
|
-
examples=pg.Ref(self.examples),
|
|
407
|
-
input=pg.Ref(self.input),
|
|
408
|
-
schema=pg.Ref(self.schema) if self.schema is not None else None,
|
|
409
|
-
)
|
|
410
401
|
if isinstance(self.input, lf.Message):
|
|
411
402
|
lm_input.source = self.input
|
|
412
403
|
return lm_input
|
|
@@ -529,24 +529,22 @@ def query(
|
|
|
529
529
|
).render(message_cls=lf.SystemMessage)
|
|
530
530
|
|
|
531
531
|
# Normalize query input.
|
|
532
|
-
if isinstance(prompt,
|
|
532
|
+
if isinstance(prompt, str):
|
|
533
533
|
# Query with structured output.
|
|
534
534
|
prompt_kwargs = kwargs.copy()
|
|
535
535
|
prompt_kwargs.pop('template_str', None)
|
|
536
536
|
query_input = lf.Template.from_value(prompt, **prompt_kwargs)
|
|
537
|
+
elif isinstance(prompt, lf.Message):
|
|
538
|
+
query_input = prompt
|
|
537
539
|
elif isinstance(prompt, lf.Template):
|
|
538
|
-
# Create a copy of the prompt if it has a parent object, so all child
|
|
539
|
-
# modality objects could be referred by path relative to the prompt.
|
|
540
|
-
query_input = prompt.clone() if prompt.sym_parent is not None else prompt
|
|
541
|
-
|
|
542
540
|
# Attach template metadata from kwargs. This is used to pass through fields
|
|
543
541
|
# from kwargs to the rendered message.
|
|
544
|
-
|
|
545
|
-
k: v for k, v in kwargs.items() if k.startswith('metadata_')
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
template_metadata, skip_notification=True, raise_on_no_change=False
|
|
542
|
+
prompt.rebind(
|
|
543
|
+
{k: v for k, v in kwargs.items() if k.startswith('metadata_')},
|
|
544
|
+
skip_notification=True,
|
|
545
|
+
raise_on_no_change=False
|
|
549
546
|
)
|
|
547
|
+
query_input = prompt
|
|
550
548
|
elif pg.MISSING_VALUE == prompt:
|
|
551
549
|
query_input = lf.UserMessage('')
|
|
552
550
|
else:
|
|
@@ -665,7 +663,11 @@ def query(
|
|
|
665
663
|
|
|
666
664
|
if returns_message:
|
|
667
665
|
return output_message
|
|
668
|
-
|
|
666
|
+
if schema not in (None, str):
|
|
667
|
+
return output_message.result
|
|
668
|
+
if returns_message or output_message.referred_modalities:
|
|
669
|
+
return output_message
|
|
670
|
+
return output_message.text
|
|
669
671
|
|
|
670
672
|
|
|
671
673
|
async def aquery(
|
|
@@ -249,35 +249,60 @@ class QueryTest(unittest.TestCase):
|
|
|
249
249
|
|
|
250
250
|
def test_root_modality_to_structure_render(self):
|
|
251
251
|
lm = fake.StaticResponse('1')
|
|
252
|
+
image = modalities.Image.from_bytes(b'mock_image')
|
|
252
253
|
self.assert_render(
|
|
253
|
-
|
|
254
|
+
image,
|
|
254
255
|
int,
|
|
255
256
|
lm=lm,
|
|
256
|
-
expected_snippet='\n\nREQUEST:\n <<[[
|
|
257
|
+
expected_snippet=f'\n\nREQUEST:\n <<[[{image.id}]]>>\n\n',
|
|
257
258
|
expected_modalities=1,
|
|
258
259
|
)
|
|
259
260
|
|
|
260
261
|
def test_root_modality_to_str_render(self):
|
|
261
262
|
lm = fake.StaticResponse('1')
|
|
263
|
+
modality = modalities.Image.from_bytes(b'mock_image')
|
|
262
264
|
self.assert_render(
|
|
263
|
-
|
|
265
|
+
modality,
|
|
264
266
|
None,
|
|
265
267
|
lm=lm,
|
|
266
|
-
expected_snippet='<<[[
|
|
268
|
+
expected_snippet=f'<<[[{modality.id}]]>>',
|
|
267
269
|
exact_match=True,
|
|
268
270
|
expected_modalities=1,
|
|
269
271
|
)
|
|
270
272
|
|
|
271
273
|
def test_str_with_modality_to_str_render(self):
|
|
272
274
|
lm = fake.StaticResponse('A cat and a mouse.')
|
|
275
|
+
cat_image = modalities.Image.from_bytes(b'cat_image')
|
|
276
|
+
mouse_image = modalities.Image.from_bytes(b'mouse_image')
|
|
273
277
|
self.assert_render(
|
|
274
278
|
'What are these? {{this_image}} and {{that_image}}',
|
|
275
279
|
None,
|
|
276
|
-
this_image=
|
|
277
|
-
that_image=
|
|
280
|
+
this_image=cat_image,
|
|
281
|
+
that_image=mouse_image,
|
|
278
282
|
lm=lm,
|
|
279
283
|
expected_snippet=(
|
|
280
|
-
'What are these? <<[[
|
|
284
|
+
f'What are these? <<[[{cat_image.id}]]>> and '
|
|
285
|
+
f'<<[[{mouse_image.id}]]>>'
|
|
286
|
+
),
|
|
287
|
+
exact_match=True,
|
|
288
|
+
expected_modalities=2,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
def test_message_with_modality_to_str_render(self):
|
|
292
|
+
lm = fake.StaticResponse('A cat and a mouse.')
|
|
293
|
+
cat_image = modalities.Image.from_bytes(b'cat_image')
|
|
294
|
+
mouse_image = modalities.Image.from_bytes(b'mouse_image')
|
|
295
|
+
self.assert_render(
|
|
296
|
+
lf.Template(
|
|
297
|
+
'What are these? {{this_image}} and {{that_image}}',
|
|
298
|
+
this_image=cat_image,
|
|
299
|
+
that_image=mouse_image,
|
|
300
|
+
).render(),
|
|
301
|
+
None,
|
|
302
|
+
lm=lm,
|
|
303
|
+
expected_snippet=(
|
|
304
|
+
f'What are these? <<[[{cat_image.id}]]>> and '
|
|
305
|
+
f'<<[[{mouse_image.id}]]>>'
|
|
281
306
|
),
|
|
282
307
|
exact_match=True,
|
|
283
308
|
expected_modalities=2,
|
|
@@ -285,33 +310,33 @@ class QueryTest(unittest.TestCase):
|
|
|
285
310
|
|
|
286
311
|
def test_structure_with_modality_to_str_render(self):
|
|
287
312
|
lm = fake.StaticResponse('A cat and a mouse.')
|
|
313
|
+
cat_image = modalities.Image.from_bytes(b'cat_image')
|
|
314
|
+
mouse_image = modalities.Image.from_bytes(b'mouse_image')
|
|
288
315
|
self.assert_render(
|
|
289
|
-
[
|
|
290
|
-
modalities.Image.from_bytes(b'cat_image'),
|
|
291
|
-
modalities.Image.from_bytes(b'mouse_image'),
|
|
292
|
-
],
|
|
316
|
+
[cat_image, mouse_image],
|
|
293
317
|
None,
|
|
294
318
|
lm=lm,
|
|
295
|
-
expected_snippet=
|
|
319
|
+
expected_snippet=(
|
|
320
|
+
f'`[<<[[{cat_image.id}]]>>, <<[[{mouse_image.id}]]>>]`'
|
|
321
|
+
),
|
|
296
322
|
exact_match=True,
|
|
297
323
|
expected_modalities=2,
|
|
298
324
|
)
|
|
299
325
|
|
|
300
326
|
def test_structure_with_modality_to_structure_render(self):
|
|
301
327
|
lm = fake.StaticResponse('["cat", "mouse"]')
|
|
328
|
+
cat_image = modalities.Image.from_bytes(b'cat_image')
|
|
329
|
+
mouse_image = modalities.Image.from_bytes(b'mouse_image')
|
|
302
330
|
self.assert_render(
|
|
303
|
-
[
|
|
304
|
-
modalities.Image.from_bytes(b'cat_image'),
|
|
305
|
-
modalities.Image.from_bytes(b'mouse_image'),
|
|
306
|
-
],
|
|
331
|
+
[cat_image, mouse_image],
|
|
307
332
|
list[str],
|
|
308
333
|
lm=lm,
|
|
309
|
-
expected_snippet=inspect.cleandoc("""
|
|
334
|
+
expected_snippet=inspect.cleandoc(f"""
|
|
310
335
|
REQUEST:
|
|
311
336
|
```python
|
|
312
337
|
[
|
|
313
|
-
<<[[
|
|
314
|
-
<<[[
|
|
338
|
+
<<[[{cat_image.id}]]>>,
|
|
339
|
+
<<[[{mouse_image.id}]]>>
|
|
315
340
|
]
|
|
316
341
|
```
|
|
317
342
|
"""),
|
|
@@ -320,25 +345,25 @@ class QueryTest(unittest.TestCase):
|
|
|
320
345
|
|
|
321
346
|
def test_structure_with_modality_and_examples_to_structure_render(self):
|
|
322
347
|
lm = fake.StaticResponse('["cat", "mouse"]')
|
|
348
|
+
cat_image = modalities.Image.from_bytes(b'cat_image')
|
|
349
|
+
mouse_image = modalities.Image.from_bytes(b'mouse_image')
|
|
350
|
+
dog_image = modalities.Image.from_bytes(b'dog_image')
|
|
323
351
|
self.assert_render(
|
|
324
|
-
[
|
|
325
|
-
modalities.Image.from_bytes(b'cat_image'),
|
|
326
|
-
modalities.Image.from_bytes(b'mouse_image'),
|
|
327
|
-
],
|
|
352
|
+
[cat_image, mouse_image],
|
|
328
353
|
list[str],
|
|
329
354
|
examples=[
|
|
330
355
|
mapping.MappingExample(
|
|
331
|
-
input=[
|
|
356
|
+
input=[dog_image],
|
|
332
357
|
schema=list[str],
|
|
333
358
|
output=['dog'],
|
|
334
359
|
),
|
|
335
360
|
],
|
|
336
361
|
lm=lm,
|
|
337
|
-
expected_snippet=inspect.cleandoc("""
|
|
362
|
+
expected_snippet=inspect.cleandoc(f"""
|
|
338
363
|
REQUEST:
|
|
339
364
|
```python
|
|
340
365
|
[
|
|
341
|
-
<<[[
|
|
366
|
+
<<[[{dog_image.id}]]>>
|
|
342
367
|
]
|
|
343
368
|
```
|
|
344
369
|
|
|
@@ -356,8 +381,8 @@ class QueryTest(unittest.TestCase):
|
|
|
356
381
|
REQUEST:
|
|
357
382
|
```python
|
|
358
383
|
[
|
|
359
|
-
<<[[
|
|
360
|
-
<<[[
|
|
384
|
+
<<[[{cat_image.id}]]>>,
|
|
385
|
+
<<[[{mouse_image.id}]]>>
|
|
361
386
|
]
|
|
362
387
|
```
|
|
363
388
|
|
|
@@ -369,6 +394,17 @@ class QueryTest(unittest.TestCase):
|
|
|
369
394
|
expected_modalities=3,
|
|
370
395
|
)
|
|
371
396
|
|
|
397
|
+
def test_query_with_modality_output(self):
|
|
398
|
+
cat_image = modalities.Image.from_bytes(b'cat_image')
|
|
399
|
+
lm = fake.StaticResponse(
|
|
400
|
+
lf.Template('Here you go: {{image}}', image=cat_image).render(
|
|
401
|
+
message_cls=lf.AIMessage
|
|
402
|
+
)
|
|
403
|
+
)
|
|
404
|
+
response = querying.query('Generate a cat image', lm=lm)
|
|
405
|
+
self.assertIsInstance(response, lf.AIMessage)
|
|
406
|
+
self.assertEqual(response.modalities(), [cat_image])
|
|
407
|
+
|
|
372
408
|
def test_multiple_queries(self):
|
|
373
409
|
self.assertEqual(
|
|
374
410
|
querying.query(
|
|
@@ -545,7 +581,7 @@ class QueryTest(unittest.TestCase):
|
|
|
545
581
|
)
|
|
546
582
|
).input,
|
|
547
583
|
)
|
|
548
|
-
self.
|
|
584
|
+
self.assertEqual(len(output.referred_modalities), 1)
|
|
549
585
|
|
|
550
586
|
def test_query_and_reduce(self):
|
|
551
587
|
self.assertEqual(
|
langfun/core/template.py
CHANGED
|
@@ -171,6 +171,7 @@ class Template(
|
|
|
171
171
|
|
|
172
172
|
# Last render output.
|
|
173
173
|
self._cached_render_output = None
|
|
174
|
+
self._referred_modalities = None
|
|
174
175
|
|
|
175
176
|
@property
|
|
176
177
|
def render_output(self) -> message_lib.Message | None:
|
|
@@ -322,24 +323,46 @@ class Template(
|
|
|
322
323
|
compact=True,
|
|
323
324
|
python_format=True,
|
|
324
325
|
):
|
|
325
|
-
|
|
326
|
-
#
|
|
326
|
+
|
|
327
|
+
# Capture the modality objects whose references are being
|
|
328
|
+
# rendered
|
|
327
329
|
# in the template.
|
|
328
|
-
with modality.
|
|
329
|
-
|
|
330
|
+
with modality.capture_rendered_modalities() as modality_refs:
|
|
331
|
+
|
|
332
|
+
# Natural language formattable objects will be returned in
|
|
333
|
+
# natural language when they are directly returned as rendering
|
|
334
|
+
# elements in the template.
|
|
335
|
+
with modality.format_modality_as_ref():
|
|
336
|
+
rendered_text = self._template.render(**inputs)
|
|
330
337
|
|
|
331
|
-
|
|
332
|
-
|
|
338
|
+
# Carry the modality references passed from the constructor.
|
|
339
|
+
# This is to support modality objects that is already rendered
|
|
340
|
+
# in the template string.
|
|
341
|
+
if self._referred_modalities:
|
|
342
|
+
modality_refs.update(self._referred_modalities)
|
|
333
343
|
|
|
334
344
|
if self.clean:
|
|
335
345
|
rendered_text = rendered_text.strip()
|
|
336
346
|
|
|
337
|
-
metadata.
|
|
338
|
-
|
|
339
|
-
|
|
347
|
+
# Fill message metadata.
|
|
348
|
+
metadata = {
|
|
349
|
+
'__template_input__': {
|
|
350
|
+
k: pg.Ref(v) for k, v in inputs.items()
|
|
351
|
+
if not inspect.ismethod(v)
|
|
352
|
+
},
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
# Carry additional metadata.
|
|
356
|
+
# TODO(daiyip): Consider to put additional metadata into a separate
|
|
357
|
+
# key under `metadata`.
|
|
358
|
+
metadata.update(self.additional_metadata())
|
|
340
359
|
|
|
341
360
|
# Fill the variables for rendering the template as metadata.
|
|
342
|
-
message = message_cls(
|
|
361
|
+
message = message_cls(
|
|
362
|
+
text=rendered_text,
|
|
363
|
+
referred_modalities=modality_refs,
|
|
364
|
+
metadata=metadata
|
|
365
|
+
)
|
|
343
366
|
|
|
344
367
|
# Tag input as rendered message.
|
|
345
368
|
message.tag(message_lib.Message.TAG_RENDERED)
|
|
@@ -518,10 +541,13 @@ class Template(
|
|
|
518
541
|
if isinstance(value, str):
|
|
519
542
|
return cls(template_str=value, **kwargs)
|
|
520
543
|
if isinstance(value, message_lib.Message):
|
|
521
|
-
|
|
522
|
-
|
|
544
|
+
for k, v in value.metadata.sym_items(): # pylint: disable=attribute-error
|
|
545
|
+
kwargs[_ADDITIONAL_METADATA_PREFIX + k] = v
|
|
546
|
+
t = cls(template_str=value.text, **kwargs)
|
|
547
|
+
t._referred_modalities = value.referred_modalities
|
|
548
|
+
return t
|
|
523
549
|
if isinstance(value, Template):
|
|
524
|
-
lfun = cls(template_str=value.template_str, **kwargs)
|
|
550
|
+
lfun = cls(template_str=value.template_str, **kwargs) # pylint: disable=attribute-error
|
|
525
551
|
# So lfun could acccess all attributes from value.
|
|
526
552
|
lfun.sym_setparent(value)
|
|
527
553
|
return lfun
|