langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +22 -2
- langfun/core/__init__.py +17 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -28
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +69 -2
- langfun/core/component_test.py +54 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +17 -0
- langfun/core/eval/base.py +767 -140
- langfun/core/eval/base_test.py +238 -53
- langfun/core/eval/matching.py +80 -76
- langfun/core/eval/matching_test.py +19 -9
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +37 -28
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +3 -21
- langfun/core/langfunc_test.py +26 -8
- langfun/core/language_model.py +686 -48
- langfun/core/language_model_test.py +681 -44
- langfun/core/llms/__init__.py +100 -12
- langfun/core/llms/anthropic.py +488 -0
- langfun/core/llms/anthropic_test.py +235 -0
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +88 -28
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +39 -26
- langfun/core/llms/fake_test.py +136 -11
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -197
- langfun/core/llms/groq.py +276 -0
- langfun/core/llms/groq_test.py +64 -0
- langfun/core/llms/llama_cpp.py +15 -40
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +436 -226
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +35 -174
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -23
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +15 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +9 -8
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +278 -0
- langfun/core/structured/function_generation_test.py +399 -0
- langfun/core/structured/mapping.py +150 -46
- langfun/core/structured/mapping_test.py +105 -0
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +71 -22
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
- langfun/core/structured/schema.py +208 -99
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +133 -34
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +240 -11
- langfun/core/template_test.py +146 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +14 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -217
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
|
|
16
16
|
import abc
|
17
17
|
import inspect
|
18
18
|
import io
|
19
|
+
import re
|
19
20
|
import textwrap
|
20
21
|
import typing
|
21
22
|
from typing import Any, Literal, Sequence, Type, Union
|
@@ -24,6 +25,17 @@ from langfun.core.coding.python import correction
|
|
24
25
|
import pyglove as pg
|
25
26
|
|
26
27
|
|
28
|
+
def include_method_in_prompt(method):
|
29
|
+
"""Decorator to include a method in the class definition of the prompt."""
|
30
|
+
setattr(method, '__show_in_prompt__', True)
|
31
|
+
return method
|
32
|
+
|
33
|
+
|
34
|
+
def should_include_method_in_prompt(method):
|
35
|
+
"""Returns true if the method should be shown in the prompt."""
|
36
|
+
return getattr(method, '__show_in_prompt__', False)
|
37
|
+
|
38
|
+
|
27
39
|
def parse_value_spec(value) -> pg.typing.ValueSpec:
|
28
40
|
"""Parses a PyGlove ValueSpec equivalence into a ValueSpec."""
|
29
41
|
if isinstance(value, pg.typing.ValueSpec):
|
@@ -79,26 +91,35 @@ class SchemaError(Exception): # pylint: disable=g-bad-exception-name
|
|
79
91
|
def __str__(self):
|
80
92
|
r = io.StringIO()
|
81
93
|
r.write(
|
82
|
-
|
94
|
+
pg.colored(
|
95
|
+
f'{self.cause.__class__.__name__}: {self.cause}', 'magenta'
|
96
|
+
)
|
97
|
+
)
|
83
98
|
|
84
99
|
r.write('\n')
|
85
|
-
r.write(
|
100
|
+
r.write(pg.colored('Schema:', 'red'))
|
86
101
|
r.write('\n\n')
|
87
102
|
r.write(textwrap.indent(
|
88
|
-
|
103
|
+
pg.colored(
|
104
|
+
schema_repr(self.protocol).repr(self.schema), 'magenta'
|
105
|
+
),
|
89
106
|
' ' * 2
|
90
107
|
))
|
91
108
|
r.write('\n\n')
|
92
|
-
r.write(
|
109
|
+
r.write(pg.colored('Generated value:', 'red'))
|
93
110
|
r.write('\n\n')
|
94
111
|
r.write(textwrap.indent(
|
95
|
-
|
112
|
+
pg.colored(value_repr(self.protocol).repr(self.value), 'magenta'),
|
96
113
|
' ' * 2
|
97
114
|
))
|
98
115
|
return r.getvalue()
|
99
116
|
|
100
117
|
|
101
|
-
class Schema(
|
118
|
+
class Schema(
|
119
|
+
lf.NaturalLanguageFormattable,
|
120
|
+
pg.Object,
|
121
|
+
pg.views.HtmlTreeView.Extension
|
122
|
+
):
|
102
123
|
"""Base class for structured data schema."""
|
103
124
|
|
104
125
|
spec: pg.typing.Annotated[
|
@@ -163,9 +184,12 @@ class Schema(lf.NaturalLanguageFormattable, pg.Object):
|
|
163
184
|
|
164
185
|
def class_dependencies(
|
165
186
|
self,
|
187
|
+
include_base_classes: bool = True,
|
166
188
|
include_subclasses: bool = True) -> list[Type[Any]]:
|
167
189
|
"""Returns a list of class dependencies for current schema."""
|
168
|
-
return class_dependencies(
|
190
|
+
return class_dependencies(
|
191
|
+
self.spec, include_base_classes, include_subclasses
|
192
|
+
)
|
169
193
|
|
170
194
|
@classmethod
|
171
195
|
def from_value(cls, value) -> 'Schema':
|
@@ -174,6 +198,29 @@ class Schema(lf.NaturalLanguageFormattable, pg.Object):
|
|
174
198
|
return value
|
175
199
|
return cls(parse_value_spec(value))
|
176
200
|
|
201
|
+
def _html_tree_view_content(
|
202
|
+
self,
|
203
|
+
*,
|
204
|
+
view: pg.views.HtmlTreeView,
|
205
|
+
**kwargs,
|
206
|
+
):
|
207
|
+
return pg.Html.element(
|
208
|
+
'div',
|
209
|
+
[pg.Html.escape(self.schema_str(protocol='python'))],
|
210
|
+
css_classes=['lf-schema-definition']
|
211
|
+
).add_style(
|
212
|
+
"""
|
213
|
+
.lf-schema-definition {
|
214
|
+
color: blue;
|
215
|
+
margin: 5px;
|
216
|
+
white-space: pre-wrap;
|
217
|
+
}
|
218
|
+
"""
|
219
|
+
)
|
220
|
+
|
221
|
+
|
222
|
+
SchemaType = Union[Schema, Type[Any], list[Type[Any]], dict[str, Any]]
|
223
|
+
|
177
224
|
|
178
225
|
def _top_level_object_specs_from_value(value: pg.Symbolic) -> list[Type[Any]]:
|
179
226
|
"""Returns a list of top level value specs from a symbolic value."""
|
@@ -198,11 +245,12 @@ def class_dependencies(
|
|
198
245
|
Type[pg.Object],
|
199
246
|
tuple[Union[pg.typing.ValueSpec, Type[pg.Object]], ...],
|
200
247
|
],
|
248
|
+
include_base_classes: bool = True,
|
201
249
|
include_subclasses: bool = True,
|
202
250
|
) -> list[Type[Any]]:
|
203
251
|
"""Returns a list of class dependencies from a value or specs."""
|
204
252
|
if isinstance(value_or_spec, Schema):
|
205
|
-
|
253
|
+
value_or_spec = value_or_spec.spec
|
206
254
|
|
207
255
|
if inspect.isclass(value_or_spec) or isinstance(
|
208
256
|
value_or_spec, pg.typing.ValueSpec
|
@@ -236,16 +284,17 @@ def class_dependencies(
|
|
236
284
|
if vs.cls not in seen:
|
237
285
|
seen.add(vs.cls)
|
238
286
|
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
287
|
+
if include_base_classes:
|
288
|
+
# Add base classes as dependencies.
|
289
|
+
for base_cls in vs.cls.__bases__:
|
290
|
+
# We only keep track of user-defined symbolic classes.
|
291
|
+
if base_cls is not object and base_cls is not pg.Object:
|
292
|
+
_fill_dependencies(
|
293
|
+
pg.typing.Object(base_cls), include_subclasses=False
|
294
|
+
)
|
246
295
|
|
247
296
|
# Add members as dependencies.
|
248
|
-
for field in
|
297
|
+
for field in pg.schema(vs.cls).values():
|
249
298
|
_fill_dependencies(field.value, include_subclasses)
|
250
299
|
_add_dependency(vs.cls)
|
251
300
|
|
@@ -262,7 +311,7 @@ def class_dependencies(
|
|
262
311
|
_fill_dependencies(elem.value, include_subclasses)
|
263
312
|
elif isinstance(vs, pg.typing.Dict) and vs.schema:
|
264
313
|
for v in vs.schema.values():
|
265
|
-
_fill_dependencies(v, include_subclasses)
|
314
|
+
_fill_dependencies(v.value, include_subclasses)
|
266
315
|
elif isinstance(vs, pg.typing.Union):
|
267
316
|
for v in vs.candidates:
|
268
317
|
_fill_dependencies(v, include_subclasses)
|
@@ -314,23 +363,35 @@ class SchemaPythonRepr(SchemaRepr):
|
|
314
363
|
ret += f'\n\n{class_definition_str}'
|
315
364
|
return ret.strip()
|
316
365
|
|
317
|
-
def class_definitions(
|
318
|
-
|
319
|
-
|
366
|
+
def class_definitions(
|
367
|
+
self,
|
368
|
+
schema: Schema,
|
369
|
+
additional_dependencies: list[Type[Any]] | None = None,
|
370
|
+
**kwargs
|
371
|
+
) -> str | None:
|
372
|
+
"""Returns a string containing of class definitions from a schema."""
|
373
|
+
deps = schema.class_dependencies(
|
374
|
+
include_base_classes=False, include_subclasses=True
|
375
|
+
)
|
376
|
+
allowed_dependencies = set(deps)
|
377
|
+
if additional_dependencies:
|
378
|
+
allowed_dependencies.update(additional_dependencies)
|
379
|
+
return class_definitions(
|
380
|
+
deps, allowed_dependencies=allowed_dependencies, **kwargs)
|
320
381
|
|
321
382
|
def result_definition(self, schema: Schema) -> str:
|
322
383
|
return annotation(schema.spec)
|
323
384
|
|
324
385
|
|
325
|
-
def source_form(value, markdown: bool = False) -> str:
|
386
|
+
def source_form(value, compact: bool = True, markdown: bool = False) -> str:
|
326
387
|
"""Returns the source code form of an object."""
|
327
|
-
return ValuePythonRepr().repr(value, markdown=markdown)
|
388
|
+
return ValuePythonRepr().repr(value, compact=compact, markdown=markdown)
|
328
389
|
|
329
390
|
|
330
391
|
def class_definitions(
|
331
392
|
classes: Sequence[Type[Any]],
|
332
393
|
*,
|
333
|
-
|
394
|
+
allowed_dependencies: set[Type[Any]] | None = None,
|
334
395
|
strict: bool = False,
|
335
396
|
markdown: bool = False,
|
336
397
|
) -> str | None:
|
@@ -345,7 +406,7 @@ def class_definitions(
|
|
345
406
|
class_definition(
|
346
407
|
cls,
|
347
408
|
strict=strict,
|
348
|
-
|
409
|
+
allowed_dependencies=allowed_dependencies,
|
349
410
|
)
|
350
411
|
)
|
351
412
|
ret = def_str.getvalue()
|
@@ -355,15 +416,17 @@ def class_definitions(
|
|
355
416
|
|
356
417
|
|
357
418
|
def class_definition(
|
358
|
-
cls,
|
419
|
+
cls,
|
420
|
+
strict: bool = False,
|
421
|
+
allowed_dependencies: set[Type[Any]] | None = None,
|
359
422
|
) -> str:
|
360
423
|
"""Returns the Python class definition."""
|
361
424
|
out = io.StringIO()
|
362
|
-
schema =
|
425
|
+
schema = pg.schema(cls)
|
363
426
|
eligible_bases = []
|
364
427
|
for base_cls in cls.__bases__:
|
365
428
|
if base_cls is not object:
|
366
|
-
if
|
429
|
+
if allowed_dependencies is None or base_cls in allowed_dependencies:
|
367
430
|
eligible_bases.append(base_cls.__name__)
|
368
431
|
|
369
432
|
if eligible_bases:
|
@@ -383,13 +446,16 @@ def class_definition(
|
|
383
446
|
out.write('\n')
|
384
447
|
out.write(' """\n')
|
385
448
|
|
449
|
+
empty_class = True
|
386
450
|
if schema.fields:
|
387
451
|
for key, field in schema.items():
|
388
452
|
if not isinstance(key, pg.typing.ConstStrKey):
|
389
|
-
|
453
|
+
pg.logging.warning(
|
390
454
|
'Variable-length keyword arguments is not supported in '
|
391
|
-
f'structured parsing or query. Encountered: {
|
455
|
+
f'structured parsing or query. Encountered: {cls}, Schema: {schema}'
|
392
456
|
)
|
457
|
+
continue
|
458
|
+
|
393
459
|
# Write field doc string as comments before the field definition.
|
394
460
|
if field.description:
|
395
461
|
for line in field.description.split('\n'):
|
@@ -397,19 +463,54 @@ def class_definition(
|
|
397
463
|
out.write(' # ')
|
398
464
|
out.write(line)
|
399
465
|
out.write('\n')
|
400
|
-
|
466
|
+
|
467
|
+
annotation_str = annotation(
|
468
|
+
field.value, strict=strict, allowed_dependencies=allowed_dependencies
|
469
|
+
)
|
470
|
+
out.write(f' {field.key}: {annotation_str}')
|
401
471
|
out.write('\n')
|
402
|
-
|
472
|
+
empty_class = False
|
473
|
+
|
474
|
+
for method in _iter_newly_defined_methods(cls, allowed_dependencies):
|
475
|
+
source = inspect.getsource(method)
|
476
|
+
# Remove decorators from the method definition.
|
477
|
+
source = re.sub(r'\s*@.*\.include_method_in_prompt.*\n', '', source)
|
478
|
+
out.write('\n')
|
479
|
+
out.write(
|
480
|
+
textwrap.indent(
|
481
|
+
inspect.cleandoc('\n' + source), ' ' * 2)
|
482
|
+
)
|
483
|
+
out.write('\n')
|
484
|
+
empty_class = False
|
485
|
+
|
486
|
+
if empty_class:
|
403
487
|
out.write(' pass\n')
|
404
488
|
return out.getvalue()
|
405
489
|
|
406
490
|
|
491
|
+
def _iter_newly_defined_methods(
|
492
|
+
cls, allowed_dependencies: set[Type[Any]] | None):
|
493
|
+
names = {attr_name: True for attr_name in dir(cls)}
|
494
|
+
for base in cls.__bases__:
|
495
|
+
if allowed_dependencies is None or base in allowed_dependencies:
|
496
|
+
for name in dir(base):
|
497
|
+
names.pop(name, None)
|
498
|
+
for name in names.keys():
|
499
|
+
attr = getattr(cls, name)
|
500
|
+
if callable(attr) and should_include_method_in_prompt(attr):
|
501
|
+
yield attr
|
502
|
+
|
503
|
+
|
407
504
|
def annotation(
|
408
505
|
vs: pg.typing.ValueSpec,
|
409
506
|
annotate_optional: bool = True,
|
410
507
|
strict: bool = False,
|
508
|
+
allowed_dependencies: set[Type[Any]] | None = None,
|
411
509
|
) -> str:
|
412
510
|
"""Returns the annotation string for a value spec."""
|
511
|
+
child_annotation_kwargs = dict(
|
512
|
+
strict=strict, allowed_dependencies=allowed_dependencies
|
513
|
+
)
|
413
514
|
if isinstance(vs, pg.typing.Any):
|
414
515
|
return 'Any'
|
415
516
|
elif isinstance(vs, pg.typing.Enum):
|
@@ -418,7 +519,7 @@ def annotation(
|
|
418
519
|
elif isinstance(vs, pg.typing.Union):
|
419
520
|
candidate_str = ', '.join(
|
420
521
|
[
|
421
|
-
annotation(c, annotate_optional=False,
|
522
|
+
annotation(c, annotate_optional=False, **child_annotation_kwargs)
|
422
523
|
for c in vs.candidates
|
423
524
|
]
|
424
525
|
)
|
@@ -454,20 +555,23 @@ def annotation(
|
|
454
555
|
)
|
455
556
|
x += '(' + ', '.join(constraints) + ')'
|
456
557
|
elif isinstance(vs, pg.typing.Object):
|
457
|
-
|
558
|
+
if allowed_dependencies is None or vs.cls in allowed_dependencies:
|
559
|
+
x = vs.cls.__name__
|
560
|
+
else:
|
561
|
+
x = 'Any'
|
458
562
|
elif isinstance(vs, pg.typing.List):
|
459
|
-
item_str = annotation(vs.element.value,
|
563
|
+
item_str = annotation(vs.element.value, **child_annotation_kwargs)
|
460
564
|
x = f'list[{item_str}]'
|
461
565
|
elif isinstance(vs, pg.typing.Tuple):
|
462
566
|
elem_str = ', '.join(
|
463
|
-
[annotation(el.value,
|
567
|
+
[annotation(el.value, **child_annotation_kwargs) for el in vs.elements]
|
464
568
|
)
|
465
569
|
x = f'tuple[{elem_str}]'
|
466
570
|
elif isinstance(vs, pg.typing.Dict):
|
467
571
|
kv_pairs = None
|
468
572
|
if vs.schema is not None:
|
469
573
|
kv_pairs = [
|
470
|
-
(k, annotation(f.value,
|
574
|
+
(k, annotation(f.value, **child_annotation_kwargs))
|
471
575
|
for k, f in vs.schema.items()
|
472
576
|
if isinstance(k, pg.typing.ConstStrKey)
|
473
577
|
]
|
@@ -477,6 +581,9 @@ def annotation(
|
|
477
581
|
x = '{' + kv_str + '}'
|
478
582
|
if strict:
|
479
583
|
x = f'pg.typing.Dict({x})'
|
584
|
+
elif vs.schema and vs.schema.dynamic_field:
|
585
|
+
v = annotation(vs.schema.dynamic_field.value, **child_annotation_kwargs)
|
586
|
+
x = f'dict[str, {v}]'
|
480
587
|
else:
|
481
588
|
x = 'dict[str, Any]'
|
482
589
|
|
@@ -491,7 +598,8 @@ def annotation(
|
|
491
598
|
class SchemaJsonRepr(SchemaRepr):
|
492
599
|
"""JSON-representation for a schema."""
|
493
600
|
|
494
|
-
def repr(self, schema: Schema) -> str:
|
601
|
+
def repr(self, schema: Schema, **kwargs) -> str:
|
602
|
+
del kwargs
|
495
603
|
out = io.StringIO()
|
496
604
|
def _visit(node: Any) -> None:
|
497
605
|
if isinstance(node, str):
|
@@ -569,12 +677,19 @@ class ValuePythonRepr(ValueRepr):
|
|
569
677
|
cls_schema = Schema.from_value(value)
|
570
678
|
if isinstance(cls_schema.spec, pg.typing.Object):
|
571
679
|
object_code = SchemaPythonRepr().class_definitions(
|
572
|
-
cls_schema,
|
680
|
+
cls_schema,
|
681
|
+
markdown=markdown,
|
682
|
+
# We add `pg.Object` as additional dependencies to the class
|
683
|
+
# definition so exemplars for class generation could show
|
684
|
+
# pg.Object as their bases.
|
685
|
+
additional_dependencies=[pg.Object]
|
573
686
|
)
|
574
687
|
assert object_code is not None
|
575
688
|
return object_code
|
576
689
|
else:
|
577
690
|
object_code = SchemaPythonRepr().result_definition(cls_schema)
|
691
|
+
elif isinstance(value, lf.Template):
|
692
|
+
return str(value)
|
578
693
|
else:
|
579
694
|
object_code = pg.format(
|
580
695
|
value, compact=compact, verbose=verbose, python_format=True
|
@@ -649,12 +764,15 @@ class JsonError(Exception):
|
|
649
764
|
def __str__(self) -> str:
|
650
765
|
r = io.StringIO()
|
651
766
|
r.write(
|
652
|
-
|
767
|
+
pg.colored(
|
768
|
+
f'{self.cause.__class__.__name__}: {self.cause}', 'magenta'
|
769
|
+
)
|
770
|
+
)
|
653
771
|
|
654
772
|
r.write('\n\n')
|
655
|
-
r.write(
|
773
|
+
r.write(pg.colored('JSON text:', 'red'))
|
656
774
|
r.write('\n\n')
|
657
|
-
r.write(textwrap.indent(
|
775
|
+
r.write(textwrap.indent(pg.colored(self.json, 'magenta'), ' ' * 2))
|
658
776
|
return r.getvalue()
|
659
777
|
|
660
778
|
|
@@ -669,7 +787,7 @@ class ValueJsonRepr(ValueRepr):
|
|
669
787
|
"""Parse a JSON string into a structured object."""
|
670
788
|
del schema
|
671
789
|
try:
|
672
|
-
text =
|
790
|
+
text = cleanup_json(text)
|
673
791
|
v = pg.from_json_str(text, **kwargs)
|
674
792
|
except Exception as e:
|
675
793
|
raise JsonError(text, e) # pylint: disable=raise-missing-from
|
@@ -681,55 +799,56 @@ class ValueJsonRepr(ValueRepr):
|
|
681
799
|
))
|
682
800
|
return v['result']
|
683
801
|
|
684
|
-
def cleanup_json(self, json_str: str) -> str:
|
685
|
-
"""Clean up the LM responded JSON string."""
|
686
|
-
# Treatments:
|
687
|
-
# 1. Extract the JSON string with a top-level dict from the response.
|
688
|
-
# This prevents the leading and trailing texts in the response to
|
689
|
-
# be counted as part of the JSON.
|
690
|
-
# 2. Escape new lines in JSON values.
|
691
|
-
|
692
|
-
curly_brackets = 0
|
693
|
-
under_json = False
|
694
|
-
under_str = False
|
695
|
-
str_begin = -1
|
696
|
-
|
697
|
-
cleaned = io.StringIO()
|
698
|
-
for i, c in enumerate(json_str):
|
699
|
-
if c == '{' and not under_str:
|
700
|
-
cleaned.write(c)
|
701
|
-
curly_brackets += 1
|
702
|
-
under_json = True
|
703
|
-
continue
|
704
|
-
elif not under_json:
|
705
|
-
continue
|
706
802
|
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
803
|
+
def cleanup_json(json_str: str) -> str:
|
804
|
+
"""Clean up the LM responded JSON string."""
|
805
|
+
# Treatments:
|
806
|
+
# 1. Extract the JSON string with a top-level dict from the response.
|
807
|
+
# This prevents the leading and trailing texts in the response to
|
808
|
+
# be counted as part of the JSON.
|
809
|
+
# 2. Escape new lines in JSON values.
|
810
|
+
|
811
|
+
curly_brackets = 0
|
812
|
+
under_json = False
|
813
|
+
under_str = False
|
814
|
+
str_begin = -1
|
815
|
+
|
816
|
+
cleaned = io.StringIO()
|
817
|
+
for i, c in enumerate(json_str):
|
818
|
+
if c == '{' and not under_str:
|
819
|
+
cleaned.write(c)
|
820
|
+
curly_brackets += 1
|
821
|
+
under_json = True
|
822
|
+
continue
|
823
|
+
elif not under_json:
|
824
|
+
continue
|
825
|
+
|
826
|
+
if c == '}' and not under_str:
|
827
|
+
cleaned.write(c)
|
828
|
+
curly_brackets -= 1
|
829
|
+
if curly_brackets == 0:
|
830
|
+
break
|
831
|
+
elif c == '"' and json_str[i - 1] != '\\':
|
832
|
+
under_str = not under_str
|
833
|
+
if under_str:
|
834
|
+
str_begin = i
|
835
|
+
else:
|
836
|
+
assert str_begin > 0
|
837
|
+
str_value = json_str[str_begin : i + 1].replace('\n', '\\n')
|
838
|
+
cleaned.write(str_value)
|
839
|
+
str_begin = -1
|
840
|
+
elif not under_str:
|
841
|
+
cleaned.write(c)
|
842
|
+
|
843
|
+
if not under_json:
|
844
|
+
raise ValueError(f'No JSON dict in the output: {json_str}')
|
845
|
+
|
846
|
+
if curly_brackets > 0:
|
847
|
+
raise ValueError(
|
848
|
+
f'Malformated JSON: missing {curly_brackets} closing curly braces.'
|
849
|
+
)
|
731
850
|
|
732
|
-
|
851
|
+
return cleaned.getvalue()
|
733
852
|
|
734
853
|
|
735
854
|
def schema_repr(protocol: SchemaProtocol) -> SchemaRepr:
|
@@ -830,13 +949,3 @@ class Unknown(pg.Object, pg.typing.CustomTyping):
|
|
830
949
|
|
831
950
|
|
832
951
|
UNKNOWN = Unknown()
|
833
|
-
|
834
|
-
|
835
|
-
def _pg_schema(cls: Type[Any]) -> pg.Schema:
|
836
|
-
"""Returns PyGlove schema for the constructor of a class."""
|
837
|
-
schema = getattr(cls, '__schema__', None)
|
838
|
-
if schema is None:
|
839
|
-
schema = pg.symbolic.callable_schema(
|
840
|
-
cls.__init__, auto_typing=True, auto_doc=True, remove_self=True
|
841
|
-
)
|
842
|
-
return schema
|
@@ -58,7 +58,7 @@ class GenerateClass(mapping.Mapping):
|
|
58
58
|
class_name = self.context
|
59
59
|
cls = output_vars.get(class_name, None)
|
60
60
|
if cls is None:
|
61
|
-
raise
|
61
|
+
raise pg.coding.CodeError(
|
62
62
|
final_code,
|
63
63
|
TypeError(f'Class {class_name} is absent from LLM output.'),
|
64
64
|
)
|
@@ -14,8 +14,8 @@
|
|
14
14
|
import inspect
|
15
15
|
import unittest
|
16
16
|
|
17
|
-
import langfun.core.coding as lf_coding
|
18
17
|
from langfun.core.llms import fake
|
18
|
+
from langfun.core.structured import mapping
|
19
19
|
from langfun.core.structured import schema_generation
|
20
20
|
|
21
21
|
|
@@ -92,7 +92,7 @@ class GenerateClassTest(unittest.TestCase):
|
|
92
92
|
)
|
93
93
|
self.assertIs(cls.__name__, 'B')
|
94
94
|
|
95
|
-
with self.assertRaises(
|
95
|
+
with self.assertRaises(mapping.MappingError):
|
96
96
|
schema_generation.generate_class(
|
97
97
|
'Foo',
|
98
98
|
'Generate a Foo class with a field pointing to another class A',
|