langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512040805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +1 -1
- langfun/core/__init__.py +7 -1
- langfun/core/agentic/__init__.py +8 -1
- langfun/core/agentic/action.py +740 -112
- langfun/core/agentic/action_eval.py +9 -2
- langfun/core/agentic/action_test.py +189 -24
- langfun/core/async_support.py +104 -5
- langfun/core/async_support_test.py +23 -0
- langfun/core/coding/python/correction.py +19 -9
- langfun/core/coding/python/execution.py +14 -12
- langfun/core/coding/python/generation.py +21 -16
- langfun/core/coding/python/sandboxing.py +23 -3
- langfun/core/component.py +42 -3
- langfun/core/concurrent.py +70 -6
- langfun/core/concurrent_test.py +9 -2
- langfun/core/console.py +1 -1
- langfun/core/data/conversion/anthropic.py +12 -3
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini.py +11 -2
- langfun/core/data/conversion/gemini_test.py +48 -9
- langfun/core/data/conversion/openai.py +145 -31
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base.py +48 -44
- langfun/core/eval/base_test.py +5 -5
- langfun/core/eval/matching.py +5 -2
- langfun/core/eval/patching.py +3 -3
- langfun/core/eval/scoring.py +4 -3
- langfun/core/eval/v2/__init__.py +2 -0
- langfun/core/eval/v2/checkpointing.py +76 -7
- langfun/core/eval/v2/checkpointing_test.py +9 -2
- langfun/core/eval/v2/config_saver.py +37 -0
- langfun/core/eval/v2/config_saver_test.py +36 -0
- langfun/core/eval/v2/eval_test_helper.py +104 -3
- langfun/core/eval/v2/evaluation.py +92 -17
- langfun/core/eval/v2/evaluation_test.py +9 -3
- langfun/core/eval/v2/example.py +50 -40
- langfun/core/eval/v2/example_test.py +16 -8
- langfun/core/eval/v2/experiment.py +84 -15
- langfun/core/eval/v2/experiment_test.py +19 -0
- langfun/core/eval/v2/metric_values.py +31 -3
- langfun/core/eval/v2/metric_values_test.py +32 -0
- langfun/core/eval/v2/metrics.py +157 -44
- langfun/core/eval/v2/metrics_test.py +39 -18
- langfun/core/eval/v2/progress.py +31 -1
- langfun/core/eval/v2/progress_test.py +27 -0
- langfun/core/eval/v2/progress_tracking.py +13 -5
- langfun/core/eval/v2/progress_tracking_test.py +9 -1
- langfun/core/eval/v2/reporting.py +90 -71
- langfun/core/eval/v2/reporting_test.py +24 -6
- langfun/core/eval/v2/runners/__init__.py +30 -0
- langfun/core/eval/v2/{runners.py → runners/base.py} +72 -180
- langfun/core/eval/v2/runners/beam.py +354 -0
- langfun/core/eval/v2/runners/beam_test.py +153 -0
- langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
- langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
- langfun/core/eval/v2/runners/debug.py +40 -0
- langfun/core/eval/v2/runners/debug_test.py +76 -0
- langfun/core/eval/v2/runners/parallel.py +243 -0
- langfun/core/eval/v2/runners/parallel_test.py +182 -0
- langfun/core/eval/v2/runners/sequential.py +47 -0
- langfun/core/eval/v2/runners/sequential_test.py +169 -0
- langfun/core/langfunc.py +45 -130
- langfun/core/langfunc_test.py +7 -5
- langfun/core/language_model.py +189 -36
- langfun/core/language_model_test.py +54 -3
- langfun/core/llms/__init__.py +12 -1
- langfun/core/llms/anthropic.py +157 -2
- langfun/core/llms/azure_openai.py +29 -17
- langfun/core/llms/cache/base.py +25 -3
- langfun/core/llms/cache/in_memory.py +48 -7
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/compositional.py +25 -1
- langfun/core/llms/deepseek.py +30 -2
- langfun/core/llms/fake.py +32 -1
- langfun/core/llms/gemini.py +64 -12
- langfun/core/llms/gemini_test.py +110 -0
- langfun/core/llms/google_genai.py +34 -1
- langfun/core/llms/groq.py +28 -3
- langfun/core/llms/llama_cpp.py +23 -4
- langfun/core/llms/openai.py +120 -3
- langfun/core/llms/openai_compatible.py +148 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/rest.py +16 -1
- langfun/core/llms/vertexai.py +58 -8
- langfun/core/logging.py +1 -1
- langfun/core/mcp/__init__.py +10 -0
- langfun/core/mcp/client.py +177 -0
- langfun/core/mcp/client_test.py +71 -0
- langfun/core/mcp/session.py +241 -0
- langfun/core/mcp/session_test.py +54 -0
- langfun/core/mcp/testing/simple_mcp_client.py +33 -0
- langfun/core/mcp/testing/simple_mcp_server.py +33 -0
- langfun/core/mcp/tool.py +254 -0
- langfun/core/mcp/tool_test.py +197 -0
- langfun/core/memory.py +1 -0
- langfun/core/message.py +160 -55
- langfun/core/message_test.py +65 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/audio.py +21 -1
- langfun/core/modalities/image.py +73 -3
- langfun/core/modalities/image_test.py +116 -0
- langfun/core/modalities/mime.py +64 -3
- langfun/core/modalities/mime_test.py +11 -0
- langfun/core/modalities/pdf.py +19 -1
- langfun/core/modalities/video.py +21 -1
- langfun/core/modality.py +167 -29
- langfun/core/modality_test.py +42 -12
- langfun/core/natural_language.py +1 -1
- langfun/core/sampling.py +4 -4
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/__init__.py +2 -24
- langfun/core/structured/completion.py +34 -44
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/description.py +54 -50
- langfun/core/structured/function_generation.py +29 -12
- langfun/core/structured/mapping.py +81 -37
- langfun/core/structured/parsing.py +95 -79
- langfun/core/structured/parsing_test.py +0 -3
- langfun/core/structured/querying.py +230 -154
- langfun/core/structured/querying_test.py +69 -33
- langfun/core/structured/schema/__init__.py +49 -0
- langfun/core/structured/schema/base.py +664 -0
- langfun/core/structured/schema/base_test.py +531 -0
- langfun/core/structured/schema/json.py +174 -0
- langfun/core/structured/schema/json_test.py +121 -0
- langfun/core/structured/schema/python.py +316 -0
- langfun/core/structured/schema/python_test.py +410 -0
- langfun/core/structured/schema_generation.py +33 -14
- langfun/core/structured/scoring.py +47 -36
- langfun/core/structured/tokenization.py +26 -11
- langfun/core/subscription.py +2 -2
- langfun/core/template.py +175 -50
- langfun/core/template_test.py +123 -17
- langfun/env/__init__.py +43 -0
- langfun/env/base_environment.py +827 -0
- langfun/env/base_environment_test.py +473 -0
- langfun/env/base_feature.py +304 -0
- langfun/env/base_feature_test.py +228 -0
- langfun/env/base_sandbox.py +842 -0
- langfun/env/base_sandbox_test.py +1235 -0
- langfun/env/event_handlers/__init__.py +14 -0
- langfun/env/event_handlers/chain.py +233 -0
- langfun/env/event_handlers/chain_test.py +253 -0
- langfun/env/event_handlers/event_logger.py +472 -0
- langfun/env/event_handlers/event_logger_test.py +304 -0
- langfun/env/event_handlers/metric_writer.py +726 -0
- langfun/env/event_handlers/metric_writer_test.py +214 -0
- langfun/env/interface.py +1640 -0
- langfun/env/interface_test.py +153 -0
- langfun/env/load_balancers.py +59 -0
- langfun/env/load_balancers_test.py +141 -0
- langfun/env/test_utils.py +507 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/METADATA +7 -3
- langfun-0.1.2.dev202512040805.dist-info/RECORD +217 -0
- langfun/core/eval/v2/runners_test.py +0 -343
- langfun/core/structured/schema.py +0 -987
- langfun/core/structured/schema_test.py +0 -982
- langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Tests for MCP tool."""
|
|
15
|
+
|
|
16
|
+
import base64
|
|
17
|
+
import inspect
|
|
18
|
+
import unittest
|
|
19
|
+
|
|
20
|
+
from langfun.core import async_support
|
|
21
|
+
from langfun.core import message as lf_message
|
|
22
|
+
from langfun.core import modalities as lf_modalities
|
|
23
|
+
from langfun.core.mcp import client as mcp_client
|
|
24
|
+
from langfun.core.mcp import tool as mcp_tool
|
|
25
|
+
import mcp
|
|
26
|
+
from mcp.server import fastmcp as fastmcp_lib
|
|
27
|
+
import pyglove as pg
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# MCP server setup for testing.
|
|
31
|
+
_mcp_server = fastmcp_lib.FastMCP(host='0.0.0.0', port=1235)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@_mcp_server.tool()
|
|
35
|
+
async def add(a: int, b: int) -> int:
|
|
36
|
+
"""Adds two integers."""
|
|
37
|
+
return a + b
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class McpToolTest(unittest.TestCase):
|
|
41
|
+
|
|
42
|
+
def setUp(self):
|
|
43
|
+
super().setUp()
|
|
44
|
+
self.client = mcp_client.McpClient.from_fastmcp(_mcp_server)
|
|
45
|
+
self.tools = self.client.list_tools()
|
|
46
|
+
|
|
47
|
+
def test_snake_to_camel(self):
|
|
48
|
+
self.assertEqual(mcp_tool._snake_to_camel('foo_bar'), 'FooBar')
|
|
49
|
+
self.assertEqual(mcp_tool._snake_to_camel('foo'), 'Foo')
|
|
50
|
+
|
|
51
|
+
def test_base64_decode(self):
|
|
52
|
+
self.assertEqual(
|
|
53
|
+
mcp_tool._base64_decode(base64.b64encode(b'foo').decode('utf-8')),
|
|
54
|
+
b'foo'
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
def test_make_input_class(self):
|
|
58
|
+
schema = pg.Schema(
|
|
59
|
+
description='Foo input.',
|
|
60
|
+
fields=[
|
|
61
|
+
pg.typing.Field('x', pg.typing.Int(), 'Integer x.'),
|
|
62
|
+
pg.typing.Field('y', pg.typing.Str(), 'String y.'),
|
|
63
|
+
],
|
|
64
|
+
)
|
|
65
|
+
input_cls = mcp_tool.McpToolInput.make_class('foo_input', schema)
|
|
66
|
+
self.assertTrue(issubclass(input_cls, mcp_tool.McpToolInput))
|
|
67
|
+
self.assertEqual(input_cls.__name__, 'FooInput')
|
|
68
|
+
self.assertEqual(input_cls.__doc__, 'Foo input.')
|
|
69
|
+
s = input_cls.__schema__
|
|
70
|
+
self.assertEqual(list(s.fields.keys()), ['x', 'y'])
|
|
71
|
+
self.assertEqual(repr(input_cls), "<input-class 'FooInput'>")
|
|
72
|
+
self.assertEqual(
|
|
73
|
+
repr(input_cls(x=1, y='abc')),
|
|
74
|
+
"FooInput(x=1, y='abc')",
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def test_make_tool_class(self):
|
|
78
|
+
tool_def = mcp.Tool(
|
|
79
|
+
name='my_tool',
|
|
80
|
+
inputSchema={
|
|
81
|
+
'type': 'object',
|
|
82
|
+
'properties': {
|
|
83
|
+
'a': {'type': 'integer', 'description': 'Integer a.'},
|
|
84
|
+
'b': {'type': 'string', 'description': 'String b.'},
|
|
85
|
+
},
|
|
86
|
+
'required': ['a'],
|
|
87
|
+
},
|
|
88
|
+
description='My tool.',
|
|
89
|
+
)
|
|
90
|
+
tool_cls = mcp_tool.McpTool.make_class(tool_def)
|
|
91
|
+
self.assertTrue(issubclass(tool_cls, mcp_tool.McpTool))
|
|
92
|
+
self.assertEqual(tool_cls.__name__, 'MyTool')
|
|
93
|
+
self.assertEqual(tool_cls.TOOL_NAME, 'my_tool')
|
|
94
|
+
self.assertEqual(tool_cls.__doc__, 'My tool.')
|
|
95
|
+
s = tool_cls.__schema__
|
|
96
|
+
self.assertEqual(list(s.fields.keys()), ['a', 'b'])
|
|
97
|
+
self.assertEqual(repr(tool_cls), "<tool-class 'MyTool'>")
|
|
98
|
+
self.assertEqual(s.fields['a'].description, 'Integer a.')
|
|
99
|
+
self.assertEqual(s.fields['b'].description, 'String b.')
|
|
100
|
+
|
|
101
|
+
self.assertEqual(
|
|
102
|
+
tool_cls.python_definition(markdown=True),
|
|
103
|
+
inspect.cleandoc(
|
|
104
|
+
"""
|
|
105
|
+
MyTool
|
|
106
|
+
|
|
107
|
+
```python
|
|
108
|
+
class MyTool:
|
|
109
|
+
\"\"\"My tool.\"\"\"
|
|
110
|
+
# Integer a.
|
|
111
|
+
a: int
|
|
112
|
+
# String b.
|
|
113
|
+
b: str | None
|
|
114
|
+
```
|
|
115
|
+
"""
|
|
116
|
+
),
|
|
117
|
+
)
|
|
118
|
+
self.assertEqual(
|
|
119
|
+
tool_cls.python_definition(markdown=False),
|
|
120
|
+
inspect.cleandoc(
|
|
121
|
+
"""
|
|
122
|
+
MyTool
|
|
123
|
+
|
|
124
|
+
class MyTool:
|
|
125
|
+
\"\"\"My tool.\"\"\"
|
|
126
|
+
# Integer a.
|
|
127
|
+
a: int
|
|
128
|
+
# String b.
|
|
129
|
+
b: str | None
|
|
130
|
+
"""
|
|
131
|
+
),
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
def test_input_parameters(self):
|
|
135
|
+
tool_cls = self.tools['add']
|
|
136
|
+
self.assertEqual(tool_cls(a=1, b=2).input_parameters(), {'a': 1, 'b': 2})
|
|
137
|
+
|
|
138
|
+
def test_result_to_message(self):
|
|
139
|
+
img_data = base64.b64encode(b'image-data').decode('utf-8')
|
|
140
|
+
audio_data = base64.b64encode(b'audio-data').decode('utf-8')
|
|
141
|
+
|
|
142
|
+
tool_def = self.tools['add']
|
|
143
|
+
result = mcp.types.CallToolResult(
|
|
144
|
+
content=[
|
|
145
|
+
mcp.types.TextContent(type='text', text='hello'),
|
|
146
|
+
mcp.types.ImageContent(
|
|
147
|
+
type='image', data=img_data, mimeType='image/png'
|
|
148
|
+
),
|
|
149
|
+
mcp.types.AudioContent(
|
|
150
|
+
type='audio', data=audio_data, mimeType='audio/wav'
|
|
151
|
+
),
|
|
152
|
+
],
|
|
153
|
+
structuredContent={'x': 1},
|
|
154
|
+
)
|
|
155
|
+
message = tool_def.result_to_message(result)
|
|
156
|
+
self.assertIsInstance(message, lf_message.ToolMessage)
|
|
157
|
+
self.assertIn('hello', message.text)
|
|
158
|
+
self.assertIn('<<[[image', message.text)
|
|
159
|
+
self.assertIn('<<[[audio', message.text)
|
|
160
|
+
self.assertEqual(message.metadata, {'x': 1})
|
|
161
|
+
modalities = message.modalities()
|
|
162
|
+
self.assertEqual(len(modalities), 2)
|
|
163
|
+
self.assertIsInstance(modalities[0], lf_modalities.Image)
|
|
164
|
+
self.assertEqual(modalities[0].to_bytes(), b'image-data')
|
|
165
|
+
self.assertIsInstance(modalities[1], lf_modalities.Audio)
|
|
166
|
+
self.assertEqual(modalities[1].to_bytes(), b'audio-data')
|
|
167
|
+
|
|
168
|
+
def test_sync_call(self):
|
|
169
|
+
add_tool_cls = self.tools['add']
|
|
170
|
+
with self.client.session() as session:
|
|
171
|
+
# Test returning structured content.
|
|
172
|
+
self.assertEqual(add_tool_cls(a=1, b=2)(session), 3)
|
|
173
|
+
|
|
174
|
+
# Test returning message.
|
|
175
|
+
self.assertEqual(
|
|
176
|
+
add_tool_cls(a=1, b=2)(session, returns_message=True),
|
|
177
|
+
lf_message.ToolMessage(text='3', result=3),
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def test_async_call(self):
|
|
181
|
+
async def _test():
|
|
182
|
+
add_tool_cls = self.tools['add']
|
|
183
|
+
async with self.client.session() as session:
|
|
184
|
+
# Test returning structured content.
|
|
185
|
+
self.assertEqual(await add_tool_cls(a=1, b=2).acall(session), 3)
|
|
186
|
+
|
|
187
|
+
# Test returning message.
|
|
188
|
+
self.assertEqual(
|
|
189
|
+
await add_tool_cls(a=1, b=2).acall(session, returns_message=True),
|
|
190
|
+
lf_message.ToolMessage(text='3', result=3),
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
async_support.invoke_sync(_test)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
if __name__ == '__main__':
|
|
197
|
+
unittest.main()
|
langfun/core/memory.py
CHANGED
langfun/core/message.py
CHANGED
|
@@ -20,7 +20,7 @@ import contextlib
|
|
|
20
20
|
import functools
|
|
21
21
|
import inspect
|
|
22
22
|
import io
|
|
23
|
-
from typing import Annotated, Any, ClassVar, Optional, Type, Union
|
|
23
|
+
from typing import Annotated, Any, Callable, ClassVar, Optional, Type, Union
|
|
24
24
|
|
|
25
25
|
from langfun.core import modality
|
|
26
26
|
from langfun.core import natural_language
|
|
@@ -32,15 +32,49 @@ class Message(
|
|
|
32
32
|
pg.Object,
|
|
33
33
|
pg.views.HtmlTreeView.Extension
|
|
34
34
|
):
|
|
35
|
-
"""Message.
|
|
35
|
+
"""Message between users, LLMs and tools.
|
|
36
36
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
meta-data.
|
|
37
|
+
`lf.Message` is the fundamental unit of communication in Langfun. It
|
|
38
|
+
standardizes interactions with LLMs by encapsulating not only text but also
|
|
39
|
+
multi-modal content, as well as the sender's role and structured metadata.
|
|
41
40
|
|
|
42
|
-
|
|
43
|
-
|
|
41
|
+
**Key Components:**
|
|
42
|
+
|
|
43
|
+
* **`text`**: The natural language content of the message.
|
|
44
|
+
* **`sender`**: An identifier for the message originator (e.g., 'User',
|
|
45
|
+
'AI', 'System').
|
|
46
|
+
* **`metadata`**: A dictionary for structured data, such as tool inputs/
|
|
47
|
+
outputs, scores, or other contextual information.
|
|
48
|
+
* **`referred_modalities`**: A dictionary of modality objects (e.g.,
|
|
49
|
+
`lf.Image`, `lf.Audio`) referenced within the message text via placeholders
|
|
50
|
+
like `<<[[image_id]]>>`.
|
|
51
|
+
|
|
52
|
+
Subclasses like `lf.UserMessage`, `lf.AIMessage`, and `lf.ToolMessage`
|
|
53
|
+
represent messages from specific roles, enabling more complex conversational
|
|
54
|
+
flows and agentic behaviors.
|
|
55
|
+
|
|
56
|
+
**Example:**
|
|
57
|
+
|
|
58
|
+
```python
|
|
59
|
+
import langfun as lf
|
|
60
|
+
|
|
61
|
+
# Creating a user message with an image
|
|
62
|
+
image = lf.Image.from_path('/path/to/image.png')
|
|
63
|
+
user_message = lf.UserMessage(
|
|
64
|
+
f'What is in this image <<[[{image.id}]]>>?',
|
|
65
|
+
referred_modalities=[image])
|
|
66
|
+
|
|
67
|
+
# Creating an AI message with structured results
|
|
68
|
+
ai_message = lf.AIMessage(
|
|
69
|
+
'It is a cat.',
|
|
70
|
+
metadata=dict(result=dict(label='cat', confidence=0.9)))
|
|
71
|
+
|
|
72
|
+
print(user_message.chunk())
|
|
73
|
+
# Output: ['What is in this image', <lf.Image object>, '?']
|
|
74
|
+
|
|
75
|
+
print(ai_message.result)
|
|
76
|
+
# Output: {'label': 'cat', 'confidence': 0.9}
|
|
77
|
+
```
|
|
44
78
|
"""
|
|
45
79
|
|
|
46
80
|
#
|
|
@@ -86,6 +120,11 @@ class Message(
|
|
|
86
120
|
|
|
87
121
|
sender: Annotated[str, 'The sender of the message.']
|
|
88
122
|
|
|
123
|
+
referred_modalities: Annotated[
|
|
124
|
+
dict[str, pg.Ref[modality.Modality]],
|
|
125
|
+
'The modality objects referred in the message.'
|
|
126
|
+
] = pg.Dict()
|
|
127
|
+
|
|
89
128
|
metadata: Annotated[
|
|
90
129
|
dict[str, Any],
|
|
91
130
|
(
|
|
@@ -111,6 +150,11 @@ class Message(
|
|
|
111
150
|
*,
|
|
112
151
|
# Default sender is specified in subclasses.
|
|
113
152
|
sender: str | pg.object_utils.MissingValue = pg.MISSING_VALUE,
|
|
153
|
+
referred_modalities: (
|
|
154
|
+
list[modality.Modality]
|
|
155
|
+
| dict[str, modality.Modality]
|
|
156
|
+
| None
|
|
157
|
+
) = None,
|
|
114
158
|
metadata: dict[str, Any] | None = None,
|
|
115
159
|
tags: list[str] | None = None,
|
|
116
160
|
source: Optional['Message'] = None,
|
|
@@ -125,6 +169,7 @@ class Message(
|
|
|
125
169
|
Args:
|
|
126
170
|
text: The text in the message.
|
|
127
171
|
sender: The sender name of the message.
|
|
172
|
+
referred_modalities: The modality objects referred in the message.
|
|
128
173
|
metadata: Structured meta-data associated with this message.
|
|
129
174
|
tags: Tags for the message.
|
|
130
175
|
source: The source message of the current message.
|
|
@@ -138,9 +183,13 @@ class Message(
|
|
|
138
183
|
"""
|
|
139
184
|
metadata = metadata or {}
|
|
140
185
|
metadata.update(kwargs)
|
|
186
|
+
if isinstance(referred_modalities, list):
|
|
187
|
+
referred_modalities = {m.id: pg.Ref(m) for m in referred_modalities}
|
|
188
|
+
|
|
141
189
|
super().__init__(
|
|
142
190
|
text=text,
|
|
143
191
|
metadata=metadata,
|
|
192
|
+
referred_modalities=referred_modalities or {},
|
|
144
193
|
tags=tags or [],
|
|
145
194
|
sender=sender,
|
|
146
195
|
allow_partial=allow_partial,
|
|
@@ -186,7 +235,7 @@ class Message(
|
|
|
186
235
|
A message created from the value.
|
|
187
236
|
"""
|
|
188
237
|
if isinstance(value, modality.Modality):
|
|
189
|
-
return cls('<<[[
|
|
238
|
+
return cls(f'<<[[{value.id}]]>>', referred_modalities=[value])
|
|
190
239
|
if isinstance(value, Message):
|
|
191
240
|
return value
|
|
192
241
|
if isinstance(value, str):
|
|
@@ -224,6 +273,11 @@ class Message(
|
|
|
224
273
|
"""
|
|
225
274
|
return MessageConverter.get(format_or_type, **kwargs).to_value(self)
|
|
226
275
|
|
|
276
|
+
@classmethod
|
|
277
|
+
def is_convertible(cls, format_or_type: str | Type[Any]) -> bool:
|
|
278
|
+
"""Returns True if the value can be converted to a message."""
|
|
279
|
+
return MessageConverter.is_convertible(format_or_type)
|
|
280
|
+
|
|
227
281
|
@classmethod
|
|
228
282
|
def convertible_formats(cls) -> list[str]:
|
|
229
283
|
"""Returns supported format for message conversion."""
|
|
@@ -280,8 +334,7 @@ class Message(
|
|
|
280
334
|
if key_path == Message.PATH_TEXT:
|
|
281
335
|
return self.text
|
|
282
336
|
else:
|
|
283
|
-
|
|
284
|
-
return v.value if isinstance(v, pg.Ref) else v
|
|
337
|
+
return self.metadata.sym_get(key_path, default, use_inferred=True)
|
|
285
338
|
|
|
286
339
|
#
|
|
287
340
|
# API for accessing the structured result and error.
|
|
@@ -361,46 +414,63 @@ class Message(
|
|
|
361
414
|
# API for supporting modalities.
|
|
362
415
|
#
|
|
363
416
|
|
|
417
|
+
def modalities(
|
|
418
|
+
self,
|
|
419
|
+
filter: ( # pylint: disable=redefined-builtin
|
|
420
|
+
Type[modality.Modality]
|
|
421
|
+
| Callable[[modality.Modality], bool]
|
|
422
|
+
| None
|
|
423
|
+
) = None # pylint: disable=bad-whitespace
|
|
424
|
+
) -> list[modality.Modality]:
|
|
425
|
+
"""Returns the modality objects referred in the message."""
|
|
426
|
+
if inspect.isclass(filter) and issubclass(filter, modality.Modality):
|
|
427
|
+
filter_fn = lambda v: isinstance(v, filter) # pytype: disable=wrong-arg-types
|
|
428
|
+
elif filter is None:
|
|
429
|
+
filter_fn = lambda v: True
|
|
430
|
+
else:
|
|
431
|
+
filter_fn = filter
|
|
432
|
+
return [v for v in self.referred_modalities.values() if filter_fn(v)]
|
|
433
|
+
|
|
364
434
|
@property
|
|
365
|
-
def
|
|
366
|
-
"""Returns
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
435
|
+
def images(self) -> list[modality.Modality]:
|
|
436
|
+
"""Returns the image objects referred in the message."""
|
|
437
|
+
assert False, 'Overridden in core/modalities/__init__.py'
|
|
438
|
+
|
|
439
|
+
@property
|
|
440
|
+
def videos(self) -> list[modality.Modality]:
|
|
441
|
+
"""Returns the video objects referred in the message."""
|
|
442
|
+
assert False, 'Overridden in core/modalities/__init__.py'
|
|
443
|
+
|
|
444
|
+
@property
|
|
445
|
+
def audios(self) -> list[modality.Modality]:
|
|
446
|
+
"""Returns the audio objects referred in the message."""
|
|
447
|
+
assert False, 'Overridden in core/modalities/__init__.py'
|
|
373
448
|
|
|
374
449
|
def get_modality(
|
|
375
|
-
self,
|
|
450
|
+
self,
|
|
451
|
+
var_name: str,
|
|
452
|
+
default: Any = None
|
|
376
453
|
) -> modality.Modality | None:
|
|
377
|
-
"""
|
|
454
|
+
"""Returns modality object referred in the message by its variable name.
|
|
378
455
|
|
|
379
456
|
Args:
|
|
380
457
|
var_name: The referred variable name for the modality object.
|
|
381
458
|
default: default value.
|
|
382
|
-
from_message_chain: If True, the look up will be performed from the
|
|
383
|
-
message chain. Otherwise it will be performed in current message.
|
|
384
459
|
|
|
385
460
|
Returns:
|
|
386
461
|
A modality object if found, otherwise None.
|
|
387
462
|
"""
|
|
388
|
-
|
|
389
|
-
if isinstance(obj, modality.Modality):
|
|
390
|
-
return obj
|
|
391
|
-
elif obj is None and self.source is not None:
|
|
392
|
-
return self.source.get_modality(var_name, default, from_message_chain)
|
|
393
|
-
return default
|
|
394
|
-
|
|
395
|
-
def referred_modalities(self) -> dict[str, modality.Modality]:
|
|
396
|
-
"""Returns modality objects attached on this message."""
|
|
397
|
-
chunks = self.chunk()
|
|
398
|
-
return {
|
|
399
|
-
m.referred_name: m for m in chunks if isinstance(m, modality.Modality)
|
|
400
|
-
}
|
|
463
|
+
return self.referred_modalities.get(var_name, default)
|
|
401
464
|
|
|
402
465
|
def chunk(self, text: str | None = None) -> list[str | modality.Modality]:
|
|
403
|
-
"""
|
|
466
|
+
"""Chunks message into a list of text and modality chunks.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
text: The text to chunk. If None, use `self.text`.
|
|
470
|
+
|
|
471
|
+
Returns:
|
|
472
|
+
A list of text and modality chunks.
|
|
473
|
+
"""
|
|
404
474
|
chunks = []
|
|
405
475
|
|
|
406
476
|
def add_text_chunk(text_piece: str) -> None:
|
|
@@ -425,20 +495,25 @@ class Message(
|
|
|
425
495
|
|
|
426
496
|
var_name = text[var_start:ref_end].strip()
|
|
427
497
|
var_value = self.get_modality(var_name)
|
|
428
|
-
if var_value is
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
498
|
+
if var_value is None:
|
|
499
|
+
raise ValueError(
|
|
500
|
+
f'Unknown modality reference: {var_name!r}. '
|
|
501
|
+
'Please make sure the modality object is present in '
|
|
502
|
+
f'`referred_modalities` when creating {self.__class__.__name__}.'
|
|
503
|
+
)
|
|
504
|
+
add_text_chunk(text[chunk_start:ref_start].strip(' '))
|
|
505
|
+
chunks.append(var_value)
|
|
506
|
+
chunk_start = ref_end + len(modality.Modality.REF_END)
|
|
432
507
|
return chunks
|
|
433
508
|
|
|
434
509
|
@classmethod
|
|
435
510
|
def from_chunks(
|
|
436
511
|
cls, chunks: list[str | modality.Modality], separator: str = ' '
|
|
437
512
|
) -> 'Message':
|
|
438
|
-
"""
|
|
513
|
+
"""Assembles a message from a list of string or modality objects."""
|
|
439
514
|
fused_text = io.StringIO()
|
|
440
|
-
ref_index = 0
|
|
441
515
|
metadata = dict()
|
|
516
|
+
referred_modalities = dict()
|
|
442
517
|
last_char = None
|
|
443
518
|
for i, chunk in enumerate(chunks):
|
|
444
519
|
if i > 0 and last_char not in ('\t', ' ', '\n', None):
|
|
@@ -451,14 +526,16 @@ class Message(
|
|
|
451
526
|
last_char = None
|
|
452
527
|
else:
|
|
453
528
|
assert isinstance(chunk, modality.Modality), chunk
|
|
454
|
-
|
|
455
|
-
fused_text.write(modality.Modality.text_marker(var_name))
|
|
529
|
+
fused_text.write(modality.Modality.text_marker(chunk.id))
|
|
456
530
|
last_char = modality.Modality.REF_END[-1]
|
|
457
531
|
# Make a reference if the chunk is already owned by another object
|
|
458
532
|
# to avoid copy.
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
533
|
+
referred_modalities[chunk.id] = pg.Ref(chunk)
|
|
534
|
+
return cls(
|
|
535
|
+
fused_text.getvalue().strip(),
|
|
536
|
+
referred_modalities=referred_modalities,
|
|
537
|
+
metadata=metadata,
|
|
538
|
+
)
|
|
462
539
|
|
|
463
540
|
#
|
|
464
541
|
# Tagging
|
|
@@ -523,7 +600,7 @@ class Message(
|
|
|
523
600
|
return self.trace(Message.TAG_LM_OUTPUT)
|
|
524
601
|
|
|
525
602
|
def last(self, tag: str) -> Optional['Message']:
|
|
526
|
-
"""
|
|
603
|
+
"""Returns the last message with a given tag."""
|
|
527
604
|
current = self
|
|
528
605
|
while current is not None:
|
|
529
606
|
if tag in current.tags:
|
|
@@ -551,6 +628,11 @@ class Message(
|
|
|
551
628
|
#
|
|
552
629
|
|
|
553
630
|
def natural_language_format(self) -> str:
|
|
631
|
+
"""Returns the natural language format representation."""
|
|
632
|
+
# Propagate the modality references to parent context if any.
|
|
633
|
+
if capture_context := modality.get_modality_capture_context():
|
|
634
|
+
for v in self.referred_modalities.values():
|
|
635
|
+
capture_context.capture(v)
|
|
554
636
|
return self.text
|
|
555
637
|
|
|
556
638
|
def __eq__(self, other: Any) -> bool:
|
|
@@ -568,8 +650,7 @@ class Message(
|
|
|
568
650
|
def __getattr__(self, key: str) -> Any:
|
|
569
651
|
if key not in self.metadata:
|
|
570
652
|
raise AttributeError(key)
|
|
571
|
-
|
|
572
|
-
return v.value if isinstance(v, pg.Ref) else v
|
|
653
|
+
return self.metadata[key]
|
|
573
654
|
|
|
574
655
|
def _html_tree_view_content(
|
|
575
656
|
self,
|
|
@@ -646,15 +727,14 @@ class Message(
|
|
|
646
727
|
s.write(s.escape(chunk))
|
|
647
728
|
else:
|
|
648
729
|
assert isinstance(chunk, modality.Modality), chunk
|
|
649
|
-
child_path = pg.KeyPath(['metadata', chunk.referred_name], root_path)
|
|
650
730
|
s.write(
|
|
651
731
|
pg.Html.element(
|
|
652
732
|
'div',
|
|
653
733
|
[
|
|
654
734
|
view.render(
|
|
655
735
|
chunk,
|
|
656
|
-
name=chunk.
|
|
657
|
-
root_path=
|
|
736
|
+
name=chunk.id,
|
|
737
|
+
root_path=chunk.sym_path,
|
|
658
738
|
collapse_level=(
|
|
659
739
|
0 if collapse_modalities_in_text else 1
|
|
660
740
|
),
|
|
@@ -667,7 +747,7 @@ class Message(
|
|
|
667
747
|
css_classes=['modality-in-text'],
|
|
668
748
|
)
|
|
669
749
|
)
|
|
670
|
-
referred_chunks[chunk.
|
|
750
|
+
referred_chunks[chunk.id] = chunk
|
|
671
751
|
s.write('</div>')
|
|
672
752
|
return s
|
|
673
753
|
|
|
@@ -874,6 +954,12 @@ class _MessageConverterRegistry:
|
|
|
874
954
|
if converter.OUTPUT_TYPE is not None:
|
|
875
955
|
self._type_to_converters[converter.OUTPUT_TYPE].append(converter)
|
|
876
956
|
|
|
957
|
+
def unregister(self, converter: Type['MessageConverter']) -> None:
|
|
958
|
+
"""Unregisters a message converter."""
|
|
959
|
+
self._name_to_converter.pop(converter.FORMAT_ID, None)
|
|
960
|
+
if converter.OUTPUT_TYPE is not None:
|
|
961
|
+
self._type_to_converters[converter.OUTPUT_TYPE].remove(converter)
|
|
962
|
+
|
|
877
963
|
def get_by_type(self, t: Type[Any], **kwargs) -> 'MessageConverter':
|
|
878
964
|
"""Returns a message converter for the given type."""
|
|
879
965
|
t = self._type_to_converters[t]
|
|
@@ -904,6 +990,13 @@ class _MessageConverterRegistry:
|
|
|
904
990
|
assert isinstance(format_or_type, type), format_or_type
|
|
905
991
|
return self.get_by_type(format_or_type, **kwargs)
|
|
906
992
|
|
|
993
|
+
def is_convertible(self, format_or_type: str | Type[Any]) -> bool:
|
|
994
|
+
"""Returns whether the message is convertible to the given format or type."""
|
|
995
|
+
if isinstance(format_or_type, str):
|
|
996
|
+
return format_or_type in self._name_to_converter
|
|
997
|
+
assert isinstance(format_or_type, type), format_or_type
|
|
998
|
+
return bool(self._type_to_converters.get(format_or_type))
|
|
999
|
+
|
|
907
1000
|
def convertible_formats(self) -> list[str]:
|
|
908
1001
|
"""Returns a list of converter names."""
|
|
909
1002
|
return sorted(list(self._name_to_converter.keys()))
|
|
@@ -995,6 +1088,11 @@ class MessageConverter(pg.Object):
|
|
|
995
1088
|
"""Returns a message converter for the given type."""
|
|
996
1089
|
return cls._REGISTRY.get_by_type(t, **kwargs)
|
|
997
1090
|
|
|
1091
|
+
@classmethod
|
|
1092
|
+
def is_convertible(cls, format_or_type: str | Type[Any]) -> bool:
|
|
1093
|
+
"""Returns whether the message is convertible to the given format or type."""
|
|
1094
|
+
return cls._REGISTRY.is_convertible(format_or_type)
|
|
1095
|
+
|
|
998
1096
|
@classmethod
|
|
999
1097
|
def convertible_formats(cls) -> list[str]:
|
|
1000
1098
|
"""Returns a list of converter names."""
|
|
@@ -1036,3 +1134,10 @@ class MemoryRecord(Message):
|
|
|
1036
1134
|
"""Message used as a memory record."""
|
|
1037
1135
|
|
|
1038
1136
|
sender = 'Memory'
|
|
1137
|
+
|
|
1138
|
+
|
|
1139
|
+
@pg.use_init_args(['text', 'sender', 'metadata'])
|
|
1140
|
+
class ToolMessage(Message):
|
|
1141
|
+
"""Message used as a tool call."""
|
|
1142
|
+
|
|
1143
|
+
sender = 'Tool'
|