langfun 0.1.2.dev202510230805__py3-none-any.whl → 0.1.2.dev202510250803__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (44) hide show
  1. langfun/core/concurrent_test.py +1 -0
  2. langfun/core/data/conversion/anthropic_test.py +8 -6
  3. langfun/core/data/conversion/gemini_test.py +12 -9
  4. langfun/core/data/conversion/openai.py +134 -30
  5. langfun/core/data/conversion/openai_test.py +161 -17
  6. langfun/core/eval/base_test.py +4 -4
  7. langfun/core/eval/v2/progress_tracking_test.py +3 -0
  8. langfun/core/langfunc_test.py +6 -4
  9. langfun/core/language_model.py +15 -6
  10. langfun/core/language_model_test.py +9 -3
  11. langfun/core/llms/__init__.py +7 -1
  12. langfun/core/llms/anthropic.py +130 -0
  13. langfun/core/llms/cache/base.py +3 -1
  14. langfun/core/llms/cache/in_memory_test.py +14 -4
  15. langfun/core/llms/deepseek.py +1 -1
  16. langfun/core/llms/gemini.py +2 -5
  17. langfun/core/llms/groq.py +1 -1
  18. langfun/core/llms/llama_cpp.py +1 -1
  19. langfun/core/llms/openai.py +7 -2
  20. langfun/core/llms/openai_compatible.py +136 -27
  21. langfun/core/llms/openai_compatible_test.py +207 -20
  22. langfun/core/llms/openai_test.py +0 -2
  23. langfun/core/llms/vertexai.py +12 -2
  24. langfun/core/message.py +78 -44
  25. langfun/core/message_test.py +56 -81
  26. langfun/core/modalities/__init__.py +8 -0
  27. langfun/core/modalities/mime.py +9 -0
  28. langfun/core/modality.py +104 -27
  29. langfun/core/modality_test.py +42 -12
  30. langfun/core/sampling_test.py +20 -4
  31. langfun/core/structured/completion.py +2 -7
  32. langfun/core/structured/completion_test.py +23 -43
  33. langfun/core/structured/mapping.py +4 -13
  34. langfun/core/structured/querying.py +13 -11
  35. langfun/core/structured/querying_test.py +65 -29
  36. langfun/core/template.py +39 -13
  37. langfun/core/template_test.py +83 -17
  38. langfun/env/event_handlers/metric_writer_test.py +3 -3
  39. langfun/env/load_balancers_test.py +2 -2
  40. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/METADATA +1 -1
  41. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/RECORD +44 -44
  42. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/WHEEL +0 -0
  43. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/licenses/LICENSE +0 -0
  44. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/top_level.txt +0 -0
@@ -29,34 +29,64 @@ class ModalityTest(unittest.TestCase):
29
29
 
30
30
  def test_basic(self):
31
31
  v = CustomModality('a')
32
- self.assertIsNone(v.referred_name)
32
+ self.assertEqual(v.id, 'custom_modality:0cc175b9')
33
33
  self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
34
34
  self.assertEqual(v.hash, '0cc175b9')
35
35
 
36
36
  _ = pg.Dict(metadata=pg.Dict(x=pg.Dict(metadata=pg.Dict(y=v))))
37
- self.assertEqual(v.referred_name, 'x.metadata.y')
37
+ self.assertEqual(v.id, 'custom_modality:0cc175b9')
38
38
  self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
39
39
  with modality.format_modality_as_ref():
40
- self.assertEqual(str(v), '<<[[x.metadata.y]]>>')
40
+ self.assertEqual(str(v), '<<[[custom_modality:0cc175b9]]>>')
41
+
42
+ def test_capture_rendered_modalities(self):
43
+ x = CustomModality('a')
44
+ y = CustomModality('b')
45
+ z = CustomModality('b')
46
+
47
+ with modality.capture_rendered_modalities() as rendered_modalities:
48
+ with modality.format_modality_as_ref():
49
+ self.assertEqual(
50
+ f'Hello {x} {y} {z}',
51
+ (
52
+ 'Hello <<[[custom_modality:0cc175b9]]>> '
53
+ '<<[[custom_modality:92eb5ffe]]>> '
54
+ '<<[[custom_modality:92eb5ffe]]>>'
55
+ )
56
+ )
57
+ self.assertEqual(len(rendered_modalities), 2)
58
+ self.assertIs(rendered_modalities['custom_modality:0cc175b9'].value, x)
59
+ # y and z share the same content will be treated as the same object.
60
+ self.assertIs(rendered_modalities['custom_modality:92eb5ffe'].value, z)
41
61
 
42
62
 
43
63
  class ModalityRefTest(unittest.TestCase):
44
64
 
45
- def test_placehold(self):
65
+ def test_placehold_and_restore(self):
46
66
  class A(pg.Object):
47
67
  x: Any
48
68
  y: Any
49
69
 
50
- a = A(x=dict(z=CustomModality('a')), y=CustomModality('b'))
70
+ image_a = CustomModality('a')
71
+ image_b = CustomModality('b')
72
+ a = A(x=dict(z=image_a), y=image_b)
73
+ a_placehold = modality.ModalityRef.placehold(a)
51
74
  self.assertEqual(
52
- modality.ModalityRef.placehold(a),
53
- A(x=dict(z=modality.ModalityRef('x.z')), y=modality.ModalityRef('y')),
75
+ a_placehold,
76
+ A(x=dict(z=modality.ModalityRef(image_a.id)),
77
+ y=modality.ModalityRef(image_b.id)),
78
+ )
79
+ a_restore = modality.ModalityRef.restore(
80
+ a_placehold.clone(),
81
+ {image_a.id: image_a, image_b.id: image_b},
54
82
  )
83
+ self.assertTrue(pg.eq(a_restore, a))
55
84
  self.assertEqual(
56
85
  modality.ModalityRef.placehold(a.x),
57
- # The prefix 'x' of referred name is preserved.
58
- dict(z=modality.ModalityRef('x.z')),
86
+ dict(z=modality.ModalityRef(image_a.id)),
59
87
  )
88
+ with self.assertRaisesRegex(ValueError, 'Modality .* not found'):
89
+ modality.ModalityRef.restore(a_placehold, {image_a.id: image_a})
60
90
 
61
91
  def test_from_value(self):
62
92
  class A(pg.Object):
@@ -68,8 +98,8 @@ class ModalityRefTest(unittest.TestCase):
68
98
  pg.eq(
69
99
  modality.Modality.from_value(a),
70
100
  {
71
- 'x.z': CustomModality('a'),
72
- 'y': CustomModality('b'),
101
+ 'custom_modality:0cc175b9': CustomModality('a'),
102
+ 'custom_modality:92eb5ffe': CustomModality('b'),
73
103
  },
74
104
  )
75
105
  )
@@ -77,7 +107,7 @@ class ModalityRefTest(unittest.TestCase):
77
107
  pg.eq(
78
108
  modality.Modality.from_value(a.x.z),
79
109
  {
80
- 'x.z': CustomModality('a'),
110
+ 'custom_modality:0cc175b9': CustomModality('a'),
81
111
  },
82
112
  )
83
113
  )
@@ -39,8 +39,13 @@ class SamplingTest(unittest.TestCase):
39
39
  l = LangFunc('Compute {{x}} and {{y}}', x=pg.oneof([1, 2]))
40
40
  with component.context(lm=ExcitedEchoer()):
41
41
  samples = list(sampling.sweep(l, y=pg.oneof([3, 4])))
42
- samples = sorted(samples, key=lambda x: (x[0].x, x[0].y))
43
-
42
+ samples = sorted(
43
+ samples,
44
+ key=lambda x: (
45
+ x[0].__template_input__.x,
46
+ x[0].__template_input__.y
47
+ )
48
+ )
44
49
  self.assertEqual(
45
50
  samples,
46
51
  [
@@ -57,7 +62,12 @@ class SamplingTest(unittest.TestCase):
57
62
  samples = list(
58
63
  sampling.random_sample(l, y=pg.oneof([2, 4]), num_examples=3, seed=1)
59
64
  )
60
- samples = sorted(samples, key=lambda x: (x[0].x, x[0].y))
65
+ samples = sorted(
66
+ samples, key=lambda x: (
67
+ x[0].__template_input__.x,
68
+ x[0].__template_input__.y
69
+ )
70
+ )
61
71
 
62
72
  self.assertEqual(
63
73
  samples,
@@ -97,7 +107,13 @@ class SamplingTest(unittest.TestCase):
97
107
  silence_on_errors=(AttributeError,),
98
108
  ignore_examples_with_errors=False))
99
109
 
100
- samples = sorted(samples, key=lambda x: (x[0].x, x[0].y))
110
+ samples = sorted(
111
+ samples,
112
+ key=lambda x: (
113
+ x[0].__template_input__.x,
114
+ x[0].__template_input__.y
115
+ )
116
+ )
101
117
  self.assertEqual(
102
118
  [x[0] for x in samples],
103
119
  [
@@ -118,13 +118,8 @@ class _CompleteStructure(mapping.Mapping):
118
118
  def postprocess_result(self, result: Any) -> Any:
119
119
  """Postprocess result."""
120
120
  # Try restore modality objects from the input value to output value.
121
- modalities = self.modalities(self.input)
122
- if modalities:
123
- # Remove the `input` prefix for all entries.
124
- modalities = pg.object_utils.flatten(
125
- pg.object_utils.canonicalize(modalities)['input']
126
- )
127
- result.rebind(modalities)
121
+ if modalities := self.modalities(self.input):
122
+ result = lf.ModalityRef.restore(result, modalities)
128
123
  return result
129
124
 
130
125
  def globals(self):
@@ -407,22 +407,17 @@ class CompleteStructureTest(unittest.TestCase):
407
407
  image: modalities.Image
408
408
  name: str
409
409
 
410
+ image_elephant = modalities.Image.from_bytes(b'image_of_elephant')
411
+ image_rabbit = modalities.Image.from_bytes(b'image_of_rabbit')
410
412
  input_value = schema_lib.mark_missing(
411
- Animal.partial(
412
- modalities.Image.from_bytes(b'image_of_elephant'),
413
- )
413
+ Animal.partial(image_elephant)
414
414
  )
415
415
  l = completion._CompleteStructure(
416
416
  input=input_value,
417
417
  examples=[
418
418
  mapping.MappingExample(
419
- input=Animal.partial(
420
- modalities.Image.from_bytes(b'image_of_rabbit')
421
- ),
422
- output=Animal(
423
- modalities.Image.from_bytes(b'image_of_rabbit'),
424
- 'rabbit',
425
- ),
419
+ input=Animal.partial(image_rabbit),
420
+ output=Animal(image_rabbit, 'rabbit'),
426
421
  )
427
422
  ],
428
423
  )
@@ -430,7 +425,7 @@ class CompleteStructureTest(unittest.TestCase):
430
425
  self.maxDiff = None
431
426
  self.assertEqual(
432
427
  lm_input.text,
433
- inspect.cleandoc("""
428
+ inspect.cleandoc(f"""
434
429
  Please generate the OUTPUT_OBJECT by completing the MISSING fields from the last INPUT_OBJECT.
435
430
 
436
431
  INSTRUCTIONS:
@@ -457,22 +452,22 @@ class CompleteStructureTest(unittest.TestCase):
457
452
  ```python
458
453
  Animal(
459
454
  image=ModalityRef(
460
- name='examples[0].input.image'
455
+ id='{image_rabbit.id}'
461
456
  ),
462
457
  name=MISSING(str)
463
458
  )
464
459
  ```
465
460
 
466
461
  MODALITY_REFERENCES:
467
- {
468
- 'examples[0].input.image': <<[[examples[0].input.image]]>>
469
- }
462
+ {{
463
+ '{image_rabbit.id}': <<[[{image_rabbit.id}]]>>
464
+ }}
470
465
 
471
466
  OUTPUT_OBJECT:
472
467
  ```python
473
468
  Animal(
474
469
  image=ModalityRef(
475
- name='examples[0].output.image'
470
+ id='{image_rabbit.id}'
476
471
  ),
477
472
  name='rabbit'
478
473
  )
@@ -483,16 +478,16 @@ class CompleteStructureTest(unittest.TestCase):
483
478
  ```python
484
479
  Animal(
485
480
  image=ModalityRef(
486
- name='input.image'
481
+ id='{image_elephant.id}'
487
482
  ),
488
483
  name=MISSING(str)
489
484
  )
490
485
  ```
491
486
 
492
487
  MODALITY_REFERENCES:
493
- {
494
- 'input.image': <<[[input.image]]>>
495
- }
488
+ {{
489
+ '{image_elephant.id}': <<[[{image_elephant.id}]]>>
490
+ }}
496
491
 
497
492
  OUTPUT_OBJECT:
498
493
  """),
@@ -500,39 +495,27 @@ class CompleteStructureTest(unittest.TestCase):
500
495
  self.assertTrue(
501
496
  pg.eq(
502
497
  {
503
- 'examples': lm_input.get('examples'),
504
- 'input': lm_input.get('input'),
498
+ 'examples': lm_input.__template_input__.examples,
499
+ 'input': lm_input.__template_input__.mapping_request.input,
505
500
  },
506
501
  {
507
502
  'examples': [
508
503
  mapping.MappingExample(
509
- input=Animal.partial(
510
- image=modalities.Image.from_bytes(
511
- b'image_of_rabbit'
512
- )
513
- ),
514
- output=Animal.partial(
515
- image=modalities.Image.from_bytes(
516
- b'image_of_rabbit'
517
- ),
518
- name='rabbit',
519
- ),
504
+ input=Animal.partial(image_rabbit),
505
+ output=Animal.partial(image_rabbit, 'rabbit'),
520
506
  )
521
507
  ],
522
- 'input': Animal(
523
- image=modalities.Image.from_bytes(b'image_of_elephant'),
524
- name=schema_lib.MISSING,
525
- ),
508
+ 'input': Animal(image_elephant, name=schema_lib.MISSING),
526
509
  },
527
510
  )
528
511
  )
529
512
  lm_output = l(
530
513
  input=input_value,
531
- lm=fake.StaticResponse(inspect.cleandoc("""
514
+ lm=fake.StaticResponse(inspect.cleandoc(f"""
532
515
  ```python
533
516
  Animal(
534
517
  image=ModalityRef(
535
- name='input.image'
518
+ id='{image_elephant.id}'
536
519
  ),
537
520
  name='elephant'
538
521
  )
@@ -542,10 +525,7 @@ class CompleteStructureTest(unittest.TestCase):
542
525
  self.assertTrue(
543
526
  pg.eq(
544
527
  lm_output.result,
545
- Animal(
546
- image=modalities.Image.from_bytes(b'image_of_elephant'),
547
- name='elephant',
548
- ),
528
+ Animal(image=image_elephant, name='elephant'),
549
529
  )
550
530
  )
551
531
 
@@ -127,6 +127,8 @@ class MappingExample(lf.NaturalLanguageFormattable,
127
127
  ) -> str:
128
128
  if isinstance(value, str):
129
129
  return value
130
+ if isinstance(value, lf.Message):
131
+ return str(value)
130
132
  if isinstance(value, lf.Modality):
131
133
  with lf.modality.format_modality_as_ref():
132
134
  return str(value)
@@ -192,9 +194,7 @@ class MappingExample(lf.NaturalLanguageFormattable,
192
194
 
193
195
  def render_value(view, *, value, **kwargs):
194
196
  if isinstance(value, lf.Template):
195
- # Make a shallow copy to make sure modalities are rooted by
196
- # the input.
197
- value = value.clone().render()
197
+ value = value.render()
198
198
  if value is None:
199
199
  return None
200
200
  return view.render(value, **kwargs)
@@ -286,12 +286,8 @@ class Mapping(lf.LangFunc):
286
286
  @property
287
287
  def mapping_request(self) -> MappingExample:
288
288
  """Returns a MappingExample as the mapping request."""
289
- if isinstance(self.input, lf.Message):
290
- input_value = self.input.text
291
- else:
292
- input_value = pg.Ref(self.input)
293
289
  return MappingExample(
294
- input=input_value,
290
+ input=pg.Ref(self.input),
295
291
  schema=pg.Ref(self.schema),
296
292
  context=self.context,
297
293
  )
@@ -402,11 +398,6 @@ class Mapping(lf.LangFunc):
402
398
 
403
399
  def transform_input(self, lm_input: lf.Message) -> lf.Message:
404
400
  # Find modalities to fill the input message.
405
- lm_input.metadata.update(
406
- examples=pg.Ref(self.examples),
407
- input=pg.Ref(self.input),
408
- schema=pg.Ref(self.schema) if self.schema is not None else None,
409
- )
410
401
  if isinstance(self.input, lf.Message):
411
402
  lm_input.source = self.input
412
403
  return lm_input
@@ -529,24 +529,22 @@ def query(
529
529
  ).render(message_cls=lf.SystemMessage)
530
530
 
531
531
  # Normalize query input.
532
- if isinstance(prompt, (lf.Message, str)):
532
+ if isinstance(prompt, str):
533
533
  # Query with structured output.
534
534
  prompt_kwargs = kwargs.copy()
535
535
  prompt_kwargs.pop('template_str', None)
536
536
  query_input = lf.Template.from_value(prompt, **prompt_kwargs)
537
+ elif isinstance(prompt, lf.Message):
538
+ query_input = prompt
537
539
  elif isinstance(prompt, lf.Template):
538
- # Create a copy of the prompt if it has a parent object, so all child
539
- # modality objects could be referred by path relative to the prompt.
540
- query_input = prompt.clone() if prompt.sym_parent is not None else prompt
541
-
542
540
  # Attach template metadata from kwargs. This is used to pass through fields
543
541
  # from kwargs to the rendered message.
544
- template_metadata = {
545
- k: v for k, v in kwargs.items() if k.startswith('metadata_')
546
- }
547
- query_input.rebind(
548
- template_metadata, skip_notification=True, raise_on_no_change=False
542
+ prompt.rebind(
543
+ {k: v for k, v in kwargs.items() if k.startswith('metadata_')},
544
+ skip_notification=True,
545
+ raise_on_no_change=False
549
546
  )
547
+ query_input = prompt
550
548
  elif pg.MISSING_VALUE == prompt:
551
549
  query_input = lf.UserMessage('')
552
550
  else:
@@ -665,7 +663,11 @@ def query(
665
663
 
666
664
  if returns_message:
667
665
  return output_message
668
- return output_message.text if schema in (None, str) else output_message.result
666
+ if schema not in (None, str):
667
+ return output_message.result
668
+ if returns_message or output_message.referred_modalities:
669
+ return output_message
670
+ return output_message.text
669
671
 
670
672
 
671
673
  async def aquery(
@@ -249,35 +249,60 @@ class QueryTest(unittest.TestCase):
249
249
 
250
250
  def test_root_modality_to_structure_render(self):
251
251
  lm = fake.StaticResponse('1')
252
+ image = modalities.Image.from_bytes(b'mock_image')
252
253
  self.assert_render(
253
- modalities.Image.from_bytes(b'mock_image'),
254
+ image,
254
255
  int,
255
256
  lm=lm,
256
- expected_snippet='\n\nREQUEST:\n <<[[input]]>>\n\n',
257
+ expected_snippet=f'\n\nREQUEST:\n <<[[{image.id}]]>>\n\n',
257
258
  expected_modalities=1,
258
259
  )
259
260
 
260
261
  def test_root_modality_to_str_render(self):
261
262
  lm = fake.StaticResponse('1')
263
+ modality = modalities.Image.from_bytes(b'mock_image')
262
264
  self.assert_render(
263
- modalities.Image.from_bytes(b'mock_image'),
265
+ modality,
264
266
  None,
265
267
  lm=lm,
266
- expected_snippet='<<[[input]]>>',
268
+ expected_snippet=f'<<[[{modality.id}]]>>',
267
269
  exact_match=True,
268
270
  expected_modalities=1,
269
271
  )
270
272
 
271
273
  def test_str_with_modality_to_str_render(self):
272
274
  lm = fake.StaticResponse('A cat and a mouse.')
275
+ cat_image = modalities.Image.from_bytes(b'cat_image')
276
+ mouse_image = modalities.Image.from_bytes(b'mouse_image')
273
277
  self.assert_render(
274
278
  'What are these? {{this_image}} and {{that_image}}',
275
279
  None,
276
- this_image=modalities.Image.from_bytes(b'cat_image'),
277
- that_image=modalities.Image.from_bytes(b'mouse_image'),
280
+ this_image=cat_image,
281
+ that_image=mouse_image,
278
282
  lm=lm,
279
283
  expected_snippet=(
280
- 'What are these? <<[[this_image]]>> and <<[[that_image]]>>'
284
+ f'What are these? <<[[{cat_image.id}]]>> and '
285
+ f'<<[[{mouse_image.id}]]>>'
286
+ ),
287
+ exact_match=True,
288
+ expected_modalities=2,
289
+ )
290
+
291
+ def test_message_with_modality_to_str_render(self):
292
+ lm = fake.StaticResponse('A cat and a mouse.')
293
+ cat_image = modalities.Image.from_bytes(b'cat_image')
294
+ mouse_image = modalities.Image.from_bytes(b'mouse_image')
295
+ self.assert_render(
296
+ lf.Template(
297
+ 'What are these? {{this_image}} and {{that_image}}',
298
+ this_image=cat_image,
299
+ that_image=mouse_image,
300
+ ).render(),
301
+ None,
302
+ lm=lm,
303
+ expected_snippet=(
304
+ f'What are these? <<[[{cat_image.id}]]>> and '
305
+ f'<<[[{mouse_image.id}]]>>'
281
306
  ),
282
307
  exact_match=True,
283
308
  expected_modalities=2,
@@ -285,33 +310,33 @@ class QueryTest(unittest.TestCase):
285
310
 
286
311
  def test_structure_with_modality_to_str_render(self):
287
312
  lm = fake.StaticResponse('A cat and a mouse.')
313
+ cat_image = modalities.Image.from_bytes(b'cat_image')
314
+ mouse_image = modalities.Image.from_bytes(b'mouse_image')
288
315
  self.assert_render(
289
- [
290
- modalities.Image.from_bytes(b'cat_image'),
291
- modalities.Image.from_bytes(b'mouse_image'),
292
- ],
316
+ [cat_image, mouse_image],
293
317
  None,
294
318
  lm=lm,
295
- expected_snippet='`[<<[[input[0]]]>>, <<[[input[1]]]>>]`',
319
+ expected_snippet=(
320
+ f'`[<<[[{cat_image.id}]]>>, <<[[{mouse_image.id}]]>>]`'
321
+ ),
296
322
  exact_match=True,
297
323
  expected_modalities=2,
298
324
  )
299
325
 
300
326
  def test_structure_with_modality_to_structure_render(self):
301
327
  lm = fake.StaticResponse('["cat", "mouse"]')
328
+ cat_image = modalities.Image.from_bytes(b'cat_image')
329
+ mouse_image = modalities.Image.from_bytes(b'mouse_image')
302
330
  self.assert_render(
303
- [
304
- modalities.Image.from_bytes(b'cat_image'),
305
- modalities.Image.from_bytes(b'mouse_image'),
306
- ],
331
+ [cat_image, mouse_image],
307
332
  list[str],
308
333
  lm=lm,
309
- expected_snippet=inspect.cleandoc("""
334
+ expected_snippet=inspect.cleandoc(f"""
310
335
  REQUEST:
311
336
  ```python
312
337
  [
313
- <<[[input[0]]]>>,
314
- <<[[input[1]]]>>
338
+ <<[[{cat_image.id}]]>>,
339
+ <<[[{mouse_image.id}]]>>
315
340
  ]
316
341
  ```
317
342
  """),
@@ -320,25 +345,25 @@ class QueryTest(unittest.TestCase):
320
345
 
321
346
  def test_structure_with_modality_and_examples_to_structure_render(self):
322
347
  lm = fake.StaticResponse('["cat", "mouse"]')
348
+ cat_image = modalities.Image.from_bytes(b'cat_image')
349
+ mouse_image = modalities.Image.from_bytes(b'mouse_image')
350
+ dog_image = modalities.Image.from_bytes(b'dog_image')
323
351
  self.assert_render(
324
- [
325
- modalities.Image.from_bytes(b'cat_image'),
326
- modalities.Image.from_bytes(b'mouse_image'),
327
- ],
352
+ [cat_image, mouse_image],
328
353
  list[str],
329
354
  examples=[
330
355
  mapping.MappingExample(
331
- input=[modalities.Image.from_bytes(b'dog_image')],
356
+ input=[dog_image],
332
357
  schema=list[str],
333
358
  output=['dog'],
334
359
  ),
335
360
  ],
336
361
  lm=lm,
337
- expected_snippet=inspect.cleandoc("""
362
+ expected_snippet=inspect.cleandoc(f"""
338
363
  REQUEST:
339
364
  ```python
340
365
  [
341
- <<[[examples[0].input[0]]]>>
366
+ <<[[{dog_image.id}]]>>
342
367
  ]
343
368
  ```
344
369
 
@@ -356,8 +381,8 @@ class QueryTest(unittest.TestCase):
356
381
  REQUEST:
357
382
  ```python
358
383
  [
359
- <<[[input[0]]]>>,
360
- <<[[input[1]]]>>
384
+ <<[[{cat_image.id}]]>>,
385
+ <<[[{mouse_image.id}]]>>
361
386
  ]
362
387
  ```
363
388
 
@@ -369,6 +394,17 @@ class QueryTest(unittest.TestCase):
369
394
  expected_modalities=3,
370
395
  )
371
396
 
397
+ def test_query_with_modality_output(self):
398
+ cat_image = modalities.Image.from_bytes(b'cat_image')
399
+ lm = fake.StaticResponse(
400
+ lf.Template('Here you go: {{image}}', image=cat_image).render(
401
+ message_cls=lf.AIMessage
402
+ )
403
+ )
404
+ response = querying.query('Generate a cat image', lm=lm)
405
+ self.assertIsInstance(response, lf.AIMessage)
406
+ self.assertEqual(response.modalities(), [cat_image])
407
+
372
408
  def test_multiple_queries(self):
373
409
  self.assertEqual(
374
410
  querying.query(
@@ -545,7 +581,7 @@ class QueryTest(unittest.TestCase):
545
581
  )
546
582
  ).input,
547
583
  )
548
- self.assertIsNotNone(output.get_modality('image'))
584
+ self.assertEqual(len(output.referred_modalities), 1)
549
585
 
550
586
  def test_query_and_reduce(self):
551
587
  self.assertEqual(
langfun/core/template.py CHANGED
@@ -171,6 +171,7 @@ class Template(
171
171
 
172
172
  # Last render output.
173
173
  self._cached_render_output = None
174
+ self._referred_modalities = None
174
175
 
175
176
  @property
176
177
  def render_output(self) -> message_lib.Message | None:
@@ -322,24 +323,46 @@ class Template(
322
323
  compact=True,
323
324
  python_format=True,
324
325
  ):
325
- # Natural language formattable objects will be returned in natural
326
- # language when they are directly returned as rendering elements
326
+
327
+ # Capture the modality objects whose references are being
328
+ # rendered
327
329
  # in the template.
328
- with modality.format_modality_as_ref():
329
- rendered_text = self._template.render(**inputs)
330
+ with modality.capture_rendered_modalities() as modality_refs:
331
+
332
+ # Natural language formattable objects will be returned in
333
+ # natural language when they are directly returned as rendering
334
+ # elements in the template.
335
+ with modality.format_modality_as_ref():
336
+ rendered_text = self._template.render(**inputs)
330
337
 
331
- # Carry additional metadata.
332
- metadata = self.additional_metadata()
338
+ # Carry the modality references passed from the constructor.
339
+ # This is to support modality objects that is already rendered
340
+ # in the template string.
341
+ if self._referred_modalities:
342
+ modality_refs.update(self._referred_modalities)
333
343
 
334
344
  if self.clean:
335
345
  rendered_text = rendered_text.strip()
336
346
 
337
- metadata.update(
338
- {k: pg.Ref(v) for k, v in inputs.items() if not inspect.ismethod(v)}
339
- )
347
+ # Fill message metadata.
348
+ metadata = {
349
+ '__template_input__': {
350
+ k: pg.Ref(v) for k, v in inputs.items()
351
+ if not inspect.ismethod(v)
352
+ },
353
+ }
354
+
355
+ # Carry additional metadata.
356
+ # TODO(daiyip): Consider to put additional metadata into a separate
357
+ # key under `metadata`.
358
+ metadata.update(self.additional_metadata())
340
359
 
341
360
  # Fill the variables for rendering the template as metadata.
342
- message = message_cls(text=rendered_text, metadata=metadata)
361
+ message = message_cls(
362
+ text=rendered_text,
363
+ referred_modalities=modality_refs,
364
+ metadata=metadata
365
+ )
343
366
 
344
367
  # Tag input as rendered message.
345
368
  message.tag(message_lib.Message.TAG_RENDERED)
@@ -518,10 +541,13 @@ class Template(
518
541
  if isinstance(value, str):
519
542
  return cls(template_str=value, **kwargs)
520
543
  if isinstance(value, message_lib.Message):
521
- kwargs.update(value.metadata)
522
- return cls(template_str=value.text, **kwargs)
544
+ for k, v in value.metadata.sym_items(): # pylint: disable=attribute-error
545
+ kwargs[_ADDITIONAL_METADATA_PREFIX + k] = v
546
+ t = cls(template_str=value.text, **kwargs)
547
+ t._referred_modalities = value.referred_modalities
548
+ return t
523
549
  if isinstance(value, Template):
524
- lfun = cls(template_str=value.template_str, **kwargs)
550
+ lfun = cls(template_str=value.template_str, **kwargs) # pylint: disable=attribute-error
525
551
  # So lfun could acccess all attributes from value.
526
552
  lfun.sym_setparent(value)
527
553
  return lfun