langfun 0.1.2.dev202510240805__py3-none-any.whl → 0.1.2.dev202510250803__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (41) hide show
  1. langfun/core/concurrent_test.py +1 -0
  2. langfun/core/data/conversion/anthropic_test.py +8 -6
  3. langfun/core/data/conversion/gemini_test.py +12 -9
  4. langfun/core/data/conversion/openai.py +134 -30
  5. langfun/core/data/conversion/openai_test.py +161 -17
  6. langfun/core/eval/v2/progress_tracking_test.py +3 -0
  7. langfun/core/langfunc_test.py +4 -2
  8. langfun/core/language_model.py +6 -6
  9. langfun/core/language_model_test.py +9 -3
  10. langfun/core/llms/__init__.py +2 -1
  11. langfun/core/llms/cache/base.py +3 -1
  12. langfun/core/llms/cache/in_memory_test.py +14 -4
  13. langfun/core/llms/deepseek.py +1 -1
  14. langfun/core/llms/groq.py +1 -1
  15. langfun/core/llms/llama_cpp.py +1 -1
  16. langfun/core/llms/openai.py +7 -2
  17. langfun/core/llms/openai_compatible.py +134 -27
  18. langfun/core/llms/openai_compatible_test.py +207 -20
  19. langfun/core/llms/openai_test.py +0 -2
  20. langfun/core/llms/vertexai.py +2 -2
  21. langfun/core/message.py +78 -44
  22. langfun/core/message_test.py +56 -81
  23. langfun/core/modalities/__init__.py +8 -0
  24. langfun/core/modalities/mime.py +9 -0
  25. langfun/core/modality.py +104 -27
  26. langfun/core/modality_test.py +42 -12
  27. langfun/core/sampling_test.py +20 -4
  28. langfun/core/structured/completion.py +2 -7
  29. langfun/core/structured/completion_test.py +23 -43
  30. langfun/core/structured/mapping.py +4 -13
  31. langfun/core/structured/querying.py +13 -11
  32. langfun/core/structured/querying_test.py +65 -29
  33. langfun/core/template.py +39 -13
  34. langfun/core/template_test.py +83 -17
  35. langfun/env/event_handlers/metric_writer_test.py +3 -3
  36. langfun/env/load_balancers_test.py +2 -2
  37. {langfun-0.1.2.dev202510240805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/METADATA +1 -1
  38. {langfun-0.1.2.dev202510240805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/RECORD +41 -41
  39. {langfun-0.1.2.dev202510240805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/WHEEL +0 -0
  40. {langfun-0.1.2.dev202510240805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/licenses/LICENSE +0 -0
  41. {langfun-0.1.2.dev202510240805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@
15
15
 
16
16
  import base64
17
17
  import functools
18
+ import hashlib
18
19
  from typing import Annotated, Any, Iterable, Type, Union
19
20
  import langfun.core as lf
20
21
  # Placeholder for Google-internal internet access import.
@@ -87,6 +88,14 @@ class Mime(lf.Modality):
87
88
  """Returns True if the MIME type is a binary type."""
88
89
  return not self.is_text
89
90
 
91
+ @property
92
+ def hash(self) -> str:
93
+ """Returns the hash of the MIME content."""
94
+ # Hash the URI to avoid downloading the content.
95
+ if self.uri is not None:
96
+ return hashlib.md5(self.uri.encode()).hexdigest()[:8]
97
+ return super().hash
98
+
90
99
  def to_text(self) -> str:
91
100
  """Returns the text content of the MIME type."""
92
101
  if not self.is_text:
langfun/core/modality.py CHANGED
@@ -14,23 +14,15 @@
14
14
  """Interface for modality (e.g. Image, Video, etc.)."""
15
15
 
16
16
  import abc
17
+ import contextlib
17
18
  import functools
18
19
  import hashlib
19
- from typing import Any, ContextManager
20
+ import re
21
+ from typing import Any, ContextManager, Iterator
20
22
  from langfun.core import component
21
23
  import pyglove as pg
22
24
 
23
25
 
24
- _TLS_MODALITY_AS_REF = '__format_modality_as_ref__'
25
-
26
-
27
- def format_modality_as_ref(enabled: bool = True) -> ContextManager[None]:
28
- """A context manager that formats modality objects as references."""
29
- return pg.object_utils.thread_local_value_scope(
30
- _TLS_MODALITY_AS_REF, enabled, False
31
- )
32
-
33
-
34
26
  class Modality(component.Component, pg.views.HtmlTreeView.Extension):
35
27
  """Base class for multimodal object."""
36
28
 
@@ -39,15 +31,18 @@ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
39
31
 
40
32
  def _on_bound(self):
41
33
  super()._on_bound()
42
- # Invalidate cached hash if modality member is changed.
34
+ # Invalidate cached hash and id if modality member is changed.
43
35
  self.__dict__.pop('hash', None)
36
+ self.__dict__.pop('id', None)
44
37
 
45
38
  def format(self, *args, **kwargs) -> str:
46
- if self.referred_name is None or not pg.object_utils.thread_local_get(
47
- _TLS_MODALITY_AS_REF, False
48
- ):
39
+ if not pg.object_utils.thread_local_get(_TLS_MODALITY_AS_REF, False):
49
40
  return super().format(*args, **kwargs)
50
- return Modality.text_marker(self.referred_name)
41
+
42
+ capture_scope = get_modality_capture_context()
43
+ if capture_scope is not None:
44
+ capture_scope.capture(self)
45
+ return Modality.text_marker(self.id)
51
46
 
52
47
  def __str_kwargs__(self) -> dict[str, Any]:
53
48
  # For modality objects, we don't want to use markdown format when they
@@ -70,14 +65,11 @@ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
70
65
  """Returns a marker in the text for this object."""
71
66
  return Modality.REF_START + var_name + Modality.REF_END
72
67
 
73
- @property
74
- def referred_name(self) -> str | None:
68
+ @functools.cached_property
69
+ def id(self) -> str | None:
75
70
  """Returns the referred name of this object in its template."""
76
- if not self.sym_path:
77
- return None
78
- # Strip the metadata prefix under message.
79
- path = str(self.sym_path)
80
- return path[9:] if path.startswith('metadata.') else path
71
+ modality_type = _camel_to_snake(self.__class__.__name__)
72
+ return f'{modality_type}:{self.hash}'
81
73
 
82
74
  @classmethod
83
75
  def from_value(cls, value: pg.Symbolic) -> dict[str, 'Modality']:
@@ -86,7 +78,7 @@ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
86
78
  def _visit(k, v, p):
87
79
  del k, p
88
80
  if isinstance(v, Modality):
89
- modalities[v.referred_name] = v
81
+ modalities[v.id] = v
90
82
  return pg.TraverseAction.CONTINUE
91
83
  return pg.TraverseAction.ENTER
92
84
 
@@ -102,7 +94,7 @@ class ModalityRef(pg.Object, pg.typing.CustomTyping):
102
94
  structure.
103
95
  """
104
96
 
105
- name: str
97
+ id: str
106
98
 
107
99
  def custom_apply(
108
100
  self, path: pg.KeyPath, value_spec: pg.ValueSpec, *args, **kwargs
@@ -122,12 +114,97 @@ class ModalityRef(pg.Object, pg.typing.CustomTyping):
122
114
  """
123
115
 
124
116
  def _placehold(k, v, p):
125
- del p
117
+ del k, p
126
118
  if isinstance(v, Modality):
127
- return ModalityRef(name=value.sym_path + k)
119
+ return ModalityRef(id=v.id)
128
120
  return v
129
121
  return value.clone().rebind(_placehold, raise_on_no_change=False)
130
122
 
123
+ @classmethod
124
+ def restore(cls, value: pg.Symbolic, modalities: dict[str, Modality]) -> Any:
125
+ """Returns a copy of value by replacing refs with modality objects."""
126
+ def _restore(k, v, p):
127
+ del k, p
128
+ if isinstance(v, ModalityRef):
129
+ modality_object = modalities.get(v.id)
130
+ if modality_object is None:
131
+ raise ValueError(
132
+ f'Modality {v.id} not found in modalities {modalities.keys()}'
133
+ )
134
+ return modality_object
135
+ return v
136
+ return value.rebind(_restore, raise_on_no_change=False)
137
+
131
138
 
132
139
  class ModalityError(RuntimeError): # pylint: disable=g-bad-exception-name
133
140
  """Exception raised when modality is not supported."""
141
+
142
+
143
+ #
144
+ # Context managers to deal with modality objects.
145
+ #
146
+
147
+
148
+ _TLS_MODALITY_CAPTURE_SCOPE = '__modality_capture_scope__'
149
+ _TLS_MODALITY_AS_REF = '__format_modality_as_ref__'
150
+
151
+
152
+ def format_modality_as_ref(enabled: bool = True) -> ContextManager[None]:
153
+ """A context manager that formats modality objects as references."""
154
+ return pg.object_utils.thread_local_value_scope(
155
+ _TLS_MODALITY_AS_REF, enabled, False
156
+ )
157
+
158
+
159
+ class _ModalityCaptureContext:
160
+ """A context to capture modality objects when being rendered."""
161
+
162
+ def __init__(self):
163
+ self._references: dict[str, pg.Ref[Modality]] = {}
164
+
165
+ def capture(self, modality: Modality) -> None:
166
+ """Captures the modality object."""
167
+ self._references[modality.id] = pg.Ref(modality)
168
+
169
+ @property
170
+ def references(self) -> dict[str, pg.Ref[Modality]]:
171
+ """Returns the modality references captured in this context."""
172
+ return self._references
173
+
174
+
175
+ @contextlib.contextmanager
176
+ def capture_rendered_modalities() -> Iterator[dict[str, pg.Ref[Modality]]]:
177
+ """Capture modality objects whose references is being rendered.
178
+
179
+ Example:
180
+ ```
181
+ image = lf.Image.from_url(...)
182
+ with lf.modality.capture_rendered_modalities() as rendered_modalities:
183
+ with lf.modality.format_modality_as_ref():
184
+ print(f'Hello {image}')
185
+ self.assertEqual(rendered_modalities, {'image:<hash>': pg.Ref(image)})
186
+ ```
187
+ """
188
+ context = get_modality_capture_context()
189
+ top_level = context is None
190
+ if top_level:
191
+ context = _ModalityCaptureContext()
192
+ pg.object_utils.thread_local_set(_TLS_MODALITY_CAPTURE_SCOPE, context)
193
+
194
+ try:
195
+ yield context.references # pylint: disable=attribute-error
196
+ finally:
197
+ if top_level:
198
+ pg.object_utils.thread_local_del(_TLS_MODALITY_CAPTURE_SCOPE)
199
+
200
+
201
+ def get_modality_capture_context() -> _ModalityCaptureContext | None:
202
+ """Returns the current modality capture context."""
203
+ return pg.object_utils.thread_local_get(_TLS_MODALITY_CAPTURE_SCOPE, None)
204
+
205
+
206
+ def _camel_to_snake(name: str) -> str:
207
+ """Converts a camelCase name to snake_case."""
208
+ return re.sub(
209
+ pattern=r'([A-Z]+)', repl=r'_\1', string=name
210
+ ).lower().lstrip('_')
@@ -29,34 +29,64 @@ class ModalityTest(unittest.TestCase):
29
29
 
30
30
  def test_basic(self):
31
31
  v = CustomModality('a')
32
- self.assertIsNone(v.referred_name)
32
+ self.assertEqual(v.id, 'custom_modality:0cc175b9')
33
33
  self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
34
34
  self.assertEqual(v.hash, '0cc175b9')
35
35
 
36
36
  _ = pg.Dict(metadata=pg.Dict(x=pg.Dict(metadata=pg.Dict(y=v))))
37
- self.assertEqual(v.referred_name, 'x.metadata.y')
37
+ self.assertEqual(v.id, 'custom_modality:0cc175b9')
38
38
  self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
39
39
  with modality.format_modality_as_ref():
40
- self.assertEqual(str(v), '<<[[x.metadata.y]]>>')
40
+ self.assertEqual(str(v), '<<[[custom_modality:0cc175b9]]>>')
41
+
42
+ def test_capture_rendered_modalities(self):
43
+ x = CustomModality('a')
44
+ y = CustomModality('b')
45
+ z = CustomModality('b')
46
+
47
+ with modality.capture_rendered_modalities() as rendered_modalities:
48
+ with modality.format_modality_as_ref():
49
+ self.assertEqual(
50
+ f'Hello {x} {y} {z}',
51
+ (
52
+ 'Hello <<[[custom_modality:0cc175b9]]>> '
53
+ '<<[[custom_modality:92eb5ffe]]>> '
54
+ '<<[[custom_modality:92eb5ffe]]>>'
55
+ )
56
+ )
57
+ self.assertEqual(len(rendered_modalities), 2)
58
+ self.assertIs(rendered_modalities['custom_modality:0cc175b9'].value, x)
59
+ # y and z share the same content will be treated as the same object.
60
+ self.assertIs(rendered_modalities['custom_modality:92eb5ffe'].value, z)
41
61
 
42
62
 
43
63
  class ModalityRefTest(unittest.TestCase):
44
64
 
45
- def test_placehold(self):
65
+ def test_placehold_and_restore(self):
46
66
  class A(pg.Object):
47
67
  x: Any
48
68
  y: Any
49
69
 
50
- a = A(x=dict(z=CustomModality('a')), y=CustomModality('b'))
70
+ image_a = CustomModality('a')
71
+ image_b = CustomModality('b')
72
+ a = A(x=dict(z=image_a), y=image_b)
73
+ a_placehold = modality.ModalityRef.placehold(a)
51
74
  self.assertEqual(
52
- modality.ModalityRef.placehold(a),
53
- A(x=dict(z=modality.ModalityRef('x.z')), y=modality.ModalityRef('y')),
75
+ a_placehold,
76
+ A(x=dict(z=modality.ModalityRef(image_a.id)),
77
+ y=modality.ModalityRef(image_b.id)),
78
+ )
79
+ a_restore = modality.ModalityRef.restore(
80
+ a_placehold.clone(),
81
+ {image_a.id: image_a, image_b.id: image_b},
54
82
  )
83
+ self.assertTrue(pg.eq(a_restore, a))
55
84
  self.assertEqual(
56
85
  modality.ModalityRef.placehold(a.x),
57
- # The prefix 'x' of referred name is preserved.
58
- dict(z=modality.ModalityRef('x.z')),
86
+ dict(z=modality.ModalityRef(image_a.id)),
59
87
  )
88
+ with self.assertRaisesRegex(ValueError, 'Modality .* not found'):
89
+ modality.ModalityRef.restore(a_placehold, {image_a.id: image_a})
60
90
 
61
91
  def test_from_value(self):
62
92
  class A(pg.Object):
@@ -68,8 +98,8 @@ class ModalityRefTest(unittest.TestCase):
68
98
  pg.eq(
69
99
  modality.Modality.from_value(a),
70
100
  {
71
- 'x.z': CustomModality('a'),
72
- 'y': CustomModality('b'),
101
+ 'custom_modality:0cc175b9': CustomModality('a'),
102
+ 'custom_modality:92eb5ffe': CustomModality('b'),
73
103
  },
74
104
  )
75
105
  )
@@ -77,7 +107,7 @@ class ModalityRefTest(unittest.TestCase):
77
107
  pg.eq(
78
108
  modality.Modality.from_value(a.x.z),
79
109
  {
80
- 'x.z': CustomModality('a'),
110
+ 'custom_modality:0cc175b9': CustomModality('a'),
81
111
  },
82
112
  )
83
113
  )
@@ -39,8 +39,13 @@ class SamplingTest(unittest.TestCase):
39
39
  l = LangFunc('Compute {{x}} and {{y}}', x=pg.oneof([1, 2]))
40
40
  with component.context(lm=ExcitedEchoer()):
41
41
  samples = list(sampling.sweep(l, y=pg.oneof([3, 4])))
42
- samples = sorted(samples, key=lambda x: (x[0].x, x[0].y))
43
-
42
+ samples = sorted(
43
+ samples,
44
+ key=lambda x: (
45
+ x[0].__template_input__.x,
46
+ x[0].__template_input__.y
47
+ )
48
+ )
44
49
  self.assertEqual(
45
50
  samples,
46
51
  [
@@ -57,7 +62,12 @@ class SamplingTest(unittest.TestCase):
57
62
  samples = list(
58
63
  sampling.random_sample(l, y=pg.oneof([2, 4]), num_examples=3, seed=1)
59
64
  )
60
- samples = sorted(samples, key=lambda x: (x[0].x, x[0].y))
65
+ samples = sorted(
66
+ samples, key=lambda x: (
67
+ x[0].__template_input__.x,
68
+ x[0].__template_input__.y
69
+ )
70
+ )
61
71
 
62
72
  self.assertEqual(
63
73
  samples,
@@ -97,7 +107,13 @@ class SamplingTest(unittest.TestCase):
97
107
  silence_on_errors=(AttributeError,),
98
108
  ignore_examples_with_errors=False))
99
109
 
100
- samples = sorted(samples, key=lambda x: (x[0].x, x[0].y))
110
+ samples = sorted(
111
+ samples,
112
+ key=lambda x: (
113
+ x[0].__template_input__.x,
114
+ x[0].__template_input__.y
115
+ )
116
+ )
101
117
  self.assertEqual(
102
118
  [x[0] for x in samples],
103
119
  [
@@ -118,13 +118,8 @@ class _CompleteStructure(mapping.Mapping):
118
118
  def postprocess_result(self, result: Any) -> Any:
119
119
  """Postprocess result."""
120
120
  # Try restore modality objects from the input value to output value.
121
- modalities = self.modalities(self.input)
122
- if modalities:
123
- # Remove the `input` prefix for all entries.
124
- modalities = pg.object_utils.flatten(
125
- pg.object_utils.canonicalize(modalities)['input']
126
- )
127
- result.rebind(modalities)
121
+ if modalities := self.modalities(self.input):
122
+ result = lf.ModalityRef.restore(result, modalities)
128
123
  return result
129
124
 
130
125
  def globals(self):
@@ -407,22 +407,17 @@ class CompleteStructureTest(unittest.TestCase):
407
407
  image: modalities.Image
408
408
  name: str
409
409
 
410
+ image_elephant = modalities.Image.from_bytes(b'image_of_elephant')
411
+ image_rabbit = modalities.Image.from_bytes(b'image_of_rabbit')
410
412
  input_value = schema_lib.mark_missing(
411
- Animal.partial(
412
- modalities.Image.from_bytes(b'image_of_elephant'),
413
- )
413
+ Animal.partial(image_elephant)
414
414
  )
415
415
  l = completion._CompleteStructure(
416
416
  input=input_value,
417
417
  examples=[
418
418
  mapping.MappingExample(
419
- input=Animal.partial(
420
- modalities.Image.from_bytes(b'image_of_rabbit')
421
- ),
422
- output=Animal(
423
- modalities.Image.from_bytes(b'image_of_rabbit'),
424
- 'rabbit',
425
- ),
419
+ input=Animal.partial(image_rabbit),
420
+ output=Animal(image_rabbit, 'rabbit'),
426
421
  )
427
422
  ],
428
423
  )
@@ -430,7 +425,7 @@ class CompleteStructureTest(unittest.TestCase):
430
425
  self.maxDiff = None
431
426
  self.assertEqual(
432
427
  lm_input.text,
433
- inspect.cleandoc("""
428
+ inspect.cleandoc(f"""
434
429
  Please generate the OUTPUT_OBJECT by completing the MISSING fields from the last INPUT_OBJECT.
435
430
 
436
431
  INSTRUCTIONS:
@@ -457,22 +452,22 @@ class CompleteStructureTest(unittest.TestCase):
457
452
  ```python
458
453
  Animal(
459
454
  image=ModalityRef(
460
- name='examples[0].input.image'
455
+ id='{image_rabbit.id}'
461
456
  ),
462
457
  name=MISSING(str)
463
458
  )
464
459
  ```
465
460
 
466
461
  MODALITY_REFERENCES:
467
- {
468
- 'examples[0].input.image': <<[[examples[0].input.image]]>>
469
- }
462
+ {{
463
+ '{image_rabbit.id}': <<[[{image_rabbit.id}]]>>
464
+ }}
470
465
 
471
466
  OUTPUT_OBJECT:
472
467
  ```python
473
468
  Animal(
474
469
  image=ModalityRef(
475
- name='examples[0].output.image'
470
+ id='{image_rabbit.id}'
476
471
  ),
477
472
  name='rabbit'
478
473
  )
@@ -483,16 +478,16 @@ class CompleteStructureTest(unittest.TestCase):
483
478
  ```python
484
479
  Animal(
485
480
  image=ModalityRef(
486
- name='input.image'
481
+ id='{image_elephant.id}'
487
482
  ),
488
483
  name=MISSING(str)
489
484
  )
490
485
  ```
491
486
 
492
487
  MODALITY_REFERENCES:
493
- {
494
- 'input.image': <<[[input.image]]>>
495
- }
488
+ {{
489
+ '{image_elephant.id}': <<[[{image_elephant.id}]]>>
490
+ }}
496
491
 
497
492
  OUTPUT_OBJECT:
498
493
  """),
@@ -500,39 +495,27 @@ class CompleteStructureTest(unittest.TestCase):
500
495
  self.assertTrue(
501
496
  pg.eq(
502
497
  {
503
- 'examples': lm_input.get('examples'),
504
- 'input': lm_input.get('input'),
498
+ 'examples': lm_input.__template_input__.examples,
499
+ 'input': lm_input.__template_input__.mapping_request.input,
505
500
  },
506
501
  {
507
502
  'examples': [
508
503
  mapping.MappingExample(
509
- input=Animal.partial(
510
- image=modalities.Image.from_bytes(
511
- b'image_of_rabbit'
512
- )
513
- ),
514
- output=Animal.partial(
515
- image=modalities.Image.from_bytes(
516
- b'image_of_rabbit'
517
- ),
518
- name='rabbit',
519
- ),
504
+ input=Animal.partial(image_rabbit),
505
+ output=Animal.partial(image_rabbit, 'rabbit'),
520
506
  )
521
507
  ],
522
- 'input': Animal(
523
- image=modalities.Image.from_bytes(b'image_of_elephant'),
524
- name=schema_lib.MISSING,
525
- ),
508
+ 'input': Animal(image_elephant, name=schema_lib.MISSING),
526
509
  },
527
510
  )
528
511
  )
529
512
  lm_output = l(
530
513
  input=input_value,
531
- lm=fake.StaticResponse(inspect.cleandoc("""
514
+ lm=fake.StaticResponse(inspect.cleandoc(f"""
532
515
  ```python
533
516
  Animal(
534
517
  image=ModalityRef(
535
- name='input.image'
518
+ id='{image_elephant.id}'
536
519
  ),
537
520
  name='elephant'
538
521
  )
@@ -542,10 +525,7 @@ class CompleteStructureTest(unittest.TestCase):
542
525
  self.assertTrue(
543
526
  pg.eq(
544
527
  lm_output.result,
545
- Animal(
546
- image=modalities.Image.from_bytes(b'image_of_elephant'),
547
- name='elephant',
548
- ),
528
+ Animal(image=image_elephant, name='elephant'),
549
529
  )
550
530
  )
551
531
 
@@ -127,6 +127,8 @@ class MappingExample(lf.NaturalLanguageFormattable,
127
127
  ) -> str:
128
128
  if isinstance(value, str):
129
129
  return value
130
+ if isinstance(value, lf.Message):
131
+ return str(value)
130
132
  if isinstance(value, lf.Modality):
131
133
  with lf.modality.format_modality_as_ref():
132
134
  return str(value)
@@ -192,9 +194,7 @@ class MappingExample(lf.NaturalLanguageFormattable,
192
194
 
193
195
  def render_value(view, *, value, **kwargs):
194
196
  if isinstance(value, lf.Template):
195
- # Make a shallow copy to make sure modalities are rooted by
196
- # the input.
197
- value = value.clone().render()
197
+ value = value.render()
198
198
  if value is None:
199
199
  return None
200
200
  return view.render(value, **kwargs)
@@ -286,12 +286,8 @@ class Mapping(lf.LangFunc):
286
286
  @property
287
287
  def mapping_request(self) -> MappingExample:
288
288
  """Returns a MappingExample as the mapping request."""
289
- if isinstance(self.input, lf.Message):
290
- input_value = self.input.text
291
- else:
292
- input_value = pg.Ref(self.input)
293
289
  return MappingExample(
294
- input=input_value,
290
+ input=pg.Ref(self.input),
295
291
  schema=pg.Ref(self.schema),
296
292
  context=self.context,
297
293
  )
@@ -402,11 +398,6 @@ class Mapping(lf.LangFunc):
402
398
 
403
399
  def transform_input(self, lm_input: lf.Message) -> lf.Message:
404
400
  # Find modalities to fill the input message.
405
- lm_input.metadata.update(
406
- examples=pg.Ref(self.examples),
407
- input=pg.Ref(self.input),
408
- schema=pg.Ref(self.schema) if self.schema is not None else None,
409
- )
410
401
  if isinstance(self.input, lf.Message):
411
402
  lm_input.source = self.input
412
403
  return lm_input
@@ -529,24 +529,22 @@ def query(
529
529
  ).render(message_cls=lf.SystemMessage)
530
530
 
531
531
  # Normalize query input.
532
- if isinstance(prompt, (lf.Message, str)):
532
+ if isinstance(prompt, str):
533
533
  # Query with structured output.
534
534
  prompt_kwargs = kwargs.copy()
535
535
  prompt_kwargs.pop('template_str', None)
536
536
  query_input = lf.Template.from_value(prompt, **prompt_kwargs)
537
+ elif isinstance(prompt, lf.Message):
538
+ query_input = prompt
537
539
  elif isinstance(prompt, lf.Template):
538
- # Create a copy of the prompt if it has a parent object, so all child
539
- # modality objects could be referred by path relative to the prompt.
540
- query_input = prompt.clone() if prompt.sym_parent is not None else prompt
541
-
542
540
  # Attach template metadata from kwargs. This is used to pass through fields
543
541
  # from kwargs to the rendered message.
544
- template_metadata = {
545
- k: v for k, v in kwargs.items() if k.startswith('metadata_')
546
- }
547
- query_input.rebind(
548
- template_metadata, skip_notification=True, raise_on_no_change=False
542
+ prompt.rebind(
543
+ {k: v for k, v in kwargs.items() if k.startswith('metadata_')},
544
+ skip_notification=True,
545
+ raise_on_no_change=False
549
546
  )
547
+ query_input = prompt
550
548
  elif pg.MISSING_VALUE == prompt:
551
549
  query_input = lf.UserMessage('')
552
550
  else:
@@ -665,7 +663,11 @@ def query(
665
663
 
666
664
  if returns_message:
667
665
  return output_message
668
- return output_message.text if schema in (None, str) else output_message.result
666
+ if schema not in (None, str):
667
+ return output_message.result
668
+ if returns_message or output_message.referred_modalities:
669
+ return output_message
670
+ return output_message.text
669
671
 
670
672
 
671
673
  async def aquery(