langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512040805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +7 -1
  3. langfun/core/agentic/__init__.py +8 -1
  4. langfun/core/agentic/action.py +740 -112
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +189 -24
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +11 -2
  20. langfun/core/data/conversion/gemini_test.py +48 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +48 -44
  24. langfun/core/eval/base_test.py +5 -5
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +2 -0
  29. langfun/core/eval/v2/checkpointing.py +76 -7
  30. langfun/core/eval/v2/checkpointing_test.py +9 -2
  31. langfun/core/eval/v2/config_saver.py +37 -0
  32. langfun/core/eval/v2/config_saver_test.py +36 -0
  33. langfun/core/eval/v2/eval_test_helper.py +104 -3
  34. langfun/core/eval/v2/evaluation.py +92 -17
  35. langfun/core/eval/v2/evaluation_test.py +9 -3
  36. langfun/core/eval/v2/example.py +50 -40
  37. langfun/core/eval/v2/example_test.py +16 -8
  38. langfun/core/eval/v2/experiment.py +84 -15
  39. langfun/core/eval/v2/experiment_test.py +19 -0
  40. langfun/core/eval/v2/metric_values.py +31 -3
  41. langfun/core/eval/v2/metric_values_test.py +32 -0
  42. langfun/core/eval/v2/metrics.py +157 -44
  43. langfun/core/eval/v2/metrics_test.py +39 -18
  44. langfun/core/eval/v2/progress.py +31 -1
  45. langfun/core/eval/v2/progress_test.py +27 -0
  46. langfun/core/eval/v2/progress_tracking.py +13 -5
  47. langfun/core/eval/v2/progress_tracking_test.py +9 -1
  48. langfun/core/eval/v2/reporting.py +90 -71
  49. langfun/core/eval/v2/reporting_test.py +24 -6
  50. langfun/core/eval/v2/runners/__init__.py +30 -0
  51. langfun/core/eval/v2/{runners.py → runners/base.py} +72 -180
  52. langfun/core/eval/v2/runners/beam.py +354 -0
  53. langfun/core/eval/v2/runners/beam_test.py +153 -0
  54. langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
  55. langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
  56. langfun/core/eval/v2/runners/debug.py +40 -0
  57. langfun/core/eval/v2/runners/debug_test.py +76 -0
  58. langfun/core/eval/v2/runners/parallel.py +243 -0
  59. langfun/core/eval/v2/runners/parallel_test.py +182 -0
  60. langfun/core/eval/v2/runners/sequential.py +47 -0
  61. langfun/core/eval/v2/runners/sequential_test.py +169 -0
  62. langfun/core/langfunc.py +45 -130
  63. langfun/core/langfunc_test.py +7 -5
  64. langfun/core/language_model.py +189 -36
  65. langfun/core/language_model_test.py +54 -3
  66. langfun/core/llms/__init__.py +12 -1
  67. langfun/core/llms/anthropic.py +157 -2
  68. langfun/core/llms/azure_openai.py +29 -17
  69. langfun/core/llms/cache/base.py +25 -3
  70. langfun/core/llms/cache/in_memory.py +48 -7
  71. langfun/core/llms/cache/in_memory_test.py +14 -4
  72. langfun/core/llms/compositional.py +25 -1
  73. langfun/core/llms/deepseek.py +30 -2
  74. langfun/core/llms/fake.py +32 -1
  75. langfun/core/llms/gemini.py +64 -12
  76. langfun/core/llms/gemini_test.py +110 -0
  77. langfun/core/llms/google_genai.py +34 -1
  78. langfun/core/llms/groq.py +28 -3
  79. langfun/core/llms/llama_cpp.py +23 -4
  80. langfun/core/llms/openai.py +120 -3
  81. langfun/core/llms/openai_compatible.py +148 -27
  82. langfun/core/llms/openai_compatible_test.py +207 -20
  83. langfun/core/llms/openai_test.py +0 -2
  84. langfun/core/llms/rest.py +16 -1
  85. langfun/core/llms/vertexai.py +58 -8
  86. langfun/core/logging.py +1 -1
  87. langfun/core/mcp/__init__.py +10 -0
  88. langfun/core/mcp/client.py +177 -0
  89. langfun/core/mcp/client_test.py +71 -0
  90. langfun/core/mcp/session.py +241 -0
  91. langfun/core/mcp/session_test.py +54 -0
  92. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  93. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  94. langfun/core/mcp/tool.py +254 -0
  95. langfun/core/mcp/tool_test.py +197 -0
  96. langfun/core/memory.py +1 -0
  97. langfun/core/message.py +160 -55
  98. langfun/core/message_test.py +65 -81
  99. langfun/core/modalities/__init__.py +8 -0
  100. langfun/core/modalities/audio.py +21 -1
  101. langfun/core/modalities/image.py +73 -3
  102. langfun/core/modalities/image_test.py +116 -0
  103. langfun/core/modalities/mime.py +64 -3
  104. langfun/core/modalities/mime_test.py +11 -0
  105. langfun/core/modalities/pdf.py +19 -1
  106. langfun/core/modalities/video.py +21 -1
  107. langfun/core/modality.py +167 -29
  108. langfun/core/modality_test.py +42 -12
  109. langfun/core/natural_language.py +1 -1
  110. langfun/core/sampling.py +4 -4
  111. langfun/core/sampling_test.py +20 -4
  112. langfun/core/structured/__init__.py +2 -24
  113. langfun/core/structured/completion.py +34 -44
  114. langfun/core/structured/completion_test.py +23 -43
  115. langfun/core/structured/description.py +54 -50
  116. langfun/core/structured/function_generation.py +29 -12
  117. langfun/core/structured/mapping.py +81 -37
  118. langfun/core/structured/parsing.py +95 -79
  119. langfun/core/structured/parsing_test.py +0 -3
  120. langfun/core/structured/querying.py +230 -154
  121. langfun/core/structured/querying_test.py +69 -33
  122. langfun/core/structured/schema/__init__.py +49 -0
  123. langfun/core/structured/schema/base.py +664 -0
  124. langfun/core/structured/schema/base_test.py +531 -0
  125. langfun/core/structured/schema/json.py +174 -0
  126. langfun/core/structured/schema/json_test.py +121 -0
  127. langfun/core/structured/schema/python.py +316 -0
  128. langfun/core/structured/schema/python_test.py +410 -0
  129. langfun/core/structured/schema_generation.py +33 -14
  130. langfun/core/structured/scoring.py +47 -36
  131. langfun/core/structured/tokenization.py +26 -11
  132. langfun/core/subscription.py +2 -2
  133. langfun/core/template.py +175 -50
  134. langfun/core/template_test.py +123 -17
  135. langfun/env/__init__.py +43 -0
  136. langfun/env/base_environment.py +827 -0
  137. langfun/env/base_environment_test.py +473 -0
  138. langfun/env/base_feature.py +304 -0
  139. langfun/env/base_feature_test.py +228 -0
  140. langfun/env/base_sandbox.py +842 -0
  141. langfun/env/base_sandbox_test.py +1235 -0
  142. langfun/env/event_handlers/__init__.py +14 -0
  143. langfun/env/event_handlers/chain.py +233 -0
  144. langfun/env/event_handlers/chain_test.py +253 -0
  145. langfun/env/event_handlers/event_logger.py +472 -0
  146. langfun/env/event_handlers/event_logger_test.py +304 -0
  147. langfun/env/event_handlers/metric_writer.py +726 -0
  148. langfun/env/event_handlers/metric_writer_test.py +214 -0
  149. langfun/env/interface.py +1640 -0
  150. langfun/env/interface_test.py +153 -0
  151. langfun/env/load_balancers.py +59 -0
  152. langfun/env/load_balancers_test.py +141 -0
  153. langfun/env/test_utils.py +507 -0
  154. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/METADATA +7 -3
  155. langfun-0.1.2.dev202512040805.dist-info/RECORD +217 -0
  156. langfun/core/eval/v2/runners_test.py +0 -343
  157. langfun/core/structured/schema.py +0 -987
  158. langfun/core/structured/schema_test.py +0 -982
  159. langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
  160. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/WHEEL +0 -0
  161. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/licenses/LICENSE +0 -0
  162. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,664 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Schema and Prompting Protocol for Structured Data."""
15
+
16
+ import abc
17
+ import inspect
18
+ import io
19
+ import textwrap
20
+ import typing
21
+ from typing import Any, ClassVar, Type, Union
22
+ import langfun.core as lf
23
+ import pyglove as pg
24
+
25
+
26
+ def _parse_value_spec(value) -> pg.typing.ValueSpec:
27
+ """Parses a PyGlove ValueSpec equivalent into a ValueSpec.
28
+
29
+ Examples:
30
+ ```
31
+ _parse_value_spec(int) -> pg.typing.Int
32
+ _parse_value_spec(list[int]) -> pg.typing.List(pg.typing.Int)
33
+ _parse_value_spec(dict(a=int, b=str)) -> pg.typing.Dict(
34
+ pg.typing.Int, pg.typing.Str
35
+ )
36
+ ```
37
+ Args:
38
+ value: The value to parse. It can be a PyGlove ValueSpec, a dict with a
39
+ single 'result' key, or a Python type annotation.
40
+
41
+ Returns:
42
+ A PyGlove ValueSpec.
43
+ """
44
+ if isinstance(value, pg.typing.ValueSpec):
45
+ return value
46
+
47
+ if isinstance(value, dict) and len(value) == 1 and 'result' in value:
48
+ value = value['result']
49
+
50
+ def _parse_node(v) -> pg.typing.ValueSpec:
51
+ if isinstance(v, dict):
52
+ return pg.typing.Dict([(k, _parse_node(cv)) for k, cv in v.items()])
53
+ elif isinstance(v, list):
54
+ if len(v) != 1:
55
+ raise ValueError(
56
+ 'Annotation with list must be a list of a single element. '
57
+ f'Encountered: {v}'
58
+ )
59
+ return pg.typing.List(_parse_node(v[0]))
60
+ else:
61
+ spec = pg.typing.ValueSpec.from_annotation(v, auto_typing=True)
62
+ if isinstance(
63
+ spec,
64
+ (
65
+ pg.typing.Any,
66
+ pg.typing.Callable,
67
+ pg.typing.Tuple,
68
+ pg.typing.Type,
69
+ pg.typing.Union,
70
+ ),
71
+ ):
72
+ raise ValueError(f'Unsupported schema specification: {v}')
73
+ return spec
74
+
75
+ return _parse_node(value)
76
+
77
+
78
+ class SchemaError(Exception): # pylint: disable=g-bad-exception-name
79
+ """Schema error."""
80
+
81
+ def __init__(
82
+ self,
83
+ schema: 'Schema',
84
+ value: Any,
85
+ protocol: str,
86
+ cause: Exception
87
+ ):
88
+ self.schema = schema
89
+ self.value = value
90
+ self.protocol = protocol
91
+ self.cause = cause
92
+
93
+ def __str__(self):
94
+ r = io.StringIO()
95
+ r.write(
96
+ pg.colored(
97
+ f'{self.cause.__class__.__name__}: {self.cause}', 'magenta'
98
+ )
99
+ )
100
+
101
+ r.write('\n')
102
+ r.write(pg.colored('Schema:', 'red'))
103
+ r.write('\n\n')
104
+ r.write(textwrap.indent(
105
+ pg.colored(
106
+ schema_repr(self.schema, protocol=self.protocol), 'magenta'
107
+ ),
108
+ ' ' * 2
109
+ ))
110
+ r.write('\n\n')
111
+ r.write(pg.colored('Generated value:', 'red'))
112
+ r.write('\n\n')
113
+ r.write(textwrap.indent(
114
+ pg.colored(value_repr(self.value, protocol=self.protocol), 'magenta'),
115
+ ' ' * 2
116
+ ))
117
+ return r.getvalue()
118
+
119
+
120
+ class Schema(
121
+ lf.NaturalLanguageFormattable,
122
+ pg.Object,
123
+ pg.views.HtmlTreeView.Extension
124
+ ):
125
+ """Schema for structured inputs and outputs.
126
+
127
+ `lf.Schema` provides a unified representation for defining the output schema
128
+ used in Langfun's structured operations like `lf.query`, `lf.parse`,
129
+ `lf.complete`, and `lf.describe`. It acts as an abstraction layer,
130
+ allowing schemas to be defined using Python type annotations, `pg.Object`
131
+ classes, or dictionaries, and then converting them into a format that
132
+ language models can understand.
133
+
134
+ `lf.Schema` can be created from various types using `lf.Schema.from_value`:
135
+ * Built-in types: `int`, `str`, `bool`, `float`
136
+ * Typing constructs: `list`, `dict`, `typing.Union`, `typing.Literal`,
137
+ `typing.Optional`
138
+ * PyGlove classes: `pg.Object` subclasses
139
+
140
+ **1. Creating a Schema:**
141
+
142
+ ```python
143
+ import langfun as lf
144
+ import pyglove as pg
145
+ from typing import Literal, Union
146
+
147
+ # From a basic type
148
+ int_schema = lf.Schema.from_value(int)
149
+
150
+ # From a list type
151
+ list_schema = lf.Schema.from_value(list[int])
152
+
153
+ # From a dictionary
154
+ dict_schema = lf.Schema.from_value(dict(a=int, b=str))
155
+
156
+ # From pg.Object
157
+ class Point(pg.Object):
158
+ x: int
159
+ y: int
160
+ point_schema = lf.Schema.from_value(Point)
161
+
162
+ # From Union or Literal
163
+ union_schema = lf.Schema.from_value(Union[int, str])
164
+ literal_schema = lf.Schema.from_value(Literal['A', 'B'])
165
+ ```
166
+
167
+ **2. Schema Representation:**
168
+ Once created, a schema object can represent itself in different formats,
169
+ such as Python-like syntax or JSON, which is used in prompts to LLMs.
170
+
171
+ ```python
172
+ print(point_schema.repr('python'))
173
+ # Output:
174
+ # class Point:
175
+ # x: int
176
+ # y: int
177
+
178
+ print(dict_schema.repr('json'))
179
+ # Output:
180
+ # {
181
+ # "a": "int",
182
+ # "b": "str"
183
+ # }
184
+ ```
185
+ """
186
+
187
+ spec: pg.typing.Annotated[
188
+ pg.typing.Object(pg.typing.ValueSpec, transform=_parse_value_spec),
189
+ (
190
+ 'A PyGlove ValueSpec object representing the spec for the value '
191
+ 'to be parsed.'
192
+ ),
193
+ ]
194
+
195
+ def schema_repr(self, protocol: str = 'python', **kwargs) -> str:
196
+ """Returns the representation of the schema."""
197
+ return schema_repr(self, protocol=protocol, **kwargs)
198
+
199
+ def value_repr(
200
+ self, value: Any, protocol: str = 'python', **kwargs
201
+ ) -> str:
202
+ """Returns the representation of a structured value."""
203
+ return value_repr(value, schema=self, protocol=protocol, **kwargs)
204
+
205
+ def parse_value(
206
+ self, text: str, protocol: str = 'python', **kwargs
207
+ ) -> Any:
208
+ """Parses a LM generated text into a structured value."""
209
+ value = parse_value(text, schema=self, protocol=protocol, **kwargs)
210
+
211
+ # TODO(daiyip): support autofix for schema error.
212
+ try:
213
+ return self.spec.apply(value)
214
+ except Exception as e:
215
+ raise SchemaError(self, value, protocol, e) # pylint: disable=raise-missing-from
216
+
217
+ def natural_language_format(self) -> str:
218
+ return self.schema_str()
219
+
220
+ def schema_dict(self) -> dict[str, Any]:
221
+ """Returns the dictionary representation of the schema."""
222
+
223
+ def _node(vs: pg.typing.ValueSpec) -> Any:
224
+ if isinstance(vs, pg.typing.PrimitiveType):
225
+ return vs
226
+ elif isinstance(vs, pg.typing.Dict):
227
+ assert vs.schema is not None
228
+ return {str(k): _node(f.value) for k, f in vs.schema.fields.items()}
229
+ elif isinstance(vs, pg.typing.List):
230
+ return [_node(vs.element.value)]
231
+ elif isinstance(vs, pg.typing.Object):
232
+ if issubclass(vs.cls, pg.Object):
233
+ d = {pg.JSONConvertible.TYPE_NAME_KEY: vs.cls.__serialization_key__}
234
+ d.update(
235
+ {
236
+ str(k): _node(f.value)
237
+ for k, f in vs.cls.__schema__.fields.items()
238
+ }
239
+ )
240
+ return d
241
+ raise TypeError(
242
+ 'Unsupported value spec being used as the schema for '
243
+ f'structured data: {vs}.')
244
+
245
+ return {'result': _node(self.spec)}
246
+
247
+ def class_dependencies(
248
+ self,
249
+ include_base_classes: bool = True,
250
+ include_subclasses: bool = True,
251
+ include_generated_subclasses: bool = False) -> list[Type[Any]]:
252
+ """Returns a list of class dependencies for current schema."""
253
+ return class_dependencies(
254
+ self.spec,
255
+ include_base_classes,
256
+ include_subclasses,
257
+ include_generated_subclasses
258
+ )
259
+
260
+ @classmethod
261
+ def from_value(cls, value) -> 'Schema':
262
+ """Creates a schema from an equivalent representation."""
263
+ if isinstance(value, Schema):
264
+ return value
265
+ return cls(_parse_value_spec(value))
266
+
267
+ def _html_tree_view_content(
268
+ self,
269
+ *,
270
+ view: pg.views.HtmlTreeView,
271
+ **kwargs,
272
+ ):
273
+ return pg.Html.element(
274
+ 'div',
275
+ [pg.Html.escape(self.schema_repr(protocol='python'))],
276
+ css_classes=['lf-schema-definition']
277
+ ).add_style(
278
+ """
279
+ .lf-schema-definition {
280
+ color: blue;
281
+ margin: 5px;
282
+ white-space: pre-wrap;
283
+ }
284
+ """
285
+ )
286
+
287
+
288
+ SchemaType = Union[Schema, Type[Any], list[Type[Any]], dict[str, Any]]
289
+
290
+
291
+ def _top_level_object_specs_from_value(value: pg.Symbolic) -> list[Type[Any]]:
292
+ """Returns a list of top level value specs from a symbolic value."""
293
+ top_level_object_specs = []
294
+
295
+ def _collect_top_level_object_specs(k, v, p):
296
+ del k, p
297
+ if isinstance(v, pg.Object):
298
+ top_level_object_specs.append(pg.typing.Object(v.__class__))
299
+ return pg.TraverseAction.CONTINUE
300
+ return pg.TraverseAction.ENTER
301
+
302
+ pg.traverse(value, _collect_top_level_object_specs)
303
+ return top_level_object_specs
304
+
305
+
306
+ def class_dependencies(
307
+ value_or_spec: Union[
308
+ pg.Symbolic,
309
+ Schema,
310
+ pg.typing.ValueSpec,
311
+ Type[pg.Object],
312
+ tuple[Union[pg.typing.ValueSpec, Type[pg.Object]], ...],
313
+ ],
314
+ include_base_classes: bool = True,
315
+ include_subclasses: bool = True,
316
+ include_generated_subclasses: bool = False,
317
+ ) -> list[Type[Any]]:
318
+ """Returns a list of class dependencies from a value or specs."""
319
+ if isinstance(value_or_spec, Schema):
320
+ value_or_spec = value_or_spec.spec
321
+
322
+ if inspect.isclass(value_or_spec) or isinstance(
323
+ value_or_spec, pg.typing.ValueSpec
324
+ ):
325
+ value_or_spec = (value_or_spec,)
326
+
327
+ if isinstance(value_or_spec, tuple):
328
+ value_specs = []
329
+ for v in value_or_spec:
330
+ if isinstance(v, pg.typing.ValueSpec):
331
+ value_specs.append(v)
332
+ elif inspect.isclass(v):
333
+ value_specs.append(pg.typing.Object(v))
334
+ else:
335
+ raise TypeError(f'Unsupported spec type: {v!r}')
336
+ else:
337
+ value_specs = _top_level_object_specs_from_value(value_or_spec)
338
+
339
+ seen = set()
340
+ dependencies = []
341
+
342
+ def _add_dependency(cls_or_classes):
343
+ if isinstance(cls_or_classes, type):
344
+ cls_or_classes = [cls_or_classes]
345
+ for cls in cls_or_classes:
346
+ if cls not in dependencies:
347
+ dependencies.append(cls)
348
+
349
+ def _fill_dependencies(vs: pg.typing.ValueSpec, include_subclasses: bool):
350
+ if isinstance(vs, pg.typing.Object):
351
+ cls = vs.cls
352
+ if cls.__module__ == 'builtins':
353
+ return
354
+
355
+ if cls not in seen:
356
+ seen.add(cls)
357
+
358
+ if include_base_classes:
359
+ # Add base classes as dependencies.
360
+ for base_cls in cls.__bases__:
361
+ # We only keep track of user-defined symbolic classes.
362
+ if base_cls is not object and base_cls is not pg.Object:
363
+ _fill_dependencies(
364
+ pg.typing.Object(base_cls), include_subclasses=False
365
+ )
366
+
367
+ # Add members as dependencies.
368
+ for field in pg.schema(cls).values():
369
+ _fill_dependencies(field.value, include_subclasses)
370
+ _add_dependency(cls)
371
+
372
+ # Check subclasses if available.
373
+ if include_subclasses:
374
+ for subcls in cls.__subclasses__():
375
+ # NOTE(daiyip): To prevent LLM-generated "hallucinated" classes from
376
+ # polluting the generation space, classes dynamically created by
377
+ # 'eval' (which have __module__ == 'builtins') are excluded from
378
+ # dependencies by default.
379
+ if ((include_generated_subclasses or subcls.__module__ != 'builtins')
380
+ and subcls not in dependencies):
381
+ _fill_dependencies(
382
+ pg.typing.Object(subcls), include_subclasses=True
383
+ )
384
+
385
+ if isinstance(vs, pg.typing.List):
386
+ _fill_dependencies(vs.element.value, include_subclasses)
387
+ elif isinstance(vs, pg.typing.Tuple):
388
+ for elem in vs.elements:
389
+ _fill_dependencies(elem.value, include_subclasses)
390
+ elif isinstance(vs, pg.typing.Dict) and vs.schema:
391
+ for v in vs.schema.values():
392
+ _fill_dependencies(v.value, include_subclasses)
393
+ elif isinstance(vs, pg.typing.Union):
394
+ for v in vs.candidates:
395
+ _fill_dependencies(v, include_subclasses)
396
+
397
+ for value_spec in value_specs:
398
+ _fill_dependencies(value_spec, include_subclasses)
399
+ return dependencies
400
+
401
+
402
+ def schema_spec(noneable: bool = False) -> pg.typing.ValueSpec: # pylint: disable=unused-argument
403
+ if typing.TYPE_CHECKING:
404
+ return Any
405
+ return pg.typing.Object(
406
+ Schema, transform=Schema.from_value, is_noneable=noneable
407
+ ) # pylint: disable=unreachable-code
408
+
409
+
410
+ def annotation(
411
+ vs: pg.typing.ValueSpec,
412
+ annotate_optional: bool = True,
413
+ strict: bool = False,
414
+ allowed_dependencies: set[Type[Any]] | None = None,
415
+ ) -> str:
416
+ """Returns the annotation string for a value spec."""
417
+ child_annotation_kwargs = dict(
418
+ strict=strict, allowed_dependencies=allowed_dependencies
419
+ )
420
+ if isinstance(vs, pg.typing.Any):
421
+ return 'Any'
422
+ elif isinstance(vs, pg.typing.Enum):
423
+ candidate_str = ', '.join([repr(v) for v in vs.values])
424
+ return f'Literal[{candidate_str}]'
425
+ elif isinstance(vs, pg.typing.Union):
426
+ candidate_str = ', '.join(
427
+ [
428
+ annotation(c, annotate_optional=False, **child_annotation_kwargs)
429
+ for c in vs.candidates
430
+ ]
431
+ )
432
+ if vs.is_noneable:
433
+ candidate_str += ', None'
434
+ return f'Union[{candidate_str}]'
435
+
436
+ if isinstance(vs, pg.typing.Bool):
437
+ x = 'bool'
438
+ elif isinstance(vs, pg.typing.Str):
439
+ if vs.regex is None:
440
+ x = 'str'
441
+ else:
442
+ if strict:
443
+ x = f"pg.typing.Str(regex='{vs.regex.pattern}')"
444
+ else:
445
+ x = f"str(regex='{vs.regex.pattern}')"
446
+ elif isinstance(vs, pg.typing.Number):
447
+ constraints = []
448
+ min_label = 'min_value' if strict else 'min'
449
+ max_label = 'max_value' if strict else 'max'
450
+ if vs.min_value is not None:
451
+ constraints.append(f'{min_label}={vs.min_value}')
452
+ if vs.max_value is not None:
453
+ constraints.append(f'{max_label}={vs.max_value}')
454
+ x = 'int' if isinstance(vs, pg.typing.Int) else 'float'
455
+ if constraints:
456
+ if strict:
457
+ x = (
458
+ 'pg.typing.Int'
459
+ if isinstance(vs, pg.typing.Int)
460
+ else 'pg.typing.Float'
461
+ )
462
+ x += '(' + ', '.join(constraints) + ')'
463
+ elif isinstance(vs, pg.typing.Object):
464
+ if allowed_dependencies is None or vs.cls in allowed_dependencies:
465
+ x = vs.cls.__name__
466
+ else:
467
+ x = 'Any'
468
+ elif isinstance(vs, pg.typing.List):
469
+ item_str = annotation(vs.element.value, **child_annotation_kwargs)
470
+ x = f'list[{item_str}]'
471
+ elif isinstance(vs, pg.typing.Tuple):
472
+ elem_str = ', '.join(
473
+ [annotation(el.value, **child_annotation_kwargs) for el in vs.elements]
474
+ )
475
+ x = f'tuple[{elem_str}]'
476
+ elif isinstance(vs, pg.typing.Dict):
477
+ kv_pairs = None
478
+ if vs.schema is not None:
479
+ kv_pairs = [
480
+ (k, annotation(f.value, **child_annotation_kwargs))
481
+ for k, f in vs.schema.items()
482
+ if isinstance(k, pg.typing.ConstStrKey)
483
+ ]
484
+
485
+ if kv_pairs:
486
+ kv_str = ', '.join(f"'{k}': {v}" for k, v in kv_pairs)
487
+ x = '{' + kv_str + '}'
488
+ if strict:
489
+ x = f'pg.typing.Dict({x})'
490
+ elif vs.schema and vs.schema.dynamic_field:
491
+ v = annotation(vs.schema.dynamic_field.value, **child_annotation_kwargs)
492
+ x = f'dict[str, {v}]'
493
+ else:
494
+ x = 'dict[str, Any]'
495
+
496
+ else:
497
+ raise TypeError(f'Unsupported value spec being used as schema: {vs}.')
498
+
499
+ if annotate_optional and vs.is_noneable:
500
+ x += ' | None'
501
+ return x
502
+
503
+ #
504
+ # Prompting protocols for structured data.
505
+ #
506
+
507
+
508
+ class PromptingProtocol(metaclass=abc.ABCMeta):
509
+ """Base class for prompting protocols for structured data."""
510
+
511
+ NAME: ClassVar[str]
512
+
513
+ _PROTOCOLS: ClassVar[dict[str, Type['PromptingProtocol']]] = {}
514
+
515
+ def __init_subclass__(cls):
516
+ PromptingProtocol._PROTOCOLS[cls.NAME] = cls
517
+
518
+ @classmethod
519
+ def from_name(cls, name: str) -> 'PromptingProtocol':
520
+ """Returns the prompting protocol from the name."""
521
+ protocol_cls = cls._PROTOCOLS.get(name)
522
+ if protocol_cls is None:
523
+ raise ValueError(f'Unsupported protocol: {name}.')
524
+ return protocol_cls() # pytype: disable=not-instantiable
525
+
526
+ @abc.abstractmethod
527
+ def schema_repr(self, schema: Schema) -> str:
528
+ """Returns the representation of the schema."""
529
+
530
+ @abc.abstractmethod
531
+ def value_repr(
532
+ self,
533
+ value: Any,
534
+ schema: Schema | None = None,
535
+ **kwargs
536
+ ) -> str:
537
+ """Returns the representation of a structured value."""
538
+
539
+ @abc.abstractmethod
540
+ def parse_value(
541
+ self,
542
+ text: str,
543
+ schema: Schema | None = None,
544
+ **kwargs
545
+ ) -> Any:
546
+ """Parses a LM generated text into a structured value."""
547
+
548
+
549
+ def schema_repr(
550
+ schema: Schema,
551
+ *,
552
+ protocol: str = 'python',
553
+ **kwargs
554
+ ) -> str:
555
+ """Returns the representation of the schema based on the protocol."""
556
+ return PromptingProtocol.from_name(protocol).schema_repr(schema, **kwargs)
557
+
558
+
559
+ def value_repr(
560
+ value: Any,
561
+ schema: Schema | None = None,
562
+ *,
563
+ protocol: str = 'python',
564
+ **kwargs) -> str:
565
+ """Returns the representation of a structured value based on the protocol."""
566
+ return PromptingProtocol.from_name(protocol).value_repr(
567
+ value, schema, **kwargs
568
+ )
569
+
570
+
571
+ def parse_value(
572
+ text: str,
573
+ schema: Schema | None = None,
574
+ *,
575
+ protocol: str = 'python',
576
+ **kwargs
577
+ ) -> Any:
578
+ """Parses a LM generated text into a structured value."""
579
+ return PromptingProtocol.from_name(protocol).parse_value(
580
+ text, schema=schema, **kwargs
581
+ )
582
+
583
+
584
+ #
585
+ # Special value markers.
586
+ #
587
+
588
+
589
+ class Missing(pg.Object, pg.typing.CustomTyping):
590
+ """Value marker for a missing field.
591
+
592
+ This class differs from pg.MISSING_VALUE in two aspects:
593
+ * When a field is assigned with lf.Missing(), it's considered non-partial.
594
+ * lf.Missing() could format the value spec as Python annotations that are
595
+ consistent with `lf.structured.Schema.schema_repr()`.
596
+ """
597
+
598
+ def _on_bound(self):
599
+ super()._on_bound()
600
+ self._value_spec = None
601
+
602
+ @property
603
+ def value_spec(self) -> pg.ValueSpec | None:
604
+ """Returns the value spec that applies to the current missing value."""
605
+ return self._value_spec
606
+
607
+ def custom_apply(
608
+ self, path: pg.KeyPath, value_spec: pg.ValueSpec, *args, **kwargs
609
+ ) -> tuple[bool, Any]:
610
+ self._value_spec = value_spec
611
+ return (False, self)
612
+
613
+ def format(self, *args, **kwargs) -> str:
614
+ if self._value_spec is None:
615
+ return 'MISSING'
616
+ return f'MISSING({annotation(self._value_spec)})'
617
+
618
+ @classmethod
619
+ def find_missing(cls, value: Any) -> dict[str, 'Missing']:
620
+ """Lists all missing values contained in the value."""
621
+ missing = {}
622
+
623
+ def _visit(k, v, p):
624
+ del p
625
+ if isinstance(v, Missing):
626
+ missing[k] = v
627
+ return pg.TraverseAction.ENTER
628
+
629
+ pg.traverse(value, _visit)
630
+ return missing
631
+
632
+
633
+ MISSING = Missing()
634
+
635
+
636
+ def mark_missing(value: Any) -> Any:
637
+ """Replaces pg.MISSING within the value with lf.structured.Missing objects."""
638
+ if isinstance(value, list):
639
+ value = pg.List(value)
640
+ elif isinstance(value, dict):
641
+ value = pg.Dict(value)
642
+ if isinstance(value, pg.Symbolic):
643
+
644
+ def _mark_missing(k, v, p):
645
+ del k, p
646
+ if pg.MISSING_VALUE == v:
647
+ v = Missing()
648
+ return v
649
+
650
+ return value.rebind(_mark_missing, raise_on_no_change=False)
651
+ return value
652
+
653
+
654
+ class Unknown(pg.Object, pg.typing.CustomTyping):
655
+ """Value marker for a field that LMs could not provide."""
656
+
657
+ def custom_apply(self, *args, **kwargs) -> tuple[bool, Any]:
658
+ return (False, self)
659
+
660
+ def format(self, *args, **kwargs) -> str:
661
+ return 'UNKNOWN'
662
+
663
+
664
+ UNKNOWN = Unknown()