langfun 0.1.2.dev202510230805__py3-none-any.whl → 0.1.2.dev202511160804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (146) hide show
  1. langfun/core/__init__.py +1 -0
  2. langfun/core/agentic/action.py +107 -12
  3. langfun/core/agentic/action_eval.py +9 -2
  4. langfun/core/agentic/action_test.py +25 -0
  5. langfun/core/async_support.py +32 -3
  6. langfun/core/coding/python/correction.py +19 -9
  7. langfun/core/coding/python/execution.py +14 -12
  8. langfun/core/coding/python/generation.py +21 -16
  9. langfun/core/coding/python/sandboxing.py +23 -3
  10. langfun/core/component.py +42 -3
  11. langfun/core/concurrent.py +70 -6
  12. langfun/core/concurrent_test.py +1 -0
  13. langfun/core/console.py +1 -1
  14. langfun/core/data/conversion/anthropic.py +12 -3
  15. langfun/core/data/conversion/anthropic_test.py +8 -6
  16. langfun/core/data/conversion/gemini.py +9 -2
  17. langfun/core/data/conversion/gemini_test.py +12 -9
  18. langfun/core/data/conversion/openai.py +145 -31
  19. langfun/core/data/conversion/openai_test.py +161 -17
  20. langfun/core/eval/base.py +47 -43
  21. langfun/core/eval/base_test.py +4 -4
  22. langfun/core/eval/matching.py +5 -2
  23. langfun/core/eval/patching.py +3 -3
  24. langfun/core/eval/scoring.py +4 -3
  25. langfun/core/eval/v2/__init__.py +1 -0
  26. langfun/core/eval/v2/checkpointing.py +39 -5
  27. langfun/core/eval/v2/checkpointing_test.py +1 -1
  28. langfun/core/eval/v2/eval_test_helper.py +96 -0
  29. langfun/core/eval/v2/evaluation.py +87 -15
  30. langfun/core/eval/v2/evaluation_test.py +9 -3
  31. langfun/core/eval/v2/example.py +45 -39
  32. langfun/core/eval/v2/example_test.py +3 -3
  33. langfun/core/eval/v2/experiment.py +51 -8
  34. langfun/core/eval/v2/metric_values.py +31 -3
  35. langfun/core/eval/v2/metric_values_test.py +32 -0
  36. langfun/core/eval/v2/metrics.py +157 -44
  37. langfun/core/eval/v2/metrics_test.py +39 -18
  38. langfun/core/eval/v2/progress.py +30 -1
  39. langfun/core/eval/v2/progress_test.py +27 -0
  40. langfun/core/eval/v2/progress_tracking_test.py +3 -0
  41. langfun/core/eval/v2/reporting.py +90 -71
  42. langfun/core/eval/v2/reporting_test.py +20 -6
  43. langfun/core/eval/v2/runners/__init__.py +26 -0
  44. langfun/core/eval/v2/{runners.py → runners/base.py} +22 -124
  45. langfun/core/eval/v2/runners/debug.py +40 -0
  46. langfun/core/eval/v2/runners/debug_test.py +79 -0
  47. langfun/core/eval/v2/runners/parallel.py +100 -0
  48. langfun/core/eval/v2/runners/parallel_test.py +98 -0
  49. langfun/core/eval/v2/runners/sequential.py +47 -0
  50. langfun/core/eval/v2/runners/sequential_test.py +175 -0
  51. langfun/core/langfunc.py +45 -130
  52. langfun/core/langfunc_test.py +6 -4
  53. langfun/core/language_model.py +103 -16
  54. langfun/core/language_model_test.py +9 -3
  55. langfun/core/llms/__init__.py +7 -1
  56. langfun/core/llms/anthropic.py +157 -2
  57. langfun/core/llms/azure_openai.py +29 -17
  58. langfun/core/llms/cache/base.py +25 -3
  59. langfun/core/llms/cache/in_memory.py +48 -7
  60. langfun/core/llms/cache/in_memory_test.py +14 -4
  61. langfun/core/llms/compositional.py +25 -1
  62. langfun/core/llms/deepseek.py +30 -2
  63. langfun/core/llms/fake.py +32 -1
  64. langfun/core/llms/gemini.py +14 -9
  65. langfun/core/llms/google_genai.py +29 -1
  66. langfun/core/llms/groq.py +28 -3
  67. langfun/core/llms/llama_cpp.py +23 -4
  68. langfun/core/llms/openai.py +36 -3
  69. langfun/core/llms/openai_compatible.py +148 -27
  70. langfun/core/llms/openai_compatible_test.py +207 -20
  71. langfun/core/llms/openai_test.py +0 -2
  72. langfun/core/llms/rest.py +12 -1
  73. langfun/core/llms/vertexai.py +51 -8
  74. langfun/core/logging.py +1 -1
  75. langfun/core/mcp/client.py +77 -22
  76. langfun/core/mcp/client_test.py +8 -35
  77. langfun/core/mcp/session.py +94 -29
  78. langfun/core/mcp/session_test.py +54 -0
  79. langfun/core/mcp/tool.py +151 -22
  80. langfun/core/mcp/tool_test.py +197 -0
  81. langfun/core/memory.py +1 -0
  82. langfun/core/message.py +160 -55
  83. langfun/core/message_test.py +65 -81
  84. langfun/core/modalities/__init__.py +8 -0
  85. langfun/core/modalities/audio.py +21 -1
  86. langfun/core/modalities/image.py +19 -1
  87. langfun/core/modalities/mime.py +62 -3
  88. langfun/core/modalities/pdf.py +19 -1
  89. langfun/core/modalities/video.py +21 -1
  90. langfun/core/modality.py +167 -29
  91. langfun/core/modality_test.py +42 -12
  92. langfun/core/natural_language.py +1 -1
  93. langfun/core/sampling.py +4 -4
  94. langfun/core/sampling_test.py +20 -4
  95. langfun/core/structured/__init__.py +2 -24
  96. langfun/core/structured/completion.py +34 -44
  97. langfun/core/structured/completion_test.py +23 -43
  98. langfun/core/structured/description.py +54 -50
  99. langfun/core/structured/function_generation.py +29 -12
  100. langfun/core/structured/mapping.py +81 -37
  101. langfun/core/structured/parsing.py +95 -79
  102. langfun/core/structured/parsing_test.py +0 -3
  103. langfun/core/structured/querying.py +215 -142
  104. langfun/core/structured/querying_test.py +65 -29
  105. langfun/core/structured/schema/__init__.py +48 -0
  106. langfun/core/structured/schema/base.py +664 -0
  107. langfun/core/structured/schema/base_test.py +531 -0
  108. langfun/core/structured/schema/json.py +174 -0
  109. langfun/core/structured/schema/json_test.py +121 -0
  110. langfun/core/structured/schema/python.py +316 -0
  111. langfun/core/structured/schema/python_test.py +410 -0
  112. langfun/core/structured/schema_generation.py +33 -14
  113. langfun/core/structured/scoring.py +47 -36
  114. langfun/core/structured/tokenization.py +26 -11
  115. langfun/core/subscription.py +2 -2
  116. langfun/core/template.py +174 -49
  117. langfun/core/template_test.py +123 -17
  118. langfun/env/__init__.py +8 -2
  119. langfun/env/base_environment.py +320 -128
  120. langfun/env/base_environment_test.py +473 -0
  121. langfun/env/base_feature.py +92 -15
  122. langfun/env/base_feature_test.py +228 -0
  123. langfun/env/base_sandbox.py +84 -361
  124. langfun/env/base_sandbox_test.py +1235 -0
  125. langfun/env/event_handlers/__init__.py +1 -1
  126. langfun/env/event_handlers/chain.py +233 -0
  127. langfun/env/event_handlers/chain_test.py +253 -0
  128. langfun/env/event_handlers/event_logger.py +95 -98
  129. langfun/env/event_handlers/event_logger_test.py +21 -21
  130. langfun/env/event_handlers/metric_writer.py +225 -140
  131. langfun/env/event_handlers/metric_writer_test.py +23 -6
  132. langfun/env/interface.py +854 -40
  133. langfun/env/interface_test.py +112 -2
  134. langfun/env/load_balancers_test.py +23 -2
  135. langfun/env/test_utils.py +126 -84
  136. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511160804.dist-info}/METADATA +1 -1
  137. langfun-0.1.2.dev202511160804.dist-info/RECORD +211 -0
  138. langfun/core/eval/v2/runners_test.py +0 -343
  139. langfun/core/structured/schema.py +0 -987
  140. langfun/core/structured/schema_test.py +0 -982
  141. langfun/env/base_test.py +0 -1481
  142. langfun/env/event_handlers/base.py +0 -350
  143. langfun-0.1.2.dev202510230805.dist-info/RECORD +0 -195
  144. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511160804.dist-info}/WHEEL +0 -0
  145. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511160804.dist-info}/licenses/LICENSE +0 -0
  146. {langfun-0.1.2.dev202510230805.dist-info → langfun-0.1.2.dev202511160804.dist-info}/top_level.txt +0 -0
langfun/core/modality.py CHANGED
@@ -14,40 +14,63 @@
14
14
  """Interface for modality (e.g. Image, Video, etc.)."""
15
15
 
16
16
  import abc
17
+ import contextlib
17
18
  import functools
18
19
  import hashlib
19
- from typing import Any, ContextManager
20
+ import re
21
+ from typing import Any, ContextManager, Iterator
20
22
  from langfun.core import component
21
23
  import pyglove as pg
22
24
 
23
25
 
24
- _TLS_MODALITY_AS_REF = '__format_modality_as_ref__'
26
+ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
27
+ """Base class for representing non-text content in prompts.
25
28
 
29
+ `lf.Modality` is the base class for multimodal objects such as `lf.Image`,
30
+ `lf.Audio`, and `lf.Video`. It allows these non-text inputs to be
31
+ seamlessly embedded within text prompts for processing by multimodal
32
+ language models.
26
33
 
27
- def format_modality_as_ref(enabled: bool = True) -> ContextManager[None]:
28
- """A context manager that formats modality objects as references."""
29
- return pg.object_utils.thread_local_value_scope(
30
- _TLS_MODALITY_AS_REF, enabled, False
31
- )
34
+ When a `Modality` object is rendered within an `lf.Template`, it is
35
+ replaced by a text marker (e.g., `<<[[image:b10a8db1]]>>`), and the
36
+ modality object itself is stored in the `referred_modalities` field of
37
+ the resulting `lf.Message`. This allows language models to associate
38
+ the placeholder with its content during processing.
32
39
 
40
+ **Example:**
33
41
 
34
- class Modality(component.Component, pg.views.HtmlTreeView.Extension):
35
- """Base class for multimodal object."""
42
+ ```python
43
+ import langfun as lf
44
+
45
+ image = lf.Image.from_path('/path/to/image.png')
46
+ prompt = lf.Template('What is in this image? {{image}}', image=image)
47
+
48
+ message = prompt.render()
49
+ print(message.text)
50
+ # Output: What is in this image? <<[[image:b10a8db1]]>>
51
+
52
+ print(message.modalities())
53
+ # Output: [<Image object>]
54
+ ```
55
+ """
36
56
 
37
57
  REF_START = '<<[['
38
58
  REF_END = ']]>>'
39
59
 
40
60
  def _on_bound(self):
41
61
  super()._on_bound()
42
- # Invalidate cached hash if modality member is changed.
62
+ # Invalidate cached hash and id if modality member is changed.
43
63
  self.__dict__.pop('hash', None)
64
+ self.__dict__.pop('id', None)
44
65
 
45
66
  def format(self, *args, **kwargs) -> str:
46
- if self.referred_name is None or not pg.object_utils.thread_local_get(
47
- _TLS_MODALITY_AS_REF, False
48
- ):
67
+ if not pg.object_utils.thread_local_get(_TLS_MODALITY_AS_REF, False):
49
68
  return super().format(*args, **kwargs)
50
- return Modality.text_marker(self.referred_name)
69
+
70
+ capture_scope = get_modality_capture_context()
71
+ if capture_scope is not None:
72
+ capture_scope.capture(self)
73
+ return Modality.text_marker(self.id)
51
74
 
52
75
  def __str_kwargs__(self) -> dict[str, Any]:
53
76
  # For modality objects, we don't want to use markdown format when they
@@ -70,14 +93,11 @@ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
70
93
  """Returns a marker in the text for this object."""
71
94
  return Modality.REF_START + var_name + Modality.REF_END
72
95
 
73
- @property
74
- def referred_name(self) -> str | None:
96
+ @functools.cached_property
97
+ def id(self) -> str | None:
75
98
  """Returns the referred name of this object in its template."""
76
- if not self.sym_path:
77
- return None
78
- # Strip the metadata prefix under message.
79
- path = str(self.sym_path)
80
- return path[9:] if path.startswith('metadata.') else path
99
+ modality_type = _camel_to_snake(self.__class__.__name__)
100
+ return f'{modality_type}:{self.hash}'
81
101
 
82
102
  @classmethod
83
103
  def from_value(cls, value: pg.Symbolic) -> dict[str, 'Modality']:
@@ -86,7 +106,7 @@ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
86
106
  def _visit(k, v, p):
87
107
  del k, p
88
108
  if isinstance(v, Modality):
89
- modalities[v.referred_name] = v
109
+ modalities[v.id] = v
90
110
  return pg.TraverseAction.CONTINUE
91
111
  return pg.TraverseAction.ENTER
92
112
 
@@ -95,14 +115,47 @@ class Modality(component.Component, pg.views.HtmlTreeView.Extension):
95
115
 
96
116
 
97
117
  class ModalityRef(pg.Object, pg.typing.CustomTyping):
98
- """References of modality objects in a symbolic tree.
118
+ """Lightweight placeholder for a `lf.Modality` object in a symbolic tree.
99
119
 
100
- `ModalityRef` was introduced to placehold modality objects in a symbolic
101
- tree, to prevent message from being chunked in the middle of a Python
102
- structure.
120
+ `ModalityRef` acts as a reference to a `Modality` object (like `lf.Image`
121
+ or `lf.Audio`) within a structured object hierarchy (e.g., a `pg.Object`).
122
+ Instead of embedding potentially large modality data directly, `ModalityRef`
123
+ stores only the ID of the modality object.
124
+
125
+ This is useful in scenarios where structured objects are serialized or
126
+ manipulated, and it's more efficient to refer to modalities by ID rather
127
+ than copying their content. The `lf.ModalityRef.placehold()` class method
128
+ can be used to replace `Modality` instances in a symbolic object with
129
+ `ModalityRef` placeholders, while `lf.ModalityRef.restore()` can reinstate
130
+ the original `Modality` objects using a lookup table.
131
+
132
+ **Example:**
133
+
134
+ ```python
135
+ import langfun as lf
136
+ import pyglove as pg
137
+
138
+ class ImagePair(pg.Object):
139
+ image1: lf.Image
140
+ image2: lf.Image
141
+
142
+ pair = ImagePair(
143
+ image1=lf.Image(content=b'abc'), image2=lf.Image(content=b'def')
144
+ )
145
+ modalities = lf.Modality.from_value(pair)
146
+
147
+ # Replace Image objects with ModalityRef placeholders
148
+ pair_with_refs = lf.ModalityRef.placehold(pair)
149
+ print(pair_with_refs.image1)
150
+ # Output: ModalityRef(id='image:d81e5a68')
151
+
152
+ # Restore Image objects from ModalityRef placeholders
153
+ pair_restored = lf.ModalityRef.restore(pair_with_refs, modalities)
154
+ assert pair_restored.image1.content == b'abc'
155
+ ```
103
156
  """
104
157
 
105
- name: str
158
+ id: str
106
159
 
107
160
  def custom_apply(
108
161
  self, path: pg.KeyPath, value_spec: pg.ValueSpec, *args, **kwargs
@@ -122,12 +175,97 @@ class ModalityRef(pg.Object, pg.typing.CustomTyping):
122
175
  """
123
176
 
124
177
  def _placehold(k, v, p):
125
- del p
178
+ del k, p
126
179
  if isinstance(v, Modality):
127
- return ModalityRef(name=value.sym_path + k)
180
+ return ModalityRef(id=v.id)
128
181
  return v
129
182
  return value.clone().rebind(_placehold, raise_on_no_change=False)
130
183
 
184
+ @classmethod
185
+ def restore(cls, value: pg.Symbolic, modalities: dict[str, Modality]) -> Any:
186
+ """Returns a copy of value by replacing refs with modality objects."""
187
+ def _restore(k, v, p):
188
+ del k, p
189
+ if isinstance(v, ModalityRef):
190
+ modality_object = modalities.get(v.id)
191
+ if modality_object is None:
192
+ raise ValueError(
193
+ f'Modality {v.id} not found in modalities {modalities.keys()}'
194
+ )
195
+ return modality_object
196
+ return v
197
+ return value.rebind(_restore, raise_on_no_change=False)
198
+
131
199
 
132
200
  class ModalityError(RuntimeError): # pylint: disable=g-bad-exception-name
133
201
  """Exception raised when modality is not supported."""
202
+
203
+
204
+ #
205
+ # Context managers to deal with modality objects.
206
+ #
207
+
208
+
209
+ _TLS_MODALITY_CAPTURE_SCOPE = '__modality_capture_scope__'
210
+ _TLS_MODALITY_AS_REF = '__format_modality_as_ref__'
211
+
212
+
213
+ def format_modality_as_ref(enabled: bool = True) -> ContextManager[None]:
214
+ """A context manager that formats modality objects as references."""
215
+ return pg.object_utils.thread_local_value_scope(
216
+ _TLS_MODALITY_AS_REF, enabled, False
217
+ )
218
+
219
+
220
+ class _ModalityCaptureContext:
221
+ """A context to capture modality objects when being rendered."""
222
+
223
+ def __init__(self):
224
+ self._references: dict[str, pg.Ref[Modality]] = {}
225
+
226
+ def capture(self, modality: Modality) -> None:
227
+ """Captures the modality object."""
228
+ self._references[modality.id] = pg.Ref(modality)
229
+
230
+ @property
231
+ def references(self) -> dict[str, pg.Ref[Modality]]:
232
+ """Returns the modality references captured in this context."""
233
+ return self._references
234
+
235
+
236
+ @contextlib.contextmanager
237
+ def capture_rendered_modalities() -> Iterator[dict[str, pg.Ref[Modality]]]:
238
+ """Capture modality objects whose references is being rendered.
239
+
240
+ Example:
241
+ ```
242
+ image = lf.Image.from_url(...)
243
+ with lf.modality.capture_rendered_modalities() as rendered_modalities:
244
+ with lf.modality.format_modality_as_ref():
245
+ print(f'Hello {image}')
246
+ self.assertEqual(rendered_modalities, {'image:<hash>': pg.Ref(image)})
247
+ ```
248
+ """
249
+ context = get_modality_capture_context()
250
+ top_level = context is None
251
+ if top_level:
252
+ context = _ModalityCaptureContext()
253
+ pg.object_utils.thread_local_set(_TLS_MODALITY_CAPTURE_SCOPE, context)
254
+
255
+ try:
256
+ yield context.references # pylint: disable=attribute-error
257
+ finally:
258
+ if top_level:
259
+ pg.object_utils.thread_local_del(_TLS_MODALITY_CAPTURE_SCOPE)
260
+
261
+
262
+ def get_modality_capture_context() -> _ModalityCaptureContext | None:
263
+ """Returns the current modality capture context."""
264
+ return pg.object_utils.thread_local_get(_TLS_MODALITY_CAPTURE_SCOPE, None)
265
+
266
+
267
+ def _camel_to_snake(name: str) -> str:
268
+ """Converts a camelCase name to snake_case."""
269
+ return re.sub(
270
+ pattern=r'([A-Z]+)', repl=r'_\1', string=name
271
+ ).lower().lstrip('_')
@@ -29,34 +29,64 @@ class ModalityTest(unittest.TestCase):
29
29
 
30
30
  def test_basic(self):
31
31
  v = CustomModality('a')
32
- self.assertIsNone(v.referred_name)
32
+ self.assertEqual(v.id, 'custom_modality:0cc175b9')
33
33
  self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
34
34
  self.assertEqual(v.hash, '0cc175b9')
35
35
 
36
36
  _ = pg.Dict(metadata=pg.Dict(x=pg.Dict(metadata=pg.Dict(y=v))))
37
- self.assertEqual(v.referred_name, 'x.metadata.y')
37
+ self.assertEqual(v.id, 'custom_modality:0cc175b9')
38
38
  self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
39
39
  with modality.format_modality_as_ref():
40
- self.assertEqual(str(v), '<<[[x.metadata.y]]>>')
40
+ self.assertEqual(str(v), '<<[[custom_modality:0cc175b9]]>>')
41
+
42
+ def test_capture_rendered_modalities(self):
43
+ x = CustomModality('a')
44
+ y = CustomModality('b')
45
+ z = CustomModality('b')
46
+
47
+ with modality.capture_rendered_modalities() as rendered_modalities:
48
+ with modality.format_modality_as_ref():
49
+ self.assertEqual(
50
+ f'Hello {x} {y} {z}',
51
+ (
52
+ 'Hello <<[[custom_modality:0cc175b9]]>> '
53
+ '<<[[custom_modality:92eb5ffe]]>> '
54
+ '<<[[custom_modality:92eb5ffe]]>>'
55
+ )
56
+ )
57
+ self.assertEqual(len(rendered_modalities), 2)
58
+ self.assertIs(rendered_modalities['custom_modality:0cc175b9'].value, x)
59
+ # y and z share the same content will be treated as the same object.
60
+ self.assertIs(rendered_modalities['custom_modality:92eb5ffe'].value, z)
41
61
 
42
62
 
43
63
  class ModalityRefTest(unittest.TestCase):
44
64
 
45
- def test_placehold(self):
65
+ def test_placehold_and_restore(self):
46
66
  class A(pg.Object):
47
67
  x: Any
48
68
  y: Any
49
69
 
50
- a = A(x=dict(z=CustomModality('a')), y=CustomModality('b'))
70
+ image_a = CustomModality('a')
71
+ image_b = CustomModality('b')
72
+ a = A(x=dict(z=image_a), y=image_b)
73
+ a_placehold = modality.ModalityRef.placehold(a)
51
74
  self.assertEqual(
52
- modality.ModalityRef.placehold(a),
53
- A(x=dict(z=modality.ModalityRef('x.z')), y=modality.ModalityRef('y')),
75
+ a_placehold,
76
+ A(x=dict(z=modality.ModalityRef(image_a.id)),
77
+ y=modality.ModalityRef(image_b.id)),
78
+ )
79
+ a_restore = modality.ModalityRef.restore(
80
+ a_placehold.clone(),
81
+ {image_a.id: image_a, image_b.id: image_b},
54
82
  )
83
+ self.assertTrue(pg.eq(a_restore, a))
55
84
  self.assertEqual(
56
85
  modality.ModalityRef.placehold(a.x),
57
- # The prefix 'x' of referred name is preserved.
58
- dict(z=modality.ModalityRef('x.z')),
86
+ dict(z=modality.ModalityRef(image_a.id)),
59
87
  )
88
+ with self.assertRaisesRegex(ValueError, 'Modality .* not found'):
89
+ modality.ModalityRef.restore(a_placehold, {image_a.id: image_a})
60
90
 
61
91
  def test_from_value(self):
62
92
  class A(pg.Object):
@@ -68,8 +98,8 @@ class ModalityRefTest(unittest.TestCase):
68
98
  pg.eq(
69
99
  modality.Modality.from_value(a),
70
100
  {
71
- 'x.z': CustomModality('a'),
72
- 'y': CustomModality('b'),
101
+ 'custom_modality:0cc175b9': CustomModality('a'),
102
+ 'custom_modality:92eb5ffe': CustomModality('b'),
73
103
  },
74
104
  )
75
105
  )
@@ -77,7 +107,7 @@ class ModalityRefTest(unittest.TestCase):
77
107
  pg.eq(
78
108
  modality.Modality.from_value(a.x.z),
79
109
  {
80
- 'x.z': CustomModality('a'),
110
+ 'custom_modality:0cc175b9': CustomModality('a'),
81
111
  },
82
112
  )
83
113
  )
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Natural language utilities."""
14
+ """Natural language formatting."""
15
15
 
16
16
  import abc
17
17
  import pyglove as pg
langfun/core/sampling.py CHANGED
@@ -38,10 +38,10 @@ def sweep(
38
38
  Union[message_lib.Message, BaseException, None], # LM output.
39
39
  ],
40
40
  ]:
41
- """Sweeps the input/output of this LangFunc concurrently.
41
+ """Sweeps the input/output of a LangFunc search space concurrently.
42
42
 
43
43
  Args:
44
- lfun: An LangFunc object that contains `pg.oneof` as the search space
44
+ lfun: An LangFunc object that contains `pg.oneof` as the search space
45
45
  for sampling.
46
46
  num_examples: Number of examples to sample.
47
47
  max_workers: Max number of concurrent workers to do sampling.
@@ -84,10 +84,10 @@ def random_sample(
84
84
  Union[message_lib.Message, BaseException, None], # LM output.
85
85
  ],
86
86
  ]:
87
- """Random samples the input/output of this LangFunc concurrently.
87
+ """Random samples the input/output of a LangFunc search space concurrently.
88
88
 
89
89
  Args:
90
- lfun: An LangFunc object that contains `pg.oneof` as the search space
90
+ lfun: An LangFunc object that contains `pg.oneof` as the search space
91
91
  for sampling.
92
92
  num_examples: Number of examples to sample.
93
93
  max_workers: Max number of concurrent workers to do sampling.
@@ -39,8 +39,13 @@ class SamplingTest(unittest.TestCase):
39
39
  l = LangFunc('Compute {{x}} and {{y}}', x=pg.oneof([1, 2]))
40
40
  with component.context(lm=ExcitedEchoer()):
41
41
  samples = list(sampling.sweep(l, y=pg.oneof([3, 4])))
42
- samples = sorted(samples, key=lambda x: (x[0].x, x[0].y))
43
-
42
+ samples = sorted(
43
+ samples,
44
+ key=lambda x: (
45
+ x[0].__template_input__.x,
46
+ x[0].__template_input__.y
47
+ )
48
+ )
44
49
  self.assertEqual(
45
50
  samples,
46
51
  [
@@ -57,7 +62,12 @@ class SamplingTest(unittest.TestCase):
57
62
  samples = list(
58
63
  sampling.random_sample(l, y=pg.oneof([2, 4]), num_examples=3, seed=1)
59
64
  )
60
- samples = sorted(samples, key=lambda x: (x[0].x, x[0].y))
65
+ samples = sorted(
66
+ samples, key=lambda x: (
67
+ x[0].__template_input__.x,
68
+ x[0].__template_input__.y
69
+ )
70
+ )
61
71
 
62
72
  self.assertEqual(
63
73
  samples,
@@ -97,7 +107,13 @@ class SamplingTest(unittest.TestCase):
97
107
  silence_on_errors=(AttributeError,),
98
108
  ignore_examples_with_errors=False))
99
109
 
100
- samples = sorted(samples, key=lambda x: (x[0].x, x[0].y))
110
+ samples = sorted(
111
+ samples,
112
+ key=lambda x: (
113
+ x[0].__template_input__.x,
114
+ x[0].__template_input__.y
115
+ )
116
+ )
101
117
  self.assertEqual(
102
118
  [x[0] for x in samples],
103
119
  [
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The Langfun Authors
1
+ # Copyright 2025 The Langfun Authors
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -16,29 +16,7 @@
16
16
  # pylint: disable=g-bad-import-order
17
17
  # pylint: disable=g-importing-member
18
18
 
19
- from langfun.core.structured.schema import include_method_in_prompt
20
-
21
- from langfun.core.structured.schema import Missing
22
- from langfun.core.structured.schema import MISSING
23
- from langfun.core.structured.schema import Unknown
24
- from langfun.core.structured.schema import UNKNOWN
25
-
26
- from langfun.core.structured.schema import Schema
27
- from langfun.core.structured.schema import SchemaProtocol
28
- from langfun.core.structured.schema import schema_spec
29
-
30
- from langfun.core.structured.schema import SchemaError
31
- from langfun.core.structured.schema import JsonError
32
-
33
- from langfun.core.structured.schema import class_dependencies
34
- from langfun.core.structured.schema import class_definition
35
- from langfun.core.structured.schema import class_definitions
36
- from langfun.core.structured.schema import annotation
37
- from langfun.core.structured.schema import structure_from_python
38
-
39
- from langfun.core.structured.schema import schema_repr
40
- from langfun.core.structured.schema import source_form
41
- from langfun.core.structured.schema import value_repr
19
+ from langfun.core.structured.schema import *
42
20
 
43
21
  from langfun.core.structured.schema_generation import generate_class
44
22
  from langfun.core.structured.schema_generation import classgen_example
@@ -116,15 +116,10 @@ class _CompleteStructure(mapping.Mapping):
116
116
  )
117
117
 
118
118
  def postprocess_result(self, result: Any) -> Any:
119
- """Postprocess result."""
119
+ """Postprocesses result."""
120
120
  # Try restore modality objects from the input value to output value.
121
- modalities = self.modalities(self.input)
122
- if modalities:
123
- # Remove the `input` prefix for all entries.
124
- modalities = pg.object_utils.flatten(
125
- pg.object_utils.canonicalize(modalities)['input']
126
- )
127
- result.rebind(modalities)
121
+ if modalities := self.modalities(self.input):
122
+ result = lf.ModalityRef.restore(result, modalities)
128
123
  return result
129
124
 
130
125
  def globals(self):
@@ -156,7 +151,7 @@ class _CompleteStructure(mapping.Mapping):
156
151
  #
157
152
 
158
153
  def has_modality_refs(self, value: Any) -> bool:
159
- """Returns true if the value has modalities."""
154
+ """Returns True if the value has modalities."""
160
155
  return not isinstance(value, lf.Modality) and pg.contains(
161
156
  value, type=lf.Modality
162
157
  )
@@ -186,41 +181,36 @@ def complete(
186
181
  returns_message: bool = False,
187
182
  **kwargs,
188
183
  ) -> Any:
189
- """Complete a symbolic value by filling its missing fields.
190
-
191
- Examples:
192
-
193
- ```
194
- class FlightDuration:
195
- hours: int
196
- minutes: int
197
-
198
- class Flight(pg.Object):
199
- airline: str
200
- flight_number: str
201
- departure_airport_code: str
202
- arrival_airport_code: str
203
- departure_time: str
204
- arrival_time: str
205
- duration: FlightDuration
206
- stops: int
207
- price: float
208
-
209
- prompt = '''
210
- Information about flight UA2631.
211
- '''
212
-
213
- r = lf.query(prompt, Flight)
214
- assert isinstance(r, Flight)
215
- assert r.airline == 'United Airlines'
216
- assert r.departure_airport_code == 'SFO'
217
- assert r.duration.hour = 7
218
- ```
184
+ """Completes a symbolic value by filling its missing fields using an LLM.
185
+
186
+ `lf.complete` is used to fill in missing information in structured
187
+ data. It takes a partially defined `pg.Object` instance where some fields
188
+ are marked as `lf.MISSING`, and uses a language model to infer and
189
+ populate those fields based on the provided values.
190
+
191
+ **Example:**
192
+
193
+ ```python
194
+ import langfun as lf
195
+ import pyglove as pg
196
+
197
+ class Country(pg.Object):
198
+ name: str
199
+ capital: str = lf.MISSING
200
+ population: int = lf.MISSING
201
+
202
+ # Filling missing fields of Country(name='France')
203
+ country = lf.complete(Country(name='France'), lm=lf.llms.Gemini25Flash())
204
+ print(country)
205
+ # Output: Country(name='France', capital='Paris', population=67000000)
206
+ ```
219
207
 
220
208
  Args:
221
- input_value: A symbolic value that may contain missing values.
222
- default: The default value if parsing failed. If not specified, error will
223
- be raised.
209
+ input_value: A symbolic value that may contain missing values marked
210
+ by `lf.MISSING`.
211
+ default: The default value to return if parsing fails. If
212
+ `lf.RAISE_IF_HAS_ERROR` is used (default), an error will be raised
213
+ instead.
224
214
  lm: The language model to use. If not specified, the language model from
225
215
  `lf.context` context manager will be used.
226
216
  examples: An optional list of fewshot examples for helping parsing. If None,
@@ -236,10 +226,10 @@ def complete(
236
226
  returns_message: If True, returns `lf.Message` as the output, instead of
237
227
  returning the structured `message.result`.
238
228
  **kwargs: Keyword arguments passed to the
239
- `lf.structured.NaturalLanguageToStructureed` transform.
229
+ `lf.structured.Mapping` transform.
240
230
 
241
231
  Returns:
242
- The result based on the schema.
232
+ The input object with missing fields completed by LLM.
243
233
  """
244
234
  t = _CompleteStructure(
245
235
  input=schema_lib.mark_missing(input_value),