langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512150805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +7 -1
  3. langfun/core/agentic/__init__.py +8 -1
  4. langfun/core/agentic/action.py +740 -112
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +189 -24
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +11 -2
  20. langfun/core/data/conversion/gemini_test.py +48 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +48 -44
  24. langfun/core/eval/base_test.py +5 -5
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +3 -0
  29. langfun/core/eval/v2/checkpointing.py +148 -46
  30. langfun/core/eval/v2/checkpointing_test.py +9 -2
  31. langfun/core/eval/v2/config_saver.py +37 -0
  32. langfun/core/eval/v2/config_saver_test.py +36 -0
  33. langfun/core/eval/v2/eval_test_helper.py +104 -3
  34. langfun/core/eval/v2/evaluation.py +102 -19
  35. langfun/core/eval/v2/evaluation_test.py +9 -3
  36. langfun/core/eval/v2/example.py +50 -40
  37. langfun/core/eval/v2/example_test.py +16 -8
  38. langfun/core/eval/v2/experiment.py +95 -20
  39. langfun/core/eval/v2/experiment_test.py +19 -0
  40. langfun/core/eval/v2/metric_values.py +31 -3
  41. langfun/core/eval/v2/metric_values_test.py +32 -0
  42. langfun/core/eval/v2/metrics.py +157 -44
  43. langfun/core/eval/v2/metrics_test.py +39 -18
  44. langfun/core/eval/v2/progress.py +31 -1
  45. langfun/core/eval/v2/progress_test.py +27 -0
  46. langfun/core/eval/v2/progress_tracking.py +13 -5
  47. langfun/core/eval/v2/progress_tracking_test.py +9 -1
  48. langfun/core/eval/v2/reporting.py +88 -71
  49. langfun/core/eval/v2/reporting_test.py +24 -6
  50. langfun/core/eval/v2/runners/__init__.py +30 -0
  51. langfun/core/eval/v2/{runners.py → runners/base.py} +73 -180
  52. langfun/core/eval/v2/runners/beam.py +354 -0
  53. langfun/core/eval/v2/runners/beam_test.py +153 -0
  54. langfun/core/eval/v2/runners/ckpt_monitor.py +350 -0
  55. langfun/core/eval/v2/runners/ckpt_monitor_test.py +213 -0
  56. langfun/core/eval/v2/runners/debug.py +40 -0
  57. langfun/core/eval/v2/runners/debug_test.py +76 -0
  58. langfun/core/eval/v2/runners/parallel.py +243 -0
  59. langfun/core/eval/v2/runners/parallel_test.py +182 -0
  60. langfun/core/eval/v2/runners/sequential.py +47 -0
  61. langfun/core/eval/v2/runners/sequential_test.py +169 -0
  62. langfun/core/langfunc.py +45 -130
  63. langfun/core/langfunc_test.py +7 -5
  64. langfun/core/language_model.py +189 -36
  65. langfun/core/language_model_test.py +54 -3
  66. langfun/core/llms/__init__.py +14 -1
  67. langfun/core/llms/anthropic.py +157 -2
  68. langfun/core/llms/azure_openai.py +29 -17
  69. langfun/core/llms/cache/base.py +25 -3
  70. langfun/core/llms/cache/in_memory.py +48 -7
  71. langfun/core/llms/cache/in_memory_test.py +14 -4
  72. langfun/core/llms/compositional.py +25 -1
  73. langfun/core/llms/deepseek.py +30 -2
  74. langfun/core/llms/fake.py +32 -1
  75. langfun/core/llms/gemini.py +90 -12
  76. langfun/core/llms/gemini_test.py +110 -0
  77. langfun/core/llms/google_genai.py +52 -1
  78. langfun/core/llms/groq.py +28 -3
  79. langfun/core/llms/llama_cpp.py +23 -4
  80. langfun/core/llms/openai.py +120 -3
  81. langfun/core/llms/openai_compatible.py +148 -27
  82. langfun/core/llms/openai_compatible_test.py +207 -20
  83. langfun/core/llms/openai_test.py +0 -2
  84. langfun/core/llms/rest.py +16 -1
  85. langfun/core/llms/vertexai.py +78 -8
  86. langfun/core/logging.py +1 -1
  87. langfun/core/mcp/__init__.py +10 -0
  88. langfun/core/mcp/client.py +177 -0
  89. langfun/core/mcp/client_test.py +71 -0
  90. langfun/core/mcp/session.py +241 -0
  91. langfun/core/mcp/session_test.py +54 -0
  92. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  93. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  94. langfun/core/mcp/tool.py +254 -0
  95. langfun/core/mcp/tool_test.py +197 -0
  96. langfun/core/memory.py +1 -0
  97. langfun/core/message.py +160 -55
  98. langfun/core/message_test.py +65 -81
  99. langfun/core/modalities/__init__.py +8 -0
  100. langfun/core/modalities/audio.py +21 -1
  101. langfun/core/modalities/image.py +73 -3
  102. langfun/core/modalities/image_test.py +116 -0
  103. langfun/core/modalities/mime.py +78 -4
  104. langfun/core/modalities/mime_test.py +59 -0
  105. langfun/core/modalities/pdf.py +19 -1
  106. langfun/core/modalities/video.py +21 -1
  107. langfun/core/modality.py +167 -29
  108. langfun/core/modality_test.py +42 -12
  109. langfun/core/natural_language.py +1 -1
  110. langfun/core/sampling.py +4 -4
  111. langfun/core/sampling_test.py +20 -4
  112. langfun/core/structured/__init__.py +2 -24
  113. langfun/core/structured/completion.py +34 -44
  114. langfun/core/structured/completion_test.py +23 -43
  115. langfun/core/structured/description.py +54 -50
  116. langfun/core/structured/function_generation.py +29 -12
  117. langfun/core/structured/mapping.py +81 -37
  118. langfun/core/structured/parsing.py +95 -79
  119. langfun/core/structured/parsing_test.py +0 -3
  120. langfun/core/structured/querying.py +230 -154
  121. langfun/core/structured/querying_test.py +69 -33
  122. langfun/core/structured/schema/__init__.py +49 -0
  123. langfun/core/structured/schema/base.py +664 -0
  124. langfun/core/structured/schema/base_test.py +531 -0
  125. langfun/core/structured/schema/json.py +174 -0
  126. langfun/core/structured/schema/json_test.py +121 -0
  127. langfun/core/structured/schema/python.py +316 -0
  128. langfun/core/structured/schema/python_test.py +410 -0
  129. langfun/core/structured/schema_generation.py +33 -14
  130. langfun/core/structured/scoring.py +47 -36
  131. langfun/core/structured/tokenization.py +26 -11
  132. langfun/core/subscription.py +2 -2
  133. langfun/core/template.py +175 -50
  134. langfun/core/template_test.py +123 -17
  135. langfun/env/__init__.py +43 -0
  136. langfun/env/base_environment.py +827 -0
  137. langfun/env/base_environment_test.py +473 -0
  138. langfun/env/base_feature.py +304 -0
  139. langfun/env/base_feature_test.py +228 -0
  140. langfun/env/base_sandbox.py +842 -0
  141. langfun/env/base_sandbox_test.py +1235 -0
  142. langfun/env/event_handlers/__init__.py +14 -0
  143. langfun/env/event_handlers/chain.py +233 -0
  144. langfun/env/event_handlers/chain_test.py +253 -0
  145. langfun/env/event_handlers/event_logger.py +472 -0
  146. langfun/env/event_handlers/event_logger_test.py +304 -0
  147. langfun/env/event_handlers/metric_writer.py +726 -0
  148. langfun/env/event_handlers/metric_writer_test.py +214 -0
  149. langfun/env/interface.py +1640 -0
  150. langfun/env/interface_test.py +153 -0
  151. langfun/env/load_balancers.py +59 -0
  152. langfun/env/load_balancers_test.py +141 -0
  153. langfun/env/test_utils.py +507 -0
  154. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/METADATA +7 -3
  155. langfun-0.1.2.dev202512150805.dist-info/RECORD +217 -0
  156. langfun/core/eval/v2/runners_test.py +0 -343
  157. langfun/core/structured/schema.py +0 -987
  158. langfun/core/structured/schema_test.py +0 -982
  159. langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
  160. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/WHEEL +0 -0
  161. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/licenses/LICENSE +0 -0
  162. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/top_level.txt +0 -0
@@ -1,987 +0,0 @@
1
- # Copyright 2023 The Langfun Authors
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- """Schema for structured data."""
15
-
16
- import abc
17
- import inspect
18
- import io
19
- import re
20
- import sys
21
- import textwrap
22
- import typing
23
- from typing import Any, Literal, Sequence, Type, Union
24
- import langfun.core as lf
25
- from langfun.core.coding.python import correction
26
- import pyglove as pg
27
-
28
-
29
- def include_method_in_prompt(method):
30
- """Decorator to include a method in the class definition of the prompt."""
31
- setattr(method, '__show_in_prompt__', True)
32
- return method
33
-
34
-
35
- def should_include_method_in_prompt(method):
36
- """Returns true if the method should be shown in the prompt."""
37
- return getattr(method, '__show_in_prompt__', False)
38
-
39
-
40
- def parse_value_spec(value) -> pg.typing.ValueSpec:
41
- """Parses a PyGlove ValueSpec equivalence into a ValueSpec."""
42
- if isinstance(value, pg.typing.ValueSpec):
43
- return value
44
-
45
- if isinstance(value, dict) and len(value) == 1 and 'result' in value:
46
- value = value['result']
47
-
48
- def _parse_node(v) -> pg.typing.ValueSpec:
49
- if isinstance(v, dict):
50
- return pg.typing.Dict([(k, _parse_node(cv)) for k, cv in v.items()])
51
- elif isinstance(v, list):
52
- if len(v) != 1:
53
- raise ValueError(
54
- 'Annotation with list must be a list of a single element. '
55
- f'Encountered: {v}'
56
- )
57
- return pg.typing.List(_parse_node(v[0]))
58
- else:
59
- spec = pg.typing.ValueSpec.from_annotation(v, auto_typing=True)
60
- if isinstance(
61
- spec,
62
- (
63
- pg.typing.Any,
64
- pg.typing.Callable,
65
- pg.typing.Tuple,
66
- pg.typing.Type,
67
- pg.typing.Union,
68
- ),
69
- ):
70
- raise ValueError(f'Unsupported schema specification: {v}')
71
- return spec
72
-
73
- return _parse_node(value)
74
-
75
-
76
- SchemaProtocol = Literal['json', 'python']
77
-
78
-
79
- class SchemaError(Exception): # pylint: disable=g-bad-exception-name
80
- """Schema error."""
81
-
82
- def __init__(self,
83
- schema: 'Schema',
84
- value: Any,
85
- protocol: SchemaProtocol,
86
- cause: Exception):
87
- self.schema = schema
88
- self.value = value
89
- self.protocol = protocol
90
- self.cause = cause
91
-
92
- def __str__(self):
93
- r = io.StringIO()
94
- r.write(
95
- pg.colored(
96
- f'{self.cause.__class__.__name__}: {self.cause}', 'magenta'
97
- )
98
- )
99
-
100
- r.write('\n')
101
- r.write(pg.colored('Schema:', 'red'))
102
- r.write('\n\n')
103
- r.write(textwrap.indent(
104
- pg.colored(
105
- schema_repr(self.protocol).repr(self.schema), 'magenta'
106
- ),
107
- ' ' * 2
108
- ))
109
- r.write('\n\n')
110
- r.write(pg.colored('Generated value:', 'red'))
111
- r.write('\n\n')
112
- r.write(textwrap.indent(
113
- pg.colored(value_repr(self.protocol).repr(self.value), 'magenta'),
114
- ' ' * 2
115
- ))
116
- return r.getvalue()
117
-
118
-
119
- class Schema(
120
- lf.NaturalLanguageFormattable,
121
- pg.Object,
122
- pg.views.HtmlTreeView.Extension
123
- ):
124
- """Base class for structured data schema."""
125
-
126
- spec: pg.typing.Annotated[
127
- pg.typing.Object(pg.typing.ValueSpec, transform=parse_value_spec),
128
- (
129
- 'A PyGlove ValueSpec object representing the spec for the value '
130
- 'to be parsed.'
131
- ),
132
- ]
133
-
134
- def schema_str(self, protocol: SchemaProtocol = 'json', **kwargs) -> str:
135
- """Returns the representation of the schema."""
136
- return schema_repr(protocol).repr(self, **kwargs)
137
-
138
- def value_str(
139
- self, value: Any, protocol: SchemaProtocol = 'json', **kwargs
140
- ) -> str:
141
- """Returns the representation of a structured value."""
142
- return value_repr(protocol).repr(value, self, **kwargs)
143
-
144
- def parse(
145
- self, text: str, protocol: SchemaProtocol = 'json', **kwargs
146
- ) -> Any:
147
- """Parse a LM generated text into a structured value."""
148
- value = value_repr(protocol).parse(text, self, **kwargs)
149
-
150
- # TODO(daiyip): support autofix for schema error.
151
- try:
152
- return self.spec.apply(value)
153
- except Exception as e:
154
- raise SchemaError(self, value, protocol, e) # pylint: disable=raise-missing-from
155
-
156
- def natural_language_format(self) -> str:
157
- return self.schema_str()
158
-
159
- def schema_dict(self) -> dict[str, Any]:
160
- """Returns the dict representation of the schema."""
161
-
162
- def _node(vs: pg.typing.ValueSpec) -> Any:
163
- if isinstance(vs, pg.typing.PrimitiveType):
164
- return vs
165
- elif isinstance(vs, pg.typing.Dict):
166
- assert vs.schema is not None
167
- return {str(k): _node(f.value) for k, f in vs.schema.fields.items()}
168
- elif isinstance(vs, pg.typing.List):
169
- return [_node(vs.element.value)]
170
- elif isinstance(vs, pg.typing.Object):
171
- if issubclass(vs.cls, pg.Object):
172
- d = {pg.JSONConvertible.TYPE_NAME_KEY: vs.cls.__serialization_key__}
173
- d.update(
174
- {
175
- str(k): _node(f.value)
176
- for k, f in vs.cls.__schema__.fields.items()
177
- }
178
- )
179
- return d
180
- raise TypeError(
181
- 'Unsupported value spec being used as the schema for '
182
- f'structured data: {vs}.')
183
-
184
- return {'result': _node(self.spec)}
185
-
186
- def class_dependencies(
187
- self,
188
- include_base_classes: bool = True,
189
- include_subclasses: bool = True,
190
- include_generated_subclasses: bool = False) -> list[Type[Any]]:
191
- """Returns a list of class dependencies for current schema."""
192
- return class_dependencies(
193
- self.spec,
194
- include_base_classes,
195
- include_subclasses,
196
- include_generated_subclasses
197
- )
198
-
199
- @classmethod
200
- def from_value(cls, value) -> 'Schema':
201
- """Creates a schema from an equivalent representation."""
202
- if isinstance(value, Schema):
203
- return value
204
- return cls(parse_value_spec(value))
205
-
206
- def _html_tree_view_content(
207
- self,
208
- *,
209
- view: pg.views.HtmlTreeView,
210
- **kwargs,
211
- ):
212
- return pg.Html.element(
213
- 'div',
214
- [pg.Html.escape(self.schema_str(protocol='python'))],
215
- css_classes=['lf-schema-definition']
216
- ).add_style(
217
- """
218
- .lf-schema-definition {
219
- color: blue;
220
- margin: 5px;
221
- white-space: pre-wrap;
222
- }
223
- """
224
- )
225
-
226
-
227
- SchemaType = Union[Schema, Type[Any], list[Type[Any]], dict[str, Any]]
228
-
229
-
230
- def _top_level_object_specs_from_value(value: pg.Symbolic) -> list[Type[Any]]:
231
- """Returns a list of top level value specs from a symbolic value."""
232
- top_level_object_specs = []
233
-
234
- def _collect_top_level_object_specs(k, v, p):
235
- del k, p
236
- if isinstance(v, pg.Object):
237
- top_level_object_specs.append(pg.typing.Object(v.__class__))
238
- return pg.TraverseAction.CONTINUE
239
- return pg.TraverseAction.ENTER
240
-
241
- pg.traverse(value, _collect_top_level_object_specs)
242
- return top_level_object_specs
243
-
244
-
245
- def class_dependencies(
246
- value_or_spec: Union[
247
- pg.Symbolic,
248
- Schema,
249
- pg.typing.ValueSpec,
250
- Type[pg.Object],
251
- tuple[Union[pg.typing.ValueSpec, Type[pg.Object]], ...],
252
- ],
253
- include_base_classes: bool = True,
254
- include_subclasses: bool = True,
255
- include_generated_subclasses: bool = False,
256
- ) -> list[Type[Any]]:
257
- """Returns a list of class dependencies from a value or specs."""
258
- if isinstance(value_or_spec, Schema):
259
- value_or_spec = value_or_spec.spec
260
-
261
- if inspect.isclass(value_or_spec) or isinstance(
262
- value_or_spec, pg.typing.ValueSpec
263
- ):
264
- value_or_spec = (value_or_spec,)
265
-
266
- if isinstance(value_or_spec, tuple):
267
- value_specs = []
268
- for v in value_or_spec:
269
- if isinstance(v, pg.typing.ValueSpec):
270
- value_specs.append(v)
271
- elif inspect.isclass(v):
272
- value_specs.append(pg.typing.Object(v))
273
- else:
274
- raise TypeError(f'Unsupported spec type: {v!r}')
275
- else:
276
- value_specs = _top_level_object_specs_from_value(value_or_spec)
277
-
278
- seen = set()
279
- dependencies = []
280
-
281
- def _add_dependency(cls_or_classes):
282
- if isinstance(cls_or_classes, type):
283
- cls_or_classes = [cls_or_classes]
284
- for cls in cls_or_classes:
285
- if cls not in dependencies:
286
- dependencies.append(cls)
287
-
288
- def _fill_dependencies(vs: pg.typing.ValueSpec, include_subclasses: bool):
289
- if isinstance(vs, pg.typing.Object):
290
- if vs.cls not in seen:
291
- seen.add(vs.cls)
292
-
293
- if include_base_classes:
294
- # Add base classes as dependencies.
295
- for base_cls in vs.cls.__bases__:
296
- # We only keep track of user-defined symbolic classes.
297
- if base_cls is not object and base_cls is not pg.Object:
298
- _fill_dependencies(
299
- pg.typing.Object(base_cls), include_subclasses=False
300
- )
301
-
302
- # Add members as dependencies.
303
- for field in pg.schema(vs.cls).values():
304
- _fill_dependencies(field.value, include_subclasses)
305
- _add_dependency(vs.cls)
306
-
307
- # Check subclasses if available.
308
- if include_subclasses:
309
- for cls in vs.cls.__subclasses__():
310
- # NOTE(daiyip): To prevent LLM-generated "hallucinated" classes from
311
- # polluting the generation space, classes dynamically created by
312
- # 'eval' (which have __module__ == 'builtins') are excluded from
313
- # dependencies by default.
314
- if ((include_generated_subclasses or cls.__module__ != 'builtins')
315
- and cls not in dependencies):
316
- _fill_dependencies(pg.typing.Object(cls), include_subclasses=True)
317
-
318
- if isinstance(vs, pg.typing.List):
319
- _fill_dependencies(vs.element.value, include_subclasses)
320
- elif isinstance(vs, pg.typing.Tuple):
321
- for elem in vs.elements:
322
- _fill_dependencies(elem.value, include_subclasses)
323
- elif isinstance(vs, pg.typing.Dict) and vs.schema:
324
- for v in vs.schema.values():
325
- _fill_dependencies(v.value, include_subclasses)
326
- elif isinstance(vs, pg.typing.Union):
327
- for v in vs.candidates:
328
- _fill_dependencies(v, include_subclasses)
329
-
330
- for value_spec in value_specs:
331
- _fill_dependencies(value_spec, include_subclasses)
332
- return dependencies
333
-
334
-
335
- def schema_spec(noneable: bool = False) -> pg.typing.ValueSpec: # pylint: disable=unused-argument
336
- if typing.TYPE_CHECKING:
337
- return Any
338
- return pg.typing.Object(
339
- Schema, transform=Schema.from_value, is_noneable=noneable
340
- ) # pylint: disable=unreachable-code
341
-
342
-
343
- #
344
- # Schema representations.
345
- #
346
-
347
-
348
- class SchemaRepr(metaclass=abc.ABCMeta):
349
- """Base class for schema representation."""
350
-
351
- @abc.abstractmethod
352
- def repr(self, schema: Schema) -> str:
353
- """Returns the representation of the schema."""
354
-
355
-
356
- class SchemaPythonRepr(SchemaRepr):
357
- """Python-representation for a schema."""
358
-
359
- def repr(
360
- self,
361
- schema: Schema,
362
- *,
363
- include_result_definition: bool = True,
364
- markdown: bool = True,
365
- **kwargs,
366
- ) -> str:
367
- ret = ''
368
- if include_result_definition:
369
- ret += self.result_definition(schema)
370
- class_definition_str = self.class_definitions(
371
- schema, markdown=markdown, **kwargs
372
- )
373
- if class_definition_str:
374
- ret += f'\n\n{class_definition_str}'
375
- return ret.strip()
376
-
377
- def class_definitions(
378
- self,
379
- schema: Schema,
380
- additional_dependencies: list[Type[Any]] | None = None,
381
- **kwargs
382
- ) -> str | None:
383
- """Returns a string containing of class definitions from a schema."""
384
- deps = schema.class_dependencies(
385
- include_base_classes=False, include_subclasses=True
386
- )
387
- allowed_dependencies = set(deps)
388
- if additional_dependencies:
389
- allowed_dependencies.update(additional_dependencies)
390
- return class_definitions(
391
- deps, allowed_dependencies=allowed_dependencies, **kwargs)
392
-
393
- def result_definition(self, schema: Schema) -> str:
394
- return annotation(schema.spec)
395
-
396
-
397
- def source_form(value, compact: bool = True, markdown: bool = False) -> str:
398
- """Returns the source code form of an object."""
399
- return ValuePythonRepr().repr(value, compact=compact, markdown=markdown)
400
-
401
-
402
- def class_definitions(
403
- classes: Sequence[Type[Any]],
404
- *,
405
- allowed_dependencies: set[Type[Any]] | None = None,
406
- strict: bool = False,
407
- markdown: bool = False,
408
- ) -> str | None:
409
- """Returns a str for class definitions."""
410
- if not classes:
411
- return None
412
- def_str = io.StringIO()
413
- for i, cls in enumerate(classes):
414
- if i > 0:
415
- def_str.write('\n')
416
- def_str.write(
417
- class_definition(
418
- cls,
419
- strict=strict,
420
- allowed_dependencies=allowed_dependencies,
421
- )
422
- )
423
- ret = def_str.getvalue()
424
- if markdown and ret:
425
- ret = f'```python\n{ret}```'
426
- return ret
427
-
428
-
429
- def class_definition(
430
- cls,
431
- strict: bool = False,
432
- allowed_dependencies: set[Type[Any]] | None = None,
433
- ) -> str:
434
- """Returns the Python class definition."""
435
- out = io.StringIO()
436
- schema = pg.schema(cls)
437
- eligible_bases = []
438
- for base_cls in cls.__bases__:
439
- if base_cls is not object:
440
- if allowed_dependencies is None or base_cls in allowed_dependencies:
441
- eligible_bases.append(base_cls.__name__)
442
-
443
- if eligible_bases:
444
- base_cls_str = ', '.join(eligible_bases)
445
- out.write(f'class {cls.__name__}({base_cls_str}):\n')
446
- else:
447
- out.write(f'class {cls.__name__}:\n')
448
-
449
- if cls.__doc__:
450
- doc_lines = cls.__doc__.strip().split('\n')
451
- if len(doc_lines) == 1:
452
- out.write(f' """{cls.__doc__}"""\n')
453
- else:
454
- out.write(' """')
455
-
456
- # Since Python 3.13, the indentation of docstring lines is removed.
457
- # Therefore, we add two spaces to each non-empty line to keep the
458
- # indentation consistent with the class definition.
459
- if sys.version_info >= (3, 13):
460
- for i in range(1, len(doc_lines)):
461
- if doc_lines[i]:
462
- doc_lines[i] = ' ' * 2 + doc_lines[i]
463
-
464
- for line in doc_lines:
465
- out.write(line)
466
- out.write('\n')
467
- out.write(' """\n')
468
-
469
- empty_class = True
470
- if schema.fields:
471
- for key, field in schema.items():
472
- if not isinstance(key, pg.typing.ConstStrKey):
473
- pg.logging.warning(
474
- 'Variable-length keyword arguments is not supported in '
475
- f'structured parsing or query. Encountered: {cls}, Schema: {schema}'
476
- )
477
- continue
478
-
479
- # Skip fields that are marked as excluded from the prompt sent to LLM
480
- # for OOP.
481
- if field.metadata.get('exclude_from_prompt', False):
482
- continue
483
-
484
- # Write field doc string as comments before the field definition.
485
- if field.description:
486
- for line in field.description.split('\n'):
487
- if line:
488
- out.write(' # ')
489
- out.write(line)
490
- out.write('\n')
491
-
492
- annotation_str = annotation(
493
- field.value, strict=strict, allowed_dependencies=allowed_dependencies
494
- )
495
- out.write(f' {field.key}: {annotation_str}')
496
- out.write('\n')
497
- empty_class = False
498
-
499
- for method in _iter_newly_defined_methods(cls, allowed_dependencies):
500
- source = inspect.getsource(method)
501
- # Remove decorators from the method definition.
502
- source = re.sub(r'\s*@.*\.include_method_in_prompt.*\n', '', source)
503
- out.write('\n')
504
- out.write(
505
- textwrap.indent(
506
- inspect.cleandoc('\n' + source), ' ' * 2)
507
- )
508
- out.write('\n')
509
- empty_class = False
510
-
511
- if empty_class:
512
- out.write(' pass\n')
513
- return out.getvalue()
514
-
515
-
516
- def _iter_newly_defined_methods(
517
- cls, allowed_dependencies: set[Type[Any]] | None):
518
- names = {attr_name: True for attr_name in dir(cls)}
519
- for base in cls.__bases__:
520
- if allowed_dependencies is None or base in allowed_dependencies:
521
- for name in dir(base):
522
- names.pop(name, None)
523
- for name in names.keys():
524
- attr = getattr(cls, name)
525
- if callable(attr) and should_include_method_in_prompt(attr):
526
- yield attr
527
-
528
-
529
- def annotation(
530
- vs: pg.typing.ValueSpec,
531
- annotate_optional: bool = True,
532
- strict: bool = False,
533
- allowed_dependencies: set[Type[Any]] | None = None,
534
- ) -> str:
535
- """Returns the annotation string for a value spec."""
536
- child_annotation_kwargs = dict(
537
- strict=strict, allowed_dependencies=allowed_dependencies
538
- )
539
- if isinstance(vs, pg.typing.Any):
540
- return 'Any'
541
- elif isinstance(vs, pg.typing.Enum):
542
- candidate_str = ', '.join([repr(v) for v in vs.values])
543
- return f'Literal[{candidate_str}]'
544
- elif isinstance(vs, pg.typing.Union):
545
- candidate_str = ', '.join(
546
- [
547
- annotation(c, annotate_optional=False, **child_annotation_kwargs)
548
- for c in vs.candidates
549
- ]
550
- )
551
- if vs.is_noneable:
552
- candidate_str += ', None'
553
- return f'Union[{candidate_str}]'
554
-
555
- if isinstance(vs, pg.typing.Bool):
556
- x = 'bool'
557
- elif isinstance(vs, pg.typing.Str):
558
- if vs.regex is None:
559
- x = 'str'
560
- else:
561
- if strict:
562
- x = f"pg.typing.Str(regex='{vs.regex.pattern}')"
563
- else:
564
- x = f"str(regex='{vs.regex.pattern}')"
565
- elif isinstance(vs, pg.typing.Number):
566
- constraints = []
567
- min_label = 'min_value' if strict else 'min'
568
- max_label = 'max_value' if strict else 'max'
569
- if vs.min_value is not None:
570
- constraints.append(f'{min_label}={vs.min_value}')
571
- if vs.max_value is not None:
572
- constraints.append(f'{max_label}={vs.max_value}')
573
- x = 'int' if isinstance(vs, pg.typing.Int) else 'float'
574
- if constraints:
575
- if strict:
576
- x = (
577
- 'pg.typing.Int'
578
- if isinstance(vs, pg.typing.Int)
579
- else 'pg.typing.Float'
580
- )
581
- x += '(' + ', '.join(constraints) + ')'
582
- elif isinstance(vs, pg.typing.Object):
583
- if allowed_dependencies is None or vs.cls in allowed_dependencies:
584
- x = vs.cls.__name__
585
- else:
586
- x = 'Any'
587
- elif isinstance(vs, pg.typing.List):
588
- item_str = annotation(vs.element.value, **child_annotation_kwargs)
589
- x = f'list[{item_str}]'
590
- elif isinstance(vs, pg.typing.Tuple):
591
- elem_str = ', '.join(
592
- [annotation(el.value, **child_annotation_kwargs) for el in vs.elements]
593
- )
594
- x = f'tuple[{elem_str}]'
595
- elif isinstance(vs, pg.typing.Dict):
596
- kv_pairs = None
597
- if vs.schema is not None:
598
- kv_pairs = [
599
- (k, annotation(f.value, **child_annotation_kwargs))
600
- for k, f in vs.schema.items()
601
- if isinstance(k, pg.typing.ConstStrKey)
602
- ]
603
-
604
- if kv_pairs:
605
- kv_str = ', '.join(f"'{k}': {v}" for k, v in kv_pairs)
606
- x = '{' + kv_str + '}'
607
- if strict:
608
- x = f'pg.typing.Dict({x})'
609
- elif vs.schema and vs.schema.dynamic_field:
610
- v = annotation(vs.schema.dynamic_field.value, **child_annotation_kwargs)
611
- x = f'dict[str, {v}]'
612
- else:
613
- x = 'dict[str, Any]'
614
-
615
- else:
616
- raise TypeError(f'Unsupported value spec being used as schema: {vs}.')
617
-
618
- if annotate_optional and vs.is_noneable:
619
- x += ' | None'
620
- return x
621
-
622
-
623
- class SchemaJsonRepr(SchemaRepr):
624
- """JSON-representation for a schema."""
625
-
626
- def repr(self, schema: Schema, **kwargs) -> str:
627
- del kwargs
628
- out = io.StringIO()
629
- def _visit(node: Any) -> None:
630
- if isinstance(node, str):
631
- out.write(f'"{node}"')
632
- elif isinstance(node, list):
633
- assert len(node) == 1, node
634
- out.write('[')
635
- _visit(node[0])
636
- out.write(']')
637
- elif isinstance(node, dict):
638
- out.write('{')
639
- for i, (k, v) in enumerate(node.items()):
640
- if i != 0:
641
- out.write(', ')
642
- out.write(f'"{k}": ')
643
- _visit(v)
644
- out.write('}')
645
- elif isinstance(node, pg.typing.Enum):
646
- out.write(' | '.join(
647
- f'"{v}"' if isinstance(v, str) else repr(v)
648
- for v in node.values))
649
- elif isinstance(node, pg.typing.PrimitiveType):
650
- x = node.value_type.__name__
651
- if isinstance(node, pg.typing.Number):
652
- params = []
653
- if node.min_value is not None:
654
- params.append(f'min={node.min_value}')
655
- if node.max_value is not None:
656
- params.append(f'max={node.max_value}')
657
- if params:
658
- x += f'({", ".join(params)})'
659
- elif isinstance(node, pg.typing.Str):
660
- if node.regex is not None:
661
- x += f'(regex={node.regex.pattern})'
662
- if node.is_noneable:
663
- x = x + ' | None'
664
- out.write(x)
665
- else:
666
- raise ValueError(
667
- f'Unsupported value spec being used as schema: {node}.')
668
- _visit(schema.schema_dict())
669
- return out.getvalue()
670
-
671
-
672
- #
673
- # Value representations.
674
- #
675
-
676
-
677
- class ValueRepr(metaclass=abc.ABCMeta):
678
- """Base class for value representation."""
679
-
680
- @abc.abstractmethod
681
- def repr(self, value: Any, schema: Schema | None = None, **kwargs) -> str:
682
- """Returns the representation of a structured value."""
683
-
684
- @abc.abstractmethod
685
- def parse(self, text: str, schema: Schema | None = None, **kwargs) -> Any:
686
- """Parse a LM generated text into a structured value."""
687
-
688
-
689
- class ValuePythonRepr(ValueRepr):
690
- """Python-representation for value."""
691
-
692
- def repr(self,
693
- value: Any,
694
- schema: Schema | None = None,
695
- *,
696
- compact: bool = True,
697
- verbose: bool = False,
698
- markdown: bool = True,
699
- assign_to_var: str | None = None,
700
- **kwargs) -> str:
701
- del schema
702
- if inspect.isclass(value):
703
- cls_schema = Schema.from_value(value)
704
- if isinstance(cls_schema.spec, pg.typing.Object):
705
- object_code = SchemaPythonRepr().class_definitions(
706
- cls_schema,
707
- markdown=markdown,
708
- # We add `pg.Object` as additional dependencies to the class
709
- # definition so exemplars for class generation could show
710
- # pg.Object as their bases.
711
- additional_dependencies=[pg.Object]
712
- )
713
- assert object_code is not None
714
- return object_code
715
- else:
716
- object_code = SchemaPythonRepr().result_definition(cls_schema)
717
- elif isinstance(value, lf.Template):
718
- return str(value)
719
- else:
720
- object_code = pg.format(
721
- value, compact=compact, verbose=verbose, python_format=True
722
- )
723
- if assign_to_var is not None:
724
- object_code = f'{assign_to_var} = {object_code}'
725
- if markdown:
726
- return f'```python\n{ object_code }\n```'
727
- return object_code
728
-
729
- def parse(
730
- self,
731
- text: str,
732
- schema: Schema | None = None,
733
- *,
734
- additional_context: dict[str, Type[Any]] | None = None,
735
- permission: pg.coding.CodePermission = (
736
- pg.coding.CodePermission.ASSIGN | pg.coding.CodePermission.CALL
737
- ),
738
- autofix=0,
739
- autofix_lm: lf.LanguageModel = lf.contextual(),
740
- **kwargs,
741
- ) -> Any:
742
- """Parse a Python string into a structured object."""
743
- del kwargs
744
- global_vars = additional_context or {}
745
- if schema is not None:
746
- dependencies = schema.class_dependencies()
747
- global_vars.update({d.__name__: d for d in dependencies})
748
- return structure_from_python(
749
- text,
750
- global_vars=global_vars,
751
- autofix=autofix,
752
- autofix_lm=autofix_lm,
753
- permission=permission,
754
- )
755
-
756
-
757
- def structure_from_python(
758
- code: str,
759
- *,
760
- global_vars: dict[str, Any] | None = None,
761
- permission: pg.coding.CodePermission = (
762
- pg.coding.CodePermission.ASSIGN | pg.coding.CodePermission.CALL
763
- ),
764
- autofix=0,
765
- autofix_lm: lf.LanguageModel = lf.contextual(),
766
- ) -> Any:
767
- """Evaluates structure from Python code with access to symbols."""
768
- global_vars = global_vars or {}
769
- global_vars.update({
770
- 'pg': pg,
771
- 'Object': pg.Object,
772
- 'Any': typing.Any,
773
- 'List': typing.List,
774
- 'Tuple': typing.Tuple,
775
- 'Dict': typing.Dict,
776
- 'Sequence': typing.Sequence,
777
- 'Optional': typing.Optional,
778
- 'Union': typing.Union,
779
- # Special value markers.
780
- 'UNKNOWN': UNKNOWN,
781
- })
782
- # We are creating objects here, so we execute the code without a sandbox.
783
- return correction.run_with_correction(
784
- code,
785
- global_vars=global_vars,
786
- sandbox=False,
787
- max_attempts=autofix,
788
- lm=autofix_lm,
789
- permission=permission,
790
- )
791
-
792
-
793
- class JsonError(Exception):
794
- """Json parsing error."""
795
-
796
- def __init__(self, json: str, cause: Exception):
797
- self.json = json
798
- self.cause = cause
799
-
800
- def __str__(self) -> str:
801
- r = io.StringIO()
802
- r.write(
803
- pg.colored(
804
- f'{self.cause.__class__.__name__}: {self.cause}', 'magenta'
805
- )
806
- )
807
-
808
- r.write('\n\n')
809
- r.write(pg.colored('JSON text:', 'red'))
810
- r.write('\n\n')
811
- r.write(textwrap.indent(pg.colored(self.json, 'magenta'), ' ' * 2))
812
- return r.getvalue()
813
-
814
-
815
- class ValueJsonRepr(ValueRepr):
816
- """JSON-representation for value."""
817
-
818
- def repr(self, value: Any, schema: Schema | None = None, **kwargs) -> str:
819
- del schema
820
- return pg.to_json_str(dict(result=value))
821
-
822
- def parse(self, text: str, schema: Schema | None = None, **kwargs) -> Any:
823
- """Parse a JSON string into a structured object."""
824
- del schema
825
- try:
826
- text = cleanup_json(text)
827
- v = pg.from_json_str(text, **kwargs)
828
- except Exception as e:
829
- raise JsonError(text, e) # pylint: disable=raise-missing-from
830
-
831
- if not isinstance(v, dict) or 'result' not in v:
832
- raise JsonError(text, ValueError(
833
- 'The root node of the JSON must be a dict with key `result`. '
834
- f'Encountered: {v}'
835
- ))
836
- return v['result']
837
-
838
-
839
- def cleanup_json(json_str: str) -> str:
840
- """Clean up the LM responded JSON string."""
841
- # Treatments:
842
- # 1. Extract the JSON string with a top-level dict from the response.
843
- # This prevents the leading and trailing texts in the response to
844
- # be counted as part of the JSON.
845
- # 2. Escape new lines in JSON values.
846
-
847
- curly_brackets = 0
848
- under_json = False
849
- under_str = False
850
- str_begin = -1
851
-
852
- cleaned = io.StringIO()
853
- for i, c in enumerate(json_str):
854
- if c == '{' and not under_str:
855
- cleaned.write(c)
856
- curly_brackets += 1
857
- under_json = True
858
- continue
859
- elif not under_json:
860
- continue
861
-
862
- if c == '}' and not under_str:
863
- cleaned.write(c)
864
- curly_brackets -= 1
865
- if curly_brackets == 0:
866
- break
867
- elif c == '"' and json_str[i - 1] != '\\':
868
- under_str = not under_str
869
- if under_str:
870
- str_begin = i
871
- else:
872
- assert str_begin > 0
873
- str_value = json_str[str_begin : i + 1].replace('\n', '\\n')
874
- cleaned.write(str_value)
875
- str_begin = -1
876
- elif not under_str:
877
- cleaned.write(c)
878
-
879
- if not under_json:
880
- raise ValueError(f'No JSON dict in the output: {json_str}')
881
-
882
- if curly_brackets > 0:
883
- raise ValueError(
884
- f'Malformated JSON: missing {curly_brackets} closing curly braces.'
885
- )
886
-
887
- return cleaned.getvalue()
888
-
889
-
890
- def schema_repr(protocol: SchemaProtocol) -> SchemaRepr:
891
- """Gets a SchemaRepr object from protocol."""
892
- if protocol == 'json':
893
- return SchemaJsonRepr()
894
- elif protocol == 'python':
895
- return SchemaPythonRepr()
896
- raise ValueError(f'Unsupported protocol: {protocol}.')
897
-
898
-
899
- def value_repr(protocol: SchemaProtocol) -> ValueRepr:
900
- if protocol == 'json':
901
- return ValueJsonRepr()
902
- elif protocol == 'python':
903
- return ValuePythonRepr()
904
- raise ValueError(f'Unsupported protocol: {protocol}.')
905
-
906
-
907
- #
908
- # Special value markers.
909
- #
910
-
911
-
912
- class Missing(pg.Object, pg.typing.CustomTyping):
913
- """Value marker for a missing field.
914
-
915
- This class differs from pg.MISSING_VALUE in two aspects:
916
- * When a field is assigned with lf.Missing(), it's considered non-partial.
917
- * lf.Missing() could format the value spec as Python annotations that are
918
- consistent with `lf.structured.Schema.schema_repr()`.
919
- """
920
-
921
- def _on_bound(self):
922
- super()._on_bound()
923
- self._value_spec = None
924
-
925
- @property
926
- def value_spec(self) -> pg.ValueSpec | None:
927
- """Returns the value spec that applies to the current missing value."""
928
- return self._value_spec
929
-
930
- def custom_apply(
931
- self, path: pg.KeyPath, value_spec: pg.ValueSpec, *args, **kwargs
932
- ) -> tuple[bool, Any]:
933
- self._value_spec = value_spec
934
- return (False, self)
935
-
936
- def format(self, *args, **kwargs) -> str:
937
- if self._value_spec is None:
938
- return 'MISSING'
939
- return f'MISSING({annotation(self._value_spec)})'
940
-
941
- @classmethod
942
- def find_missing(cls, value: Any) -> dict[str, 'Missing']:
943
- """Lists all missing values contained in the value."""
944
- missing = {}
945
-
946
- def _visit(k, v, p):
947
- del p
948
- if isinstance(v, Missing):
949
- missing[k] = v
950
- return pg.TraverseAction.ENTER
951
-
952
- pg.traverse(value, _visit)
953
- return missing
954
-
955
-
956
- MISSING = Missing()
957
-
958
-
959
- def mark_missing(value: Any) -> Any:
960
- """Replaces pg.MISSING within the value with lf.structured.Missing objects."""
961
- if isinstance(value, list):
962
- value = pg.List(value)
963
- elif isinstance(value, dict):
964
- value = pg.Dict(value)
965
- if isinstance(value, pg.Symbolic):
966
-
967
- def _mark_missing(k, v, p):
968
- del k, p
969
- if pg.MISSING_VALUE == v:
970
- v = Missing()
971
- return v
972
-
973
- return value.rebind(_mark_missing, raise_on_no_change=False)
974
- return value
975
-
976
-
977
- class Unknown(pg.Object, pg.typing.CustomTyping):
978
- """Value marker for a field that LMs could not provide."""
979
-
980
- def custom_apply(self, *args, **kwargs) -> tuple[bool, Any]:
981
- return (False, self)
982
-
983
- def format(self, *args, **kwargs) -> str:
984
- return 'UNKNOWN'
985
-
986
-
987
- UNKNOWN = Unknown()