langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512040805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +1 -1
- langfun/core/__init__.py +7 -1
- langfun/core/agentic/__init__.py +8 -1
- langfun/core/agentic/action.py +740 -112
- langfun/core/agentic/action_eval.py +9 -2
- langfun/core/agentic/action_test.py +189 -24
- langfun/core/async_support.py +104 -5
- langfun/core/async_support_test.py +23 -0
- langfun/core/coding/python/correction.py +19 -9
- langfun/core/coding/python/execution.py +14 -12
- langfun/core/coding/python/generation.py +21 -16
- langfun/core/coding/python/sandboxing.py +23 -3
- langfun/core/component.py +42 -3
- langfun/core/concurrent.py +70 -6
- langfun/core/concurrent_test.py +9 -2
- langfun/core/console.py +1 -1
- langfun/core/data/conversion/anthropic.py +12 -3
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini.py +11 -2
- langfun/core/data/conversion/gemini_test.py +48 -9
- langfun/core/data/conversion/openai.py +145 -31
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base.py +48 -44
- langfun/core/eval/base_test.py +5 -5
- langfun/core/eval/matching.py +5 -2
- langfun/core/eval/patching.py +3 -3
- langfun/core/eval/scoring.py +4 -3
- langfun/core/eval/v2/__init__.py +2 -0
- langfun/core/eval/v2/checkpointing.py +76 -7
- langfun/core/eval/v2/checkpointing_test.py +9 -2
- langfun/core/eval/v2/config_saver.py +37 -0
- langfun/core/eval/v2/config_saver_test.py +36 -0
- langfun/core/eval/v2/eval_test_helper.py +104 -3
- langfun/core/eval/v2/evaluation.py +92 -17
- langfun/core/eval/v2/evaluation_test.py +9 -3
- langfun/core/eval/v2/example.py +50 -40
- langfun/core/eval/v2/example_test.py +16 -8
- langfun/core/eval/v2/experiment.py +84 -15
- langfun/core/eval/v2/experiment_test.py +19 -0
- langfun/core/eval/v2/metric_values.py +31 -3
- langfun/core/eval/v2/metric_values_test.py +32 -0
- langfun/core/eval/v2/metrics.py +157 -44
- langfun/core/eval/v2/metrics_test.py +39 -18
- langfun/core/eval/v2/progress.py +31 -1
- langfun/core/eval/v2/progress_test.py +27 -0
- langfun/core/eval/v2/progress_tracking.py +13 -5
- langfun/core/eval/v2/progress_tracking_test.py +9 -1
- langfun/core/eval/v2/reporting.py +90 -71
- langfun/core/eval/v2/reporting_test.py +24 -6
- langfun/core/eval/v2/runners/__init__.py +30 -0
- langfun/core/eval/v2/{runners.py → runners/base.py} +72 -180
- langfun/core/eval/v2/runners/beam.py +354 -0
- langfun/core/eval/v2/runners/beam_test.py +153 -0
- langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
- langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
- langfun/core/eval/v2/runners/debug.py +40 -0
- langfun/core/eval/v2/runners/debug_test.py +76 -0
- langfun/core/eval/v2/runners/parallel.py +243 -0
- langfun/core/eval/v2/runners/parallel_test.py +182 -0
- langfun/core/eval/v2/runners/sequential.py +47 -0
- langfun/core/eval/v2/runners/sequential_test.py +169 -0
- langfun/core/langfunc.py +45 -130
- langfun/core/langfunc_test.py +7 -5
- langfun/core/language_model.py +189 -36
- langfun/core/language_model_test.py +54 -3
- langfun/core/llms/__init__.py +12 -1
- langfun/core/llms/anthropic.py +157 -2
- langfun/core/llms/azure_openai.py +29 -17
- langfun/core/llms/cache/base.py +25 -3
- langfun/core/llms/cache/in_memory.py +48 -7
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/compositional.py +25 -1
- langfun/core/llms/deepseek.py +30 -2
- langfun/core/llms/fake.py +32 -1
- langfun/core/llms/gemini.py +64 -12
- langfun/core/llms/gemini_test.py +110 -0
- langfun/core/llms/google_genai.py +34 -1
- langfun/core/llms/groq.py +28 -3
- langfun/core/llms/llama_cpp.py +23 -4
- langfun/core/llms/openai.py +120 -3
- langfun/core/llms/openai_compatible.py +148 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/rest.py +16 -1
- langfun/core/llms/vertexai.py +58 -8
- langfun/core/logging.py +1 -1
- langfun/core/mcp/__init__.py +10 -0
- langfun/core/mcp/client.py +177 -0
- langfun/core/mcp/client_test.py +71 -0
- langfun/core/mcp/session.py +241 -0
- langfun/core/mcp/session_test.py +54 -0
- langfun/core/mcp/testing/simple_mcp_client.py +33 -0
- langfun/core/mcp/testing/simple_mcp_server.py +33 -0
- langfun/core/mcp/tool.py +254 -0
- langfun/core/mcp/tool_test.py +197 -0
- langfun/core/memory.py +1 -0
- langfun/core/message.py +160 -55
- langfun/core/message_test.py +65 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/audio.py +21 -1
- langfun/core/modalities/image.py +73 -3
- langfun/core/modalities/image_test.py +116 -0
- langfun/core/modalities/mime.py +64 -3
- langfun/core/modalities/mime_test.py +11 -0
- langfun/core/modalities/pdf.py +19 -1
- langfun/core/modalities/video.py +21 -1
- langfun/core/modality.py +167 -29
- langfun/core/modality_test.py +42 -12
- langfun/core/natural_language.py +1 -1
- langfun/core/sampling.py +4 -4
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/__init__.py +2 -24
- langfun/core/structured/completion.py +34 -44
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/description.py +54 -50
- langfun/core/structured/function_generation.py +29 -12
- langfun/core/structured/mapping.py +81 -37
- langfun/core/structured/parsing.py +95 -79
- langfun/core/structured/parsing_test.py +0 -3
- langfun/core/structured/querying.py +230 -154
- langfun/core/structured/querying_test.py +69 -33
- langfun/core/structured/schema/__init__.py +49 -0
- langfun/core/structured/schema/base.py +664 -0
- langfun/core/structured/schema/base_test.py +531 -0
- langfun/core/structured/schema/json.py +174 -0
- langfun/core/structured/schema/json_test.py +121 -0
- langfun/core/structured/schema/python.py +316 -0
- langfun/core/structured/schema/python_test.py +410 -0
- langfun/core/structured/schema_generation.py +33 -14
- langfun/core/structured/scoring.py +47 -36
- langfun/core/structured/tokenization.py +26 -11
- langfun/core/subscription.py +2 -2
- langfun/core/template.py +175 -50
- langfun/core/template_test.py +123 -17
- langfun/env/__init__.py +43 -0
- langfun/env/base_environment.py +827 -0
- langfun/env/base_environment_test.py +473 -0
- langfun/env/base_feature.py +304 -0
- langfun/env/base_feature_test.py +228 -0
- langfun/env/base_sandbox.py +842 -0
- langfun/env/base_sandbox_test.py +1235 -0
- langfun/env/event_handlers/__init__.py +14 -0
- langfun/env/event_handlers/chain.py +233 -0
- langfun/env/event_handlers/chain_test.py +253 -0
- langfun/env/event_handlers/event_logger.py +472 -0
- langfun/env/event_handlers/event_logger_test.py +304 -0
- langfun/env/event_handlers/metric_writer.py +726 -0
- langfun/env/event_handlers/metric_writer_test.py +214 -0
- langfun/env/interface.py +1640 -0
- langfun/env/interface_test.py +153 -0
- langfun/env/load_balancers.py +59 -0
- langfun/env/load_balancers_test.py +141 -0
- langfun/env/test_utils.py +507 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/METADATA +7 -3
- langfun-0.1.2.dev202512040805.dist-info/RECORD +217 -0
- langfun/core/eval/v2/runners_test.py +0 -343
- langfun/core/structured/schema.py +0 -987
- langfun/core/structured/schema_test.py +0 -982
- langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/top_level.txt +0 -0
|
@@ -103,6 +103,122 @@ class ImageTest(unittest.TestCase):
|
|
|
103
103
|
image_lib.Image.from_pil_image(image), image_lib.Image
|
|
104
104
|
)
|
|
105
105
|
|
|
106
|
+
def test_from_pil_image_os_error(self):
|
|
107
|
+
img = pil_image.open(io.BytesIO(image_content))
|
|
108
|
+
with mock.patch.object(img, 'save') as mock_save:
|
|
109
|
+
mock_save.side_effect = [OSError, None]
|
|
110
|
+
with mock.patch('os.chdir') as mock_chdir:
|
|
111
|
+
with mock.patch('os.getcwd') as mock_getcwd:
|
|
112
|
+
mock_getcwd.return_value = '/curr/dir'
|
|
113
|
+
image = image_lib.Image.from_pil_image(img)
|
|
114
|
+
self.assertIsInstance(image, image_lib.Image)
|
|
115
|
+
self.assertEqual(mock_save.call_count, 2)
|
|
116
|
+
mock_save.assert_has_calls([
|
|
117
|
+
mock.call(mock.ANY, format='PNG'),
|
|
118
|
+
mock.call(mock.ANY, format='PNG'),
|
|
119
|
+
])
|
|
120
|
+
mock_chdir.assert_has_calls([
|
|
121
|
+
mock.call('/tmp'),
|
|
122
|
+
mock.call('/curr/dir'),
|
|
123
|
+
])
|
|
124
|
+
|
|
125
|
+
def test_gif_is_compatible(self):
|
|
126
|
+
# Create a simple 1x1 GIF image using PIL
|
|
127
|
+
buf = io.BytesIO()
|
|
128
|
+
img = pil_image.new('P', (1, 1))
|
|
129
|
+
img.save(buf, format='GIF')
|
|
130
|
+
gif_bytes = buf.getvalue()
|
|
131
|
+
|
|
132
|
+
gif_image = image_lib.Image.from_bytes(gif_bytes)
|
|
133
|
+
self.assertEqual(gif_image.mime_type, 'image/gif')
|
|
134
|
+
|
|
135
|
+
# GIF should be compatible if PNG is in supported types
|
|
136
|
+
self.assertTrue(gif_image._is_compatible(['image/png']))
|
|
137
|
+
self.assertTrue(gif_image._is_compatible(['image/jpeg', 'image/webp']))
|
|
138
|
+
self.assertTrue(gif_image._is_compatible(['image/png', 'image/jpeg']))
|
|
139
|
+
|
|
140
|
+
# GIF should not be compatible if only unsupported types
|
|
141
|
+
self.assertFalse(gif_image._is_compatible(['video/mp4']))
|
|
142
|
+
self.assertFalse(gif_image._is_compatible(['application/pdf']))
|
|
143
|
+
|
|
144
|
+
def test_gif_make_compatible(self):
|
|
145
|
+
# Create a simple 1x1 GIF image using PIL
|
|
146
|
+
buf = io.BytesIO()
|
|
147
|
+
img = pil_image.new('P', (1, 1))
|
|
148
|
+
img.save(buf, format='GIF')
|
|
149
|
+
gif_bytes = buf.getvalue()
|
|
150
|
+
|
|
151
|
+
gif_image = image_lib.Image.from_bytes(gif_bytes)
|
|
152
|
+
self.assertEqual(gif_image.mime_type, 'image/gif')
|
|
153
|
+
|
|
154
|
+
# Test 1: Convert to PNG (first priority when available)
|
|
155
|
+
converted = gif_image.make_compatible(['image/png', 'image/jpeg'])
|
|
156
|
+
self.assertEqual(converted.mime_type, 'image/png')
|
|
157
|
+
self.assertIsInstance(converted, image_lib.Image)
|
|
158
|
+
|
|
159
|
+
# Test 2: Convert to JPEG when PNG not supported
|
|
160
|
+
converted = gif_image.make_compatible(['image/jpeg', 'image/webp'])
|
|
161
|
+
self.assertEqual(converted.mime_type, 'image/jpeg')
|
|
162
|
+
|
|
163
|
+
# Test 3: Convert to WEBP when PNG and JPEG not supported
|
|
164
|
+
converted = gif_image.make_compatible(['image/webp'])
|
|
165
|
+
self.assertEqual(converted.mime_type, 'image/webp')
|
|
166
|
+
|
|
167
|
+
# Test 4: Should raise error when no compatible format
|
|
168
|
+
with self.assertRaises(lf.ModalityError):
|
|
169
|
+
gif_image.make_compatible(['video/mp4'])
|
|
170
|
+
|
|
171
|
+
def test_is_compatible_direct_match(self):
|
|
172
|
+
image = image_lib.Image.from_bytes(image_content) # image/png
|
|
173
|
+
self.assertTrue(image._is_compatible(['image/png', 'image/jpeg']))
|
|
174
|
+
self.assertTrue(image._is_compatible(['image/png']))
|
|
175
|
+
self.assertFalse(image._is_compatible(['image/jpeg']))
|
|
176
|
+
|
|
177
|
+
def test_make_compatible_no_conversion(self):
|
|
178
|
+
image = image_lib.Image.from_bytes(image_content) # image/png
|
|
179
|
+
converted_image = image.make_compatible(['image/png', 'image/jpeg'])
|
|
180
|
+
self.assertIs(image, converted_image)
|
|
181
|
+
|
|
182
|
+
def test_convert_to_format_jpeg_transparency(self):
|
|
183
|
+
# Create a simple RGBA PNG image
|
|
184
|
+
buf = io.BytesIO()
|
|
185
|
+
img = pil_image.new('RGBA', (1, 1), (255, 0, 0, 128))
|
|
186
|
+
img.save(buf, format='PNG')
|
|
187
|
+
rgba_png_bytes = buf.getvalue()
|
|
188
|
+
|
|
189
|
+
rgba_image = image_lib.Image.from_bytes(rgba_png_bytes)
|
|
190
|
+
self.assertEqual(rgba_image.mime_type, 'image/png')
|
|
191
|
+
|
|
192
|
+
# Convert to JPEG, should trigger transparency handling
|
|
193
|
+
converted_image = rgba_image._convert_to_format('JPEG')
|
|
194
|
+
self.assertEqual(converted_image.mime_type, 'image/jpeg')
|
|
195
|
+
pil_img = converted_image.to_pil_image()
|
|
196
|
+
self.assertEqual(pil_img.mode, 'RGB')
|
|
197
|
+
|
|
198
|
+
def test_convert_to_format_os_error(self):
|
|
199
|
+
image = image_lib.Image.from_bytes(image_content)
|
|
200
|
+
mock_pil_image = mock.MagicMock()
|
|
201
|
+
mock_save = mock_pil_image.save
|
|
202
|
+
mock_save.side_effect = [OSError, None]
|
|
203
|
+
|
|
204
|
+
with mock.patch.object(
|
|
205
|
+
image, 'to_pil_image', return_value=mock_pil_image
|
|
206
|
+
), mock.patch('os.chdir') as mock_chdir, mock.patch(
|
|
207
|
+
'os.getcwd'
|
|
208
|
+
) as mock_getcwd:
|
|
209
|
+
mock_getcwd.return_value = '/curr/dir'
|
|
210
|
+
converted_image = image._convert_to_format('PNG')
|
|
211
|
+
self.assertIsInstance(converted_image, image_lib.Image)
|
|
212
|
+
self.assertEqual(mock_save.call_count, 2)
|
|
213
|
+
mock_save.assert_has_calls([
|
|
214
|
+
mock.call(mock.ANY, format='PNG'),
|
|
215
|
+
mock.call(mock.ANY, format='PNG'),
|
|
216
|
+
])
|
|
217
|
+
mock_chdir.assert_has_calls([
|
|
218
|
+
mock.call('/tmp'),
|
|
219
|
+
mock.call('/curr/dir'),
|
|
220
|
+
])
|
|
221
|
+
|
|
106
222
|
|
|
107
223
|
if __name__ == '__main__':
|
|
108
224
|
unittest.main()
|
langfun/core/modalities/mime.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import base64
|
|
17
17
|
import functools
|
|
18
|
+
import hashlib
|
|
18
19
|
from typing import Annotated, Any, Iterable, Type, Union
|
|
19
20
|
import langfun.core as lf
|
|
20
21
|
# Placeholder for Google-internal internet access import.
|
|
@@ -36,7 +37,33 @@ def _detect_mime_type(content: bytes) -> str:
|
|
|
36
37
|
|
|
37
38
|
|
|
38
39
|
class Mime(lf.Modality):
|
|
39
|
-
"""Base for MIME
|
|
40
|
+
"""Base class for representing modality data based on MIME types.
|
|
41
|
+
|
|
42
|
+
`lf.Mime` is a subclass of `lf.Modality` that serves as a base for
|
|
43
|
+
handling various data types like images, audio, video, and PDFs,
|
|
44
|
+
identified by their MIME types. It provides unified methods for
|
|
45
|
+
loading data from URIs or bytes (`.from_uri()`, `.from_bytes()`) and
|
|
46
|
+
for accessing content (`.to_bytes()`).
|
|
47
|
+
|
|
48
|
+
Subclasses like `lf.Image`, `lf.Audio`, `lf.Video`, and `lf.PDF`
|
|
49
|
+
specialize in handling specific MIME type prefixes (e.g., 'image/', 'audio/').
|
|
50
|
+
|
|
51
|
+
**Example:**
|
|
52
|
+
|
|
53
|
+
```python
|
|
54
|
+
import langfun as lf
|
|
55
|
+
|
|
56
|
+
# Load an image from a path
|
|
57
|
+
image = lf.Image.from_path('/path/to/image.png')
|
|
58
|
+
print(image.mime_type)
|
|
59
|
+
# Output: image/png
|
|
60
|
+
|
|
61
|
+
# Create a text document
|
|
62
|
+
text = lf.Custom.from_bytes(b'hello world', mime='text/plain')
|
|
63
|
+
print(text.mime_type)
|
|
64
|
+
# Output: text/plain
|
|
65
|
+
```
|
|
66
|
+
"""
|
|
40
67
|
|
|
41
68
|
# The regular expression that describes the MIME type str.
|
|
42
69
|
# If None, the MIME type is dynamic. Subclass could override.
|
|
@@ -48,6 +75,10 @@ class Mime(lf.Modality):
|
|
|
48
75
|
Union[str, bytes, None], 'The raw content of the MIME type.'
|
|
49
76
|
] = None
|
|
50
77
|
|
|
78
|
+
metadata: Annotated[
|
|
79
|
+
dict[str, Any], 'Additional metadata attached to this object.'
|
|
80
|
+
] = {}
|
|
81
|
+
|
|
51
82
|
@functools.cached_property
|
|
52
83
|
def mime_type(self) -> str:
|
|
53
84
|
"""Returns the MIME type."""
|
|
@@ -87,6 +118,17 @@ class Mime(lf.Modality):
|
|
|
87
118
|
"""Returns True if the MIME type is a binary type."""
|
|
88
119
|
return not self.is_text
|
|
89
120
|
|
|
121
|
+
@property
|
|
122
|
+
def hash(self) -> str:
|
|
123
|
+
"""Returns the hash of the MIME content."""
|
|
124
|
+
# Hash the URI to avoid downloading the content.
|
|
125
|
+
if self.uri is not None:
|
|
126
|
+
return hashlib.md5(self.uri.encode()).hexdigest()[:8]
|
|
127
|
+
if self.content is not None:
|
|
128
|
+
return super().hash
|
|
129
|
+
assert self.metadata
|
|
130
|
+
return hashlib.md5(str(self.metadata).encode()).hexdigest()[:8]
|
|
131
|
+
|
|
90
132
|
def to_text(self) -> str:
|
|
91
133
|
"""Returns the text content of the MIME type."""
|
|
92
134
|
if not self.is_text:
|
|
@@ -132,7 +174,7 @@ class Mime(lf.Modality):
|
|
|
132
174
|
|
|
133
175
|
def _on_bound(self):
|
|
134
176
|
super()._on_bound()
|
|
135
|
-
if self.uri is None and self.content is None:
|
|
177
|
+
if self.uri is None and self.content is None and not self.metadata:
|
|
136
178
|
raise ValueError('Either uri or content must be provided.')
|
|
137
179
|
|
|
138
180
|
def to_bytes(self) -> bytes:
|
|
@@ -162,6 +204,8 @@ class Mime(lf.Modality):
|
|
|
162
204
|
return cls.class_from_mime_type(mime_type).from_bytes(content, **kwargs)
|
|
163
205
|
|
|
164
206
|
if cls is Mime:
|
|
207
|
+
if 'youtube.com/watch' in uri:
|
|
208
|
+
return Custom(mime='text/html', uri=uri, **kwargs)
|
|
165
209
|
content = cls.download(uri)
|
|
166
210
|
mime = _detect_mime_type(content)
|
|
167
211
|
return cls.class_from_mime_type(mime)(uri=uri, content=content, **kwargs)
|
|
@@ -272,7 +316,24 @@ class Mime(lf.Modality):
|
|
|
272
316
|
|
|
273
317
|
@pg.use_init_args(['mime', 'content', 'uri'])
|
|
274
318
|
class Custom(Mime):
|
|
275
|
-
"""
|
|
319
|
+
"""Represents content of a custom MIME type.
|
|
320
|
+
|
|
321
|
+
`lf.modalities.Custom` is useful for representing data with MIME types
|
|
322
|
+
that do not have dedicated classes like `lf.Image` or `lf.Audio`.
|
|
323
|
+
|
|
324
|
+
**Example:**
|
|
325
|
+
|
|
326
|
+
```python
|
|
327
|
+
import langfun as lf
|
|
328
|
+
|
|
329
|
+
# Create a custom MIME object for plain text
|
|
330
|
+
text_data = lf.Custom.from_bytes(
|
|
331
|
+
b'This is a text document.', mime='text/plain'
|
|
332
|
+
)
|
|
333
|
+
print(text_data.mime_type)
|
|
334
|
+
# Output: text/plain
|
|
335
|
+
```
|
|
336
|
+
"""
|
|
276
337
|
|
|
277
338
|
mime: Annotated[
|
|
278
339
|
str, 'The MIME type of the data. E.g. text/plain, or image/png. '
|
|
@@ -109,6 +109,17 @@ class CustomMimeTest(unittest.TestCase):
|
|
|
109
109
|
with self.assertRaisesRegex(ValueError, 'Unsupported encoding'):
|
|
110
110
|
mime.Mime.from_uri('data:text/plain;base16,abcd')
|
|
111
111
|
|
|
112
|
+
# Test YouTube URI
|
|
113
|
+
yt_uri = 'https://www.youtube.com/watch?v=dQw4w9WgXcQ'
|
|
114
|
+
with mock.patch(
|
|
115
|
+
'langfun.core.modalities.mime.Mime.download'
|
|
116
|
+
) as mock_download:
|
|
117
|
+
content = mime.Mime.from_uri(yt_uri)
|
|
118
|
+
self.assertIsInstance(content, mime.Custom)
|
|
119
|
+
self.assertEqual(content.mime_type, 'text/html')
|
|
120
|
+
self.assertEqual(content.uri, yt_uri)
|
|
121
|
+
mock_download.assert_not_called()
|
|
122
|
+
|
|
112
123
|
def assert_html_content(self, html, expected):
|
|
113
124
|
expected = inspect.cleandoc(expected).strip()
|
|
114
125
|
actual = html.content.strip()
|
langfun/core/modalities/pdf.py
CHANGED
|
@@ -17,6 +17,24 @@ from langfun.core.modalities import mime
|
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
class PDF(mime.Mime):
|
|
20
|
-
"""PDF document.
|
|
20
|
+
"""Represents a PDF document for communicating with language models.
|
|
21
|
+
|
|
22
|
+
`lf.PDF` can be initialized from a URI (HTTP/HTTPS URL or local path)
|
|
23
|
+
using `lf.PDF.from_uri()` or from raw bytes using `lf.PDF.from_bytes()`.
|
|
24
|
+
|
|
25
|
+
**Example:**
|
|
26
|
+
|
|
27
|
+
```python
|
|
28
|
+
import langfun as lf
|
|
29
|
+
|
|
30
|
+
# Load PDF from path
|
|
31
|
+
pdf = lf.PDF.from_path('/path/to/document.pdf')
|
|
32
|
+
|
|
33
|
+
# Use PDF in a prompt
|
|
34
|
+
prompt = lf.Template('Summarize this document: {{pdf}}', pdf=pdf)
|
|
35
|
+
response = lf.query(prompt, lm=lf.llms.Gemini25Flash())
|
|
36
|
+
print(response)
|
|
37
|
+
```
|
|
38
|
+
"""
|
|
21
39
|
|
|
22
40
|
MIME_PREFIX = 'application/pdf'
|
langfun/core/modalities/video.py
CHANGED
|
@@ -18,7 +18,27 @@ from langfun.core.modalities import mime
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class Video(mime.Mime):
|
|
21
|
-
"""
|
|
21
|
+
"""Represents a video for communicating with language models.
|
|
22
|
+
|
|
23
|
+
`lf.Video` can be initialized from a URI (HTTP/HTTPS URL or local path)
|
|
24
|
+
using `lf.Video.from_uri()` or from raw bytes using `lf.Video.from_bytes()`.
|
|
25
|
+
|
|
26
|
+
**Example:**
|
|
27
|
+
|
|
28
|
+
```python
|
|
29
|
+
import langfun as lf
|
|
30
|
+
|
|
31
|
+
# Load video from path
|
|
32
|
+
video = lf.Video.from_path('/path/to/video.mp4')
|
|
33
|
+
|
|
34
|
+
# Use video in a prompt
|
|
35
|
+
prompt = lf.Template(
|
|
36
|
+
'What is happening in this video? {{video}}', video=video
|
|
37
|
+
)
|
|
38
|
+
response = lf.query(prompt, lm=lf.llms.Gemini25Flash())
|
|
39
|
+
print(response)
|
|
40
|
+
```
|
|
41
|
+
"""
|
|
22
42
|
|
|
23
43
|
MIME_PREFIX = 'video'
|
|
24
44
|
|
langfun/core/modality.py
CHANGED
|
@@ -14,40 +14,63 @@
|
|
|
14
14
|
"""Interface for modality (e.g. Image, Video, etc.)."""
|
|
15
15
|
|
|
16
16
|
import abc
|
|
17
|
+
import contextlib
|
|
17
18
|
import functools
|
|
18
19
|
import hashlib
|
|
19
|
-
|
|
20
|
+
import re
|
|
21
|
+
from typing import Any, ContextManager, Iterator
|
|
20
22
|
from langfun.core import component
|
|
21
23
|
import pyglove as pg
|
|
22
24
|
|
|
23
25
|
|
|
24
|
-
|
|
26
|
+
class Modality(component.Component, pg.views.HtmlTreeView.Extension):
|
|
27
|
+
"""Base class for representing non-text content in prompts.
|
|
25
28
|
|
|
29
|
+
`lf.Modality` is the base class for multimodal objects such as `lf.Image`,
|
|
30
|
+
`lf.Audio`, and `lf.Video`. It allows these non-text inputs to be
|
|
31
|
+
seamlessly embedded within text prompts for processing by multimodal
|
|
32
|
+
language models.
|
|
26
33
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
34
|
+
When a `Modality` object is rendered within an `lf.Template`, it is
|
|
35
|
+
replaced by a text marker (e.g., `<<[[image:b10a8db1]]>>`), and the
|
|
36
|
+
modality object itself is stored in the `referred_modalities` field of
|
|
37
|
+
the resulting `lf.Message`. This allows language models to associate
|
|
38
|
+
the placeholder with its content during processing.
|
|
32
39
|
|
|
40
|
+
**Example:**
|
|
33
41
|
|
|
34
|
-
|
|
35
|
-
|
|
42
|
+
```python
|
|
43
|
+
import langfun as lf
|
|
44
|
+
|
|
45
|
+
image = lf.Image.from_path('/path/to/image.png')
|
|
46
|
+
prompt = lf.Template('What is in this image? {{image}}', image=image)
|
|
47
|
+
|
|
48
|
+
message = prompt.render()
|
|
49
|
+
print(message.text)
|
|
50
|
+
# Output: What is in this image? <<[[image:b10a8db1]]>>
|
|
51
|
+
|
|
52
|
+
print(message.modalities())
|
|
53
|
+
# Output: [<Image object>]
|
|
54
|
+
```
|
|
55
|
+
"""
|
|
36
56
|
|
|
37
57
|
REF_START = '<<[['
|
|
38
58
|
REF_END = ']]>>'
|
|
39
59
|
|
|
40
60
|
def _on_bound(self):
|
|
41
61
|
super()._on_bound()
|
|
42
|
-
# Invalidate cached hash if modality member is changed.
|
|
62
|
+
# Invalidate cached hash and id if modality member is changed.
|
|
43
63
|
self.__dict__.pop('hash', None)
|
|
64
|
+
self.__dict__.pop('id', None)
|
|
44
65
|
|
|
45
66
|
def format(self, *args, **kwargs) -> str:
|
|
46
|
-
if
|
|
47
|
-
_TLS_MODALITY_AS_REF, False
|
|
48
|
-
):
|
|
67
|
+
if not pg.object_utils.thread_local_get(_TLS_MODALITY_AS_REF, False):
|
|
49
68
|
return super().format(*args, **kwargs)
|
|
50
|
-
|
|
69
|
+
|
|
70
|
+
capture_scope = get_modality_capture_context()
|
|
71
|
+
if capture_scope is not None:
|
|
72
|
+
capture_scope.capture(self)
|
|
73
|
+
return Modality.text_marker(self.id)
|
|
51
74
|
|
|
52
75
|
def __str_kwargs__(self) -> dict[str, Any]:
|
|
53
76
|
# For modality objects, we don't want to use markdown format when they
|
|
@@ -70,14 +93,11 @@ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
|
|
|
70
93
|
"""Returns a marker in the text for this object."""
|
|
71
94
|
return Modality.REF_START + var_name + Modality.REF_END
|
|
72
95
|
|
|
73
|
-
@
|
|
74
|
-
def
|
|
96
|
+
@functools.cached_property
|
|
97
|
+
def id(self) -> str | None:
|
|
75
98
|
"""Returns the referred name of this object in its template."""
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
# Strip the metadata prefix under message.
|
|
79
|
-
path = str(self.sym_path)
|
|
80
|
-
return path[9:] if path.startswith('metadata.') else path
|
|
99
|
+
modality_type = _camel_to_snake(self.__class__.__name__)
|
|
100
|
+
return f'{modality_type}:{self.hash}'
|
|
81
101
|
|
|
82
102
|
@classmethod
|
|
83
103
|
def from_value(cls, value: pg.Symbolic) -> dict[str, 'Modality']:
|
|
@@ -86,7 +106,7 @@ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
|
|
|
86
106
|
def _visit(k, v, p):
|
|
87
107
|
del k, p
|
|
88
108
|
if isinstance(v, Modality):
|
|
89
|
-
modalities[v.
|
|
109
|
+
modalities[v.id] = v
|
|
90
110
|
return pg.TraverseAction.CONTINUE
|
|
91
111
|
return pg.TraverseAction.ENTER
|
|
92
112
|
|
|
@@ -95,14 +115,47 @@ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
|
|
|
95
115
|
|
|
96
116
|
|
|
97
117
|
class ModalityRef(pg.Object, pg.typing.CustomTyping):
|
|
98
|
-
"""
|
|
118
|
+
"""Lightweight placeholder for a `lf.Modality` object in a symbolic tree.
|
|
99
119
|
|
|
100
|
-
`ModalityRef`
|
|
101
|
-
|
|
102
|
-
|
|
120
|
+
`ModalityRef` acts as a reference to a `Modality` object (like `lf.Image`
|
|
121
|
+
or `lf.Audio`) within a structured object hierarchy (e.g., a `pg.Object`).
|
|
122
|
+
Instead of embedding potentially large modality data directly, `ModalityRef`
|
|
123
|
+
stores only the ID of the modality object.
|
|
124
|
+
|
|
125
|
+
This is useful in scenarios where structured objects are serialized or
|
|
126
|
+
manipulated, and it's more efficient to refer to modalities by ID rather
|
|
127
|
+
than copying their content. The `lf.ModalityRef.placehold()` class method
|
|
128
|
+
can be used to replace `Modality` instances in a symbolic object with
|
|
129
|
+
`ModalityRef` placeholders, while `lf.ModalityRef.restore()` can reinstate
|
|
130
|
+
the original `Modality` objects using a lookup table.
|
|
131
|
+
|
|
132
|
+
**Example:**
|
|
133
|
+
|
|
134
|
+
```python
|
|
135
|
+
import langfun as lf
|
|
136
|
+
import pyglove as pg
|
|
137
|
+
|
|
138
|
+
class ImagePair(pg.Object):
|
|
139
|
+
image1: lf.Image
|
|
140
|
+
image2: lf.Image
|
|
141
|
+
|
|
142
|
+
pair = ImagePair(
|
|
143
|
+
image1=lf.Image(content=b'abc'), image2=lf.Image(content=b'def')
|
|
144
|
+
)
|
|
145
|
+
modalities = lf.Modality.from_value(pair)
|
|
146
|
+
|
|
147
|
+
# Replace Image objects with ModalityRef placeholders
|
|
148
|
+
pair_with_refs = lf.ModalityRef.placehold(pair)
|
|
149
|
+
print(pair_with_refs.image1)
|
|
150
|
+
# Output: ModalityRef(id='image:d81e5a68')
|
|
151
|
+
|
|
152
|
+
# Restore Image objects from ModalityRef placeholders
|
|
153
|
+
pair_restored = lf.ModalityRef.restore(pair_with_refs, modalities)
|
|
154
|
+
assert pair_restored.image1.content == b'abc'
|
|
155
|
+
```
|
|
103
156
|
"""
|
|
104
157
|
|
|
105
|
-
|
|
158
|
+
id: str
|
|
106
159
|
|
|
107
160
|
def custom_apply(
|
|
108
161
|
self, path: pg.KeyPath, value_spec: pg.ValueSpec, *args, **kwargs
|
|
@@ -122,12 +175,97 @@ class ModalityRef(pg.Object, pg.typing.CustomTyping):
|
|
|
122
175
|
"""
|
|
123
176
|
|
|
124
177
|
def _placehold(k, v, p):
|
|
125
|
-
del p
|
|
178
|
+
del k, p
|
|
126
179
|
if isinstance(v, Modality):
|
|
127
|
-
return ModalityRef(
|
|
180
|
+
return ModalityRef(id=v.id)
|
|
128
181
|
return v
|
|
129
182
|
return value.clone().rebind(_placehold, raise_on_no_change=False)
|
|
130
183
|
|
|
184
|
+
@classmethod
|
|
185
|
+
def restore(cls, value: pg.Symbolic, modalities: dict[str, Modality]) -> Any:
|
|
186
|
+
"""Returns a copy of value by replacing refs with modality objects."""
|
|
187
|
+
def _restore(k, v, p):
|
|
188
|
+
del k, p
|
|
189
|
+
if isinstance(v, ModalityRef):
|
|
190
|
+
modality_object = modalities.get(v.id)
|
|
191
|
+
if modality_object is None:
|
|
192
|
+
raise ValueError(
|
|
193
|
+
f'Modality {v.id} not found in modalities {modalities.keys()}'
|
|
194
|
+
)
|
|
195
|
+
return modality_object
|
|
196
|
+
return v
|
|
197
|
+
return value.rebind(_restore, raise_on_no_change=False)
|
|
198
|
+
|
|
131
199
|
|
|
132
200
|
class ModalityError(RuntimeError): # pylint: disable=g-bad-exception-name
|
|
133
201
|
"""Exception raised when modality is not supported."""
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
#
|
|
205
|
+
# Context managers to deal with modality objects.
|
|
206
|
+
#
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
_TLS_MODALITY_CAPTURE_SCOPE = '__modality_capture_scope__'
|
|
210
|
+
_TLS_MODALITY_AS_REF = '__format_modality_as_ref__'
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def format_modality_as_ref(enabled: bool = True) -> ContextManager[None]:
|
|
214
|
+
"""A context manager that formats modality objects as references."""
|
|
215
|
+
return pg.object_utils.thread_local_value_scope(
|
|
216
|
+
_TLS_MODALITY_AS_REF, enabled, False
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class _ModalityCaptureContext:
|
|
221
|
+
"""A context to capture modality objects when being rendered."""
|
|
222
|
+
|
|
223
|
+
def __init__(self):
|
|
224
|
+
self._references: dict[str, pg.Ref[Modality]] = {}
|
|
225
|
+
|
|
226
|
+
def capture(self, modality: Modality) -> None:
|
|
227
|
+
"""Captures the modality object."""
|
|
228
|
+
self._references[modality.id] = pg.Ref(modality)
|
|
229
|
+
|
|
230
|
+
@property
|
|
231
|
+
def references(self) -> dict[str, pg.Ref[Modality]]:
|
|
232
|
+
"""Returns the modality references captured in this context."""
|
|
233
|
+
return self._references
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
@contextlib.contextmanager
|
|
237
|
+
def capture_rendered_modalities() -> Iterator[dict[str, pg.Ref[Modality]]]:
|
|
238
|
+
"""Capture modality objects whose references is being rendered.
|
|
239
|
+
|
|
240
|
+
Example:
|
|
241
|
+
```
|
|
242
|
+
image = lf.Image.from_url(...)
|
|
243
|
+
with lf.modality.capture_rendered_modalities() as rendered_modalities:
|
|
244
|
+
with lf.modality.format_modality_as_ref():
|
|
245
|
+
print(f'Hello {image}')
|
|
246
|
+
self.assertEqual(rendered_modalities, {'image:<hash>': pg.Ref(image)})
|
|
247
|
+
```
|
|
248
|
+
"""
|
|
249
|
+
context = get_modality_capture_context()
|
|
250
|
+
top_level = context is None
|
|
251
|
+
if top_level:
|
|
252
|
+
context = _ModalityCaptureContext()
|
|
253
|
+
pg.object_utils.thread_local_set(_TLS_MODALITY_CAPTURE_SCOPE, context)
|
|
254
|
+
|
|
255
|
+
try:
|
|
256
|
+
yield context.references # pylint: disable=attribute-error
|
|
257
|
+
finally:
|
|
258
|
+
if top_level:
|
|
259
|
+
pg.object_utils.thread_local_del(_TLS_MODALITY_CAPTURE_SCOPE)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def get_modality_capture_context() -> _ModalityCaptureContext | None:
|
|
263
|
+
"""Returns the current modality capture context."""
|
|
264
|
+
return pg.object_utils.thread_local_get(_TLS_MODALITY_CAPTURE_SCOPE, None)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _camel_to_snake(name: str) -> str:
|
|
268
|
+
"""Converts a camelCase name to snake_case."""
|
|
269
|
+
return re.sub(
|
|
270
|
+
pattern=r'([A-Z]+)', repl=r'_\1', string=name
|
|
271
|
+
).lower().lstrip('_')
|
langfun/core/modality_test.py
CHANGED
|
@@ -29,34 +29,64 @@ class ModalityTest(unittest.TestCase):
|
|
|
29
29
|
|
|
30
30
|
def test_basic(self):
|
|
31
31
|
v = CustomModality('a')
|
|
32
|
-
self.
|
|
32
|
+
self.assertEqual(v.id, 'custom_modality:0cc175b9')
|
|
33
33
|
self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
|
|
34
34
|
self.assertEqual(v.hash, '0cc175b9')
|
|
35
35
|
|
|
36
36
|
_ = pg.Dict(metadata=pg.Dict(x=pg.Dict(metadata=pg.Dict(y=v))))
|
|
37
|
-
self.assertEqual(v.
|
|
37
|
+
self.assertEqual(v.id, 'custom_modality:0cc175b9')
|
|
38
38
|
self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
|
|
39
39
|
with modality.format_modality_as_ref():
|
|
40
|
-
self.assertEqual(str(v), '<<[[
|
|
40
|
+
self.assertEqual(str(v), '<<[[custom_modality:0cc175b9]]>>')
|
|
41
|
+
|
|
42
|
+
def test_capture_rendered_modalities(self):
|
|
43
|
+
x = CustomModality('a')
|
|
44
|
+
y = CustomModality('b')
|
|
45
|
+
z = CustomModality('b')
|
|
46
|
+
|
|
47
|
+
with modality.capture_rendered_modalities() as rendered_modalities:
|
|
48
|
+
with modality.format_modality_as_ref():
|
|
49
|
+
self.assertEqual(
|
|
50
|
+
f'Hello {x} {y} {z}',
|
|
51
|
+
(
|
|
52
|
+
'Hello <<[[custom_modality:0cc175b9]]>> '
|
|
53
|
+
'<<[[custom_modality:92eb5ffe]]>> '
|
|
54
|
+
'<<[[custom_modality:92eb5ffe]]>>'
|
|
55
|
+
)
|
|
56
|
+
)
|
|
57
|
+
self.assertEqual(len(rendered_modalities), 2)
|
|
58
|
+
self.assertIs(rendered_modalities['custom_modality:0cc175b9'].value, x)
|
|
59
|
+
# y and z share the same content will be treated as the same object.
|
|
60
|
+
self.assertIs(rendered_modalities['custom_modality:92eb5ffe'].value, z)
|
|
41
61
|
|
|
42
62
|
|
|
43
63
|
class ModalityRefTest(unittest.TestCase):
|
|
44
64
|
|
|
45
|
-
def
|
|
65
|
+
def test_placehold_and_restore(self):
|
|
46
66
|
class A(pg.Object):
|
|
47
67
|
x: Any
|
|
48
68
|
y: Any
|
|
49
69
|
|
|
50
|
-
|
|
70
|
+
image_a = CustomModality('a')
|
|
71
|
+
image_b = CustomModality('b')
|
|
72
|
+
a = A(x=dict(z=image_a), y=image_b)
|
|
73
|
+
a_placehold = modality.ModalityRef.placehold(a)
|
|
51
74
|
self.assertEqual(
|
|
52
|
-
|
|
53
|
-
A(x=dict(z=modality.ModalityRef(
|
|
75
|
+
a_placehold,
|
|
76
|
+
A(x=dict(z=modality.ModalityRef(image_a.id)),
|
|
77
|
+
y=modality.ModalityRef(image_b.id)),
|
|
78
|
+
)
|
|
79
|
+
a_restore = modality.ModalityRef.restore(
|
|
80
|
+
a_placehold.clone(),
|
|
81
|
+
{image_a.id: image_a, image_b.id: image_b},
|
|
54
82
|
)
|
|
83
|
+
self.assertTrue(pg.eq(a_restore, a))
|
|
55
84
|
self.assertEqual(
|
|
56
85
|
modality.ModalityRef.placehold(a.x),
|
|
57
|
-
|
|
58
|
-
dict(z=modality.ModalityRef('x.z')),
|
|
86
|
+
dict(z=modality.ModalityRef(image_a.id)),
|
|
59
87
|
)
|
|
88
|
+
with self.assertRaisesRegex(ValueError, 'Modality .* not found'):
|
|
89
|
+
modality.ModalityRef.restore(a_placehold, {image_a.id: image_a})
|
|
60
90
|
|
|
61
91
|
def test_from_value(self):
|
|
62
92
|
class A(pg.Object):
|
|
@@ -68,8 +98,8 @@ class ModalityRefTest(unittest.TestCase):
|
|
|
68
98
|
pg.eq(
|
|
69
99
|
modality.Modality.from_value(a),
|
|
70
100
|
{
|
|
71
|
-
'
|
|
72
|
-
'
|
|
101
|
+
'custom_modality:0cc175b9': CustomModality('a'),
|
|
102
|
+
'custom_modality:92eb5ffe': CustomModality('b'),
|
|
73
103
|
},
|
|
74
104
|
)
|
|
75
105
|
)
|
|
@@ -77,7 +107,7 @@ class ModalityRefTest(unittest.TestCase):
|
|
|
77
107
|
pg.eq(
|
|
78
108
|
modality.Modality.from_value(a.x.z),
|
|
79
109
|
{
|
|
80
|
-
'
|
|
110
|
+
'custom_modality:0cc175b9': CustomModality('a'),
|
|
81
111
|
},
|
|
82
112
|
)
|
|
83
113
|
)
|
langfun/core/natural_language.py
CHANGED
|
@@ -11,7 +11,7 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
"""Natural language
|
|
14
|
+
"""Natural language formatting."""
|
|
15
15
|
|
|
16
16
|
import abc
|
|
17
17
|
import pyglove as pg
|