langfun 0.1.2.dev202510230805__py3-none-any.whl → 0.1.2.dev202510250803__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (44) hide show
  1. langfun/core/concurrent_test.py +1 -0
  2. langfun/core/data/conversion/anthropic_test.py +8 -6
  3. langfun/core/data/conversion/gemini_test.py +12 -9
  4. langfun/core/data/conversion/openai.py +134 -30
  5. langfun/core/data/conversion/openai_test.py +161 -17
  6. langfun/core/eval/base_test.py +4 -4
  7. langfun/core/eval/v2/progress_tracking_test.py +3 -0
  8. langfun/core/langfunc_test.py +6 -4
  9. langfun/core/language_model.py +15 -6
  10. langfun/core/language_model_test.py +9 -3
  11. langfun/core/llms/__init__.py +7 -1
  12. langfun/core/llms/anthropic.py +130 -0
  13. langfun/core/llms/cache/base.py +3 -1
  14. langfun/core/llms/cache/in_memory_test.py +14 -4
  15. langfun/core/llms/deepseek.py +1 -1
  16. langfun/core/llms/gemini.py +2 -5
  17. langfun/core/llms/groq.py +1 -1
  18. langfun/core/llms/llama_cpp.py +1 -1
  19. langfun/core/llms/openai.py +7 -2
  20. langfun/core/llms/openai_compatible.py +136 -27
  21. langfun/core/llms/openai_compatible_test.py +207 -20
  22. langfun/core/llms/openai_test.py +0 -2
  23. langfun/core/llms/vertexai.py +12 -2
  24. langfun/core/message.py +78 -44
  25. langfun/core/message_test.py +56 -81
  26. langfun/core/modalities/__init__.py +8 -0
  27. langfun/core/modalities/mime.py +9 -0
  28. langfun/core/modality.py +104 -27
  29. langfun/core/modality_test.py +42 -12
  30. langfun/core/sampling_test.py +20 -4
  31. langfun/core/structured/completion.py +2 -7
  32. langfun/core/structured/completion_test.py +23 -43
  33. langfun/core/structured/mapping.py +4 -13
  34. langfun/core/structured/querying.py +13 -11
  35. langfun/core/structured/querying_test.py +65 -29
  36. langfun/core/template.py +39 -13
  37. langfun/core/template_test.py +83 -17
  38. langfun/env/event_handlers/metric_writer_test.py +3 -3
  39. langfun/env/load_balancers_test.py +2 -2
  40. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/METADATA +1 -1
  41. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/RECORD +44 -44
  42. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/WHEEL +0 -0
  43. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/licenses/LICENSE +0 -0
  44. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/top_level.txt +0 -0
@@ -38,7 +38,7 @@ def mock_chat_completion_request(url: str, json: dict[str, Any], **kwargs):
38
38
  response_format = ''
39
39
 
40
40
  choices = []
41
- for k in range(json['n']):
41
+ for k in range(json.get('n', 1)):
42
42
  if json.get('logprobs'):
43
43
  logprobs = dict(
44
44
  content=[
@@ -89,7 +89,7 @@ def mock_chat_completion_request_vision(
89
89
  c['image_url']['url']
90
90
  for c in json['messages'][0]['content'] if c['type'] == 'image_url'
91
91
  ]
92
- for k in range(json['n']):
92
+ for k in range(json.get('n', 1)):
93
93
  choices.append(pg.Dict(
94
94
  message=pg.Dict(
95
95
  content=f'Sample {k} for message: {"".join(urls)}'
@@ -111,12 +111,88 @@ def mock_chat_completion_request_vision(
111
111
  return response
112
112
 
113
113
 
114
- class OpenAIComptibleTest(unittest.TestCase):
114
+ def mock_responses_request(url: str, json: dict[str, Any], **kwargs):
115
+ del url, kwargs
116
+ _ = json['input']
117
+
118
+ system_message = ''
119
+ if 'instructions' in json:
120
+ system_message = f' system={json["instructions"]}'
121
+
122
+ response_format = ''
123
+ if 'text' in json and 'format' in json['text']:
124
+ response_format = f' format={json["text"]["format"]["type"]}'
125
+
126
+ output = [
127
+ dict(
128
+ type='message',
129
+ content=[
130
+ dict(
131
+ type='output_text',
132
+ text=(
133
+ f'Sample 0 for message.{system_message}{response_format}'
134
+ )
135
+ )
136
+ ],
137
+ )
138
+ ]
139
+
140
+ response = requests.Response()
141
+ response.status_code = 200
142
+ response._content = pg.to_json_str(
143
+ dict(
144
+ output=output,
145
+ usage=dict(
146
+ input_tokens=100,
147
+ output_tokens=100,
148
+ total_tokens=200,
149
+ ),
150
+ )
151
+ ).encode()
152
+ return response
153
+
154
+
155
+ def mock_responses_request_vision(
156
+ url: str, json: dict[str, Any], **kwargs
157
+ ):
158
+ del url, kwargs
159
+ urls = [
160
+ c['image_url']
161
+ for c in json['input'][0]['content']
162
+ if c['type'] == 'input_image'
163
+ ]
164
+ output = [
165
+ pg.Dict(
166
+ type='message',
167
+ content=[
168
+ pg.Dict(
169
+ type='output_text',
170
+ text=f'Sample 0 for message: {"".join(urls)}',
171
+ )
172
+ ],
173
+ )
174
+ ]
175
+ response = requests.Response()
176
+ response.status_code = 200
177
+ response._content = pg.to_json_str(
178
+ dict(
179
+ output=output,
180
+ usage=dict(
181
+ input_tokens=100,
182
+ output_tokens=100,
183
+ total_tokens=200,
184
+ ),
185
+ )
186
+ ).encode()
187
+ return response
188
+
189
+
190
+ class OpenAIChatCompletionAPITest(unittest.TestCase):
115
191
  """Tests for OpenAI compatible language model."""
116
192
 
117
193
  def test_request_args(self):
118
194
  self.assertEqual(
119
- openai_compatible.OpenAICompatible(
195
+ openai_compatible.OpenAIChatCompletionAPI(
120
196
  api_endpoint='https://test-server',
121
197
  model='test-model'
122
198
  )._request_args(
@@ -126,8 +202,6 @@ class OpenAIComptibleTest(unittest.TestCase):
126
202
  ),
127
203
  dict(
128
204
  model='test-model',
129
- top_logprobs=None,
130
- n=1,
131
205
  temperature=1.0,
132
206
  stop=['\n'],
133
207
  seed=123,
@@ -137,7 +211,7 @@ class OpenAIComptibleTest(unittest.TestCase):
137
211
  def test_call_chat_completion(self):
138
212
  with mock.patch('requests.Session.post') as mock_request:
139
213
  mock_request.side_effect = mock_chat_completion_request
140
- lm = openai_compatible.OpenAICompatible(
214
+ lm = openai_compatible.OpenAIChatCompletionAPI(
141
215
  api_endpoint='https://test-server', model='test-model',
142
216
  )
143
217
  self.assertEqual(
@@ -148,7 +222,7 @@ class OpenAIComptibleTest(unittest.TestCase):
148
222
  def test_call_chat_completion_with_logprobs(self):
149
223
  with mock.patch('requests.Session.post') as mock_request:
150
224
  mock_request.side_effect = mock_chat_completion_request
151
- lm = openai_compatible.OpenAICompatible(
225
+ lm = openai_compatible.OpenAIChatCompletionAPI(
152
226
  api_endpoint='https://test-server', model='test-model',
153
227
  )
154
228
  results = lm.sample(['hello'], logprobs=True)
@@ -214,13 +288,14 @@ class OpenAIComptibleTest(unittest.TestCase):
214
288
  def mime_type(self) -> str:
215
289
  return 'image/png'
216
290
 
291
+ image = FakeImage.from_uri('https://fake/image')
217
292
  with mock.patch('requests.Session.post') as mock_request:
218
293
  mock_request.side_effect = mock_chat_completion_request_vision
219
- lm_1 = openai_compatible.OpenAICompatible(
294
+ lm_1 = openai_compatible.OpenAIChatCompletionAPI(
220
295
  api_endpoint='https://test-server',
221
296
  model='test-model1',
222
297
  )
223
- lm_2 = openai_compatible.OpenAICompatible(
298
+ lm_2 = openai_compatible.OpenAIChatCompletionAPI(
224
299
  api_endpoint='https://test-server',
225
300
  model='test-model2',
226
301
  )
@@ -228,15 +303,15 @@ class OpenAIComptibleTest(unittest.TestCase):
228
303
  self.assertEqual(
229
304
  lm(
230
305
  lf.UserMessage(
231
- 'hello <<[[image]]>>',
232
- image=FakeImage.from_uri('https://fake/image')
306
+ f'hello <<[[{image.id}]]>>',
307
+ referred_modalities=[image],
233
308
  ),
234
309
  sampling_options=lf.LMSamplingOptions(n=2)
235
310
  ),
236
311
  'Sample 0 for message: https://fake/image',
237
312
  )
238
313
 
239
- class TextOnlyModel(openai_compatible.OpenAICompatible):
314
+ class TextOnlyModel(openai_compatible.OpenAIChatCompletionAPI):
240
315
 
241
316
  class ModelInfo(lf.ModelInfo):
242
317
  input_modalities: list[str] = lf.ModelInfo.TEXT_INPUT_ONLY
@@ -251,15 +326,15 @@ class OpenAIComptibleTest(unittest.TestCase):
251
326
  with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
252
327
  lm_3(
253
328
  lf.UserMessage(
254
- 'hello <<[[image]]>>',
255
- image=FakeImage.from_uri('https://fake/image')
329
+ f'hello <<[[{image.id}]]>>',
330
+ referred_modalities=[image],
256
331
  ),
257
332
  )
258
333
 
259
334
  def test_sample_chat_completion(self):
260
335
  with mock.patch('requests.Session.post') as mock_request:
261
336
  mock_request.side_effect = mock_chat_completion_request
262
- lm = openai_compatible.OpenAICompatible(
337
+ lm = openai_compatible.OpenAIChatCompletionAPI(
263
338
  api_endpoint='https://test-server', model='test-model'
264
339
  )
265
340
  results = lm.sample(
@@ -400,7 +475,7 @@ class OpenAIComptibleTest(unittest.TestCase):
400
475
  def test_sample_with_contextual_options(self):
401
476
  with mock.patch('requests.Session.post') as mock_request:
402
477
  mock_request.side_effect = mock_chat_completion_request
403
- lm = openai_compatible.OpenAICompatible(
478
+ lm = openai_compatible.OpenAIChatCompletionAPI(
404
479
  api_endpoint='https://test-server', model='test-model'
405
480
  )
406
481
  with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)):
@@ -458,7 +533,7 @@ class OpenAIComptibleTest(unittest.TestCase):
458
533
  def test_call_with_system_message(self):
459
534
  with mock.patch('requests.Session.post') as mock_request:
460
535
  mock_request.side_effect = mock_chat_completion_request
461
- lm = openai_compatible.OpenAICompatible(
536
+ lm = openai_compatible.OpenAIChatCompletionAPI(
462
537
  api_endpoint='https://test-server', model='test-model'
463
538
  )
464
539
  self.assertEqual(
@@ -475,7 +550,7 @@ class OpenAIComptibleTest(unittest.TestCase):
475
550
  def test_call_with_json_schema(self):
476
551
  with mock.patch('requests.Session.post') as mock_request:
477
552
  mock_request.side_effect = mock_chat_completion_request
478
- lm = openai_compatible.OpenAICompatible(
553
+ lm = openai_compatible.OpenAIChatCompletionAPI(
479
554
  api_endpoint='https://test-server', model='test-model'
480
555
  )
481
556
  self.assertEqual(
@@ -515,7 +590,7 @@ class OpenAIComptibleTest(unittest.TestCase):
515
590
 
516
591
  with mock.patch('requests.Session.post') as mock_request:
517
592
  mock_request.side_effect = mock_context_limit_error
518
- lm = openai_compatible.OpenAICompatible(
593
+ lm = openai_compatible.OpenAIChatCompletionAPI(
519
594
  api_endpoint='https://test-server', model='test-model'
520
595
  )
521
596
  with self.assertRaisesRegex(
@@ -524,5 +599,117 @@ class OpenAIComptibleTest(unittest.TestCase):
524
599
  lm(lf.UserMessage('hello'))
525
600
 
526
601
 
602
+ class OpenAIResponsesAPITest(unittest.TestCase):
603
+ """Tests for OpenAI compatible language model on Responses API."""
604
+
605
+ def test_request_args(self):
606
+ lm = openai_compatible.OpenAIResponsesAPI(
607
+ api_endpoint='https://test-server', model='test-model'
608
+ )
609
+ # Test valid args.
610
+ self.assertEqual(
611
+ lm._request_args(
612
+ lf.LMSamplingOptions(
613
+ temperature=1.0, stop=['\n'], n=1, random_seed=123
614
+ )
615
+ ),
616
+ dict(
617
+ model='test-model',
618
+ temperature=1.0,
619
+ stop=['\n'],
620
+ seed=123,
621
+ ),
622
+ )
623
+ # Test unsupported n.
624
+ with self.assertRaisesRegex(ValueError, 'n must be 1 for Responses API.'):
625
+ lm._request_args(lf.LMSamplingOptions(n=2))
626
+
627
+ # Test unsupported logprobs.
628
+ with self.assertRaisesRegex(
629
+ ValueError, 'logprobs is not supported on Responses API.'
630
+ ):
631
+ lm._request_args(lf.LMSamplingOptions(logprobs=True))
632
+
633
+ def test_call_responses(self):
634
+ with mock.patch('requests.Session.post') as mock_request:
635
+ mock_request.side_effect = mock_responses_request
636
+ lm = openai_compatible.OpenAIResponsesAPI(
637
+ api_endpoint='https://test-server',
638
+ model='test-model',
639
+ )
640
+ self.assertEqual(lm('hello'), 'Sample 0 for message.')
641
+
642
+ def test_call_responses_vision(self):
643
+ class FakeImage(lf_modalities.Image):
644
+ @property
645
+ def mime_type(self) -> str:
646
+ return 'image/png'
647
+
648
+ image = FakeImage.from_uri('https://fake/image')
649
+ with mock.patch('requests.Session.post') as mock_request:
650
+ mock_request.side_effect = mock_responses_request_vision
651
+ lm = openai_compatible.OpenAIResponsesAPI(
652
+ api_endpoint='https://test-server',
653
+ model='test-model1',
654
+ )
655
+ self.assertEqual(
656
+ lm(
657
+ lf.UserMessage(
658
+ f'hello <<[[{image.id}]]>>',
659
+ referred_modalities=[image],
660
+ )
661
+ ),
662
+ 'Sample 0 for message: https://fake/image',
663
+ )
664
+
665
+ def test_call_with_system_message(self):
666
+ with mock.patch('requests.Session.post') as mock_request:
667
+ mock_request.side_effect = mock_responses_request
668
+ lm = openai_compatible.OpenAIResponsesAPI(
669
+ api_endpoint='https://test-server', model='test-model'
670
+ )
671
+ self.assertEqual(
672
+ lm(
673
+ lf.UserMessage(
674
+ 'hello',
675
+ system_message=lf.SystemMessage('hi'),
676
+ )
677
+ ),
678
+ 'Sample 0 for message. system=hi',
679
+ )
680
+
681
+ def test_call_with_json_schema(self):
682
+ with mock.patch('requests.Session.post') as mock_request:
683
+ mock_request.side_effect = mock_responses_request
684
+ lm = openai_compatible.OpenAIResponsesAPI(
685
+ api_endpoint='https://test-server', model='test-model'
686
+ )
687
+ self.assertEqual(
688
+ lm(
689
+ lf.UserMessage(
690
+ 'hello',
691
+ json_schema={
692
+ 'type': 'object',
693
+ 'properties': {
694
+ 'name': {'type': 'string'},
695
+ },
696
+ 'required': ['name'],
697
+ 'title': 'Person',
698
+ },
699
+ )
700
+ ),
701
+ 'Sample 0 for message. format=json_schema',
702
+ )
703
+
704
+ # Test bad json schema.
705
+ with self.assertRaisesRegex(ValueError, '`json_schema` must be a dict'):
706
+ lm(lf.UserMessage('hello', json_schema='foo'))
707
+
708
+ with self.assertRaisesRegex(
709
+ ValueError, 'The root of `json_schema` must have a `title` field'
710
+ ):
711
+ lm(lf.UserMessage('hello', json_schema={}))
712
+
713
+
527
714
  if __name__ == '__main__':
528
715
  unittest.main()
@@ -61,8 +61,6 @@ class OpenAITest(unittest.TestCase):
61
61
  ),
62
62
  dict(
63
63
  model='gpt-4',
64
- top_logprobs=None,
65
- n=1,
66
64
  temperature=1.0,
67
65
  stop=['\n'],
68
66
  seed=123,
@@ -369,6 +369,16 @@ class VertexAIAnthropic(VertexAI, anthropic.Anthropic):
369
369
  # pylint: disable=invalid-name
370
370
 
371
371
 
372
+ class VertexAIClaude45Haiku_20251001(VertexAIAnthropic):
373
+ """Anthropic's Claude 4.5 Haiku model on VertexAI."""
374
+ model = 'claude-haiku-4-5@20251001'
375
+
376
+
377
+ class VertexAIClaude45Sonnet_20250929(VertexAIAnthropic):
378
+ """Anthropic's Claude 4.5 Sonnet model on VertexAI."""
379
+ model = 'claude-sonnet-4-5@20250929'
380
+
381
+
372
382
  class VertexAIClaude4Opus_20250514(VertexAIAnthropic):
373
383
  """Anthropic's Claude 4 Opus model on VertexAI."""
374
384
  model = 'claude-opus-4@20250514'
@@ -487,7 +497,7 @@ _LLAMA_MODELS_BY_MODEL_ID = {m.model_id: m for m in LLAMA_MODELS}
487
497
 
488
498
  @pg.use_init_args(['model'])
489
499
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
490
- class VertexAILlama(VertexAI, openai_compatible.OpenAICompatible):
500
+ class VertexAILlama(VertexAI, openai_compatible.OpenAIChatCompletionAPI):
491
501
  """Llama models on VertexAI."""
492
502
 
493
503
  model: pg.typing.Annotated[
@@ -600,7 +610,7 @@ _MISTRAL_MODELS_BY_MODEL_ID = {m.model_id: m for m in MISTRAL_MODELS}
600
610
 
601
611
  @pg.use_init_args(['model'])
602
612
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
603
- class VertexAIMistral(VertexAI, openai_compatible.OpenAICompatible):
613
+ class VertexAIMistral(VertexAI, openai_compatible.OpenAIChatCompletionAPI):
604
614
  """Mistral AI models on VertexAI."""
605
615
 
606
616
  model: pg.typing.Annotated[
langfun/core/message.py CHANGED
@@ -20,7 +20,7 @@ import contextlib
20
20
  import functools
21
21
  import inspect
22
22
  import io
23
- from typing import Annotated, Any, ClassVar, Optional, Type, Union
23
+ from typing import Annotated, Any, Callable, ClassVar, Optional, Type, Union
24
24
 
25
25
  from langfun.core import modality
26
26
  from langfun.core import natural_language
@@ -86,6 +86,11 @@ class Message(
86
86
 
87
87
  sender: Annotated[str, 'The sender of the message.']
88
88
 
89
+ referred_modalities: Annotated[
90
+ dict[str, pg.Ref[modality.Modality]],
91
+ 'The modality objects referred in the message.'
92
+ ] = pg.Dict()
93
+
89
94
  metadata: Annotated[
90
95
  dict[str, Any],
91
96
  (
@@ -111,6 +116,11 @@ class Message(
111
116
  *,
112
117
  # Default sender is specified in subclasses.
113
118
  sender: str | pg.object_utils.MissingValue = pg.MISSING_VALUE,
119
+ referred_modalities: (
120
+ list[modality.Modality]
121
+ | dict[str, modality.Modality]
122
+ | None
123
+ ) = None,
114
124
  metadata: dict[str, Any] | None = None,
115
125
  tags: list[str] | None = None,
116
126
  source: Optional['Message'] = None,
@@ -125,6 +135,7 @@ class Message(
125
135
  Args:
126
136
  text: The text in the message.
127
137
  sender: The sender name of the message.
138
+ referred_modalities: The modality objects referred in the message.
128
139
  metadata: Structured meta-data associated with this message.
129
140
  tags: Tags for the message.
130
141
  source: The source message of the current message.
@@ -138,9 +149,13 @@ class Message(
138
149
  """
139
150
  metadata = metadata or {}
140
151
  metadata.update(kwargs)
152
+ if isinstance(referred_modalities, list):
153
+ referred_modalities = {m.id: pg.Ref(m) for m in referred_modalities}
154
+
141
155
  super().__init__(
142
156
  text=text,
143
157
  metadata=metadata,
158
+ referred_modalities=referred_modalities or {},
144
159
  tags=tags or [],
145
160
  sender=sender,
146
161
  allow_partial=allow_partial,
@@ -186,7 +201,7 @@ class Message(
186
201
  A message created from the value.
187
202
  """
188
203
  if isinstance(value, modality.Modality):
189
- return cls('<<[[object]]>>', object=value)
204
+ return cls(f'<<[[{value.id}]]>>', referred_modalities=[value])
190
205
  if isinstance(value, Message):
191
206
  return value
192
207
  if isinstance(value, str):
@@ -280,8 +295,7 @@ class Message(
280
295
  if key_path == Message.PATH_TEXT:
281
296
  return self.text
282
297
  else:
283
- v = self.metadata.sym_get(key_path, default, use_inferred=True)
284
- return v.value if isinstance(v, pg.Ref) else v
298
+ return self.metadata.sym_get(key_path, default, use_inferred=True)
285
299
 
286
300
  #
287
301
  # API for accessing the structured result and error.
@@ -361,43 +375,53 @@ class Message(
361
375
  # API for supporting modalities.
362
376
  #
363
377
 
378
+ def modalities(
379
+ self,
380
+ filter: ( # pylint: disable=redefined-builtin
381
+ Type[modality.Modality]
382
+ | Callable[[modality.Modality], bool]
383
+ | None
384
+ ) = None # pylint: disable=bad-whitespace
385
+ ) -> list[modality.Modality]:
386
+ """Returns the modality objects referred in the message."""
387
+ if inspect.isclass(filter) and issubclass(filter, modality.Modality):
388
+ filter_fn = lambda v: isinstance(v, filter) # pytype: disable=wrong-arg-types
389
+ elif filter is None:
390
+ filter_fn = lambda v: True
391
+ else:
392
+ filter_fn = filter
393
+ return [v for v in self.referred_modalities.values() if filter_fn(v)]
394
+
364
395
  @property
365
- def text_with_modality_hash(self) -> str:
366
- """Returns text with modality object placeheld by their 8-byte MD5 hash."""
367
- parts = [self.text]
368
- for name, modality_obj in self.referred_modalities().items():
369
- parts.append(
370
- f'<{name}>{modality_obj.hash}</{name}>'
371
- )
372
- return ''.join(parts)
396
+ def images(self) -> list[modality.Modality]:
397
+ """Returns the image objects referred in the message."""
398
+ assert False, 'Overridden in core/modalities/__init__.py'
399
+
400
+ @property
401
+ def videos(self) -> list[modality.Modality]:
402
+ """Returns the video objects referred in the message."""
403
+ assert False, 'Overridden in core/modalities/__init__.py'
404
+
405
+ @property
406
+ def audios(self) -> list[modality.Modality]:
407
+ """Returns the audio objects referred in the message."""
408
+ assert False, 'Overridden in core/modalities/__init__.py'
373
409
 
374
410
  def get_modality(
375
- self, var_name: str, default: Any = None, from_message_chain: bool = True
411
+ self,
412
+ var_name: str,
413
+ default: Any = None
376
414
  ) -> modality.Modality | None:
377
415
  """Gets the modality object referred in the message.
378
416
 
379
417
  Args:
380
418
  var_name: The referred variable name for the modality object.
381
419
  default: default value.
382
- from_message_chain: If True, the look up will be performed from the
383
- message chain. Otherwise it will be performed in current message.
384
420
 
385
421
  Returns:
386
422
  A modality object if found, otherwise None.
387
423
  """
388
- obj = self.get(var_name, None)
389
- if isinstance(obj, modality.Modality):
390
- return obj
391
- elif obj is None and self.source is not None:
392
- return self.source.get_modality(var_name, default, from_message_chain)
393
- return default
394
-
395
- def referred_modalities(self) -> dict[str, modality.Modality]:
396
- """Returns modality objects attached on this message."""
397
- chunks = self.chunk()
398
- return {
399
- m.referred_name: m for m in chunks if isinstance(m, modality.Modality)
400
- }
424
+ return self.referred_modalities.get(var_name, default)
401
425
 
402
426
  def chunk(self, text: str | None = None) -> list[str | modality.Modality]:
403
427
  """Chunk a message into a list of str or modality objects."""
@@ -425,10 +449,15 @@ class Message(
425
449
 
426
450
  var_name = text[var_start:ref_end].strip()
427
451
  var_value = self.get_modality(var_name)
428
- if var_value is not None:
429
- add_text_chunk(text[chunk_start:ref_start].strip(' '))
430
- chunks.append(var_value)
431
- chunk_start = ref_end + len(modality.Modality.REF_END)
452
+ if var_value is None:
453
+ raise ValueError(
454
+ f'Unknown modality reference: {var_name!r}. '
455
+ 'Please make sure the modality object is present in '
456
+ f'`referred_modalities` when creating {self.__class__.__name__}.'
457
+ )
458
+ add_text_chunk(text[chunk_start:ref_start].strip(' '))
459
+ chunks.append(var_value)
460
+ chunk_start = ref_end + len(modality.Modality.REF_END)
432
461
  return chunks
433
462
 
434
463
  @classmethod
@@ -437,8 +466,8 @@ class Message(
437
466
  ) -> 'Message':
438
467
  """Assembly a message from a list of string or modality objects."""
439
468
  fused_text = io.StringIO()
440
- ref_index = 0
441
469
  metadata = dict()
470
+ referred_modalities = dict()
442
471
  last_char = None
443
472
  for i, chunk in enumerate(chunks):
444
473
  if i > 0 and last_char not in ('\t', ' ', '\n', None):
@@ -451,14 +480,16 @@ class Message(
451
480
  last_char = None
452
481
  else:
453
482
  assert isinstance(chunk, modality.Modality), chunk
454
- var_name = f'obj{ref_index}'
455
- fused_text.write(modality.Modality.text_marker(var_name))
483
+ fused_text.write(modality.Modality.text_marker(chunk.id))
456
484
  last_char = modality.Modality.REF_END[-1]
457
485
  # Make a reference if the chunk is already owned by another object
458
486
  # to avoid copy.
459
- metadata[var_name] = pg.maybe_ref(chunk)
460
- ref_index += 1
461
- return cls(fused_text.getvalue().strip(), metadata=metadata)
487
+ referred_modalities[chunk.id] = pg.Ref(chunk)
488
+ return cls(
489
+ fused_text.getvalue().strip(),
490
+ referred_modalities=referred_modalities,
491
+ metadata=metadata,
492
+ )
462
493
 
463
494
  #
464
495
  # Tagging
@@ -551,6 +582,11 @@ class Message(
551
582
  #
552
583
 
553
584
  def natural_language_format(self) -> str:
585
+ """Returns the natural language format representation."""
586
+ # Propagate the modality references to parent context if any.
587
+ if capture_context := modality.get_modality_capture_context():
588
+ for v in self.referred_modalities.values():
589
+ capture_context.capture(v)
554
590
  return self.text
555
591
 
556
592
  def __eq__(self, other: Any) -> bool:
@@ -568,8 +604,7 @@ class Message(
568
604
  def __getattr__(self, key: str) -> Any:
569
605
  if key not in self.metadata:
570
606
  raise AttributeError(key)
571
- v = self.metadata[key]
572
- return v.value if isinstance(v, pg.Ref) else v
607
+ return self.metadata[key]
573
608
 
574
609
  def _html_tree_view_content(
575
610
  self,
@@ -646,15 +681,14 @@ class Message(
646
681
  s.write(s.escape(chunk))
647
682
  else:
648
683
  assert isinstance(chunk, modality.Modality), chunk
649
- child_path = pg.KeyPath(['metadata', chunk.referred_name], root_path)
650
684
  s.write(
651
685
  pg.Html.element(
652
686
  'div',
653
687
  [
654
688
  view.render(
655
689
  chunk,
656
- name=chunk.referred_name,
657
- root_path=child_path,
690
+ name=chunk.id,
691
+ root_path=chunk.sym_path,
658
692
  collapse_level=(
659
693
  0 if collapse_modalities_in_text else 1
660
694
  ),
@@ -667,7 +701,7 @@ class Message(
667
701
  css_classes=['modality-in-text'],
668
702
  )
669
703
  )
670
- referred_chunks[chunk.referred_name] = chunk
704
+ referred_chunks[chunk.id] = chunk
671
705
  s.write('</div>')
672
706
  return s
673
707