langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512040805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +7 -1
  3. langfun/core/agentic/__init__.py +8 -1
  4. langfun/core/agentic/action.py +740 -112
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +189 -24
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +11 -2
  20. langfun/core/data/conversion/gemini_test.py +48 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +48 -44
  24. langfun/core/eval/base_test.py +5 -5
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +2 -0
  29. langfun/core/eval/v2/checkpointing.py +76 -7
  30. langfun/core/eval/v2/checkpointing_test.py +9 -2
  31. langfun/core/eval/v2/config_saver.py +37 -0
  32. langfun/core/eval/v2/config_saver_test.py +36 -0
  33. langfun/core/eval/v2/eval_test_helper.py +104 -3
  34. langfun/core/eval/v2/evaluation.py +92 -17
  35. langfun/core/eval/v2/evaluation_test.py +9 -3
  36. langfun/core/eval/v2/example.py +50 -40
  37. langfun/core/eval/v2/example_test.py +16 -8
  38. langfun/core/eval/v2/experiment.py +84 -15
  39. langfun/core/eval/v2/experiment_test.py +19 -0
  40. langfun/core/eval/v2/metric_values.py +31 -3
  41. langfun/core/eval/v2/metric_values_test.py +32 -0
  42. langfun/core/eval/v2/metrics.py +157 -44
  43. langfun/core/eval/v2/metrics_test.py +39 -18
  44. langfun/core/eval/v2/progress.py +31 -1
  45. langfun/core/eval/v2/progress_test.py +27 -0
  46. langfun/core/eval/v2/progress_tracking.py +13 -5
  47. langfun/core/eval/v2/progress_tracking_test.py +9 -1
  48. langfun/core/eval/v2/reporting.py +90 -71
  49. langfun/core/eval/v2/reporting_test.py +24 -6
  50. langfun/core/eval/v2/runners/__init__.py +30 -0
  51. langfun/core/eval/v2/{runners.py → runners/base.py} +72 -180
  52. langfun/core/eval/v2/runners/beam.py +354 -0
  53. langfun/core/eval/v2/runners/beam_test.py +153 -0
  54. langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
  55. langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
  56. langfun/core/eval/v2/runners/debug.py +40 -0
  57. langfun/core/eval/v2/runners/debug_test.py +76 -0
  58. langfun/core/eval/v2/runners/parallel.py +243 -0
  59. langfun/core/eval/v2/runners/parallel_test.py +182 -0
  60. langfun/core/eval/v2/runners/sequential.py +47 -0
  61. langfun/core/eval/v2/runners/sequential_test.py +169 -0
  62. langfun/core/langfunc.py +45 -130
  63. langfun/core/langfunc_test.py +7 -5
  64. langfun/core/language_model.py +189 -36
  65. langfun/core/language_model_test.py +54 -3
  66. langfun/core/llms/__init__.py +12 -1
  67. langfun/core/llms/anthropic.py +157 -2
  68. langfun/core/llms/azure_openai.py +29 -17
  69. langfun/core/llms/cache/base.py +25 -3
  70. langfun/core/llms/cache/in_memory.py +48 -7
  71. langfun/core/llms/cache/in_memory_test.py +14 -4
  72. langfun/core/llms/compositional.py +25 -1
  73. langfun/core/llms/deepseek.py +30 -2
  74. langfun/core/llms/fake.py +32 -1
  75. langfun/core/llms/gemini.py +64 -12
  76. langfun/core/llms/gemini_test.py +110 -0
  77. langfun/core/llms/google_genai.py +34 -1
  78. langfun/core/llms/groq.py +28 -3
  79. langfun/core/llms/llama_cpp.py +23 -4
  80. langfun/core/llms/openai.py +120 -3
  81. langfun/core/llms/openai_compatible.py +148 -27
  82. langfun/core/llms/openai_compatible_test.py +207 -20
  83. langfun/core/llms/openai_test.py +0 -2
  84. langfun/core/llms/rest.py +16 -1
  85. langfun/core/llms/vertexai.py +58 -8
  86. langfun/core/logging.py +1 -1
  87. langfun/core/mcp/__init__.py +10 -0
  88. langfun/core/mcp/client.py +177 -0
  89. langfun/core/mcp/client_test.py +71 -0
  90. langfun/core/mcp/session.py +241 -0
  91. langfun/core/mcp/session_test.py +54 -0
  92. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  93. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  94. langfun/core/mcp/tool.py +254 -0
  95. langfun/core/mcp/tool_test.py +197 -0
  96. langfun/core/memory.py +1 -0
  97. langfun/core/message.py +160 -55
  98. langfun/core/message_test.py +65 -81
  99. langfun/core/modalities/__init__.py +8 -0
  100. langfun/core/modalities/audio.py +21 -1
  101. langfun/core/modalities/image.py +73 -3
  102. langfun/core/modalities/image_test.py +116 -0
  103. langfun/core/modalities/mime.py +64 -3
  104. langfun/core/modalities/mime_test.py +11 -0
  105. langfun/core/modalities/pdf.py +19 -1
  106. langfun/core/modalities/video.py +21 -1
  107. langfun/core/modality.py +167 -29
  108. langfun/core/modality_test.py +42 -12
  109. langfun/core/natural_language.py +1 -1
  110. langfun/core/sampling.py +4 -4
  111. langfun/core/sampling_test.py +20 -4
  112. langfun/core/structured/__init__.py +2 -24
  113. langfun/core/structured/completion.py +34 -44
  114. langfun/core/structured/completion_test.py +23 -43
  115. langfun/core/structured/description.py +54 -50
  116. langfun/core/structured/function_generation.py +29 -12
  117. langfun/core/structured/mapping.py +81 -37
  118. langfun/core/structured/parsing.py +95 -79
  119. langfun/core/structured/parsing_test.py +0 -3
  120. langfun/core/structured/querying.py +230 -154
  121. langfun/core/structured/querying_test.py +69 -33
  122. langfun/core/structured/schema/__init__.py +49 -0
  123. langfun/core/structured/schema/base.py +664 -0
  124. langfun/core/structured/schema/base_test.py +531 -0
  125. langfun/core/structured/schema/json.py +174 -0
  126. langfun/core/structured/schema/json_test.py +121 -0
  127. langfun/core/structured/schema/python.py +316 -0
  128. langfun/core/structured/schema/python_test.py +410 -0
  129. langfun/core/structured/schema_generation.py +33 -14
  130. langfun/core/structured/scoring.py +47 -36
  131. langfun/core/structured/tokenization.py +26 -11
  132. langfun/core/subscription.py +2 -2
  133. langfun/core/template.py +175 -50
  134. langfun/core/template_test.py +123 -17
  135. langfun/env/__init__.py +43 -0
  136. langfun/env/base_environment.py +827 -0
  137. langfun/env/base_environment_test.py +473 -0
  138. langfun/env/base_feature.py +304 -0
  139. langfun/env/base_feature_test.py +228 -0
  140. langfun/env/base_sandbox.py +842 -0
  141. langfun/env/base_sandbox_test.py +1235 -0
  142. langfun/env/event_handlers/__init__.py +14 -0
  143. langfun/env/event_handlers/chain.py +233 -0
  144. langfun/env/event_handlers/chain_test.py +253 -0
  145. langfun/env/event_handlers/event_logger.py +472 -0
  146. langfun/env/event_handlers/event_logger_test.py +304 -0
  147. langfun/env/event_handlers/metric_writer.py +726 -0
  148. langfun/env/event_handlers/metric_writer_test.py +214 -0
  149. langfun/env/interface.py +1640 -0
  150. langfun/env/interface_test.py +153 -0
  151. langfun/env/load_balancers.py +59 -0
  152. langfun/env/load_balancers_test.py +141 -0
  153. langfun/env/test_utils.py +507 -0
  154. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/METADATA +7 -3
  155. langfun-0.1.2.dev202512040805.dist-info/RECORD +217 -0
  156. langfun/core/eval/v2/runners_test.py +0 -343
  157. langfun/core/structured/schema.py +0 -987
  158. langfun/core/structured/schema_test.py +0 -982
  159. langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
  160. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/WHEEL +0 -0
  161. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/licenses/LICENSE +0 -0
  162. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/top_level.txt +0 -0
@@ -103,6 +103,122 @@ class ImageTest(unittest.TestCase):
103
103
  image_lib.Image.from_pil_image(image), image_lib.Image
104
104
  )
105
105
 
106
+ def test_from_pil_image_os_error(self):
107
+ img = pil_image.open(io.BytesIO(image_content))
108
+ with mock.patch.object(img, 'save') as mock_save:
109
+ mock_save.side_effect = [OSError, None]
110
+ with mock.patch('os.chdir') as mock_chdir:
111
+ with mock.patch('os.getcwd') as mock_getcwd:
112
+ mock_getcwd.return_value = '/curr/dir'
113
+ image = image_lib.Image.from_pil_image(img)
114
+ self.assertIsInstance(image, image_lib.Image)
115
+ self.assertEqual(mock_save.call_count, 2)
116
+ mock_save.assert_has_calls([
117
+ mock.call(mock.ANY, format='PNG'),
118
+ mock.call(mock.ANY, format='PNG'),
119
+ ])
120
+ mock_chdir.assert_has_calls([
121
+ mock.call('/tmp'),
122
+ mock.call('/curr/dir'),
123
+ ])
124
+
125
+ def test_gif_is_compatible(self):
126
+ # Create a simple 1x1 GIF image using PIL
127
+ buf = io.BytesIO()
128
+ img = pil_image.new('P', (1, 1))
129
+ img.save(buf, format='GIF')
130
+ gif_bytes = buf.getvalue()
131
+
132
+ gif_image = image_lib.Image.from_bytes(gif_bytes)
133
+ self.assertEqual(gif_image.mime_type, 'image/gif')
134
+
135
+ # GIF should be compatible if PNG is in supported types
136
+ self.assertTrue(gif_image._is_compatible(['image/png']))
137
+ self.assertTrue(gif_image._is_compatible(['image/jpeg', 'image/webp']))
138
+ self.assertTrue(gif_image._is_compatible(['image/png', 'image/jpeg']))
139
+
140
+ # GIF should not be compatible if only unsupported types
141
+ self.assertFalse(gif_image._is_compatible(['video/mp4']))
142
+ self.assertFalse(gif_image._is_compatible(['application/pdf']))
143
+
144
+ def test_gif_make_compatible(self):
145
+ # Create a simple 1x1 GIF image using PIL
146
+ buf = io.BytesIO()
147
+ img = pil_image.new('P', (1, 1))
148
+ img.save(buf, format='GIF')
149
+ gif_bytes = buf.getvalue()
150
+
151
+ gif_image = image_lib.Image.from_bytes(gif_bytes)
152
+ self.assertEqual(gif_image.mime_type, 'image/gif')
153
+
154
+ # Test 1: Convert to PNG (first priority when available)
155
+ converted = gif_image.make_compatible(['image/png', 'image/jpeg'])
156
+ self.assertEqual(converted.mime_type, 'image/png')
157
+ self.assertIsInstance(converted, image_lib.Image)
158
+
159
+ # Test 2: Convert to JPEG when PNG not supported
160
+ converted = gif_image.make_compatible(['image/jpeg', 'image/webp'])
161
+ self.assertEqual(converted.mime_type, 'image/jpeg')
162
+
163
+ # Test 3: Convert to WEBP when PNG and JPEG not supported
164
+ converted = gif_image.make_compatible(['image/webp'])
165
+ self.assertEqual(converted.mime_type, 'image/webp')
166
+
167
+ # Test 4: Should raise error when no compatible format
168
+ with self.assertRaises(lf.ModalityError):
169
+ gif_image.make_compatible(['video/mp4'])
170
+
171
+ def test_is_compatible_direct_match(self):
172
+ image = image_lib.Image.from_bytes(image_content) # image/png
173
+ self.assertTrue(image._is_compatible(['image/png', 'image/jpeg']))
174
+ self.assertTrue(image._is_compatible(['image/png']))
175
+ self.assertFalse(image._is_compatible(['image/jpeg']))
176
+
177
+ def test_make_compatible_no_conversion(self):
178
+ image = image_lib.Image.from_bytes(image_content) # image/png
179
+ converted_image = image.make_compatible(['image/png', 'image/jpeg'])
180
+ self.assertIs(image, converted_image)
181
+
182
+ def test_convert_to_format_jpeg_transparency(self):
183
+ # Create a simple RGBA PNG image
184
+ buf = io.BytesIO()
185
+ img = pil_image.new('RGBA', (1, 1), (255, 0, 0, 128))
186
+ img.save(buf, format='PNG')
187
+ rgba_png_bytes = buf.getvalue()
188
+
189
+ rgba_image = image_lib.Image.from_bytes(rgba_png_bytes)
190
+ self.assertEqual(rgba_image.mime_type, 'image/png')
191
+
192
+ # Convert to JPEG, should trigger transparency handling
193
+ converted_image = rgba_image._convert_to_format('JPEG')
194
+ self.assertEqual(converted_image.mime_type, 'image/jpeg')
195
+ pil_img = converted_image.to_pil_image()
196
+ self.assertEqual(pil_img.mode, 'RGB')
197
+
198
+ def test_convert_to_format_os_error(self):
199
+ image = image_lib.Image.from_bytes(image_content)
200
+ mock_pil_image = mock.MagicMock()
201
+ mock_save = mock_pil_image.save
202
+ mock_save.side_effect = [OSError, None]
203
+
204
+ with mock.patch.object(
205
+ image, 'to_pil_image', return_value=mock_pil_image
206
+ ), mock.patch('os.chdir') as mock_chdir, mock.patch(
207
+ 'os.getcwd'
208
+ ) as mock_getcwd:
209
+ mock_getcwd.return_value = '/curr/dir'
210
+ converted_image = image._convert_to_format('PNG')
211
+ self.assertIsInstance(converted_image, image_lib.Image)
212
+ self.assertEqual(mock_save.call_count, 2)
213
+ mock_save.assert_has_calls([
214
+ mock.call(mock.ANY, format='PNG'),
215
+ mock.call(mock.ANY, format='PNG'),
216
+ ])
217
+ mock_chdir.assert_has_calls([
218
+ mock.call('/tmp'),
219
+ mock.call('/curr/dir'),
220
+ ])
221
+
106
222
 
107
223
  if __name__ == '__main__':
108
224
  unittest.main()
@@ -15,6 +15,7 @@
15
15
 
16
16
  import base64
17
17
  import functools
18
+ import hashlib
18
19
  from typing import Annotated, Any, Iterable, Type, Union
19
20
  import langfun.core as lf
20
21
  # Placeholder for Google-internal internet access import.
@@ -36,7 +37,33 @@ def _detect_mime_type(content: bytes) -> str:
36
37
 
37
38
 
38
39
  class Mime(lf.Modality):
39
- """Base for MIME data."""
40
+ """Base class for representing modality data based on MIME types.
41
+
42
+ `lf.Mime` is a subclass of `lf.Modality` that serves as a base for
43
+ handling various data types like images, audio, video, and PDFs,
44
+ identified by their MIME types. It provides unified methods for
45
+ loading data from URIs or bytes (`.from_uri()`, `.from_bytes()`) and
46
+ for accessing content (`.to_bytes()`).
47
+
48
+ Subclasses like `lf.Image`, `lf.Audio`, `lf.Video`, and `lf.PDF`
49
+ specialize in handling specific MIME type prefixes (e.g., 'image/', 'audio/').
50
+
51
+ **Example:**
52
+
53
+ ```python
54
+ import langfun as lf
55
+
56
+ # Load an image from a path
57
+ image = lf.Image.from_path('/path/to/image.png')
58
+ print(image.mime_type)
59
+ # Output: image/png
60
+
61
+ # Create a text document
62
+ text = lf.Custom.from_bytes(b'hello world', mime='text/plain')
63
+ print(text.mime_type)
64
+ # Output: text/plain
65
+ ```
66
+ """
40
67
 
41
68
  # The regular expression that describes the MIME type str.
42
69
  # If None, the MIME type is dynamic. Subclass could override.
@@ -48,6 +75,10 @@ class Mime(lf.Modality):
48
75
  Union[str, bytes, None], 'The raw content of the MIME type.'
49
76
  ] = None
50
77
 
78
+ metadata: Annotated[
79
+ dict[str, Any], 'Additional metadata attached to this object.'
80
+ ] = {}
81
+
51
82
  @functools.cached_property
52
83
  def mime_type(self) -> str:
53
84
  """Returns the MIME type."""
@@ -87,6 +118,17 @@ class Mime(lf.Modality):
87
118
  """Returns True if the MIME type is a binary type."""
88
119
  return not self.is_text
89
120
 
121
+ @property
122
+ def hash(self) -> str:
123
+ """Returns the hash of the MIME content."""
124
+ # Hash the URI to avoid downloading the content.
125
+ if self.uri is not None:
126
+ return hashlib.md5(self.uri.encode()).hexdigest()[:8]
127
+ if self.content is not None:
128
+ return super().hash
129
+ assert self.metadata
130
+ return hashlib.md5(str(self.metadata).encode()).hexdigest()[:8]
131
+
90
132
  def to_text(self) -> str:
91
133
  """Returns the text content of the MIME type."""
92
134
  if not self.is_text:
@@ -132,7 +174,7 @@ class Mime(lf.Modality):
132
174
 
133
175
  def _on_bound(self):
134
176
  super()._on_bound()
135
- if self.uri is None and self.content is None:
177
+ if self.uri is None and self.content is None and not self.metadata:
136
178
  raise ValueError('Either uri or content must be provided.')
137
179
 
138
180
  def to_bytes(self) -> bytes:
@@ -162,6 +204,8 @@ class Mime(lf.Modality):
162
204
  return cls.class_from_mime_type(mime_type).from_bytes(content, **kwargs)
163
205
 
164
206
  if cls is Mime:
207
+ if 'youtube.com/watch' in uri:
208
+ return Custom(mime='text/html', uri=uri, **kwargs)
165
209
  content = cls.download(uri)
166
210
  mime = _detect_mime_type(content)
167
211
  return cls.class_from_mime_type(mime)(uri=uri, content=content, **kwargs)
@@ -272,7 +316,24 @@ class Mime(lf.Modality):
272
316
 
273
317
  @pg.use_init_args(['mime', 'content', 'uri'])
274
318
  class Custom(Mime):
275
- """Custom MIME data."""
319
+ """Represents content of a custom MIME type.
320
+
321
+ `lf.modalities.Custom` is useful for representing data with MIME types
322
+ that do not have dedicated classes like `lf.Image` or `lf.Audio`.
323
+
324
+ **Example:**
325
+
326
+ ```python
327
+ import langfun as lf
328
+
329
+ # Create a custom MIME object for plain text
330
+ text_data = lf.Custom.from_bytes(
331
+ b'This is a text document.', mime='text/plain'
332
+ )
333
+ print(text_data.mime_type)
334
+ # Output: text/plain
335
+ ```
336
+ """
276
337
 
277
338
  mime: Annotated[
278
339
  str, 'The MIME type of the data. E.g. text/plain, or image/png. '
@@ -109,6 +109,17 @@ class CustomMimeTest(unittest.TestCase):
109
109
  with self.assertRaisesRegex(ValueError, 'Unsupported encoding'):
110
110
  mime.Mime.from_uri('data:text/plain;base16,abcd')
111
111
 
112
+ # Test YouTube URI
113
+ yt_uri = 'https://www.youtube.com/watch?v=dQw4w9WgXcQ'
114
+ with mock.patch(
115
+ 'langfun.core.modalities.mime.Mime.download'
116
+ ) as mock_download:
117
+ content = mime.Mime.from_uri(yt_uri)
118
+ self.assertIsInstance(content, mime.Custom)
119
+ self.assertEqual(content.mime_type, 'text/html')
120
+ self.assertEqual(content.uri, yt_uri)
121
+ mock_download.assert_not_called()
122
+
112
123
  def assert_html_content(self, html, expected):
113
124
  expected = inspect.cleandoc(expected).strip()
114
125
  actual = html.content.strip()
@@ -17,6 +17,24 @@ from langfun.core.modalities import mime
17
17
 
18
18
 
19
19
  class PDF(mime.Mime):
20
- """PDF document."""
20
+ """Represents a PDF document for communicating with language models.
21
+
22
+ `lf.PDF` can be initialized from a URI (HTTP/HTTPS URL or local path)
23
+ using `lf.PDF.from_uri()` or from raw bytes using `lf.PDF.from_bytes()`.
24
+
25
+ **Example:**
26
+
27
+ ```python
28
+ import langfun as lf
29
+
30
+ # Load PDF from path
31
+ pdf = lf.PDF.from_path('/path/to/document.pdf')
32
+
33
+ # Use PDF in a prompt
34
+ prompt = lf.Template('Summarize this document: {{pdf}}', pdf=pdf)
35
+ response = lf.query(prompt, lm=lf.llms.Gemini25Flash())
36
+ print(response)
37
+ ```
38
+ """
21
39
 
22
40
  MIME_PREFIX = 'application/pdf'
@@ -18,7 +18,27 @@ from langfun.core.modalities import mime
18
18
 
19
19
 
20
20
  class Video(mime.Mime):
21
- """Video."""
21
+ """Represents a video for communicating with language models.
22
+
23
+ `lf.Video` can be initialized from a URI (HTTP/HTTPS URL or local path)
24
+ using `lf.Video.from_uri()` or from raw bytes using `lf.Video.from_bytes()`.
25
+
26
+ **Example:**
27
+
28
+ ```python
29
+ import langfun as lf
30
+
31
+ # Load video from path
32
+ video = lf.Video.from_path('/path/to/video.mp4')
33
+
34
+ # Use video in a prompt
35
+ prompt = lf.Template(
36
+ 'What is happening in this video? {{video}}', video=video
37
+ )
38
+ response = lf.query(prompt, lm=lf.llms.Gemini25Flash())
39
+ print(response)
40
+ ```
41
+ """
22
42
 
23
43
  MIME_PREFIX = 'video'
24
44
 
langfun/core/modality.py CHANGED
@@ -14,40 +14,63 @@
14
14
  """Interface for modality (e.g. Image, Video, etc.)."""
15
15
 
16
16
  import abc
17
+ import contextlib
17
18
  import functools
18
19
  import hashlib
19
- from typing import Any, ContextManager
20
+ import re
21
+ from typing import Any, ContextManager, Iterator
20
22
  from langfun.core import component
21
23
  import pyglove as pg
22
24
 
23
25
 
24
- _TLS_MODALITY_AS_REF = '__format_modality_as_ref__'
26
+ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
27
+ """Base class for representing non-text content in prompts.
25
28
 
29
+ `lf.Modality` is the base class for multimodal objects such as `lf.Image`,
30
+ `lf.Audio`, and `lf.Video`. It allows these non-text inputs to be
31
+ seamlessly embedded within text prompts for processing by multimodal
32
+ language models.
26
33
 
27
- def format_modality_as_ref(enabled: bool = True) -> ContextManager[None]:
28
- """A context manager that formats modality objects as references."""
29
- return pg.object_utils.thread_local_value_scope(
30
- _TLS_MODALITY_AS_REF, enabled, False
31
- )
34
+ When a `Modality` object is rendered within an `lf.Template`, it is
35
+ replaced by a text marker (e.g., `<<[[image:b10a8db1]]>>`), and the
36
+ modality object itself is stored in the `referred_modalities` field of
37
+ the resulting `lf.Message`. This allows language models to associate
38
+ the placeholder with its content during processing.
32
39
 
40
+ **Example:**
33
41
 
34
- class Modality(component.Component, pg.views.HtmlTreeView.Extension):
35
- """Base class for multimodal object."""
42
+ ```python
43
+ import langfun as lf
44
+
45
+ image = lf.Image.from_path('/path/to/image.png')
46
+ prompt = lf.Template('What is in this image? {{image}}', image=image)
47
+
48
+ message = prompt.render()
49
+ print(message.text)
50
+ # Output: What is in this image? <<[[image:b10a8db1]]>>
51
+
52
+ print(message.modalities())
53
+ # Output: [<Image object>]
54
+ ```
55
+ """
36
56
 
37
57
  REF_START = '<<[['
38
58
  REF_END = ']]>>'
39
59
 
40
60
  def _on_bound(self):
41
61
  super()._on_bound()
42
- # Invalidate cached hash if modality member is changed.
62
+ # Invalidate cached hash and id if modality member is changed.
43
63
  self.__dict__.pop('hash', None)
64
+ self.__dict__.pop('id', None)
44
65
 
45
66
  def format(self, *args, **kwargs) -> str:
46
- if self.referred_name is None or not pg.object_utils.thread_local_get(
47
- _TLS_MODALITY_AS_REF, False
48
- ):
67
+ if not pg.object_utils.thread_local_get(_TLS_MODALITY_AS_REF, False):
49
68
  return super().format(*args, **kwargs)
50
- return Modality.text_marker(self.referred_name)
69
+
70
+ capture_scope = get_modality_capture_context()
71
+ if capture_scope is not None:
72
+ capture_scope.capture(self)
73
+ return Modality.text_marker(self.id)
51
74
 
52
75
  def __str_kwargs__(self) -> dict[str, Any]:
53
76
  # For modality objects, we don't want to use markdown format when they
@@ -70,14 +93,11 @@ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
70
93
  """Returns a marker in the text for this object."""
71
94
  return Modality.REF_START + var_name + Modality.REF_END
72
95
 
73
- @property
74
- def referred_name(self) -> str | None:
96
+ @functools.cached_property
97
+ def id(self) -> str | None:
75
98
  """Returns the referred name of this object in its template."""
76
- if not self.sym_path:
77
- return None
78
- # Strip the metadata prefix under message.
79
- path = str(self.sym_path)
80
- return path[9:] if path.startswith('metadata.') else path
99
+ modality_type = _camel_to_snake(self.__class__.__name__)
100
+ return f'{modality_type}:{self.hash}'
81
101
 
82
102
  @classmethod
83
103
  def from_value(cls, value: pg.Symbolic) -> dict[str, 'Modality']:
@@ -86,7 +106,7 @@ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
86
106
  def _visit(k, v, p):
87
107
  del k, p
88
108
  if isinstance(v, Modality):
89
- modalities[v.referred_name] = v
109
+ modalities[v.id] = v
90
110
  return pg.TraverseAction.CONTINUE
91
111
  return pg.TraverseAction.ENTER
92
112
 
@@ -95,14 +115,47 @@ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
95
115
 
96
116
 
97
117
  class ModalityRef(pg.Object, pg.typing.CustomTyping):
98
- """References of modality objects in a symbolic tree.
118
+ """Lightweight placeholder for a `lf.Modality` object in a symbolic tree.
99
119
 
100
- `ModalityRef` was introduced to placehold modality objects in a symbolic
101
- tree, to prevent message from being chunked in the middle of a Python
102
- structure.
120
+ `ModalityRef` acts as a reference to a `Modality` object (like `lf.Image`
121
+ or `lf.Audio`) within a structured object hierarchy (e.g., a `pg.Object`).
122
+ Instead of embedding potentially large modality data directly, `ModalityRef`
123
+ stores only the ID of the modality object.
124
+
125
+ This is useful in scenarios where structured objects are serialized or
126
+ manipulated, and it's more efficient to refer to modalities by ID rather
127
+ than copying their content. The `lf.ModalityRef.placehold()` class method
128
+ can be used to replace `Modality` instances in a symbolic object with
129
+ `ModalityRef` placeholders, while `lf.ModalityRef.restore()` can reinstate
130
+ the original `Modality` objects using a lookup table.
131
+
132
+ **Example:**
133
+
134
+ ```python
135
+ import langfun as lf
136
+ import pyglove as pg
137
+
138
+ class ImagePair(pg.Object):
139
+ image1: lf.Image
140
+ image2: lf.Image
141
+
142
+ pair = ImagePair(
143
+ image1=lf.Image(content=b'abc'), image2=lf.Image(content=b'def')
144
+ )
145
+ modalities = lf.Modality.from_value(pair)
146
+
147
+ # Replace Image objects with ModalityRef placeholders
148
+ pair_with_refs = lf.ModalityRef.placehold(pair)
149
+ print(pair_with_refs.image1)
150
+ # Output: ModalityRef(id='image:d81e5a68')
151
+
152
+ # Restore Image objects from ModalityRef placeholders
153
+ pair_restored = lf.ModalityRef.restore(pair_with_refs, modalities)
154
+ assert pair_restored.image1.content == b'abc'
155
+ ```
103
156
  """
104
157
 
105
- name: str
158
+ id: str
106
159
 
107
160
  def custom_apply(
108
161
  self, path: pg.KeyPath, value_spec: pg.ValueSpec, *args, **kwargs
@@ -122,12 +175,97 @@ class ModalityRef(pg.Object, pg.typing.CustomTyping):
122
175
  """
123
176
 
124
177
  def _placehold(k, v, p):
125
- del p
178
+ del k, p
126
179
  if isinstance(v, Modality):
127
- return ModalityRef(name=value.sym_path + k)
180
+ return ModalityRef(id=v.id)
128
181
  return v
129
182
  return value.clone().rebind(_placehold, raise_on_no_change=False)
130
183
 
184
+ @classmethod
185
+ def restore(cls, value: pg.Symbolic, modalities: dict[str, Modality]) -> Any:
186
+ """Returns a copy of value by replacing refs with modality objects."""
187
+ def _restore(k, v, p):
188
+ del k, p
189
+ if isinstance(v, ModalityRef):
190
+ modality_object = modalities.get(v.id)
191
+ if modality_object is None:
192
+ raise ValueError(
193
+ f'Modality {v.id} not found in modalities {modalities.keys()}'
194
+ )
195
+ return modality_object
196
+ return v
197
+ return value.rebind(_restore, raise_on_no_change=False)
198
+
131
199
 
132
200
  class ModalityError(RuntimeError): # pylint: disable=g-bad-exception-name
133
201
  """Exception raised when modality is not supported."""
202
+
203
+
204
+ #
205
+ # Context managers to deal with modality objects.
206
+ #
207
+
208
+
209
+ _TLS_MODALITY_CAPTURE_SCOPE = '__modality_capture_scope__'
210
+ _TLS_MODALITY_AS_REF = '__format_modality_as_ref__'
211
+
212
+
213
+ def format_modality_as_ref(enabled: bool = True) -> ContextManager[None]:
214
+ """A context manager that formats modality objects as references."""
215
+ return pg.object_utils.thread_local_value_scope(
216
+ _TLS_MODALITY_AS_REF, enabled, False
217
+ )
218
+
219
+
220
+ class _ModalityCaptureContext:
221
+ """A context to capture modality objects when being rendered."""
222
+
223
+ def __init__(self):
224
+ self._references: dict[str, pg.Ref[Modality]] = {}
225
+
226
+ def capture(self, modality: Modality) -> None:
227
+ """Captures the modality object."""
228
+ self._references[modality.id] = pg.Ref(modality)
229
+
230
+ @property
231
+ def references(self) -> dict[str, pg.Ref[Modality]]:
232
+ """Returns the modality references captured in this context."""
233
+ return self._references
234
+
235
+
236
+ @contextlib.contextmanager
237
+ def capture_rendered_modalities() -> Iterator[dict[str, pg.Ref[Modality]]]:
238
+ """Capture modality objects whose references is being rendered.
239
+
240
+ Example:
241
+ ```
242
+ image = lf.Image.from_url(...)
243
+ with lf.modality.capture_rendered_modalities() as rendered_modalities:
244
+ with lf.modality.format_modality_as_ref():
245
+ print(f'Hello {image}')
246
+ self.assertEqual(rendered_modalities, {'image:<hash>': pg.Ref(image)})
247
+ ```
248
+ """
249
+ context = get_modality_capture_context()
250
+ top_level = context is None
251
+ if top_level:
252
+ context = _ModalityCaptureContext()
253
+ pg.object_utils.thread_local_set(_TLS_MODALITY_CAPTURE_SCOPE, context)
254
+
255
+ try:
256
+ yield context.references # pylint: disable=attribute-error
257
+ finally:
258
+ if top_level:
259
+ pg.object_utils.thread_local_del(_TLS_MODALITY_CAPTURE_SCOPE)
260
+
261
+
262
+ def get_modality_capture_context() -> _ModalityCaptureContext | None:
263
+ """Returns the current modality capture context."""
264
+ return pg.object_utils.thread_local_get(_TLS_MODALITY_CAPTURE_SCOPE, None)
265
+
266
+
267
+ def _camel_to_snake(name: str) -> str:
268
+ """Converts a camelCase name to snake_case."""
269
+ return re.sub(
270
+ pattern=r'([A-Z]+)', repl=r'_\1', string=name
271
+ ).lower().lstrip('_')
@@ -29,34 +29,64 @@ class ModalityTest(unittest.TestCase):
29
29
 
30
30
  def test_basic(self):
31
31
  v = CustomModality('a')
32
- self.assertIsNone(v.referred_name)
32
+ self.assertEqual(v.id, 'custom_modality:0cc175b9')
33
33
  self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
34
34
  self.assertEqual(v.hash, '0cc175b9')
35
35
 
36
36
  _ = pg.Dict(metadata=pg.Dict(x=pg.Dict(metadata=pg.Dict(y=v))))
37
- self.assertEqual(v.referred_name, 'x.metadata.y')
37
+ self.assertEqual(v.id, 'custom_modality:0cc175b9')
38
38
  self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
39
39
  with modality.format_modality_as_ref():
40
- self.assertEqual(str(v), '<<[[x.metadata.y]]>>')
40
+ self.assertEqual(str(v), '<<[[custom_modality:0cc175b9]]>>')
41
+
42
+ def test_capture_rendered_modalities(self):
43
+ x = CustomModality('a')
44
+ y = CustomModality('b')
45
+ z = CustomModality('b')
46
+
47
+ with modality.capture_rendered_modalities() as rendered_modalities:
48
+ with modality.format_modality_as_ref():
49
+ self.assertEqual(
50
+ f'Hello {x} {y} {z}',
51
+ (
52
+ 'Hello <<[[custom_modality:0cc175b9]]>> '
53
+ '<<[[custom_modality:92eb5ffe]]>> '
54
+ '<<[[custom_modality:92eb5ffe]]>>'
55
+ )
56
+ )
57
+ self.assertEqual(len(rendered_modalities), 2)
58
+ self.assertIs(rendered_modalities['custom_modality:0cc175b9'].value, x)
59
+ # y and z share the same content will be treated as the same object.
60
+ self.assertIs(rendered_modalities['custom_modality:92eb5ffe'].value, z)
41
61
 
42
62
 
43
63
  class ModalityRefTest(unittest.TestCase):
44
64
 
45
- def test_placehold(self):
65
+ def test_placehold_and_restore(self):
46
66
  class A(pg.Object):
47
67
  x: Any
48
68
  y: Any
49
69
 
50
- a = A(x=dict(z=CustomModality('a')), y=CustomModality('b'))
70
+ image_a = CustomModality('a')
71
+ image_b = CustomModality('b')
72
+ a = A(x=dict(z=image_a), y=image_b)
73
+ a_placehold = modality.ModalityRef.placehold(a)
51
74
  self.assertEqual(
52
- modality.ModalityRef.placehold(a),
53
- A(x=dict(z=modality.ModalityRef('x.z')), y=modality.ModalityRef('y')),
75
+ a_placehold,
76
+ A(x=dict(z=modality.ModalityRef(image_a.id)),
77
+ y=modality.ModalityRef(image_b.id)),
78
+ )
79
+ a_restore = modality.ModalityRef.restore(
80
+ a_placehold.clone(),
81
+ {image_a.id: image_a, image_b.id: image_b},
54
82
  )
83
+ self.assertTrue(pg.eq(a_restore, a))
55
84
  self.assertEqual(
56
85
  modality.ModalityRef.placehold(a.x),
57
- # The prefix 'x' of referred name is preserved.
58
- dict(z=modality.ModalityRef('x.z')),
86
+ dict(z=modality.ModalityRef(image_a.id)),
59
87
  )
88
+ with self.assertRaisesRegex(ValueError, 'Modality .* not found'):
89
+ modality.ModalityRef.restore(a_placehold, {image_a.id: image_a})
60
90
 
61
91
  def test_from_value(self):
62
92
  class A(pg.Object):
@@ -68,8 +98,8 @@ class ModalityRefTest(unittest.TestCase):
68
98
  pg.eq(
69
99
  modality.Modality.from_value(a),
70
100
  {
71
- 'x.z': CustomModality('a'),
72
- 'y': CustomModality('b'),
101
+ 'custom_modality:0cc175b9': CustomModality('a'),
102
+ 'custom_modality:92eb5ffe': CustomModality('b'),
73
103
  },
74
104
  )
75
105
  )
@@ -77,7 +107,7 @@ class ModalityRefTest(unittest.TestCase):
77
107
  pg.eq(
78
108
  modality.Modality.from_value(a.x.z),
79
109
  {
80
- 'x.z': CustomModality('a'),
110
+ 'custom_modality:0cc175b9': CustomModality('a'),
81
111
  },
82
112
  )
83
113
  )
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Natural language utilities."""
14
+ """Natural language formatting."""
15
15
 
16
16
  import abc
17
17
  import pyglove as pg