langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +20 -2
- langfun/core/__init__.py +16 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -21
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +63 -2
- langfun/core/component_test.py +53 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +16 -1
- langfun/core/eval/base.py +622 -174
- langfun/core/eval/base_test.py +200 -54
- langfun/core/eval/matching.py +63 -76
- langfun/core/eval/matching_test.py +17 -8
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +26 -26
- langfun/core/eval/scoring_test.py +19 -2
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +4 -17
- langfun/core/langfunc_test.py +22 -6
- langfun/core/language_model.py +577 -39
- langfun/core/language_model_test.py +470 -56
- langfun/core/llms/__init__.py +87 -16
- langfun/core/llms/anthropic.py +312 -87
- langfun/core/llms/anthropic_test.py +71 -3
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +53 -2
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +11 -7
- langfun/core/llms/fake_test.py +14 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -202
- langfun/core/llms/groq.py +160 -144
- langfun/core/llms/groq_test.py +31 -137
- langfun/core/llms/llama_cpp.py +15 -42
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +395 -203
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +30 -395
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -26
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +12 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +7 -6
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +60 -27
- langfun/core/structured/function_generation_test.py +72 -2
- langfun/core/structured/mapping.py +97 -47
- langfun/core/structured/mapping_test.py +90 -2
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +53 -9
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
- langfun/core/structured/schema.py +204 -97
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_test.py +130 -29
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +115 -1
- langfun/core/template_test.py +71 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +10 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -238
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
- langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
|
|
16
16
|
import abc
|
17
17
|
import inspect
|
18
18
|
import io
|
19
|
+
import re
|
19
20
|
import textwrap
|
20
21
|
import typing
|
21
22
|
from typing import Any, Literal, Sequence, Type, Union
|
@@ -24,6 +25,17 @@ from langfun.core.coding.python import correction
|
|
24
25
|
import pyglove as pg
|
25
26
|
|
26
27
|
|
28
|
+
def include_method_in_prompt(method):
|
29
|
+
"""Decorator to include a method in the class definition of the prompt."""
|
30
|
+
setattr(method, '__show_in_prompt__', True)
|
31
|
+
return method
|
32
|
+
|
33
|
+
|
34
|
+
def should_include_method_in_prompt(method):
|
35
|
+
"""Returns true if the method should be shown in the prompt."""
|
36
|
+
return getattr(method, '__show_in_prompt__', False)
|
37
|
+
|
38
|
+
|
27
39
|
def parse_value_spec(value) -> pg.typing.ValueSpec:
|
28
40
|
"""Parses a PyGlove ValueSpec equivalence into a ValueSpec."""
|
29
41
|
if isinstance(value, pg.typing.ValueSpec):
|
@@ -79,26 +91,35 @@ class SchemaError(Exception): # pylint: disable=g-bad-exception-name
|
|
79
91
|
def __str__(self):
|
80
92
|
r = io.StringIO()
|
81
93
|
r.write(
|
82
|
-
|
94
|
+
pg.colored(
|
95
|
+
f'{self.cause.__class__.__name__}: {self.cause}', 'magenta'
|
96
|
+
)
|
97
|
+
)
|
83
98
|
|
84
99
|
r.write('\n')
|
85
|
-
r.write(
|
100
|
+
r.write(pg.colored('Schema:', 'red'))
|
86
101
|
r.write('\n\n')
|
87
102
|
r.write(textwrap.indent(
|
88
|
-
|
103
|
+
pg.colored(
|
104
|
+
schema_repr(self.protocol).repr(self.schema), 'magenta'
|
105
|
+
),
|
89
106
|
' ' * 2
|
90
107
|
))
|
91
108
|
r.write('\n\n')
|
92
|
-
r.write(
|
109
|
+
r.write(pg.colored('Generated value:', 'red'))
|
93
110
|
r.write('\n\n')
|
94
111
|
r.write(textwrap.indent(
|
95
|
-
|
112
|
+
pg.colored(value_repr(self.protocol).repr(self.value), 'magenta'),
|
96
113
|
' ' * 2
|
97
114
|
))
|
98
115
|
return r.getvalue()
|
99
116
|
|
100
117
|
|
101
|
-
class Schema(
|
118
|
+
class Schema(
|
119
|
+
lf.NaturalLanguageFormattable,
|
120
|
+
pg.Object,
|
121
|
+
pg.views.HtmlTreeView.Extension
|
122
|
+
):
|
102
123
|
"""Base class for structured data schema."""
|
103
124
|
|
104
125
|
spec: pg.typing.Annotated[
|
@@ -163,9 +184,12 @@ class Schema(lf.NaturalLanguageFormattable, pg.Object):
|
|
163
184
|
|
164
185
|
def class_dependencies(
|
165
186
|
self,
|
187
|
+
include_base_classes: bool = True,
|
166
188
|
include_subclasses: bool = True) -> list[Type[Any]]:
|
167
189
|
"""Returns a list of class dependencies for current schema."""
|
168
|
-
return class_dependencies(
|
190
|
+
return class_dependencies(
|
191
|
+
self.spec, include_base_classes, include_subclasses
|
192
|
+
)
|
169
193
|
|
170
194
|
@classmethod
|
171
195
|
def from_value(cls, value) -> 'Schema':
|
@@ -174,6 +198,29 @@ class Schema(lf.NaturalLanguageFormattable, pg.Object):
|
|
174
198
|
return value
|
175
199
|
return cls(parse_value_spec(value))
|
176
200
|
|
201
|
+
def _html_tree_view_content(
|
202
|
+
self,
|
203
|
+
*,
|
204
|
+
view: pg.views.HtmlTreeView,
|
205
|
+
**kwargs,
|
206
|
+
):
|
207
|
+
return pg.Html.element(
|
208
|
+
'div',
|
209
|
+
[pg.Html.escape(self.schema_str(protocol='python'))],
|
210
|
+
css_classes=['lf-schema-definition']
|
211
|
+
).add_style(
|
212
|
+
"""
|
213
|
+
.lf-schema-definition {
|
214
|
+
color: blue;
|
215
|
+
margin: 5px;
|
216
|
+
white-space: pre-wrap;
|
217
|
+
}
|
218
|
+
"""
|
219
|
+
)
|
220
|
+
|
221
|
+
|
222
|
+
SchemaType = Union[Schema, Type[Any], list[Type[Any]], dict[str, Any]]
|
223
|
+
|
177
224
|
|
178
225
|
def _top_level_object_specs_from_value(value: pg.Symbolic) -> list[Type[Any]]:
|
179
226
|
"""Returns a list of top level value specs from a symbolic value."""
|
@@ -198,11 +245,12 @@ def class_dependencies(
|
|
198
245
|
Type[pg.Object],
|
199
246
|
tuple[Union[pg.typing.ValueSpec, Type[pg.Object]], ...],
|
200
247
|
],
|
248
|
+
include_base_classes: bool = True,
|
201
249
|
include_subclasses: bool = True,
|
202
250
|
) -> list[Type[Any]]:
|
203
251
|
"""Returns a list of class dependencies from a value or specs."""
|
204
252
|
if isinstance(value_or_spec, Schema):
|
205
|
-
|
253
|
+
value_or_spec = value_or_spec.spec
|
206
254
|
|
207
255
|
if inspect.isclass(value_or_spec) or isinstance(
|
208
256
|
value_or_spec, pg.typing.ValueSpec
|
@@ -236,16 +284,17 @@ def class_dependencies(
|
|
236
284
|
if vs.cls not in seen:
|
237
285
|
seen.add(vs.cls)
|
238
286
|
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
287
|
+
if include_base_classes:
|
288
|
+
# Add base classes as dependencies.
|
289
|
+
for base_cls in vs.cls.__bases__:
|
290
|
+
# We only keep track of user-defined symbolic classes.
|
291
|
+
if base_cls is not object and base_cls is not pg.Object:
|
292
|
+
_fill_dependencies(
|
293
|
+
pg.typing.Object(base_cls), include_subclasses=False
|
294
|
+
)
|
246
295
|
|
247
296
|
# Add members as dependencies.
|
248
|
-
for field in
|
297
|
+
for field in pg.schema(vs.cls).values():
|
249
298
|
_fill_dependencies(field.value, include_subclasses)
|
250
299
|
_add_dependency(vs.cls)
|
251
300
|
|
@@ -262,7 +311,7 @@ def class_dependencies(
|
|
262
311
|
_fill_dependencies(elem.value, include_subclasses)
|
263
312
|
elif isinstance(vs, pg.typing.Dict) and vs.schema:
|
264
313
|
for v in vs.schema.values():
|
265
|
-
_fill_dependencies(v, include_subclasses)
|
314
|
+
_fill_dependencies(v.value, include_subclasses)
|
266
315
|
elif isinstance(vs, pg.typing.Union):
|
267
316
|
for v in vs.candidates:
|
268
317
|
_fill_dependencies(v, include_subclasses)
|
@@ -314,23 +363,35 @@ class SchemaPythonRepr(SchemaRepr):
|
|
314
363
|
ret += f'\n\n{class_definition_str}'
|
315
364
|
return ret.strip()
|
316
365
|
|
317
|
-
def class_definitions(
|
318
|
-
|
319
|
-
|
366
|
+
def class_definitions(
|
367
|
+
self,
|
368
|
+
schema: Schema,
|
369
|
+
additional_dependencies: list[Type[Any]] | None = None,
|
370
|
+
**kwargs
|
371
|
+
) -> str | None:
|
372
|
+
"""Returns a string containing of class definitions from a schema."""
|
373
|
+
deps = schema.class_dependencies(
|
374
|
+
include_base_classes=False, include_subclasses=True
|
375
|
+
)
|
376
|
+
allowed_dependencies = set(deps)
|
377
|
+
if additional_dependencies:
|
378
|
+
allowed_dependencies.update(additional_dependencies)
|
379
|
+
return class_definitions(
|
380
|
+
deps, allowed_dependencies=allowed_dependencies, **kwargs)
|
320
381
|
|
321
382
|
def result_definition(self, schema: Schema) -> str:
|
322
383
|
return annotation(schema.spec)
|
323
384
|
|
324
385
|
|
325
|
-
def source_form(value, markdown: bool = False) -> str:
|
386
|
+
def source_form(value, compact: bool = True, markdown: bool = False) -> str:
|
326
387
|
"""Returns the source code form of an object."""
|
327
|
-
return ValuePythonRepr().repr(value, markdown=markdown)
|
388
|
+
return ValuePythonRepr().repr(value, compact=compact, markdown=markdown)
|
328
389
|
|
329
390
|
|
330
391
|
def class_definitions(
|
331
392
|
classes: Sequence[Type[Any]],
|
332
393
|
*,
|
333
|
-
|
394
|
+
allowed_dependencies: set[Type[Any]] | None = None,
|
334
395
|
strict: bool = False,
|
335
396
|
markdown: bool = False,
|
336
397
|
) -> str | None:
|
@@ -345,7 +406,7 @@ def class_definitions(
|
|
345
406
|
class_definition(
|
346
407
|
cls,
|
347
408
|
strict=strict,
|
348
|
-
|
409
|
+
allowed_dependencies=allowed_dependencies,
|
349
410
|
)
|
350
411
|
)
|
351
412
|
ret = def_str.getvalue()
|
@@ -355,15 +416,17 @@ def class_definitions(
|
|
355
416
|
|
356
417
|
|
357
418
|
def class_definition(
|
358
|
-
cls,
|
419
|
+
cls,
|
420
|
+
strict: bool = False,
|
421
|
+
allowed_dependencies: set[Type[Any]] | None = None,
|
359
422
|
) -> str:
|
360
423
|
"""Returns the Python class definition."""
|
361
424
|
out = io.StringIO()
|
362
|
-
schema =
|
425
|
+
schema = pg.schema(cls)
|
363
426
|
eligible_bases = []
|
364
427
|
for base_cls in cls.__bases__:
|
365
428
|
if base_cls is not object:
|
366
|
-
if
|
429
|
+
if allowed_dependencies is None or base_cls in allowed_dependencies:
|
367
430
|
eligible_bases.append(base_cls.__name__)
|
368
431
|
|
369
432
|
if eligible_bases:
|
@@ -383,6 +446,7 @@ def class_definition(
|
|
383
446
|
out.write('\n')
|
384
447
|
out.write(' """\n')
|
385
448
|
|
449
|
+
empty_class = True
|
386
450
|
if schema.fields:
|
387
451
|
for key, field in schema.items():
|
388
452
|
if not isinstance(key, pg.typing.ConstStrKey):
|
@@ -399,19 +463,54 @@ def class_definition(
|
|
399
463
|
out.write(' # ')
|
400
464
|
out.write(line)
|
401
465
|
out.write('\n')
|
402
|
-
|
466
|
+
|
467
|
+
annotation_str = annotation(
|
468
|
+
field.value, strict=strict, allowed_dependencies=allowed_dependencies
|
469
|
+
)
|
470
|
+
out.write(f' {field.key}: {annotation_str}')
|
403
471
|
out.write('\n')
|
404
|
-
|
472
|
+
empty_class = False
|
473
|
+
|
474
|
+
for method in _iter_newly_defined_methods(cls, allowed_dependencies):
|
475
|
+
source = inspect.getsource(method)
|
476
|
+
# Remove decorators from the method definition.
|
477
|
+
source = re.sub(r'\s*@.*\.include_method_in_prompt.*\n', '', source)
|
478
|
+
out.write('\n')
|
479
|
+
out.write(
|
480
|
+
textwrap.indent(
|
481
|
+
inspect.cleandoc('\n' + source), ' ' * 2)
|
482
|
+
)
|
483
|
+
out.write('\n')
|
484
|
+
empty_class = False
|
485
|
+
|
486
|
+
if empty_class:
|
405
487
|
out.write(' pass\n')
|
406
488
|
return out.getvalue()
|
407
489
|
|
408
490
|
|
491
|
+
def _iter_newly_defined_methods(
|
492
|
+
cls, allowed_dependencies: set[Type[Any]] | None):
|
493
|
+
names = {attr_name: True for attr_name in dir(cls)}
|
494
|
+
for base in cls.__bases__:
|
495
|
+
if allowed_dependencies is None or base in allowed_dependencies:
|
496
|
+
for name in dir(base):
|
497
|
+
names.pop(name, None)
|
498
|
+
for name in names.keys():
|
499
|
+
attr = getattr(cls, name)
|
500
|
+
if callable(attr) and should_include_method_in_prompt(attr):
|
501
|
+
yield attr
|
502
|
+
|
503
|
+
|
409
504
|
def annotation(
|
410
505
|
vs: pg.typing.ValueSpec,
|
411
506
|
annotate_optional: bool = True,
|
412
507
|
strict: bool = False,
|
508
|
+
allowed_dependencies: set[Type[Any]] | None = None,
|
413
509
|
) -> str:
|
414
510
|
"""Returns the annotation string for a value spec."""
|
511
|
+
child_annotation_kwargs = dict(
|
512
|
+
strict=strict, allowed_dependencies=allowed_dependencies
|
513
|
+
)
|
415
514
|
if isinstance(vs, pg.typing.Any):
|
416
515
|
return 'Any'
|
417
516
|
elif isinstance(vs, pg.typing.Enum):
|
@@ -420,7 +519,7 @@ def annotation(
|
|
420
519
|
elif isinstance(vs, pg.typing.Union):
|
421
520
|
candidate_str = ', '.join(
|
422
521
|
[
|
423
|
-
annotation(c, annotate_optional=False,
|
522
|
+
annotation(c, annotate_optional=False, **child_annotation_kwargs)
|
424
523
|
for c in vs.candidates
|
425
524
|
]
|
426
525
|
)
|
@@ -456,20 +555,23 @@ def annotation(
|
|
456
555
|
)
|
457
556
|
x += '(' + ', '.join(constraints) + ')'
|
458
557
|
elif isinstance(vs, pg.typing.Object):
|
459
|
-
|
558
|
+
if allowed_dependencies is None or vs.cls in allowed_dependencies:
|
559
|
+
x = vs.cls.__name__
|
560
|
+
else:
|
561
|
+
x = 'Any'
|
460
562
|
elif isinstance(vs, pg.typing.List):
|
461
|
-
item_str = annotation(vs.element.value,
|
563
|
+
item_str = annotation(vs.element.value, **child_annotation_kwargs)
|
462
564
|
x = f'list[{item_str}]'
|
463
565
|
elif isinstance(vs, pg.typing.Tuple):
|
464
566
|
elem_str = ', '.join(
|
465
|
-
[annotation(el.value,
|
567
|
+
[annotation(el.value, **child_annotation_kwargs) for el in vs.elements]
|
466
568
|
)
|
467
569
|
x = f'tuple[{elem_str}]'
|
468
570
|
elif isinstance(vs, pg.typing.Dict):
|
469
571
|
kv_pairs = None
|
470
572
|
if vs.schema is not None:
|
471
573
|
kv_pairs = [
|
472
|
-
(k, annotation(f.value,
|
574
|
+
(k, annotation(f.value, **child_annotation_kwargs))
|
473
575
|
for k, f in vs.schema.items()
|
474
576
|
if isinstance(k, pg.typing.ConstStrKey)
|
475
577
|
]
|
@@ -479,6 +581,9 @@ def annotation(
|
|
479
581
|
x = '{' + kv_str + '}'
|
480
582
|
if strict:
|
481
583
|
x = f'pg.typing.Dict({x})'
|
584
|
+
elif vs.schema and vs.schema.dynamic_field:
|
585
|
+
v = annotation(vs.schema.dynamic_field.value, **child_annotation_kwargs)
|
586
|
+
x = f'dict[str, {v}]'
|
482
587
|
else:
|
483
588
|
x = 'dict[str, Any]'
|
484
589
|
|
@@ -493,7 +598,8 @@ def annotation(
|
|
493
598
|
class SchemaJsonRepr(SchemaRepr):
|
494
599
|
"""JSON-representation for a schema."""
|
495
600
|
|
496
|
-
def repr(self, schema: Schema) -> str:
|
601
|
+
def repr(self, schema: Schema, **kwargs) -> str:
|
602
|
+
del kwargs
|
497
603
|
out = io.StringIO()
|
498
604
|
def _visit(node: Any) -> None:
|
499
605
|
if isinstance(node, str):
|
@@ -571,12 +677,19 @@ class ValuePythonRepr(ValueRepr):
|
|
571
677
|
cls_schema = Schema.from_value(value)
|
572
678
|
if isinstance(cls_schema.spec, pg.typing.Object):
|
573
679
|
object_code = SchemaPythonRepr().class_definitions(
|
574
|
-
cls_schema,
|
680
|
+
cls_schema,
|
681
|
+
markdown=markdown,
|
682
|
+
# We add `pg.Object` as additional dependencies to the class
|
683
|
+
# definition so exemplars for class generation could show
|
684
|
+
# pg.Object as their bases.
|
685
|
+
additional_dependencies=[pg.Object]
|
575
686
|
)
|
576
687
|
assert object_code is not None
|
577
688
|
return object_code
|
578
689
|
else:
|
579
690
|
object_code = SchemaPythonRepr().result_definition(cls_schema)
|
691
|
+
elif isinstance(value, lf.Template):
|
692
|
+
return str(value)
|
580
693
|
else:
|
581
694
|
object_code = pg.format(
|
582
695
|
value, compact=compact, verbose=verbose, python_format=True
|
@@ -651,12 +764,15 @@ class JsonError(Exception):
|
|
651
764
|
def __str__(self) -> str:
|
652
765
|
r = io.StringIO()
|
653
766
|
r.write(
|
654
|
-
|
767
|
+
pg.colored(
|
768
|
+
f'{self.cause.__class__.__name__}: {self.cause}', 'magenta'
|
769
|
+
)
|
770
|
+
)
|
655
771
|
|
656
772
|
r.write('\n\n')
|
657
|
-
r.write(
|
773
|
+
r.write(pg.colored('JSON text:', 'red'))
|
658
774
|
r.write('\n\n')
|
659
|
-
r.write(textwrap.indent(
|
775
|
+
r.write(textwrap.indent(pg.colored(self.json, 'magenta'), ' ' * 2))
|
660
776
|
return r.getvalue()
|
661
777
|
|
662
778
|
|
@@ -671,7 +787,7 @@ class ValueJsonRepr(ValueRepr):
|
|
671
787
|
"""Parse a JSON string into a structured object."""
|
672
788
|
del schema
|
673
789
|
try:
|
674
|
-
text =
|
790
|
+
text = cleanup_json(text)
|
675
791
|
v = pg.from_json_str(text, **kwargs)
|
676
792
|
except Exception as e:
|
677
793
|
raise JsonError(text, e) # pylint: disable=raise-missing-from
|
@@ -683,55 +799,56 @@ class ValueJsonRepr(ValueRepr):
|
|
683
799
|
))
|
684
800
|
return v['result']
|
685
801
|
|
686
|
-
def cleanup_json(self, json_str: str) -> str:
|
687
|
-
"""Clean up the LM responded JSON string."""
|
688
|
-
# Treatments:
|
689
|
-
# 1. Extract the JSON string with a top-level dict from the response.
|
690
|
-
# This prevents the leading and trailing texts in the response to
|
691
|
-
# be counted as part of the JSON.
|
692
|
-
# 2. Escape new lines in JSON values.
|
693
|
-
|
694
|
-
curly_brackets = 0
|
695
|
-
under_json = False
|
696
|
-
under_str = False
|
697
|
-
str_begin = -1
|
698
|
-
|
699
|
-
cleaned = io.StringIO()
|
700
|
-
for i, c in enumerate(json_str):
|
701
|
-
if c == '{' and not under_str:
|
702
|
-
cleaned.write(c)
|
703
|
-
curly_brackets += 1
|
704
|
-
under_json = True
|
705
|
-
continue
|
706
|
-
elif not under_json:
|
707
|
-
continue
|
708
802
|
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
803
|
+
def cleanup_json(json_str: str) -> str:
|
804
|
+
"""Clean up the LM responded JSON string."""
|
805
|
+
# Treatments:
|
806
|
+
# 1. Extract the JSON string with a top-level dict from the response.
|
807
|
+
# This prevents the leading and trailing texts in the response to
|
808
|
+
# be counted as part of the JSON.
|
809
|
+
# 2. Escape new lines in JSON values.
|
810
|
+
|
811
|
+
curly_brackets = 0
|
812
|
+
under_json = False
|
813
|
+
under_str = False
|
814
|
+
str_begin = -1
|
815
|
+
|
816
|
+
cleaned = io.StringIO()
|
817
|
+
for i, c in enumerate(json_str):
|
818
|
+
if c == '{' and not under_str:
|
819
|
+
cleaned.write(c)
|
820
|
+
curly_brackets += 1
|
821
|
+
under_json = True
|
822
|
+
continue
|
823
|
+
elif not under_json:
|
824
|
+
continue
|
825
|
+
|
826
|
+
if c == '}' and not under_str:
|
827
|
+
cleaned.write(c)
|
828
|
+
curly_brackets -= 1
|
829
|
+
if curly_brackets == 0:
|
830
|
+
break
|
831
|
+
elif c == '"' and json_str[i - 1] != '\\':
|
832
|
+
under_str = not under_str
|
833
|
+
if under_str:
|
834
|
+
str_begin = i
|
835
|
+
else:
|
836
|
+
assert str_begin > 0
|
837
|
+
str_value = json_str[str_begin : i + 1].replace('\n', '\\n')
|
838
|
+
cleaned.write(str_value)
|
839
|
+
str_begin = -1
|
840
|
+
elif not under_str:
|
841
|
+
cleaned.write(c)
|
842
|
+
|
843
|
+
if not under_json:
|
844
|
+
raise ValueError(f'No JSON dict in the output: {json_str}')
|
845
|
+
|
846
|
+
if curly_brackets > 0:
|
847
|
+
raise ValueError(
|
848
|
+
f'Malformated JSON: missing {curly_brackets} closing curly braces.'
|
849
|
+
)
|
733
850
|
|
734
|
-
|
851
|
+
return cleaned.getvalue()
|
735
852
|
|
736
853
|
|
737
854
|
def schema_repr(protocol: SchemaProtocol) -> SchemaRepr:
|
@@ -832,13 +949,3 @@ class Unknown(pg.Object, pg.typing.CustomTyping):
|
|
832
949
|
|
833
950
|
|
834
951
|
UNKNOWN = Unknown()
|
835
|
-
|
836
|
-
|
837
|
-
def _pg_schema(cls: Type[Any]) -> pg.Schema:
|
838
|
-
"""Returns PyGlove schema for the constructor of a class."""
|
839
|
-
schema = getattr(cls, '__schema__', None)
|
840
|
-
if schema is None:
|
841
|
-
schema = pg.symbolic.callable_schema(
|
842
|
-
cls.__init__, auto_typing=True, auto_doc=True, remove_self=True
|
843
|
-
)
|
844
|
-
return schema
|
@@ -58,7 +58,7 @@ class GenerateClass(mapping.Mapping):
|
|
58
58
|
class_name = self.context
|
59
59
|
cls = output_vars.get(class_name, None)
|
60
60
|
if cls is None:
|
61
|
-
raise
|
61
|
+
raise pg.coding.CodeError(
|
62
62
|
final_code,
|
63
63
|
TypeError(f'Class {class_name} is absent from LLM output.'),
|
64
64
|
)
|