langfun 0.1.2.dev202509020804__py3-none-any.whl → 0.1.2.dev202511110805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (133) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +6 -1
  3. langfun/core/agentic/__init__.py +4 -0
  4. langfun/core/agentic/action.py +412 -103
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +68 -6
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +9 -2
  20. langfun/core/data/conversion/gemini_test.py +12 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +47 -43
  24. langfun/core/eval/base_test.py +4 -4
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +1 -0
  29. langfun/core/eval/v2/checkpointing.py +30 -4
  30. langfun/core/eval/v2/eval_test_helper.py +1 -1
  31. langfun/core/eval/v2/evaluation.py +60 -14
  32. langfun/core/eval/v2/example.py +22 -11
  33. langfun/core/eval/v2/experiment.py +51 -8
  34. langfun/core/eval/v2/metric_values.py +31 -3
  35. langfun/core/eval/v2/metric_values_test.py +32 -0
  36. langfun/core/eval/v2/metrics.py +39 -4
  37. langfun/core/eval/v2/metrics_test.py +14 -0
  38. langfun/core/eval/v2/progress.py +30 -1
  39. langfun/core/eval/v2/progress_test.py +27 -0
  40. langfun/core/eval/v2/progress_tracking_test.py +6 -0
  41. langfun/core/eval/v2/reporting.py +90 -71
  42. langfun/core/eval/v2/reporting_test.py +20 -6
  43. langfun/core/eval/v2/runners.py +27 -7
  44. langfun/core/eval/v2/runners_test.py +3 -0
  45. langfun/core/langfunc.py +45 -130
  46. langfun/core/langfunc_test.py +6 -4
  47. langfun/core/language_model.py +151 -31
  48. langfun/core/language_model_test.py +9 -3
  49. langfun/core/llms/__init__.py +12 -1
  50. langfun/core/llms/anthropic.py +157 -2
  51. langfun/core/llms/azure_openai.py +29 -17
  52. langfun/core/llms/cache/base.py +25 -3
  53. langfun/core/llms/cache/in_memory.py +48 -7
  54. langfun/core/llms/cache/in_memory_test.py +14 -4
  55. langfun/core/llms/compositional.py +25 -1
  56. langfun/core/llms/deepseek.py +30 -2
  57. langfun/core/llms/fake.py +39 -1
  58. langfun/core/llms/fake_test.py +9 -0
  59. langfun/core/llms/gemini.py +43 -7
  60. langfun/core/llms/google_genai.py +34 -1
  61. langfun/core/llms/groq.py +28 -3
  62. langfun/core/llms/llama_cpp.py +23 -4
  63. langfun/core/llms/openai.py +93 -3
  64. langfun/core/llms/openai_compatible.py +148 -27
  65. langfun/core/llms/openai_compatible_test.py +207 -20
  66. langfun/core/llms/openai_test.py +0 -2
  67. langfun/core/llms/rest.py +16 -1
  68. langfun/core/llms/vertexai.py +59 -8
  69. langfun/core/logging.py +1 -1
  70. langfun/core/mcp/__init__.py +10 -0
  71. langfun/core/mcp/client.py +177 -0
  72. langfun/core/mcp/client_test.py +71 -0
  73. langfun/core/mcp/session.py +241 -0
  74. langfun/core/mcp/session_test.py +54 -0
  75. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  76. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  77. langfun/core/mcp/tool.py +256 -0
  78. langfun/core/mcp/tool_test.py +197 -0
  79. langfun/core/memory.py +1 -0
  80. langfun/core/message.py +160 -55
  81. langfun/core/message_test.py +65 -81
  82. langfun/core/modalities/__init__.py +8 -0
  83. langfun/core/modalities/audio.py +21 -1
  84. langfun/core/modalities/image.py +19 -1
  85. langfun/core/modalities/mime.py +62 -3
  86. langfun/core/modalities/pdf.py +19 -1
  87. langfun/core/modalities/video.py +21 -1
  88. langfun/core/modality.py +167 -29
  89. langfun/core/modality_test.py +42 -12
  90. langfun/core/natural_language.py +1 -1
  91. langfun/core/sampling.py +4 -4
  92. langfun/core/sampling_test.py +20 -4
  93. langfun/core/structured/completion.py +34 -44
  94. langfun/core/structured/completion_test.py +23 -43
  95. langfun/core/structured/description.py +54 -50
  96. langfun/core/structured/function_generation.py +29 -12
  97. langfun/core/structured/mapping.py +74 -28
  98. langfun/core/structured/parsing.py +90 -74
  99. langfun/core/structured/parsing_test.py +0 -3
  100. langfun/core/structured/querying.py +242 -156
  101. langfun/core/structured/querying_test.py +95 -64
  102. langfun/core/structured/schema.py +70 -10
  103. langfun/core/structured/schema_generation.py +33 -14
  104. langfun/core/structured/scoring.py +45 -34
  105. langfun/core/structured/tokenization.py +24 -9
  106. langfun/core/subscription.py +2 -2
  107. langfun/core/template.py +175 -50
  108. langfun/core/template_test.py +123 -17
  109. langfun/env/__init__.py +43 -0
  110. langfun/env/base_environment.py +827 -0
  111. langfun/env/base_environment_test.py +473 -0
  112. langfun/env/base_feature.py +304 -0
  113. langfun/env/base_feature_test.py +228 -0
  114. langfun/env/base_sandbox.py +842 -0
  115. langfun/env/base_sandbox_test.py +1235 -0
  116. langfun/env/event_handlers/__init__.py +14 -0
  117. langfun/env/event_handlers/chain.py +233 -0
  118. langfun/env/event_handlers/chain_test.py +253 -0
  119. langfun/env/event_handlers/event_logger.py +472 -0
  120. langfun/env/event_handlers/event_logger_test.py +304 -0
  121. langfun/env/event_handlers/metric_writer.py +726 -0
  122. langfun/env/event_handlers/metric_writer_test.py +214 -0
  123. langfun/env/interface.py +1640 -0
  124. langfun/env/interface_test.py +151 -0
  125. langfun/env/load_balancers.py +59 -0
  126. langfun/env/load_balancers_test.py +139 -0
  127. langfun/env/test_utils.py +497 -0
  128. {langfun-0.1.2.dev202509020804.dist-info → langfun-0.1.2.dev202511110805.dist-info}/METADATA +7 -3
  129. langfun-0.1.2.dev202511110805.dist-info/RECORD +200 -0
  130. langfun-0.1.2.dev202509020804.dist-info/RECORD +0 -172
  131. {langfun-0.1.2.dev202509020804.dist-info → langfun-0.1.2.dev202511110805.dist-info}/WHEEL +0 -0
  132. {langfun-0.1.2.dev202509020804.dist-info → langfun-0.1.2.dev202511110805.dist-info}/licenses/LICENSE +0 -0
  133. {langfun-0.1.2.dev202509020804.dist-info → langfun-0.1.2.dev202511110805.dist-info}/top_level.txt +0 -0
langfun/core/component.py CHANGED
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """langfun Component."""
14
+ """Base component for Langfun."""
15
15
 
16
16
  from typing import ContextManager
17
17
  import pyglove as pg
@@ -22,7 +22,37 @@ RAISE_IF_HAS_ERROR = (pg.MISSING_VALUE,)
22
22
 
23
23
 
24
24
  class Component(pg.ContextualObject):
25
- """Base class for langfun components."""
25
+ """Base class for Langfun components.
26
+
27
+ Langfun components are context-aware symbolic objects powered by PyGlove.
28
+ (See [PyGlove basics](https://pyglove.readthedocs.io/en/latest/basics.html)
29
+ for more details).
30
+
31
+ **Context-awareness**
32
+
33
+ Langfun components can have contextual attributes using `lf.contextual`,
34
+ whose values can be provided or overridden via `lf.context` or
35
+ `lf.use_settings`.
36
+
37
+ Example:
38
+ ```python
39
+ import langfun as lf
40
+
41
+ class Bar(lf.Component):
42
+ y = lf.contextual(1)
43
+
44
+ class Foo(lf.Component):
45
+ x = lf.contextual(0)
46
+ bar = Bar()
47
+
48
+ f = Foo()
49
+ assert f.x == 0 and f.bar.y == 1
50
+
51
+ # `lf.context` overrides `lf.contextual` attributes.
52
+ with lf.context(x=10, y=20):
53
+ assert f.x == 10 and f.bar.y == 20
54
+ ```
55
+ """
26
56
 
27
57
  # Allow symbolic assignment, which invalidates the object and recomputes
28
58
  # states upon update.
@@ -78,6 +108,15 @@ def use_settings(
78
108
  ) -> ContextManager[dict[str, pg.utils.ContextualOverride]]:
79
109
  """Shortcut method for overriding component attributes.
80
110
 
111
+ Example:
112
+
113
+ ```
114
+ with lf.use_settings(
115
+ lm=lf.llms.Gpt35(),
116
+ temperature=0.0):
117
+ lf.query('who are you?')
118
+ ```
119
+
81
120
  Args:
82
121
  cascade: If True, this override will apply to both current scope and nested
83
122
  scope, meaning that this `lf.context` will take precedence over all
@@ -85,6 +124,6 @@ def use_settings(
85
124
  **settings: Key/values as override for component attributes.
86
125
 
87
126
  Returns:
88
- A dict of attribute names to their contextual overrides.
127
+ A context manager for overriding settings.
89
128
  """
90
129
  return context(cascade=cascade, override_attrs=True, **settings)
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Utility library for handling concurrency in langfun."""
14
+ """Utilities for concurrency in Langfun."""
15
15
 
16
16
  import abc
17
17
  import collections
@@ -97,7 +97,7 @@ class RetryError(RuntimeError):
97
97
 
98
98
 
99
99
  def with_retry(
100
- func: Callable[[Any], Any],
100
+ func: Callable[..., Any],
101
101
  retry_on_errors: Union[
102
102
  Union[Type[BaseException], Tuple[Type[BaseException], str]],
103
103
  Sequence[Union[Type[BaseException], Tuple[Type[BaseException], str]]],
@@ -108,10 +108,25 @@ def with_retry(
108
108
  max_retry_interval: int = 300,
109
109
  seed: int | None = None,
110
110
  ) -> Callable[..., Any]:
111
- """Derives a user function with retry on error.
111
+ """Decorator-like function to add retry mechanism to a function.
112
+
113
+ Example:
114
+
115
+ ```
116
+ def flaky_function():
117
+ if random.random() < 0.5:
118
+ raise ValueError('error')
119
+ return 1
120
+
121
+ reliable_function = lf.with_retry(
122
+ flaky_function,
123
+ retry_on_errors=ValueError,
124
+ max_attempts=3)
125
+ reliable_function()
126
+ ```
112
127
 
113
128
  Args:
114
- func: A user function.
129
+ func: The function to add retry mechanism.
115
130
  retry_on_errors: A sequence of exception types or tuples of exception type
116
131
  and error messages (described in regular expression) as the desired
117
132
  exception types to retry.
@@ -128,8 +143,7 @@ def with_retry(
128
143
  determined based on current time.
129
144
 
130
145
  Returns:
131
- A function with the same signature of the input function, with the retry
132
- capability.
146
+ A function with the same signature of `func`, but with retry capability.
133
147
  """
134
148
 
135
149
  def _func(*args, **kwargs):
@@ -179,6 +193,24 @@ def concurrent_execute(
179
193
  ) -> list[Any]:
180
194
  """Executes a function concurrently under current component context.
181
195
 
196
+ `lf.concurrent_execute` applies a function to each item in an iterable of
197
+ inputs in parallel and returns a list of results in the same order as the
198
+ inputs. It is a convenient wrapper around `lf.concurrent_map` for synchronous
199
+ bulk processing.
200
+
201
+ **Example:**
202
+
203
+ ```python
204
+ import langfun as lf
205
+
206
+ def square(x):
207
+ return x ** 2
208
+
209
+ results = lf.concurrent_execute(square, [1, 2, 3, 4], max_workers=2)
210
+ print(results)
211
+ # Output: [1, 4, 9, 16]
212
+ ```
213
+
182
214
  Args:
183
215
  func: A user function.
184
216
  parallel_inputs: The inputs for `func` which will be processed in parallel.
@@ -649,6 +681,38 @@ def concurrent_map(
649
681
  ) -> Iterator[Any]:
650
682
  """Maps inputs to outptus via func concurrently under current context.
651
683
 
684
+ `lf.concurrent_map` applies a function to each item in an iterable of
685
+ inputs in parallel and yields `(input, output, error)` tuples as they are
686
+ completed. It supports features like ordered/unordered results, progress
687
+ bars, timeouts, and automatic retries for transient errors.
688
+
689
+ **Example:**
690
+
691
+ ```python
692
+ import langfun as lf
693
+ import time
694
+ import random
695
+
696
+ def flaky_square(x):
697
+ time.sleep(random.random())
698
+ if random.random() < 0.3:
699
+ raise ValueError("Flaky error")
700
+ return x ** 2
701
+
702
+ # Unordered execution with progress bar and retries
703
+ for input, output, error in lf.concurrent_map(
704
+ flaky_square,
705
+ range(10),
706
+ max_workers=3,
707
+ show_progress=True,
708
+ retry_on_errors=ValueError,
709
+ max_attempts=3):
710
+ if error:
711
+ print(f"Input {input} failed with error: {error}")
712
+ else:
713
+ print(f"Input {input} succeeded with output: {output}")
714
+ ```
715
+
652
716
  Args:
653
717
  func: A user function.
654
718
  parallel_inputs: The inputs for `func` which will be processed in parallel.
@@ -262,6 +262,7 @@ class ProgressControlTest(unittest.TestCase):
262
262
  with contextlib.redirect_stderr(string_io):
263
263
  ctrl.update(1)
264
264
  ctrl.refresh()
265
+ sys.stderr.flush()
265
266
  self.assertEqual(string_io.getvalue(), '')
266
267
  concurrent.progress_bar = 'tqdm'
267
268
 
@@ -274,6 +275,7 @@ class ProgressControlTest(unittest.TestCase):
274
275
  ctrl.set_status('bar')
275
276
  ctrl.update(10)
276
277
  ctrl.refresh()
278
+ sys.stderr.flush()
277
279
  self.assertEqual(
278
280
  string_io.getvalue(),
279
281
  '\x1b[1m\x1b[31mfoo\x1b[0m: \x1b[34m10% (10/100)\x1b[0m : bar\n'
@@ -288,6 +290,7 @@ class ProgressControlTest(unittest.TestCase):
288
290
  self.assertIsInstance(ctrl, concurrent._TqdmProgressControl)
289
291
  ctrl.update(10)
290
292
  ctrl.refresh()
293
+ sys.stderr.flush()
291
294
  self.assertIn('10/100', string_io.getvalue())
292
295
 
293
296
  tqdm = concurrent.tqdm
@@ -316,6 +319,7 @@ class ProgressBarTest(unittest.TestCase):
316
319
  for _ in concurrent.concurrent_execute(fun, range(5)):
317
320
  concurrent.ProgressBar.refresh()
318
321
  concurrent.ProgressBar.uninstall(bar_id)
322
+ sys.stderr.flush()
319
323
  output_str = string_io.getvalue()
320
324
  self.assertIn('100%', output_str)
321
325
  self.assertIn('5/5', output_str)
@@ -332,7 +336,7 @@ class ProgressBarTest(unittest.TestCase):
332
336
  concurrent.ProgressBar.update(bar_id, 0, status=1)
333
337
  concurrent.ProgressBar.uninstall(bar_id)
334
338
  sys.stderr.flush()
335
- time.sleep(1)
339
+ time.sleep(1)
336
340
  self.assertIn('1/4', string_io.getvalue())
337
341
  # TODO(daiyip): Re-enable once flakiness is fixed.
338
342
  # self.assertIn('2/4', string_io.getvalue())
@@ -564,7 +568,8 @@ class ConcurrentMapTest(unittest.TestCase):
564
568
  fun, [1, 2, 3], timeout=1.5, max_workers=1, show_progress=True
565
569
  )
566
570
  ], key=lambda x: x[0])
567
- string_io.flush()
571
+ sys.stderr.flush()
572
+
568
573
  self.assertEqual( # pylint: disable=g-generic-assert
569
574
  output,
570
575
  [
@@ -592,6 +597,7 @@ class ConcurrentMapTest(unittest.TestCase):
592
597
  show_progress=bar_id, status_fn=lambda p: dict(x=1, y=1)
593
598
  )
594
599
  ], key=lambda x: x[0])
600
+ sys.stderr.flush()
595
601
 
596
602
  self.assertEqual( # pylint: disable=g-generic-assert
597
603
  output,
@@ -602,6 +608,7 @@ class ConcurrentMapTest(unittest.TestCase):
602
608
  ],
603
609
  )
604
610
  concurrent.ProgressBar.uninstall(bar_id)
611
+ concurrent.ProgressBar.refresh()
605
612
  self.assertIn('100%', string_io.getvalue())
606
613
 
607
614
 
langfun/core/console.py CHANGED
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Console utilities."""
14
+ """Utilities for console output and notebook display."""
15
15
 
16
16
  import sys
17
17
  from typing import Any
@@ -21,7 +21,14 @@ from langfun.core import modalities as lf_modalities
21
21
 
22
22
 
23
23
  class AnthropicMessageConverter(lf.MessageConverter):
24
- """Converter to Anthropic public API."""
24
+ """Converter between Langfun messages and Anthropic API message format.
25
+
26
+ This converter translates `lf.Message` objects into the JSON format required
27
+ by the Anthropic API and vice versa. It handles text and modalities like
28
+ images and PDFs by encoding them in base64 format as expected by Anthropic.
29
+ An optional `chunk_preprocessor` can be provided to modify or filter
30
+ chunks before conversion.
31
+ """
25
32
 
26
33
  FORMAT_ID = 'anthropic'
27
34
 
@@ -30,12 +37,12 @@ class AnthropicMessageConverter(lf.MessageConverter):
30
37
  (
31
38
  'Chunk preprocessor for Langfun chunk to Anthropic chunk conversion. '
32
39
  'It will be applied before each Langfun chunk is converted. '
33
- 'If returns None, the chunk will be skipped.'
40
+ 'If it returns None, the chunk will be skipped.'
34
41
  )
35
42
  ] = None
36
43
 
37
44
  def to_value(self, message: lf.Message) -> dict[str, Any]:
38
- """Converts a Langfun message to Gemini API."""
45
+ """Converts a Langfun message to Anthropic API."""
39
46
  content = []
40
47
  for chunk in message.chunk():
41
48
  if self.chunk_preprocessor:
@@ -97,6 +104,8 @@ class AnthropicMessageConverter(lf.MessageConverter):
97
104
  self._safe_read(source, 'media_type')
98
105
  ).from_bytes(base64.b64decode(self._safe_read(source, 'data')))
99
106
  )
107
+ elif t in ('server_tool_use', 'web_search_tool_result'):
108
+ continue
100
109
  else:
101
110
  raise ValueError(f'Unsupported content part: {part!r}.')
102
111
  message = message_cls.from_chunks(chunks)
@@ -253,14 +253,16 @@ class AnthropicConversionTest(unittest.TestCase):
253
253
  )
254
254
  self.assertEqual(
255
255
  m.text,
256
- 'What are the common words from <<[[obj0]]>> and <<[[obj1]]>> ?'
256
+ 'What are the common words from <<[[image:dc6e1e43]]>> and'
257
+ ' <<[[pdf:5daf5f31]]>> ?'
257
258
  )
258
- self.assertIsInstance(m.obj0, lf_modalities.Image)
259
- self.assertEqual(m.obj0.mime_type, 'image/png')
260
- self.assertEqual(m.obj0.to_bytes(), image_content)
259
+ modalities = m.modalities()
260
+ self.assertIsInstance(modalities[0], lf_modalities.Image)
261
+ self.assertEqual(modalities[0].mime_type, 'image/png')
262
+ self.assertEqual(modalities[0].content, image_content)
261
263
 
262
- self.assertIsInstance(m.obj1, lf_modalities.PDF)
263
- self.assertEqual(m.obj1.to_bytes(), pdf_content)
264
+ self.assertIsInstance(modalities[1], lf_modalities.PDF)
265
+ self.assertEqual(modalities[1].content, pdf_content)
264
266
 
265
267
 
266
268
  if __name__ == '__main__':
@@ -21,7 +21,14 @@ from langfun.core import modalities as lf_modalities
21
21
 
22
22
 
23
23
  class GeminiMessageConverter(lf.MessageConverter):
24
- """Converter to Gemini public API."""
24
+ """Converter between Langfun messages and Gemini API message format.
25
+
26
+ This converter translates `lf.Message` objects into the JSON format required
27
+ by the public Gemini API (e.g., via Vertex AI or Google AI Studio) and
28
+ vice versa. It handles text and modalities like images, extracting thought
29
+ chunks if present. An optional `chunk_preprocessor` can be provided to
30
+ modify or filter chunks before conversion.
31
+ """
25
32
 
26
33
  FORMAT_ID = 'gemini'
27
34
 
@@ -30,7 +37,7 @@ class GeminiMessageConverter(lf.MessageConverter):
30
37
  (
31
38
  'Chunk preprocessor for Langfun chunk to Gemini chunk conversion. '
32
39
  'It will be applied before each Langfun chunk is converted. '
33
- 'If returns None, the chunk will be skipped.'
40
+ 'If it returns None, the chunk will be skipped.'
34
41
  ),
35
42
  ] = None
36
43
 
@@ -225,19 +225,22 @@ class GeminiConversionTest(unittest.TestCase):
225
225
  self.assertEqual(
226
226
  m.text,
227
227
  (
228
- 'What are the common words from <<[[obj0]]>> , <<[[obj1]]>> '
229
- 'and <<[[obj2]]>> ?'
228
+ 'What are the common words from <<[[image:dc6e1e43]]>> , '
229
+ '<<[[pdf:4dc12e93]]>> and <<[[video:7e169565]]>> ?'
230
230
  )
231
231
  )
232
- self.assertIsInstance(m.obj0, lf_modalities.Image)
233
- self.assertEqual(m.obj0.mime_type, 'image/png')
234
- self.assertEqual(m.obj0.to_bytes(), image_content)
232
+ self.assertIsInstance(m.modalities()[0], lf_modalities.Image)
233
+ self.assertEqual(m.modalities()[0].mime_type, 'image/png')
234
+ self.assertEqual(m.modalities()[0].to_bytes(), image_content)
235
235
 
236
- self.assertIsInstance(m.obj1, lf_modalities.PDF)
237
- self.assertEqual(m.obj1.uri, 'https://my.pdf')
236
+ self.assertIsInstance(m.modalities()[1], lf_modalities.PDF)
237
+ self.assertEqual(m.modalities()[1].uri, 'https://my.pdf')
238
238
 
239
- self.assertIsInstance(m.obj2, lf_modalities.Video)
240
- self.assertEqual(m.obj2.uri, 'https://www.youtube.com/watch?v=abcd')
239
+ self.assertIsInstance(m.modalities()[2], lf_modalities.Video)
240
+ self.assertEqual(
241
+ m.modalities()[2].uri,
242
+ 'https://www.youtube.com/watch?v=abcd'
243
+ )
241
244
 
242
245
 
243
246
  if __name__ == '__main__':
@@ -19,17 +19,25 @@ import langfun.core as lf
19
19
  from langfun.core import modalities as lf_modalities
20
20
 
21
21
 
22
- class OpenAIMessageConverter(lf.MessageConverter):
23
- """Converter to OpenAI API."""
22
+ class OpenAIChatCompletionAPIMessageConverter(lf.MessageConverter):
23
+ """Converter for OpenAI Chat Completion API.
24
24
 
25
- FORMAT_ID = 'openai'
25
+ This converter translates `lf.Message` objects into the JSON format
26
+ required by the OpenAI Chat Completions API
27
+ (https://platform.openai.com/docs/api-reference/chat) and vice versa.
28
+ It handles text and image modalities, mapping Langfun roles to OpenAI
29
+ roles ('system', 'user', 'assistant'). An optional `chunk_preprocessor`
30
+ can be provided to modify or filter chunks before conversion.
31
+ """
32
+
33
+ FORMAT_ID = 'openai_chat_completion_api'
26
34
 
27
35
  chunk_preprocessor: Annotated[
28
36
  Callable[[str | lf.Modality], Any] | None,
29
37
  (
30
38
  'Chunk preprocessor for Langfun chunk to OpenAI chunk conversion. '
31
39
  'It will be applied before each Langfun chunk is converted. '
32
- 'If returns None, the chunk will be skipped.'
40
+ 'If it returns None, the chunk will be skipped.'
33
41
  )
34
42
  ] = None
35
43
 
@@ -41,22 +49,29 @@ class OpenAIMessageConverter(lf.MessageConverter):
41
49
  chunk = self.chunk_preprocessor(chunk)
42
50
  if chunk is None:
43
51
  continue
44
-
45
- if isinstance(chunk, str):
46
- item = dict(type='text', text=chunk)
47
- elif isinstance(chunk, lf_modalities.Image):
48
- item = dict(
49
- type='image_url', image_url=dict(url=chunk.embeddable_uri)
50
- )
51
- # TODO(daiyip): Support audio_input.
52
- else:
53
- raise ValueError(f'Unsupported content type: {chunk!r}.')
54
- parts.append(item)
52
+ parts.append(self.chunk_to_json(type(message), chunk))
55
53
  return dict(
56
54
  role=self.get_role(message),
57
55
  content=parts,
58
56
  )
59
57
 
58
+ def chunk_to_json(
59
+ self,
60
+ message_cls: type[lf.Message],
61
+ chunk: str | lf.Modality
62
+ ) -> dict[str, Any]:
63
+ """Converts a Langfun chunk to OpenAI chunk."""
64
+ del message_cls
65
+ if isinstance(chunk, str):
66
+ return dict(type='text', text=chunk)
67
+ elif isinstance(chunk, lf_modalities.Image):
68
+ return dict(
69
+ type='image_url', image_url=dict(url=chunk.embeddable_uri)
70
+ )
71
+ # TODO(daiyip): Support audio_input.
72
+ else:
73
+ raise ValueError(f'Unsupported content type: {chunk!r}.')
74
+
60
75
  def get_role(self, message: lf.Message) -> str:
61
76
  """Returns the role of the message."""
62
77
  if isinstance(message, lf.SystemMessage):
@@ -92,40 +107,139 @@ class OpenAIMessageConverter(lf.MessageConverter):
92
107
  assert isinstance(content, list)
93
108
  chunks = []
94
109
  for item in content:
95
- t = self._safe_read(item, 'type')
96
- if t == 'text':
97
- chunk = self._safe_read(item, 'text')
98
- elif t == 'image_url':
99
- chunk = lf_modalities.Image.from_uri(
100
- self._safe_read(self._safe_read(item, 'image_url'), 'url')
101
- )
102
- else:
103
- raise ValueError(f'Unsupported content type: {item!r}.')
104
- chunks.append(chunk)
110
+ chunks.append(self.json_to_chunk(item))
105
111
  return message_cls.from_chunks(chunks)
106
112
 
113
+ def json_to_chunk(self, json: dict[str, Any]) -> str | lf.Modality:
114
+ """Returns a Langfun chunk from OpenAI chunk JSON."""
115
+ t = self._safe_read(json, 'type')
116
+ if t == 'text':
117
+ return self._safe_read(json, 'text')
118
+ elif t == 'image_url':
119
+ return lf_modalities.Image.from_uri(
120
+ self._safe_read(self._safe_read(json, 'image_url'), 'url')
121
+ )
122
+ else:
123
+ raise ValueError(f'Unsupported content type: {json!r}.')
124
+
107
125
 
108
- def _as_openai_format(
126
+ def _as_openai_chat_completion_api_format(
109
127
  self,
110
128
  chunk_preprocessor: Callable[[str | lf.Modality], Any] | None = None,
111
129
  **kwargs
112
130
  ) -> dict[str, Any]:
113
131
  """Returns an OpenAI format message."""
114
- return OpenAIMessageConverter(
132
+ return OpenAIChatCompletionAPIMessageConverter(
115
133
  chunk_preprocessor=chunk_preprocessor, **kwargs
116
134
  ).to_value(self)
117
135
 
118
136
 
119
137
  @classmethod
120
- def _from_openai_format(
138
+ def _from_openai_chat_completion_api_format(
121
139
  cls,
122
140
  openai_message: dict[str, Any],
123
141
  **kwargs
124
142
  ) -> lf.Message:
125
143
  """Creates a Langfun message from the OpenAI format message."""
126
144
  del cls
127
- return OpenAIMessageConverter(**kwargs).from_value(openai_message)
145
+ return OpenAIChatCompletionAPIMessageConverter(
146
+ **kwargs
147
+ ).from_value(openai_message)
128
148
 
129
149
  # Set shortcut methods in lf.Message.
130
- lf.Message.as_openai_format = _as_openai_format
131
- lf.Message.from_openai_format = _from_openai_format
150
+ lf.Message.as_openai_chat_completion_api_format = (
151
+ _as_openai_chat_completion_api_format
152
+ )
153
+
154
+ lf.Message.from_openai_chat_completion_api_format = (
155
+ _from_openai_chat_completion_api_format
156
+ )
157
+
158
+
159
+ #
160
+ # OpenAI Responses API message converter.
161
+ #
162
+
163
+
164
+ class OpenAIResponsesAPIMessageConverter(
165
+ OpenAIChatCompletionAPIMessageConverter
166
+ ):
167
+ """Converter for OpenAI Responses API.
168
+
169
+ This converter translates `lf.Message` objects into the JSON format
170
+ required by the OpenAI Responses API
171
+ (https://platform.openai.com/docs/api-reference/responses/create),
172
+ which is used for human-in-the-loop rating, and vice versa.
173
+ It extends `OpenAIChatCompletionAPIMessageConverter` but uses different
174
+ type names for content chunks (e.g., 'input_text', 'output_image').
175
+ """
176
+
177
+ FORMAT_ID = 'openai_responses_api'
178
+
179
+ def to_value(self, message: lf.Message) -> dict[str, Any]:
180
+ """Converts a Langfun message to OpenAI API."""
181
+ message_json = super().to_value(message)
182
+ message_json['type'] = 'message'
183
+ return message_json
184
+
185
+ def chunk_to_json(
186
+ self,
187
+ message_cls: type[lf.Message],
188
+ chunk: str | lf.Modality
189
+ ) -> dict[str, Any]:
190
+ """Converts a Langfun chunk to OpenAI chunk."""
191
+ source = 'output' if issubclass(message_cls, lf.AIMessage) else 'input'
192
+
193
+ if isinstance(chunk, str):
194
+ return dict(type=f'{source}_text', text=chunk)
195
+ elif isinstance(chunk, lf_modalities.Image):
196
+ return dict(
197
+ type=f'{source}_image', image_url=chunk.embeddable_uri
198
+ )
199
+ # TODO(daiyip): Support audio_input.
200
+ else:
201
+ raise ValueError(f'Unsupported content type: {chunk!r}.')
202
+
203
+ def json_to_chunk(self, json: dict[str, Any]) -> str | lf.Modality:
204
+ """Returns a Langfun chunk from OpenAI chunk JSON."""
205
+ t = self._safe_read(json, 'type')
206
+ if t in ('input_text', 'output_text'):
207
+ return self._safe_read(json, 'text')
208
+ elif t in ('input_image', 'output_image'):
209
+ return lf_modalities.Image.from_uri(self._safe_read(json, 'image_url'))
210
+ else:
211
+ raise ValueError(f'Unsupported content type: {json!r}.')
212
+
213
+
214
+ def _as_openai_responses_api_format(
215
+ self,
216
+ chunk_preprocessor: Callable[[str | lf.Modality], Any] | None = None,
217
+ **kwargs
218
+ ) -> dict[str, Any]:
219
+ """Returns an OpenAI format message."""
220
+ return OpenAIResponsesAPIMessageConverter(
221
+ chunk_preprocessor=chunk_preprocessor, **kwargs
222
+ ).to_value(self)
223
+
224
+
225
+ @classmethod
226
+ def _from_openai_responses_api_format(
227
+ cls,
228
+ openai_message: dict[str, Any],
229
+ **kwargs
230
+ ) -> lf.Message:
231
+ """Creates a Langfun message from the OpenAI format message."""
232
+ del cls
233
+ return OpenAIResponsesAPIMessageConverter(
234
+ **kwargs
235
+ ).from_value(openai_message)
236
+
237
+
238
+ # Set shortcut methods in lf.Message.
239
+ lf.Message.as_openai_responses_api_format = (
240
+ _as_openai_responses_api_format
241
+ )
242
+
243
+ lf.Message.from_openai_responses_api_format = (
244
+ _from_openai_responses_api_format
245
+ )