langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512150805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +1 -1
- langfun/core/__init__.py +7 -1
- langfun/core/agentic/__init__.py +8 -1
- langfun/core/agentic/action.py +740 -112
- langfun/core/agentic/action_eval.py +9 -2
- langfun/core/agentic/action_test.py +189 -24
- langfun/core/async_support.py +104 -5
- langfun/core/async_support_test.py +23 -0
- langfun/core/coding/python/correction.py +19 -9
- langfun/core/coding/python/execution.py +14 -12
- langfun/core/coding/python/generation.py +21 -16
- langfun/core/coding/python/sandboxing.py +23 -3
- langfun/core/component.py +42 -3
- langfun/core/concurrent.py +70 -6
- langfun/core/concurrent_test.py +9 -2
- langfun/core/console.py +1 -1
- langfun/core/data/conversion/anthropic.py +12 -3
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini.py +11 -2
- langfun/core/data/conversion/gemini_test.py +48 -9
- langfun/core/data/conversion/openai.py +145 -31
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base.py +48 -44
- langfun/core/eval/base_test.py +5 -5
- langfun/core/eval/matching.py +5 -2
- langfun/core/eval/patching.py +3 -3
- langfun/core/eval/scoring.py +4 -3
- langfun/core/eval/v2/__init__.py +3 -0
- langfun/core/eval/v2/checkpointing.py +148 -46
- langfun/core/eval/v2/checkpointing_test.py +9 -2
- langfun/core/eval/v2/config_saver.py +37 -0
- langfun/core/eval/v2/config_saver_test.py +36 -0
- langfun/core/eval/v2/eval_test_helper.py +104 -3
- langfun/core/eval/v2/evaluation.py +102 -19
- langfun/core/eval/v2/evaluation_test.py +9 -3
- langfun/core/eval/v2/example.py +50 -40
- langfun/core/eval/v2/example_test.py +16 -8
- langfun/core/eval/v2/experiment.py +95 -20
- langfun/core/eval/v2/experiment_test.py +19 -0
- langfun/core/eval/v2/metric_values.py +31 -3
- langfun/core/eval/v2/metric_values_test.py +32 -0
- langfun/core/eval/v2/metrics.py +157 -44
- langfun/core/eval/v2/metrics_test.py +39 -18
- langfun/core/eval/v2/progress.py +31 -1
- langfun/core/eval/v2/progress_test.py +27 -0
- langfun/core/eval/v2/progress_tracking.py +13 -5
- langfun/core/eval/v2/progress_tracking_test.py +9 -1
- langfun/core/eval/v2/reporting.py +88 -71
- langfun/core/eval/v2/reporting_test.py +24 -6
- langfun/core/eval/v2/runners/__init__.py +30 -0
- langfun/core/eval/v2/{runners.py → runners/base.py} +73 -180
- langfun/core/eval/v2/runners/beam.py +354 -0
- langfun/core/eval/v2/runners/beam_test.py +153 -0
- langfun/core/eval/v2/runners/ckpt_monitor.py +350 -0
- langfun/core/eval/v2/runners/ckpt_monitor_test.py +213 -0
- langfun/core/eval/v2/runners/debug.py +40 -0
- langfun/core/eval/v2/runners/debug_test.py +76 -0
- langfun/core/eval/v2/runners/parallel.py +243 -0
- langfun/core/eval/v2/runners/parallel_test.py +182 -0
- langfun/core/eval/v2/runners/sequential.py +47 -0
- langfun/core/eval/v2/runners/sequential_test.py +169 -0
- langfun/core/langfunc.py +45 -130
- langfun/core/langfunc_test.py +7 -5
- langfun/core/language_model.py +189 -36
- langfun/core/language_model_test.py +54 -3
- langfun/core/llms/__init__.py +14 -1
- langfun/core/llms/anthropic.py +157 -2
- langfun/core/llms/azure_openai.py +29 -17
- langfun/core/llms/cache/base.py +25 -3
- langfun/core/llms/cache/in_memory.py +48 -7
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/compositional.py +25 -1
- langfun/core/llms/deepseek.py +30 -2
- langfun/core/llms/fake.py +32 -1
- langfun/core/llms/gemini.py +90 -12
- langfun/core/llms/gemini_test.py +110 -0
- langfun/core/llms/google_genai.py +52 -1
- langfun/core/llms/groq.py +28 -3
- langfun/core/llms/llama_cpp.py +23 -4
- langfun/core/llms/openai.py +120 -3
- langfun/core/llms/openai_compatible.py +148 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/rest.py +16 -1
- langfun/core/llms/vertexai.py +78 -8
- langfun/core/logging.py +1 -1
- langfun/core/mcp/__init__.py +10 -0
- langfun/core/mcp/client.py +177 -0
- langfun/core/mcp/client_test.py +71 -0
- langfun/core/mcp/session.py +241 -0
- langfun/core/mcp/session_test.py +54 -0
- langfun/core/mcp/testing/simple_mcp_client.py +33 -0
- langfun/core/mcp/testing/simple_mcp_server.py +33 -0
- langfun/core/mcp/tool.py +254 -0
- langfun/core/mcp/tool_test.py +197 -0
- langfun/core/memory.py +1 -0
- langfun/core/message.py +160 -55
- langfun/core/message_test.py +65 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/audio.py +21 -1
- langfun/core/modalities/image.py +73 -3
- langfun/core/modalities/image_test.py +116 -0
- langfun/core/modalities/mime.py +78 -4
- langfun/core/modalities/mime_test.py +59 -0
- langfun/core/modalities/pdf.py +19 -1
- langfun/core/modalities/video.py +21 -1
- langfun/core/modality.py +167 -29
- langfun/core/modality_test.py +42 -12
- langfun/core/natural_language.py +1 -1
- langfun/core/sampling.py +4 -4
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/__init__.py +2 -24
- langfun/core/structured/completion.py +34 -44
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/description.py +54 -50
- langfun/core/structured/function_generation.py +29 -12
- langfun/core/structured/mapping.py +81 -37
- langfun/core/structured/parsing.py +95 -79
- langfun/core/structured/parsing_test.py +0 -3
- langfun/core/structured/querying.py +230 -154
- langfun/core/structured/querying_test.py +69 -33
- langfun/core/structured/schema/__init__.py +49 -0
- langfun/core/structured/schema/base.py +664 -0
- langfun/core/structured/schema/base_test.py +531 -0
- langfun/core/structured/schema/json.py +174 -0
- langfun/core/structured/schema/json_test.py +121 -0
- langfun/core/structured/schema/python.py +316 -0
- langfun/core/structured/schema/python_test.py +410 -0
- langfun/core/structured/schema_generation.py +33 -14
- langfun/core/structured/scoring.py +47 -36
- langfun/core/structured/tokenization.py +26 -11
- langfun/core/subscription.py +2 -2
- langfun/core/template.py +175 -50
- langfun/core/template_test.py +123 -17
- langfun/env/__init__.py +43 -0
- langfun/env/base_environment.py +827 -0
- langfun/env/base_environment_test.py +473 -0
- langfun/env/base_feature.py +304 -0
- langfun/env/base_feature_test.py +228 -0
- langfun/env/base_sandbox.py +842 -0
- langfun/env/base_sandbox_test.py +1235 -0
- langfun/env/event_handlers/__init__.py +14 -0
- langfun/env/event_handlers/chain.py +233 -0
- langfun/env/event_handlers/chain_test.py +253 -0
- langfun/env/event_handlers/event_logger.py +472 -0
- langfun/env/event_handlers/event_logger_test.py +304 -0
- langfun/env/event_handlers/metric_writer.py +726 -0
- langfun/env/event_handlers/metric_writer_test.py +214 -0
- langfun/env/interface.py +1640 -0
- langfun/env/interface_test.py +153 -0
- langfun/env/load_balancers.py +59 -0
- langfun/env/load_balancers_test.py +141 -0
- langfun/env/test_utils.py +507 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/METADATA +7 -3
- langfun-0.1.2.dev202512150805.dist-info/RECORD +217 -0
- langfun/core/eval/v2/runners_test.py +0 -343
- langfun/core/structured/schema.py +0 -987
- langfun/core/structured/schema_test.py +0 -982
- langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/top_level.txt +0 -0
|
@@ -103,6 +103,122 @@ class ImageTest(unittest.TestCase):
|
|
|
103
103
|
image_lib.Image.from_pil_image(image), image_lib.Image
|
|
104
104
|
)
|
|
105
105
|
|
|
106
|
+
def test_from_pil_image_os_error(self):
|
|
107
|
+
img = pil_image.open(io.BytesIO(image_content))
|
|
108
|
+
with mock.patch.object(img, 'save') as mock_save:
|
|
109
|
+
mock_save.side_effect = [OSError, None]
|
|
110
|
+
with mock.patch('os.chdir') as mock_chdir:
|
|
111
|
+
with mock.patch('os.getcwd') as mock_getcwd:
|
|
112
|
+
mock_getcwd.return_value = '/curr/dir'
|
|
113
|
+
image = image_lib.Image.from_pil_image(img)
|
|
114
|
+
self.assertIsInstance(image, image_lib.Image)
|
|
115
|
+
self.assertEqual(mock_save.call_count, 2)
|
|
116
|
+
mock_save.assert_has_calls([
|
|
117
|
+
mock.call(mock.ANY, format='PNG'),
|
|
118
|
+
mock.call(mock.ANY, format='PNG'),
|
|
119
|
+
])
|
|
120
|
+
mock_chdir.assert_has_calls([
|
|
121
|
+
mock.call('/tmp'),
|
|
122
|
+
mock.call('/curr/dir'),
|
|
123
|
+
])
|
|
124
|
+
|
|
125
|
+
def test_gif_is_compatible(self):
|
|
126
|
+
# Create a simple 1x1 GIF image using PIL
|
|
127
|
+
buf = io.BytesIO()
|
|
128
|
+
img = pil_image.new('P', (1, 1))
|
|
129
|
+
img.save(buf, format='GIF')
|
|
130
|
+
gif_bytes = buf.getvalue()
|
|
131
|
+
|
|
132
|
+
gif_image = image_lib.Image.from_bytes(gif_bytes)
|
|
133
|
+
self.assertEqual(gif_image.mime_type, 'image/gif')
|
|
134
|
+
|
|
135
|
+
# GIF should be compatible if PNG is in supported types
|
|
136
|
+
self.assertTrue(gif_image._is_compatible(['image/png']))
|
|
137
|
+
self.assertTrue(gif_image._is_compatible(['image/jpeg', 'image/webp']))
|
|
138
|
+
self.assertTrue(gif_image._is_compatible(['image/png', 'image/jpeg']))
|
|
139
|
+
|
|
140
|
+
# GIF should not be compatible if only unsupported types
|
|
141
|
+
self.assertFalse(gif_image._is_compatible(['video/mp4']))
|
|
142
|
+
self.assertFalse(gif_image._is_compatible(['application/pdf']))
|
|
143
|
+
|
|
144
|
+
def test_gif_make_compatible(self):
|
|
145
|
+
# Create a simple 1x1 GIF image using PIL
|
|
146
|
+
buf = io.BytesIO()
|
|
147
|
+
img = pil_image.new('P', (1, 1))
|
|
148
|
+
img.save(buf, format='GIF')
|
|
149
|
+
gif_bytes = buf.getvalue()
|
|
150
|
+
|
|
151
|
+
gif_image = image_lib.Image.from_bytes(gif_bytes)
|
|
152
|
+
self.assertEqual(gif_image.mime_type, 'image/gif')
|
|
153
|
+
|
|
154
|
+
# Test 1: Convert to PNG (first priority when available)
|
|
155
|
+
converted = gif_image.make_compatible(['image/png', 'image/jpeg'])
|
|
156
|
+
self.assertEqual(converted.mime_type, 'image/png')
|
|
157
|
+
self.assertIsInstance(converted, image_lib.Image)
|
|
158
|
+
|
|
159
|
+
# Test 2: Convert to JPEG when PNG not supported
|
|
160
|
+
converted = gif_image.make_compatible(['image/jpeg', 'image/webp'])
|
|
161
|
+
self.assertEqual(converted.mime_type, 'image/jpeg')
|
|
162
|
+
|
|
163
|
+
# Test 3: Convert to WEBP when PNG and JPEG not supported
|
|
164
|
+
converted = gif_image.make_compatible(['image/webp'])
|
|
165
|
+
self.assertEqual(converted.mime_type, 'image/webp')
|
|
166
|
+
|
|
167
|
+
# Test 4: Should raise error when no compatible format
|
|
168
|
+
with self.assertRaises(lf.ModalityError):
|
|
169
|
+
gif_image.make_compatible(['video/mp4'])
|
|
170
|
+
|
|
171
|
+
def test_is_compatible_direct_match(self):
|
|
172
|
+
image = image_lib.Image.from_bytes(image_content) # image/png
|
|
173
|
+
self.assertTrue(image._is_compatible(['image/png', 'image/jpeg']))
|
|
174
|
+
self.assertTrue(image._is_compatible(['image/png']))
|
|
175
|
+
self.assertFalse(image._is_compatible(['image/jpeg']))
|
|
176
|
+
|
|
177
|
+
def test_make_compatible_no_conversion(self):
|
|
178
|
+
image = image_lib.Image.from_bytes(image_content) # image/png
|
|
179
|
+
converted_image = image.make_compatible(['image/png', 'image/jpeg'])
|
|
180
|
+
self.assertIs(image, converted_image)
|
|
181
|
+
|
|
182
|
+
def test_convert_to_format_jpeg_transparency(self):
|
|
183
|
+
# Create a simple RGBA PNG image
|
|
184
|
+
buf = io.BytesIO()
|
|
185
|
+
img = pil_image.new('RGBA', (1, 1), (255, 0, 0, 128))
|
|
186
|
+
img.save(buf, format='PNG')
|
|
187
|
+
rgba_png_bytes = buf.getvalue()
|
|
188
|
+
|
|
189
|
+
rgba_image = image_lib.Image.from_bytes(rgba_png_bytes)
|
|
190
|
+
self.assertEqual(rgba_image.mime_type, 'image/png')
|
|
191
|
+
|
|
192
|
+
# Convert to JPEG, should trigger transparency handling
|
|
193
|
+
converted_image = rgba_image._convert_to_format('JPEG')
|
|
194
|
+
self.assertEqual(converted_image.mime_type, 'image/jpeg')
|
|
195
|
+
pil_img = converted_image.to_pil_image()
|
|
196
|
+
self.assertEqual(pil_img.mode, 'RGB')
|
|
197
|
+
|
|
198
|
+
def test_convert_to_format_os_error(self):
|
|
199
|
+
image = image_lib.Image.from_bytes(image_content)
|
|
200
|
+
mock_pil_image = mock.MagicMock()
|
|
201
|
+
mock_save = mock_pil_image.save
|
|
202
|
+
mock_save.side_effect = [OSError, None]
|
|
203
|
+
|
|
204
|
+
with mock.patch.object(
|
|
205
|
+
image, 'to_pil_image', return_value=mock_pil_image
|
|
206
|
+
), mock.patch('os.chdir') as mock_chdir, mock.patch(
|
|
207
|
+
'os.getcwd'
|
|
208
|
+
) as mock_getcwd:
|
|
209
|
+
mock_getcwd.return_value = '/curr/dir'
|
|
210
|
+
converted_image = image._convert_to_format('PNG')
|
|
211
|
+
self.assertIsInstance(converted_image, image_lib.Image)
|
|
212
|
+
self.assertEqual(mock_save.call_count, 2)
|
|
213
|
+
mock_save.assert_has_calls([
|
|
214
|
+
mock.call(mock.ANY, format='PNG'),
|
|
215
|
+
mock.call(mock.ANY, format='PNG'),
|
|
216
|
+
])
|
|
217
|
+
mock_chdir.assert_has_calls([
|
|
218
|
+
mock.call('/tmp'),
|
|
219
|
+
mock.call('/curr/dir'),
|
|
220
|
+
])
|
|
221
|
+
|
|
106
222
|
|
|
107
223
|
if __name__ == '__main__':
|
|
108
224
|
unittest.main()
|
langfun/core/modalities/mime.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import base64
|
|
17
17
|
import functools
|
|
18
|
+
import hashlib
|
|
18
19
|
from typing import Annotated, Any, Iterable, Type, Union
|
|
19
20
|
import langfun.core as lf
|
|
20
21
|
# Placeholder for Google-internal internet access import.
|
|
@@ -36,7 +37,33 @@ def _detect_mime_type(content: bytes) -> str:
|
|
|
36
37
|
|
|
37
38
|
|
|
38
39
|
class Mime(lf.Modality):
|
|
39
|
-
"""Base for MIME
|
|
40
|
+
"""Base class for representing modality data based on MIME types.
|
|
41
|
+
|
|
42
|
+
`lf.Mime` is a subclass of `lf.Modality` that serves as a base for
|
|
43
|
+
handling various data types like images, audio, video, and PDFs,
|
|
44
|
+
identified by their MIME types. It provides unified methods for
|
|
45
|
+
loading data from URIs or bytes (`.from_uri()`, `.from_bytes()`) and
|
|
46
|
+
for accessing content (`.to_bytes()`).
|
|
47
|
+
|
|
48
|
+
Subclasses like `lf.Image`, `lf.Audio`, `lf.Video`, and `lf.PDF`
|
|
49
|
+
specialize in handling specific MIME type prefixes (e.g., 'image/', 'audio/').
|
|
50
|
+
|
|
51
|
+
**Example:**
|
|
52
|
+
|
|
53
|
+
```python
|
|
54
|
+
import langfun as lf
|
|
55
|
+
|
|
56
|
+
# Load an image from a path
|
|
57
|
+
image = lf.Image.from_path('/path/to/image.png')
|
|
58
|
+
print(image.mime_type)
|
|
59
|
+
# Output: image/png
|
|
60
|
+
|
|
61
|
+
# Create a text document
|
|
62
|
+
text = lf.Custom.from_bytes(b'hello world', mime='text/plain')
|
|
63
|
+
print(text.mime_type)
|
|
64
|
+
# Output: text/plain
|
|
65
|
+
```
|
|
66
|
+
"""
|
|
40
67
|
|
|
41
68
|
# The regular expression that describes the MIME type str.
|
|
42
69
|
# If None, the MIME type is dynamic. Subclass could override.
|
|
@@ -48,6 +75,10 @@ class Mime(lf.Modality):
|
|
|
48
75
|
Union[str, bytes, None], 'The raw content of the MIME type.'
|
|
49
76
|
] = None
|
|
50
77
|
|
|
78
|
+
metadata: Annotated[
|
|
79
|
+
dict[str, Any], 'Additional metadata attached to this object.'
|
|
80
|
+
] = {}
|
|
81
|
+
|
|
51
82
|
@functools.cached_property
|
|
52
83
|
def mime_type(self) -> str:
|
|
53
84
|
"""Returns the MIME type."""
|
|
@@ -87,13 +118,37 @@ class Mime(lf.Modality):
|
|
|
87
118
|
"""Returns True if the MIME type is a binary type."""
|
|
88
119
|
return not self.is_text
|
|
89
120
|
|
|
121
|
+
@property
|
|
122
|
+
def hash(self) -> str:
|
|
123
|
+
"""Returns the hash of the MIME content."""
|
|
124
|
+
# Hash the URI to avoid downloading the content.
|
|
125
|
+
if self.uri is not None:
|
|
126
|
+
return hashlib.md5(self.uri.encode()).hexdigest()[:8]
|
|
127
|
+
if self.content is not None:
|
|
128
|
+
return super().hash
|
|
129
|
+
assert self.metadata
|
|
130
|
+
return hashlib.md5(str(self.metadata).encode()).hexdigest()[:8]
|
|
131
|
+
|
|
90
132
|
def to_text(self) -> str:
|
|
91
133
|
"""Returns the text content of the MIME type."""
|
|
92
134
|
if not self.is_text:
|
|
93
135
|
raise lf.ModalityError(
|
|
94
136
|
f'MIME type {self.mime_type!r} cannot be converted to text.'
|
|
95
137
|
)
|
|
96
|
-
|
|
138
|
+
content = self.to_bytes()
|
|
139
|
+
# Try UTF-8 first (most common encoding).
|
|
140
|
+
try:
|
|
141
|
+
return content.decode('utf-8')
|
|
142
|
+
except UnicodeDecodeError:
|
|
143
|
+
pass
|
|
144
|
+
# Check for UTF-16 BOM (0xff 0xfe or 0xfe 0xff).
|
|
145
|
+
if content[:2] in (b'\xff\xfe', b'\xfe\xff'):
|
|
146
|
+
try:
|
|
147
|
+
return content.decode('utf-16')
|
|
148
|
+
except UnicodeDecodeError:
|
|
149
|
+
pass
|
|
150
|
+
# Fallback: decode with error replacement to avoid crashing.
|
|
151
|
+
return content.decode('utf-8', errors='replace')
|
|
97
152
|
|
|
98
153
|
def is_compatible(
|
|
99
154
|
self, mime_types: str | Iterable[str]
|
|
@@ -132,7 +187,7 @@ class Mime(lf.Modality):
|
|
|
132
187
|
|
|
133
188
|
def _on_bound(self):
|
|
134
189
|
super()._on_bound()
|
|
135
|
-
if self.uri is None and self.content is None:
|
|
190
|
+
if self.uri is None and self.content is None and not self.metadata:
|
|
136
191
|
raise ValueError('Either uri or content must be provided.')
|
|
137
192
|
|
|
138
193
|
def to_bytes(self) -> bytes:
|
|
@@ -162,6 +217,8 @@ class Mime(lf.Modality):
|
|
|
162
217
|
return cls.class_from_mime_type(mime_type).from_bytes(content, **kwargs)
|
|
163
218
|
|
|
164
219
|
if cls is Mime:
|
|
220
|
+
if 'youtube.com/watch' in uri:
|
|
221
|
+
return Custom(mime='text/html', uri=uri, **kwargs)
|
|
165
222
|
content = cls.download(uri)
|
|
166
223
|
mime = _detect_mime_type(content)
|
|
167
224
|
return cls.class_from_mime_type(mime)(uri=uri, content=content, **kwargs)
|
|
@@ -272,7 +329,24 @@ class Mime(lf.Modality):
|
|
|
272
329
|
|
|
273
330
|
@pg.use_init_args(['mime', 'content', 'uri'])
|
|
274
331
|
class Custom(Mime):
|
|
275
|
-
"""
|
|
332
|
+
"""Represents content of a custom MIME type.
|
|
333
|
+
|
|
334
|
+
`lf.modalities.Custom` is useful for representing data with MIME types
|
|
335
|
+
that do not have dedicated classes like `lf.Image` or `lf.Audio`.
|
|
336
|
+
|
|
337
|
+
**Example:**
|
|
338
|
+
|
|
339
|
+
```python
|
|
340
|
+
import langfun as lf
|
|
341
|
+
|
|
342
|
+
# Create a custom MIME object for plain text
|
|
343
|
+
text_data = lf.Custom.from_bytes(
|
|
344
|
+
b'This is a text document.', mime='text/plain'
|
|
345
|
+
)
|
|
346
|
+
print(text_data.mime_type)
|
|
347
|
+
# Output: text/plain
|
|
348
|
+
```
|
|
349
|
+
"""
|
|
276
350
|
|
|
277
351
|
mime: Annotated[
|
|
278
352
|
str, 'The MIME type of the data. E.g. text/plain, or image/png. '
|
|
@@ -109,6 +109,17 @@ class CustomMimeTest(unittest.TestCase):
|
|
|
109
109
|
with self.assertRaisesRegex(ValueError, 'Unsupported encoding'):
|
|
110
110
|
mime.Mime.from_uri('data:text/plain;base16,abcd')
|
|
111
111
|
|
|
112
|
+
# Test YouTube URI
|
|
113
|
+
yt_uri = 'https://www.youtube.com/watch?v=dQw4w9WgXcQ'
|
|
114
|
+
with mock.patch(
|
|
115
|
+
'langfun.core.modalities.mime.Mime.download'
|
|
116
|
+
) as mock_download:
|
|
117
|
+
content = mime.Mime.from_uri(yt_uri)
|
|
118
|
+
self.assertIsInstance(content, mime.Custom)
|
|
119
|
+
self.assertEqual(content.mime_type, 'text/html')
|
|
120
|
+
self.assertEqual(content.uri, yt_uri)
|
|
121
|
+
mock_download.assert_not_called()
|
|
122
|
+
|
|
112
123
|
def assert_html_content(self, html, expected):
|
|
113
124
|
expected = inspect.cleandoc(expected).strip()
|
|
114
125
|
actual = html.content.strip()
|
|
@@ -152,5 +163,53 @@ class CustomMimeTest(unittest.TestCase):
|
|
|
152
163
|
)
|
|
153
164
|
|
|
154
165
|
|
|
166
|
+
class ToTextEncodingTest(unittest.TestCase):
|
|
167
|
+
"""Tests for to_text() encoding handling."""
|
|
168
|
+
|
|
169
|
+
def test_utf8_decoding(self):
|
|
170
|
+
"""Test that valid UTF-8 content is decoded correctly."""
|
|
171
|
+
content = mime.Custom('text/plain', b'Hello, World!')
|
|
172
|
+
self.assertEqual(content.to_text(), 'Hello, World!')
|
|
173
|
+
|
|
174
|
+
# UTF-8 with multi-byte characters.
|
|
175
|
+
utf8_content = 'こんにちは'.encode('utf-8')
|
|
176
|
+
content = mime.Custom('text/plain', utf8_content)
|
|
177
|
+
self.assertEqual(content.to_text(), 'こんにちは')
|
|
178
|
+
|
|
179
|
+
def test_utf16_le_bom_decoding(self):
|
|
180
|
+
"""Test that UTF-16 Little Endian with BOM is decoded correctly."""
|
|
181
|
+
# UTF-16 LE BOM: 0xff 0xfe
|
|
182
|
+
utf16_le_content = 'Hello'.encode('utf-16-le')
|
|
183
|
+
content_with_bom = b'\xff\xfe' + utf16_le_content
|
|
184
|
+
content = mime.Custom('text/plain', content_with_bom)
|
|
185
|
+
self.assertEqual(content.to_text(), 'Hello')
|
|
186
|
+
|
|
187
|
+
def test_utf16_be_bom_decoding(self):
|
|
188
|
+
"""Test that UTF-16 Big Endian with BOM is decoded correctly."""
|
|
189
|
+
# UTF-16 BE BOM: 0xfe 0xff
|
|
190
|
+
utf16_be_content = 'Hello'.encode('utf-16-be')
|
|
191
|
+
content_with_bom = b'\xfe\xff' + utf16_be_content
|
|
192
|
+
content = mime.Custom('text/plain', content_with_bom)
|
|
193
|
+
self.assertEqual(content.to_text(), 'Hello')
|
|
194
|
+
|
|
195
|
+
def test_invalid_bytes_fallback_with_replacement(self):
|
|
196
|
+
"""Test that invalid bytes are replaced with replacement character."""
|
|
197
|
+
# 0xff alone is invalid in UTF-8 and doesn't have UTF-16 BOM pattern.
|
|
198
|
+
invalid_content = b'\xff\xfdHello'
|
|
199
|
+
content = mime.Custom('text/plain', invalid_content)
|
|
200
|
+
result = content.to_text()
|
|
201
|
+
# Invalid bytes should be replaced with U+FFFD (replacement character).
|
|
202
|
+
self.assertIn('\ufffd', result)
|
|
203
|
+
self.assertIn('Hello', result)
|
|
204
|
+
|
|
205
|
+
def test_binary_mime_type_raises_error(self):
|
|
206
|
+
"""Test that binary MIME types raise ModalityError."""
|
|
207
|
+
content = mime.Custom('application/octet-stream', b'\x00\x01\x02')
|
|
208
|
+
with self.assertRaisesRegex(
|
|
209
|
+
lf.ModalityError, 'cannot be converted to text'
|
|
210
|
+
):
|
|
211
|
+
content.to_text()
|
|
212
|
+
|
|
213
|
+
|
|
155
214
|
if __name__ == '__main__':
|
|
156
215
|
unittest.main()
|
langfun/core/modalities/pdf.py
CHANGED
|
@@ -17,6 +17,24 @@ from langfun.core.modalities import mime
|
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
class PDF(mime.Mime):
|
|
20
|
-
"""PDF document.
|
|
20
|
+
"""Represents a PDF document for communicating with language models.
|
|
21
|
+
|
|
22
|
+
`lf.PDF` can be initialized from a URI (HTTP/HTTPS URL or local path)
|
|
23
|
+
using `lf.PDF.from_uri()` or from raw bytes using `lf.PDF.from_bytes()`.
|
|
24
|
+
|
|
25
|
+
**Example:**
|
|
26
|
+
|
|
27
|
+
```python
|
|
28
|
+
import langfun as lf
|
|
29
|
+
|
|
30
|
+
# Load PDF from path
|
|
31
|
+
pdf = lf.PDF.from_path('/path/to/document.pdf')
|
|
32
|
+
|
|
33
|
+
# Use PDF in a prompt
|
|
34
|
+
prompt = lf.Template('Summarize this document: {{pdf}}', pdf=pdf)
|
|
35
|
+
response = lf.query(prompt, lm=lf.llms.Gemini25Flash())
|
|
36
|
+
print(response)
|
|
37
|
+
```
|
|
38
|
+
"""
|
|
21
39
|
|
|
22
40
|
MIME_PREFIX = 'application/pdf'
|
langfun/core/modalities/video.py
CHANGED
|
@@ -18,7 +18,27 @@ from langfun.core.modalities import mime
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class Video(mime.Mime):
|
|
21
|
-
"""
|
|
21
|
+
"""Represents a video for communicating with language models.
|
|
22
|
+
|
|
23
|
+
`lf.Video` can be initialized from a URI (HTTP/HTTPS URL or local path)
|
|
24
|
+
using `lf.Video.from_uri()` or from raw bytes using `lf.Video.from_bytes()`.
|
|
25
|
+
|
|
26
|
+
**Example:**
|
|
27
|
+
|
|
28
|
+
```python
|
|
29
|
+
import langfun as lf
|
|
30
|
+
|
|
31
|
+
# Load video from path
|
|
32
|
+
video = lf.Video.from_path('/path/to/video.mp4')
|
|
33
|
+
|
|
34
|
+
# Use video in a prompt
|
|
35
|
+
prompt = lf.Template(
|
|
36
|
+
'What is happening in this video? {{video}}', video=video
|
|
37
|
+
)
|
|
38
|
+
response = lf.query(prompt, lm=lf.llms.Gemini25Flash())
|
|
39
|
+
print(response)
|
|
40
|
+
```
|
|
41
|
+
"""
|
|
22
42
|
|
|
23
43
|
MIME_PREFIX = 'video'
|
|
24
44
|
|
langfun/core/modality.py
CHANGED
|
@@ -14,40 +14,63 @@
|
|
|
14
14
|
"""Interface for modality (e.g. Image, Video, etc.)."""
|
|
15
15
|
|
|
16
16
|
import abc
|
|
17
|
+
import contextlib
|
|
17
18
|
import functools
|
|
18
19
|
import hashlib
|
|
19
|
-
|
|
20
|
+
import re
|
|
21
|
+
from typing import Any, ContextManager, Iterator
|
|
20
22
|
from langfun.core import component
|
|
21
23
|
import pyglove as pg
|
|
22
24
|
|
|
23
25
|
|
|
24
|
-
|
|
26
|
+
class Modality(component.Component, pg.views.HtmlTreeView.Extension):
|
|
27
|
+
"""Base class for representing non-text content in prompts.
|
|
25
28
|
|
|
29
|
+
`lf.Modality` is the base class for multimodal objects such as `lf.Image`,
|
|
30
|
+
`lf.Audio`, and `lf.Video`. It allows these non-text inputs to be
|
|
31
|
+
seamlessly embedded within text prompts for processing by multimodal
|
|
32
|
+
language models.
|
|
26
33
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
34
|
+
When a `Modality` object is rendered within an `lf.Template`, it is
|
|
35
|
+
replaced by a text marker (e.g., `<<[[image:b10a8db1]]>>`), and the
|
|
36
|
+
modality object itself is stored in the `referred_modalities` field of
|
|
37
|
+
the resulting `lf.Message`. This allows language models to associate
|
|
38
|
+
the placeholder with its content during processing.
|
|
32
39
|
|
|
40
|
+
**Example:**
|
|
33
41
|
|
|
34
|
-
|
|
35
|
-
|
|
42
|
+
```python
|
|
43
|
+
import langfun as lf
|
|
44
|
+
|
|
45
|
+
image = lf.Image.from_path('/path/to/image.png')
|
|
46
|
+
prompt = lf.Template('What is in this image? {{image}}', image=image)
|
|
47
|
+
|
|
48
|
+
message = prompt.render()
|
|
49
|
+
print(message.text)
|
|
50
|
+
# Output: What is in this image? <<[[image:b10a8db1]]>>
|
|
51
|
+
|
|
52
|
+
print(message.modalities())
|
|
53
|
+
# Output: [<Image object>]
|
|
54
|
+
```
|
|
55
|
+
"""
|
|
36
56
|
|
|
37
57
|
REF_START = '<<[['
|
|
38
58
|
REF_END = ']]>>'
|
|
39
59
|
|
|
40
60
|
def _on_bound(self):
|
|
41
61
|
super()._on_bound()
|
|
42
|
-
# Invalidate cached hash if modality member is changed.
|
|
62
|
+
# Invalidate cached hash and id if modality member is changed.
|
|
43
63
|
self.__dict__.pop('hash', None)
|
|
64
|
+
self.__dict__.pop('id', None)
|
|
44
65
|
|
|
45
66
|
def format(self, *args, **kwargs) -> str:
|
|
46
|
-
if
|
|
47
|
-
_TLS_MODALITY_AS_REF, False
|
|
48
|
-
):
|
|
67
|
+
if not pg.object_utils.thread_local_get(_TLS_MODALITY_AS_REF, False):
|
|
49
68
|
return super().format(*args, **kwargs)
|
|
50
|
-
|
|
69
|
+
|
|
70
|
+
capture_scope = get_modality_capture_context()
|
|
71
|
+
if capture_scope is not None:
|
|
72
|
+
capture_scope.capture(self)
|
|
73
|
+
return Modality.text_marker(self.id)
|
|
51
74
|
|
|
52
75
|
def __str_kwargs__(self) -> dict[str, Any]:
|
|
53
76
|
# For modality objects, we don't want to use markdown format when they
|
|
@@ -70,14 +93,11 @@ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
|
|
|
70
93
|
"""Returns a marker in the text for this object."""
|
|
71
94
|
return Modality.REF_START + var_name + Modality.REF_END
|
|
72
95
|
|
|
73
|
-
@
|
|
74
|
-
def
|
|
96
|
+
@functools.cached_property
|
|
97
|
+
def id(self) -> str | None:
|
|
75
98
|
"""Returns the referred name of this object in its template."""
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
# Strip the metadata prefix under message.
|
|
79
|
-
path = str(self.sym_path)
|
|
80
|
-
return path[9:] if path.startswith('metadata.') else path
|
|
99
|
+
modality_type = _camel_to_snake(self.__class__.__name__)
|
|
100
|
+
return f'{modality_type}:{self.hash}'
|
|
81
101
|
|
|
82
102
|
@classmethod
|
|
83
103
|
def from_value(cls, value: pg.Symbolic) -> dict[str, 'Modality']:
|
|
@@ -86,7 +106,7 @@ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
|
|
|
86
106
|
def _visit(k, v, p):
|
|
87
107
|
del k, p
|
|
88
108
|
if isinstance(v, Modality):
|
|
89
|
-
modalities[v.
|
|
109
|
+
modalities[v.id] = v
|
|
90
110
|
return pg.TraverseAction.CONTINUE
|
|
91
111
|
return pg.TraverseAction.ENTER
|
|
92
112
|
|
|
@@ -95,14 +115,47 @@ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
|
|
|
95
115
|
|
|
96
116
|
|
|
97
117
|
class ModalityRef(pg.Object, pg.typing.CustomTyping):
|
|
98
|
-
"""
|
|
118
|
+
"""Lightweight placeholder for a `lf.Modality` object in a symbolic tree.
|
|
99
119
|
|
|
100
|
-
`ModalityRef`
|
|
101
|
-
|
|
102
|
-
|
|
120
|
+
`ModalityRef` acts as a reference to a `Modality` object (like `lf.Image`
|
|
121
|
+
or `lf.Audio`) within a structured object hierarchy (e.g., a `pg.Object`).
|
|
122
|
+
Instead of embedding potentially large modality data directly, `ModalityRef`
|
|
123
|
+
stores only the ID of the modality object.
|
|
124
|
+
|
|
125
|
+
This is useful in scenarios where structured objects are serialized or
|
|
126
|
+
manipulated, and it's more efficient to refer to modalities by ID rather
|
|
127
|
+
than copying their content. The `lf.ModalityRef.placehold()` class method
|
|
128
|
+
can be used to replace `Modality` instances in a symbolic object with
|
|
129
|
+
`ModalityRef` placeholders, while `lf.ModalityRef.restore()` can reinstate
|
|
130
|
+
the original `Modality` objects using a lookup table.
|
|
131
|
+
|
|
132
|
+
**Example:**
|
|
133
|
+
|
|
134
|
+
```python
|
|
135
|
+
import langfun as lf
|
|
136
|
+
import pyglove as pg
|
|
137
|
+
|
|
138
|
+
class ImagePair(pg.Object):
|
|
139
|
+
image1: lf.Image
|
|
140
|
+
image2: lf.Image
|
|
141
|
+
|
|
142
|
+
pair = ImagePair(
|
|
143
|
+
image1=lf.Image(content=b'abc'), image2=lf.Image(content=b'def')
|
|
144
|
+
)
|
|
145
|
+
modalities = lf.Modality.from_value(pair)
|
|
146
|
+
|
|
147
|
+
# Replace Image objects with ModalityRef placeholders
|
|
148
|
+
pair_with_refs = lf.ModalityRef.placehold(pair)
|
|
149
|
+
print(pair_with_refs.image1)
|
|
150
|
+
# Output: ModalityRef(id='image:d81e5a68')
|
|
151
|
+
|
|
152
|
+
# Restore Image objects from ModalityRef placeholders
|
|
153
|
+
pair_restored = lf.ModalityRef.restore(pair_with_refs, modalities)
|
|
154
|
+
assert pair_restored.image1.content == b'abc'
|
|
155
|
+
```
|
|
103
156
|
"""
|
|
104
157
|
|
|
105
|
-
|
|
158
|
+
id: str
|
|
106
159
|
|
|
107
160
|
def custom_apply(
|
|
108
161
|
self, path: pg.KeyPath, value_spec: pg.ValueSpec, *args, **kwargs
|
|
@@ -122,12 +175,97 @@ class ModalityRef(pg.Object, pg.typing.CustomTyping):
|
|
|
122
175
|
"""
|
|
123
176
|
|
|
124
177
|
def _placehold(k, v, p):
|
|
125
|
-
del p
|
|
178
|
+
del k, p
|
|
126
179
|
if isinstance(v, Modality):
|
|
127
|
-
return ModalityRef(
|
|
180
|
+
return ModalityRef(id=v.id)
|
|
128
181
|
return v
|
|
129
182
|
return value.clone().rebind(_placehold, raise_on_no_change=False)
|
|
130
183
|
|
|
184
|
+
@classmethod
|
|
185
|
+
def restore(cls, value: pg.Symbolic, modalities: dict[str, Modality]) -> Any:
|
|
186
|
+
"""Returns a copy of value by replacing refs with modality objects."""
|
|
187
|
+
def _restore(k, v, p):
|
|
188
|
+
del k, p
|
|
189
|
+
if isinstance(v, ModalityRef):
|
|
190
|
+
modality_object = modalities.get(v.id)
|
|
191
|
+
if modality_object is None:
|
|
192
|
+
raise ValueError(
|
|
193
|
+
f'Modality {v.id} not found in modalities {modalities.keys()}'
|
|
194
|
+
)
|
|
195
|
+
return modality_object
|
|
196
|
+
return v
|
|
197
|
+
return value.rebind(_restore, raise_on_no_change=False)
|
|
198
|
+
|
|
131
199
|
|
|
132
200
|
class ModalityError(RuntimeError): # pylint: disable=g-bad-exception-name
|
|
133
201
|
"""Exception raised when modality is not supported."""
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
#
|
|
205
|
+
# Context managers to deal with modality objects.
|
|
206
|
+
#
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
_TLS_MODALITY_CAPTURE_SCOPE = '__modality_capture_scope__'
|
|
210
|
+
_TLS_MODALITY_AS_REF = '__format_modality_as_ref__'
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def format_modality_as_ref(enabled: bool = True) -> ContextManager[None]:
|
|
214
|
+
"""A context manager that formats modality objects as references."""
|
|
215
|
+
return pg.object_utils.thread_local_value_scope(
|
|
216
|
+
_TLS_MODALITY_AS_REF, enabled, False
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class _ModalityCaptureContext:
|
|
221
|
+
"""A context to capture modality objects when being rendered."""
|
|
222
|
+
|
|
223
|
+
def __init__(self):
|
|
224
|
+
self._references: dict[str, pg.Ref[Modality]] = {}
|
|
225
|
+
|
|
226
|
+
def capture(self, modality: Modality) -> None:
|
|
227
|
+
"""Captures the modality object."""
|
|
228
|
+
self._references[modality.id] = pg.Ref(modality)
|
|
229
|
+
|
|
230
|
+
@property
|
|
231
|
+
def references(self) -> dict[str, pg.Ref[Modality]]:
|
|
232
|
+
"""Returns the modality references captured in this context."""
|
|
233
|
+
return self._references
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
@contextlib.contextmanager
|
|
237
|
+
def capture_rendered_modalities() -> Iterator[dict[str, pg.Ref[Modality]]]:
|
|
238
|
+
"""Capture modality objects whose references is being rendered.
|
|
239
|
+
|
|
240
|
+
Example:
|
|
241
|
+
```
|
|
242
|
+
image = lf.Image.from_url(...)
|
|
243
|
+
with lf.modality.capture_rendered_modalities() as rendered_modalities:
|
|
244
|
+
with lf.modality.format_modality_as_ref():
|
|
245
|
+
print(f'Hello {image}')
|
|
246
|
+
self.assertEqual(rendered_modalities, {'image:<hash>': pg.Ref(image)})
|
|
247
|
+
```
|
|
248
|
+
"""
|
|
249
|
+
context = get_modality_capture_context()
|
|
250
|
+
top_level = context is None
|
|
251
|
+
if top_level:
|
|
252
|
+
context = _ModalityCaptureContext()
|
|
253
|
+
pg.object_utils.thread_local_set(_TLS_MODALITY_CAPTURE_SCOPE, context)
|
|
254
|
+
|
|
255
|
+
try:
|
|
256
|
+
yield context.references # pylint: disable=attribute-error
|
|
257
|
+
finally:
|
|
258
|
+
if top_level:
|
|
259
|
+
pg.object_utils.thread_local_del(_TLS_MODALITY_CAPTURE_SCOPE)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def get_modality_capture_context() -> _ModalityCaptureContext | None:
|
|
263
|
+
"""Returns the current modality capture context."""
|
|
264
|
+
return pg.object_utils.thread_local_get(_TLS_MODALITY_CAPTURE_SCOPE, None)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _camel_to_snake(name: str) -> str:
|
|
268
|
+
"""Converts a camelCase name to snake_case."""
|
|
269
|
+
return re.sub(
|
|
270
|
+
pattern=r'([A-Z]+)', repl=r'_\1', string=name
|
|
271
|
+
).lower().lstrip('_')
|