langfun 0.1.2.dev202510230805__py3-none-any.whl → 0.1.2.dev202510250803__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/core/concurrent_test.py +1 -0
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini_test.py +12 -9
- langfun/core/data/conversion/openai.py +134 -30
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base_test.py +4 -4
- langfun/core/eval/v2/progress_tracking_test.py +3 -0
- langfun/core/langfunc_test.py +6 -4
- langfun/core/language_model.py +15 -6
- langfun/core/language_model_test.py +9 -3
- langfun/core/llms/__init__.py +7 -1
- langfun/core/llms/anthropic.py +130 -0
- langfun/core/llms/cache/base.py +3 -1
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/deepseek.py +1 -1
- langfun/core/llms/gemini.py +2 -5
- langfun/core/llms/groq.py +1 -1
- langfun/core/llms/llama_cpp.py +1 -1
- langfun/core/llms/openai.py +7 -2
- langfun/core/llms/openai_compatible.py +136 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/vertexai.py +12 -2
- langfun/core/message.py +78 -44
- langfun/core/message_test.py +56 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/mime.py +9 -0
- langfun/core/modality.py +104 -27
- langfun/core/modality_test.py +42 -12
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/completion.py +2 -7
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/mapping.py +4 -13
- langfun/core/structured/querying.py +13 -11
- langfun/core/structured/querying_test.py +65 -29
- langfun/core/template.py +39 -13
- langfun/core/template_test.py +83 -17
- langfun/env/event_handlers/metric_writer_test.py +3 -3
- langfun/env/load_balancers_test.py +2 -2
- {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/RECORD +44 -44
- {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202510250803.dist-info}/top_level.txt +0 -0
|
@@ -38,7 +38,7 @@ def mock_chat_completion_request(url: str, json: dict[str, Any], **kwargs):
|
|
|
38
38
|
response_format = ''
|
|
39
39
|
|
|
40
40
|
choices = []
|
|
41
|
-
for k in range(json
|
|
41
|
+
for k in range(json.get('n', 1)):
|
|
42
42
|
if json.get('logprobs'):
|
|
43
43
|
logprobs = dict(
|
|
44
44
|
content=[
|
|
@@ -89,7 +89,7 @@ def mock_chat_completion_request_vision(
|
|
|
89
89
|
c['image_url']['url']
|
|
90
90
|
for c in json['messages'][0]['content'] if c['type'] == 'image_url'
|
|
91
91
|
]
|
|
92
|
-
for k in range(json
|
|
92
|
+
for k in range(json.get('n', 1)):
|
|
93
93
|
choices.append(pg.Dict(
|
|
94
94
|
message=pg.Dict(
|
|
95
95
|
content=f'Sample {k} for message: {"".join(urls)}'
|
|
@@ -111,12 +111,88 @@ def mock_chat_completion_request_vision(
|
|
|
111
111
|
return response
|
|
112
112
|
|
|
113
113
|
|
|
114
|
-
|
|
114
|
+
def mock_responses_request(url: str, json: dict[str, Any], **kwargs):
|
|
115
|
+
del url, kwargs
|
|
116
|
+
_ = json['input']
|
|
117
|
+
|
|
118
|
+
system_message = ''
|
|
119
|
+
if 'instructions' in json:
|
|
120
|
+
system_message = f' system={json["instructions"]}'
|
|
121
|
+
|
|
122
|
+
response_format = ''
|
|
123
|
+
if 'text' in json and 'format' in json['text']:
|
|
124
|
+
response_format = f' format={json["text"]["format"]["type"]}'
|
|
125
|
+
|
|
126
|
+
output = [
|
|
127
|
+
dict(
|
|
128
|
+
type='message',
|
|
129
|
+
content=[
|
|
130
|
+
dict(
|
|
131
|
+
type='output_text',
|
|
132
|
+
text=(
|
|
133
|
+
f'Sample 0 for message.{system_message}{response_format}'
|
|
134
|
+
)
|
|
135
|
+
)
|
|
136
|
+
],
|
|
137
|
+
)
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
response = requests.Response()
|
|
141
|
+
response.status_code = 200
|
|
142
|
+
response._content = pg.to_json_str(
|
|
143
|
+
dict(
|
|
144
|
+
output=output,
|
|
145
|
+
usage=dict(
|
|
146
|
+
input_tokens=100,
|
|
147
|
+
output_tokens=100,
|
|
148
|
+
total_tokens=200,
|
|
149
|
+
),
|
|
150
|
+
)
|
|
151
|
+
).encode()
|
|
152
|
+
return response
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def mock_responses_request_vision(
|
|
156
|
+
url: str, json: dict[str, Any], **kwargs
|
|
157
|
+
):
|
|
158
|
+
del url, kwargs
|
|
159
|
+
urls = [
|
|
160
|
+
c['image_url']
|
|
161
|
+
for c in json['input'][0]['content']
|
|
162
|
+
if c['type'] == 'input_image'
|
|
163
|
+
]
|
|
164
|
+
output = [
|
|
165
|
+
pg.Dict(
|
|
166
|
+
type='message',
|
|
167
|
+
content=[
|
|
168
|
+
pg.Dict(
|
|
169
|
+
type='output_text',
|
|
170
|
+
text=f'Sample 0 for message: {"".join(urls)}',
|
|
171
|
+
)
|
|
172
|
+
],
|
|
173
|
+
)
|
|
174
|
+
]
|
|
175
|
+
response = requests.Response()
|
|
176
|
+
response.status_code = 200
|
|
177
|
+
response._content = pg.to_json_str(
|
|
178
|
+
dict(
|
|
179
|
+
output=output,
|
|
180
|
+
usage=dict(
|
|
181
|
+
input_tokens=100,
|
|
182
|
+
output_tokens=100,
|
|
183
|
+
total_tokens=200,
|
|
184
|
+
),
|
|
185
|
+
)
|
|
186
|
+
).encode()
|
|
187
|
+
return response
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class OpenAIChatCompletionAPITest(unittest.TestCase):
|
|
115
191
|
"""Tests for OpenAI compatible language model."""
|
|
116
192
|
|
|
117
193
|
def test_request_args(self):
|
|
118
194
|
self.assertEqual(
|
|
119
|
-
openai_compatible.
|
|
195
|
+
openai_compatible.OpenAIChatCompletionAPI(
|
|
120
196
|
api_endpoint='https://test-server',
|
|
121
197
|
model='test-model'
|
|
122
198
|
)._request_args(
|
|
@@ -126,8 +202,6 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
126
202
|
),
|
|
127
203
|
dict(
|
|
128
204
|
model='test-model',
|
|
129
|
-
top_logprobs=None,
|
|
130
|
-
n=1,
|
|
131
205
|
temperature=1.0,
|
|
132
206
|
stop=['\n'],
|
|
133
207
|
seed=123,
|
|
@@ -137,7 +211,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
137
211
|
def test_call_chat_completion(self):
|
|
138
212
|
with mock.patch('requests.Session.post') as mock_request:
|
|
139
213
|
mock_request.side_effect = mock_chat_completion_request
|
|
140
|
-
lm = openai_compatible.
|
|
214
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
141
215
|
api_endpoint='https://test-server', model='test-model',
|
|
142
216
|
)
|
|
143
217
|
self.assertEqual(
|
|
@@ -148,7 +222,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
148
222
|
def test_call_chat_completion_with_logprobs(self):
|
|
149
223
|
with mock.patch('requests.Session.post') as mock_request:
|
|
150
224
|
mock_request.side_effect = mock_chat_completion_request
|
|
151
|
-
lm = openai_compatible.
|
|
225
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
152
226
|
api_endpoint='https://test-server', model='test-model',
|
|
153
227
|
)
|
|
154
228
|
results = lm.sample(['hello'], logprobs=True)
|
|
@@ -214,13 +288,14 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
214
288
|
def mime_type(self) -> str:
|
|
215
289
|
return 'image/png'
|
|
216
290
|
|
|
291
|
+
image = FakeImage.from_uri('https://fake/image')
|
|
217
292
|
with mock.patch('requests.Session.post') as mock_request:
|
|
218
293
|
mock_request.side_effect = mock_chat_completion_request_vision
|
|
219
|
-
lm_1 = openai_compatible.
|
|
294
|
+
lm_1 = openai_compatible.OpenAIChatCompletionAPI(
|
|
220
295
|
api_endpoint='https://test-server',
|
|
221
296
|
model='test-model1',
|
|
222
297
|
)
|
|
223
|
-
lm_2 = openai_compatible.
|
|
298
|
+
lm_2 = openai_compatible.OpenAIChatCompletionAPI(
|
|
224
299
|
api_endpoint='https://test-server',
|
|
225
300
|
model='test-model2',
|
|
226
301
|
)
|
|
@@ -228,15 +303,15 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
228
303
|
self.assertEqual(
|
|
229
304
|
lm(
|
|
230
305
|
lf.UserMessage(
|
|
231
|
-
'hello <<[[image]]>>',
|
|
232
|
-
|
|
306
|
+
f'hello <<[[{image.id}]]>>',
|
|
307
|
+
referred_modalities=[image],
|
|
233
308
|
),
|
|
234
309
|
sampling_options=lf.LMSamplingOptions(n=2)
|
|
235
310
|
),
|
|
236
311
|
'Sample 0 for message: https://fake/image',
|
|
237
312
|
)
|
|
238
313
|
|
|
239
|
-
class TextOnlyModel(openai_compatible.
|
|
314
|
+
class TextOnlyModel(openai_compatible.OpenAIChatCompletionAPI):
|
|
240
315
|
|
|
241
316
|
class ModelInfo(lf.ModelInfo):
|
|
242
317
|
input_modalities: list[str] = lf.ModelInfo.TEXT_INPUT_ONLY
|
|
@@ -251,15 +326,15 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
251
326
|
with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
|
|
252
327
|
lm_3(
|
|
253
328
|
lf.UserMessage(
|
|
254
|
-
'hello <<[[image]]>>',
|
|
255
|
-
|
|
329
|
+
f'hello <<[[{image.id}]]>>',
|
|
330
|
+
referred_modalities=[image],
|
|
256
331
|
),
|
|
257
332
|
)
|
|
258
333
|
|
|
259
334
|
def test_sample_chat_completion(self):
|
|
260
335
|
with mock.patch('requests.Session.post') as mock_request:
|
|
261
336
|
mock_request.side_effect = mock_chat_completion_request
|
|
262
|
-
lm = openai_compatible.
|
|
337
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
263
338
|
api_endpoint='https://test-server', model='test-model'
|
|
264
339
|
)
|
|
265
340
|
results = lm.sample(
|
|
@@ -400,7 +475,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
400
475
|
def test_sample_with_contextual_options(self):
|
|
401
476
|
with mock.patch('requests.Session.post') as mock_request:
|
|
402
477
|
mock_request.side_effect = mock_chat_completion_request
|
|
403
|
-
lm = openai_compatible.
|
|
478
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
404
479
|
api_endpoint='https://test-server', model='test-model'
|
|
405
480
|
)
|
|
406
481
|
with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)):
|
|
@@ -458,7 +533,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
458
533
|
def test_call_with_system_message(self):
|
|
459
534
|
with mock.patch('requests.Session.post') as mock_request:
|
|
460
535
|
mock_request.side_effect = mock_chat_completion_request
|
|
461
|
-
lm = openai_compatible.
|
|
536
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
462
537
|
api_endpoint='https://test-server', model='test-model'
|
|
463
538
|
)
|
|
464
539
|
self.assertEqual(
|
|
@@ -475,7 +550,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
475
550
|
def test_call_with_json_schema(self):
|
|
476
551
|
with mock.patch('requests.Session.post') as mock_request:
|
|
477
552
|
mock_request.side_effect = mock_chat_completion_request
|
|
478
|
-
lm = openai_compatible.
|
|
553
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
479
554
|
api_endpoint='https://test-server', model='test-model'
|
|
480
555
|
)
|
|
481
556
|
self.assertEqual(
|
|
@@ -515,7 +590,7 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
515
590
|
|
|
516
591
|
with mock.patch('requests.Session.post') as mock_request:
|
|
517
592
|
mock_request.side_effect = mock_context_limit_error
|
|
518
|
-
lm = openai_compatible.
|
|
593
|
+
lm = openai_compatible.OpenAIChatCompletionAPI(
|
|
519
594
|
api_endpoint='https://test-server', model='test-model'
|
|
520
595
|
)
|
|
521
596
|
with self.assertRaisesRegex(
|
|
@@ -524,5 +599,117 @@ class OpenAIComptibleTest(unittest.TestCase):
|
|
|
524
599
|
lm(lf.UserMessage('hello'))
|
|
525
600
|
|
|
526
601
|
|
|
602
|
+
class OpenAIResponsesAPITest(unittest.TestCase):
|
|
603
|
+
"""Tests for OpenAI compatible language model on Responses API."""
|
|
604
|
+
|
|
605
|
+
def test_request_args(self):
|
|
606
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
607
|
+
api_endpoint='https://test-server', model='test-model'
|
|
608
|
+
)
|
|
609
|
+
# Test valid args.
|
|
610
|
+
self.assertEqual(
|
|
611
|
+
lm._request_args(
|
|
612
|
+
lf.LMSamplingOptions(
|
|
613
|
+
temperature=1.0, stop=['\n'], n=1, random_seed=123
|
|
614
|
+
)
|
|
615
|
+
),
|
|
616
|
+
dict(
|
|
617
|
+
model='test-model',
|
|
618
|
+
temperature=1.0,
|
|
619
|
+
stop=['\n'],
|
|
620
|
+
seed=123,
|
|
621
|
+
),
|
|
622
|
+
)
|
|
623
|
+
# Test unsupported n.
|
|
624
|
+
with self.assertRaisesRegex(ValueError, 'n must be 1 for Responses API.'):
|
|
625
|
+
lm._request_args(lf.LMSamplingOptions(n=2))
|
|
626
|
+
|
|
627
|
+
# Test unsupported logprobs.
|
|
628
|
+
with self.assertRaisesRegex(
|
|
629
|
+
ValueError, 'logprobs is not supported on Responses API.'
|
|
630
|
+
):
|
|
631
|
+
lm._request_args(lf.LMSamplingOptions(logprobs=True))
|
|
632
|
+
|
|
633
|
+
def test_call_responses(self):
|
|
634
|
+
with mock.patch('requests.Session.post') as mock_request:
|
|
635
|
+
mock_request.side_effect = mock_responses_request
|
|
636
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
637
|
+
api_endpoint='https://test-server',
|
|
638
|
+
model='test-model',
|
|
639
|
+
)
|
|
640
|
+
self.assertEqual(lm('hello'), 'Sample 0 for message.')
|
|
641
|
+
|
|
642
|
+
def test_call_responses_vision(self):
|
|
643
|
+
class FakeImage(lf_modalities.Image):
|
|
644
|
+
@property
|
|
645
|
+
def mime_type(self) -> str:
|
|
646
|
+
return 'image/png'
|
|
647
|
+
|
|
648
|
+
image = FakeImage.from_uri('https://fake/image')
|
|
649
|
+
with mock.patch('requests.Session.post') as mock_request:
|
|
650
|
+
mock_request.side_effect = mock_responses_request_vision
|
|
651
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
652
|
+
api_endpoint='https://test-server',
|
|
653
|
+
model='test-model1',
|
|
654
|
+
)
|
|
655
|
+
self.assertEqual(
|
|
656
|
+
lm(
|
|
657
|
+
lf.UserMessage(
|
|
658
|
+
f'hello <<[[{image.id}]]>>',
|
|
659
|
+
referred_modalities=[image],
|
|
660
|
+
)
|
|
661
|
+
),
|
|
662
|
+
'Sample 0 for message: https://fake/image',
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
def test_call_with_system_message(self):
|
|
666
|
+
with mock.patch('requests.Session.post') as mock_request:
|
|
667
|
+
mock_request.side_effect = mock_responses_request
|
|
668
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
669
|
+
api_endpoint='https://test-server', model='test-model'
|
|
670
|
+
)
|
|
671
|
+
self.assertEqual(
|
|
672
|
+
lm(
|
|
673
|
+
lf.UserMessage(
|
|
674
|
+
'hello',
|
|
675
|
+
system_message=lf.SystemMessage('hi'),
|
|
676
|
+
)
|
|
677
|
+
),
|
|
678
|
+
'Sample 0 for message. system=hi',
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
def test_call_with_json_schema(self):
|
|
682
|
+
with mock.patch('requests.Session.post') as mock_request:
|
|
683
|
+
mock_request.side_effect = mock_responses_request
|
|
684
|
+
lm = openai_compatible.OpenAIResponsesAPI(
|
|
685
|
+
api_endpoint='https://test-server', model='test-model'
|
|
686
|
+
)
|
|
687
|
+
self.assertEqual(
|
|
688
|
+
lm(
|
|
689
|
+
lf.UserMessage(
|
|
690
|
+
'hello',
|
|
691
|
+
json_schema={
|
|
692
|
+
'type': 'object',
|
|
693
|
+
'properties': {
|
|
694
|
+
'name': {'type': 'string'},
|
|
695
|
+
},
|
|
696
|
+
'required': ['name'],
|
|
697
|
+
'title': 'Person',
|
|
698
|
+
},
|
|
699
|
+
)
|
|
700
|
+
),
|
|
701
|
+
'Sample 0 for message. format=json_schema',
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
# Test bad json schema.
|
|
705
|
+
with self.assertRaisesRegex(ValueError, '`json_schema` must be a dict'):
|
|
706
|
+
lm(lf.UserMessage('hello', json_schema='foo'))
|
|
707
|
+
|
|
708
|
+
with self.assertRaisesRegex(
|
|
709
|
+
ValueError, 'The root of `json_schema` must have a `title` field'
|
|
710
|
+
):
|
|
711
|
+
lm(lf.UserMessage('hello', json_schema={}))
|
|
712
|
+
|
|
713
|
+
|
|
527
714
|
if __name__ == '__main__':
|
|
528
715
|
unittest.main()
|
langfun/core/llms/openai_test.py
CHANGED
langfun/core/llms/vertexai.py
CHANGED
|
@@ -369,6 +369,16 @@ class VertexAIAnthropic(VertexAI, anthropic.Anthropic):
|
|
|
369
369
|
# pylint: disable=invalid-name
|
|
370
370
|
|
|
371
371
|
|
|
372
|
+
class VertexAIClaude45Haiku_20251001(VertexAIAnthropic):
|
|
373
|
+
"""Anthropic's Claude 4.5 Haiku model on VertexAI."""
|
|
374
|
+
model = 'claude-haiku-4-5@20251001'
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
class VertexAIClaude45Sonnet_20250929(VertexAIAnthropic):
|
|
378
|
+
"""Anthropic's Claude 4.5 Sonnet model on VertexAI."""
|
|
379
|
+
model = 'claude-sonnet-4-5@20250929'
|
|
380
|
+
|
|
381
|
+
|
|
372
382
|
class VertexAIClaude4Opus_20250514(VertexAIAnthropic):
|
|
373
383
|
"""Anthropic's Claude 4 Opus model on VertexAI."""
|
|
374
384
|
model = 'claude-opus-4@20250514'
|
|
@@ -487,7 +497,7 @@ _LLAMA_MODELS_BY_MODEL_ID = {m.model_id: m for m in LLAMA_MODELS}
|
|
|
487
497
|
|
|
488
498
|
@pg.use_init_args(['model'])
|
|
489
499
|
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
|
490
|
-
class VertexAILlama(VertexAI, openai_compatible.
|
|
500
|
+
class VertexAILlama(VertexAI, openai_compatible.OpenAIChatCompletionAPI):
|
|
491
501
|
"""Llama models on VertexAI."""
|
|
492
502
|
|
|
493
503
|
model: pg.typing.Annotated[
|
|
@@ -600,7 +610,7 @@ _MISTRAL_MODELS_BY_MODEL_ID = {m.model_id: m for m in MISTRAL_MODELS}
|
|
|
600
610
|
|
|
601
611
|
@pg.use_init_args(['model'])
|
|
602
612
|
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
|
603
|
-
class VertexAIMistral(VertexAI, openai_compatible.
|
|
613
|
+
class VertexAIMistral(VertexAI, openai_compatible.OpenAIChatCompletionAPI):
|
|
604
614
|
"""Mistral AI models on VertexAI."""
|
|
605
615
|
|
|
606
616
|
model: pg.typing.Annotated[
|
langfun/core/message.py
CHANGED
|
@@ -20,7 +20,7 @@ import contextlib
|
|
|
20
20
|
import functools
|
|
21
21
|
import inspect
|
|
22
22
|
import io
|
|
23
|
-
from typing import Annotated, Any, ClassVar, Optional, Type, Union
|
|
23
|
+
from typing import Annotated, Any, Callable, ClassVar, Optional, Type, Union
|
|
24
24
|
|
|
25
25
|
from langfun.core import modality
|
|
26
26
|
from langfun.core import natural_language
|
|
@@ -86,6 +86,11 @@ class Message(
|
|
|
86
86
|
|
|
87
87
|
sender: Annotated[str, 'The sender of the message.']
|
|
88
88
|
|
|
89
|
+
referred_modalities: Annotated[
|
|
90
|
+
dict[str, pg.Ref[modality.Modality]],
|
|
91
|
+
'The modality objects referred in the message.'
|
|
92
|
+
] = pg.Dict()
|
|
93
|
+
|
|
89
94
|
metadata: Annotated[
|
|
90
95
|
dict[str, Any],
|
|
91
96
|
(
|
|
@@ -111,6 +116,11 @@ class Message(
|
|
|
111
116
|
*,
|
|
112
117
|
# Default sender is specified in subclasses.
|
|
113
118
|
sender: str | pg.object_utils.MissingValue = pg.MISSING_VALUE,
|
|
119
|
+
referred_modalities: (
|
|
120
|
+
list[modality.Modality]
|
|
121
|
+
| dict[str, modality.Modality]
|
|
122
|
+
| None
|
|
123
|
+
) = None,
|
|
114
124
|
metadata: dict[str, Any] | None = None,
|
|
115
125
|
tags: list[str] | None = None,
|
|
116
126
|
source: Optional['Message'] = None,
|
|
@@ -125,6 +135,7 @@ class Message(
|
|
|
125
135
|
Args:
|
|
126
136
|
text: The text in the message.
|
|
127
137
|
sender: The sender name of the message.
|
|
138
|
+
referred_modalities: The modality objects referred in the message.
|
|
128
139
|
metadata: Structured meta-data associated with this message.
|
|
129
140
|
tags: Tags for the message.
|
|
130
141
|
source: The source message of the current message.
|
|
@@ -138,9 +149,13 @@ class Message(
|
|
|
138
149
|
"""
|
|
139
150
|
metadata = metadata or {}
|
|
140
151
|
metadata.update(kwargs)
|
|
152
|
+
if isinstance(referred_modalities, list):
|
|
153
|
+
referred_modalities = {m.id: pg.Ref(m) for m in referred_modalities}
|
|
154
|
+
|
|
141
155
|
super().__init__(
|
|
142
156
|
text=text,
|
|
143
157
|
metadata=metadata,
|
|
158
|
+
referred_modalities=referred_modalities or {},
|
|
144
159
|
tags=tags or [],
|
|
145
160
|
sender=sender,
|
|
146
161
|
allow_partial=allow_partial,
|
|
@@ -186,7 +201,7 @@ class Message(
|
|
|
186
201
|
A message created from the value.
|
|
187
202
|
"""
|
|
188
203
|
if isinstance(value, modality.Modality):
|
|
189
|
-
return cls('<<[[
|
|
204
|
+
return cls(f'<<[[{value.id}]]>>', referred_modalities=[value])
|
|
190
205
|
if isinstance(value, Message):
|
|
191
206
|
return value
|
|
192
207
|
if isinstance(value, str):
|
|
@@ -280,8 +295,7 @@ class Message(
|
|
|
280
295
|
if key_path == Message.PATH_TEXT:
|
|
281
296
|
return self.text
|
|
282
297
|
else:
|
|
283
|
-
|
|
284
|
-
return v.value if isinstance(v, pg.Ref) else v
|
|
298
|
+
return self.metadata.sym_get(key_path, default, use_inferred=True)
|
|
285
299
|
|
|
286
300
|
#
|
|
287
301
|
# API for accessing the structured result and error.
|
|
@@ -361,43 +375,53 @@ class Message(
|
|
|
361
375
|
# API for supporting modalities.
|
|
362
376
|
#
|
|
363
377
|
|
|
378
|
+
def modalities(
|
|
379
|
+
self,
|
|
380
|
+
filter: ( # pylint: disable=redefined-builtin
|
|
381
|
+
Type[modality.Modality]
|
|
382
|
+
| Callable[[modality.Modality], bool]
|
|
383
|
+
| None
|
|
384
|
+
) = None # pylint: disable=bad-whitespace
|
|
385
|
+
) -> list[modality.Modality]:
|
|
386
|
+
"""Returns the modality objects referred in the message."""
|
|
387
|
+
if inspect.isclass(filter) and issubclass(filter, modality.Modality):
|
|
388
|
+
filter_fn = lambda v: isinstance(v, filter) # pytype: disable=wrong-arg-types
|
|
389
|
+
elif filter is None:
|
|
390
|
+
filter_fn = lambda v: True
|
|
391
|
+
else:
|
|
392
|
+
filter_fn = filter
|
|
393
|
+
return [v for v in self.referred_modalities.values() if filter_fn(v)]
|
|
394
|
+
|
|
364
395
|
@property
|
|
365
|
-
def
|
|
366
|
-
"""Returns
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
396
|
+
def images(self) -> list[modality.Modality]:
|
|
397
|
+
"""Returns the image objects referred in the message."""
|
|
398
|
+
assert False, 'Overridden in core/modalities/__init__.py'
|
|
399
|
+
|
|
400
|
+
@property
|
|
401
|
+
def videos(self) -> list[modality.Modality]:
|
|
402
|
+
"""Returns the video objects referred in the message."""
|
|
403
|
+
assert False, 'Overridden in core/modalities/__init__.py'
|
|
404
|
+
|
|
405
|
+
@property
|
|
406
|
+
def audios(self) -> list[modality.Modality]:
|
|
407
|
+
"""Returns the audio objects referred in the message."""
|
|
408
|
+
assert False, 'Overridden in core/modalities/__init__.py'
|
|
373
409
|
|
|
374
410
|
def get_modality(
|
|
375
|
-
self,
|
|
411
|
+
self,
|
|
412
|
+
var_name: str,
|
|
413
|
+
default: Any = None
|
|
376
414
|
) -> modality.Modality | None:
|
|
377
415
|
"""Gets the modality object referred in the message.
|
|
378
416
|
|
|
379
417
|
Args:
|
|
380
418
|
var_name: The referred variable name for the modality object.
|
|
381
419
|
default: default value.
|
|
382
|
-
from_message_chain: If True, the look up will be performed from the
|
|
383
|
-
message chain. Otherwise it will be performed in current message.
|
|
384
420
|
|
|
385
421
|
Returns:
|
|
386
422
|
A modality object if found, otherwise None.
|
|
387
423
|
"""
|
|
388
|
-
|
|
389
|
-
if isinstance(obj, modality.Modality):
|
|
390
|
-
return obj
|
|
391
|
-
elif obj is None and self.source is not None:
|
|
392
|
-
return self.source.get_modality(var_name, default, from_message_chain)
|
|
393
|
-
return default
|
|
394
|
-
|
|
395
|
-
def referred_modalities(self) -> dict[str, modality.Modality]:
|
|
396
|
-
"""Returns modality objects attached on this message."""
|
|
397
|
-
chunks = self.chunk()
|
|
398
|
-
return {
|
|
399
|
-
m.referred_name: m for m in chunks if isinstance(m, modality.Modality)
|
|
400
|
-
}
|
|
424
|
+
return self.referred_modalities.get(var_name, default)
|
|
401
425
|
|
|
402
426
|
def chunk(self, text: str | None = None) -> list[str | modality.Modality]:
|
|
403
427
|
"""Chunk a message into a list of str or modality objects."""
|
|
@@ -425,10 +449,15 @@ class Message(
|
|
|
425
449
|
|
|
426
450
|
var_name = text[var_start:ref_end].strip()
|
|
427
451
|
var_value = self.get_modality(var_name)
|
|
428
|
-
if var_value is
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
452
|
+
if var_value is None:
|
|
453
|
+
raise ValueError(
|
|
454
|
+
f'Unknown modality reference: {var_name!r}. '
|
|
455
|
+
'Please make sure the modality object is present in '
|
|
456
|
+
f'`referred_modalities` when creating {self.__class__.__name__}.'
|
|
457
|
+
)
|
|
458
|
+
add_text_chunk(text[chunk_start:ref_start].strip(' '))
|
|
459
|
+
chunks.append(var_value)
|
|
460
|
+
chunk_start = ref_end + len(modality.Modality.REF_END)
|
|
432
461
|
return chunks
|
|
433
462
|
|
|
434
463
|
@classmethod
|
|
@@ -437,8 +466,8 @@ class Message(
|
|
|
437
466
|
) -> 'Message':
|
|
438
467
|
"""Assembly a message from a list of string or modality objects."""
|
|
439
468
|
fused_text = io.StringIO()
|
|
440
|
-
ref_index = 0
|
|
441
469
|
metadata = dict()
|
|
470
|
+
referred_modalities = dict()
|
|
442
471
|
last_char = None
|
|
443
472
|
for i, chunk in enumerate(chunks):
|
|
444
473
|
if i > 0 and last_char not in ('\t', ' ', '\n', None):
|
|
@@ -451,14 +480,16 @@ class Message(
|
|
|
451
480
|
last_char = None
|
|
452
481
|
else:
|
|
453
482
|
assert isinstance(chunk, modality.Modality), chunk
|
|
454
|
-
|
|
455
|
-
fused_text.write(modality.Modality.text_marker(var_name))
|
|
483
|
+
fused_text.write(modality.Modality.text_marker(chunk.id))
|
|
456
484
|
last_char = modality.Modality.REF_END[-1]
|
|
457
485
|
# Make a reference if the chunk is already owned by another object
|
|
458
486
|
# to avoid copy.
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
487
|
+
referred_modalities[chunk.id] = pg.Ref(chunk)
|
|
488
|
+
return cls(
|
|
489
|
+
fused_text.getvalue().strip(),
|
|
490
|
+
referred_modalities=referred_modalities,
|
|
491
|
+
metadata=metadata,
|
|
492
|
+
)
|
|
462
493
|
|
|
463
494
|
#
|
|
464
495
|
# Tagging
|
|
@@ -551,6 +582,11 @@ class Message(
|
|
|
551
582
|
#
|
|
552
583
|
|
|
553
584
|
def natural_language_format(self) -> str:
|
|
585
|
+
"""Returns the natural language format representation."""
|
|
586
|
+
# Propagate the modality references to parent context if any.
|
|
587
|
+
if capture_context := modality.get_modality_capture_context():
|
|
588
|
+
for v in self.referred_modalities.values():
|
|
589
|
+
capture_context.capture(v)
|
|
554
590
|
return self.text
|
|
555
591
|
|
|
556
592
|
def __eq__(self, other: Any) -> bool:
|
|
@@ -568,8 +604,7 @@ class Message(
|
|
|
568
604
|
def __getattr__(self, key: str) -> Any:
|
|
569
605
|
if key not in self.metadata:
|
|
570
606
|
raise AttributeError(key)
|
|
571
|
-
|
|
572
|
-
return v.value if isinstance(v, pg.Ref) else v
|
|
607
|
+
return self.metadata[key]
|
|
573
608
|
|
|
574
609
|
def _html_tree_view_content(
|
|
575
610
|
self,
|
|
@@ -646,15 +681,14 @@ class Message(
|
|
|
646
681
|
s.write(s.escape(chunk))
|
|
647
682
|
else:
|
|
648
683
|
assert isinstance(chunk, modality.Modality), chunk
|
|
649
|
-
child_path = pg.KeyPath(['metadata', chunk.referred_name], root_path)
|
|
650
684
|
s.write(
|
|
651
685
|
pg.Html.element(
|
|
652
686
|
'div',
|
|
653
687
|
[
|
|
654
688
|
view.render(
|
|
655
689
|
chunk,
|
|
656
|
-
name=chunk.
|
|
657
|
-
root_path=
|
|
690
|
+
name=chunk.id,
|
|
691
|
+
root_path=chunk.sym_path,
|
|
658
692
|
collapse_level=(
|
|
659
693
|
0 if collapse_modalities_in_text else 1
|
|
660
694
|
),
|
|
@@ -667,7 +701,7 @@ class Message(
|
|
|
667
701
|
css_classes=['modality-in-text'],
|
|
668
702
|
)
|
|
669
703
|
)
|
|
670
|
-
referred_chunks[chunk.
|
|
704
|
+
referred_chunks[chunk.id] = chunk
|
|
671
705
|
s.write('</div>')
|
|
672
706
|
return s
|
|
673
707
|
|