langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. langfun/__init__.py +22 -2
  2. langfun/core/__init__.py +17 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -28
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +69 -2
  18. langfun/core/component_test.py +54 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +17 -0
  24. langfun/core/eval/base.py +767 -140
  25. langfun/core/eval/base_test.py +238 -53
  26. langfun/core/eval/matching.py +80 -76
  27. langfun/core/eval/matching_test.py +19 -9
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +37 -28
  31. langfun/core/eval/scoring_test.py +21 -3
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +3 -21
  55. langfun/core/langfunc_test.py +26 -8
  56. langfun/core/language_model.py +686 -48
  57. langfun/core/language_model_test.py +681 -44
  58. langfun/core/llms/__init__.py +100 -12
  59. langfun/core/llms/anthropic.py +488 -0
  60. langfun/core/llms/anthropic_test.py +235 -0
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +88 -28
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +39 -26
  69. langfun/core/llms/fake_test.py +136 -11
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -197
  74. langfun/core/llms/groq.py +276 -0
  75. langfun/core/llms/groq_test.py +64 -0
  76. langfun/core/llms/llama_cpp.py +15 -40
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +436 -226
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +35 -174
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -23
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +15 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +9 -8
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +278 -0
  112. langfun/core/structured/function_generation_test.py +399 -0
  113. langfun/core/structured/mapping.py +150 -46
  114. langfun/core/structured/mapping_test.py +105 -0
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +71 -22
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
  119. langfun/core/structured/schema.py +208 -99
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_generation_test.py +2 -2
  122. langfun/core/structured/schema_test.py +133 -34
  123. langfun/core/structured/scoring.py +125 -19
  124. langfun/core/structured/scoring_test.py +30 -0
  125. langfun/core/structured/tokenization.py +64 -0
  126. langfun/core/structured/tokenization_test.py +48 -0
  127. langfun/core/template.py +240 -11
  128. langfun/core/template_test.py +146 -1
  129. langfun/core/templates/conversation.py +9 -0
  130. langfun/core/templates/conversation_test.py +4 -3
  131. langfun/core/templates/selfplay_test.py +14 -2
  132. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  133. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  134. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  135. langfun/core/coding/python/errors.py +0 -108
  136. langfun/core/coding/python/errors_test.py +0 -99
  137. langfun/core/coding/python/permissions.py +0 -90
  138. langfun/core/coding/python/permissions_test.py +0 -86
  139. langfun/core/structured/prompting.py +0 -217
  140. langfun/core/text_formatting.py +0 -162
  141. langfun/core/text_formatting_test.py +0 -47
  142. langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
  143. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  144. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  145. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
16
16
  import abc
17
17
  import inspect
18
18
  import io
19
+ import re
19
20
  import textwrap
20
21
  import typing
21
22
  from typing import Any, Literal, Sequence, Type, Union
@@ -24,6 +25,17 @@ from langfun.core.coding.python import correction
24
25
  import pyglove as pg
25
26
 
26
27
 
28
+ def include_method_in_prompt(method):
29
+ """Decorator to include a method in the class definition of the prompt."""
30
+ setattr(method, '__show_in_prompt__', True)
31
+ return method
32
+
33
+
34
+ def should_include_method_in_prompt(method):
35
+ """Returns true if the method should be shown in the prompt."""
36
+ return getattr(method, '__show_in_prompt__', False)
37
+
38
+
27
39
  def parse_value_spec(value) -> pg.typing.ValueSpec:
28
40
  """Parses a PyGlove ValueSpec equivalence into a ValueSpec."""
29
41
  if isinstance(value, pg.typing.ValueSpec):
@@ -79,26 +91,35 @@ class SchemaError(Exception): # pylint: disable=g-bad-exception-name
79
91
  def __str__(self):
80
92
  r = io.StringIO()
81
93
  r.write(
82
- lf.colored(f'{self.cause.__class__.__name__}: {self.cause}', 'magenta'))
94
+ pg.colored(
95
+ f'{self.cause.__class__.__name__}: {self.cause}', 'magenta'
96
+ )
97
+ )
83
98
 
84
99
  r.write('\n')
85
- r.write(lf.colored('Schema:', 'red'))
100
+ r.write(pg.colored('Schema:', 'red'))
86
101
  r.write('\n\n')
87
102
  r.write(textwrap.indent(
88
- lf.colored(schema_repr(self.protocol).repr(self.schema), 'magenta'),
103
+ pg.colored(
104
+ schema_repr(self.protocol).repr(self.schema), 'magenta'
105
+ ),
89
106
  ' ' * 2
90
107
  ))
91
108
  r.write('\n\n')
92
- r.write(lf.colored('Generated value:', 'red'))
109
+ r.write(pg.colored('Generated value:', 'red'))
93
110
  r.write('\n\n')
94
111
  r.write(textwrap.indent(
95
- lf.colored(value_repr(self.protocol).repr(self.value), 'magenta'),
112
+ pg.colored(value_repr(self.protocol).repr(self.value), 'magenta'),
96
113
  ' ' * 2
97
114
  ))
98
115
  return r.getvalue()
99
116
 
100
117
 
101
- class Schema(lf.NaturalLanguageFormattable, pg.Object):
118
+ class Schema(
119
+ lf.NaturalLanguageFormattable,
120
+ pg.Object,
121
+ pg.views.HtmlTreeView.Extension
122
+ ):
102
123
  """Base class for structured data schema."""
103
124
 
104
125
  spec: pg.typing.Annotated[
@@ -163,9 +184,12 @@ class Schema(lf.NaturalLanguageFormattable, pg.Object):
163
184
 
164
185
  def class_dependencies(
165
186
  self,
187
+ include_base_classes: bool = True,
166
188
  include_subclasses: bool = True) -> list[Type[Any]]:
167
189
  """Returns a list of class dependencies for current schema."""
168
- return class_dependencies(self.spec, include_subclasses)
190
+ return class_dependencies(
191
+ self.spec, include_base_classes, include_subclasses
192
+ )
169
193
 
170
194
  @classmethod
171
195
  def from_value(cls, value) -> 'Schema':
@@ -174,6 +198,29 @@ class Schema(lf.NaturalLanguageFormattable, pg.Object):
174
198
  return value
175
199
  return cls(parse_value_spec(value))
176
200
 
201
+ def _html_tree_view_content(
202
+ self,
203
+ *,
204
+ view: pg.views.HtmlTreeView,
205
+ **kwargs,
206
+ ):
207
+ return pg.Html.element(
208
+ 'div',
209
+ [pg.Html.escape(self.schema_str(protocol='python'))],
210
+ css_classes=['lf-schema-definition']
211
+ ).add_style(
212
+ """
213
+ .lf-schema-definition {
214
+ color: blue;
215
+ margin: 5px;
216
+ white-space: pre-wrap;
217
+ }
218
+ """
219
+ )
220
+
221
+
222
+ SchemaType = Union[Schema, Type[Any], list[Type[Any]], dict[str, Any]]
223
+
177
224
 
178
225
  def _top_level_object_specs_from_value(value: pg.Symbolic) -> list[Type[Any]]:
179
226
  """Returns a list of top level value specs from a symbolic value."""
@@ -198,11 +245,12 @@ def class_dependencies(
198
245
  Type[pg.Object],
199
246
  tuple[Union[pg.typing.ValueSpec, Type[pg.Object]], ...],
200
247
  ],
248
+ include_base_classes: bool = True,
201
249
  include_subclasses: bool = True,
202
250
  ) -> list[Type[Any]]:
203
251
  """Returns a list of class dependencies from a value or specs."""
204
252
  if isinstance(value_or_spec, Schema):
205
- return value_or_spec.class_dependencies(include_subclasses)
253
+ value_or_spec = value_or_spec.spec
206
254
 
207
255
  if inspect.isclass(value_or_spec) or isinstance(
208
256
  value_or_spec, pg.typing.ValueSpec
@@ -236,16 +284,17 @@ def class_dependencies(
236
284
  if vs.cls not in seen:
237
285
  seen.add(vs.cls)
238
286
 
239
- # Add base classes as dependencies.
240
- for base_cls in vs.cls.__bases__:
241
- # We only keep track of user-defined symbolic classes.
242
- if base_cls is not object and base_cls is not pg.Object:
243
- _fill_dependencies(
244
- pg.typing.Object(base_cls), include_subclasses=False
245
- )
287
+ if include_base_classes:
288
+ # Add base classes as dependencies.
289
+ for base_cls in vs.cls.__bases__:
290
+ # We only keep track of user-defined symbolic classes.
291
+ if base_cls is not object and base_cls is not pg.Object:
292
+ _fill_dependencies(
293
+ pg.typing.Object(base_cls), include_subclasses=False
294
+ )
246
295
 
247
296
  # Add members as dependencies.
248
- for field in _pg_schema(vs.cls).values():
297
+ for field in pg.schema(vs.cls).values():
249
298
  _fill_dependencies(field.value, include_subclasses)
250
299
  _add_dependency(vs.cls)
251
300
 
@@ -262,7 +311,7 @@ def class_dependencies(
262
311
  _fill_dependencies(elem.value, include_subclasses)
263
312
  elif isinstance(vs, pg.typing.Dict) and vs.schema:
264
313
  for v in vs.schema.values():
265
- _fill_dependencies(v, include_subclasses)
314
+ _fill_dependencies(v.value, include_subclasses)
266
315
  elif isinstance(vs, pg.typing.Union):
267
316
  for v in vs.candidates:
268
317
  _fill_dependencies(v, include_subclasses)
@@ -314,23 +363,35 @@ class SchemaPythonRepr(SchemaRepr):
314
363
  ret += f'\n\n{class_definition_str}'
315
364
  return ret.strip()
316
365
 
317
- def class_definitions(self, schema: Schema, **kwargs) -> str | None:
318
- deps = schema.class_dependencies(include_subclasses=True)
319
- return class_definitions(deps, **kwargs)
366
+ def class_definitions(
367
+ self,
368
+ schema: Schema,
369
+ additional_dependencies: list[Type[Any]] | None = None,
370
+ **kwargs
371
+ ) -> str | None:
372
+ """Returns a string containing of class definitions from a schema."""
373
+ deps = schema.class_dependencies(
374
+ include_base_classes=False, include_subclasses=True
375
+ )
376
+ allowed_dependencies = set(deps)
377
+ if additional_dependencies:
378
+ allowed_dependencies.update(additional_dependencies)
379
+ return class_definitions(
380
+ deps, allowed_dependencies=allowed_dependencies, **kwargs)
320
381
 
321
382
  def result_definition(self, schema: Schema) -> str:
322
383
  return annotation(schema.spec)
323
384
 
324
385
 
325
- def source_form(value, markdown: bool = False) -> str:
386
+ def source_form(value, compact: bool = True, markdown: bool = False) -> str:
326
387
  """Returns the source code form of an object."""
327
- return ValuePythonRepr().repr(value, markdown=markdown)
388
+ return ValuePythonRepr().repr(value, compact=compact, markdown=markdown)
328
389
 
329
390
 
330
391
  def class_definitions(
331
392
  classes: Sequence[Type[Any]],
332
393
  *,
333
- include_pg_object_as_base: bool = False,
394
+ allowed_dependencies: set[Type[Any]] | None = None,
334
395
  strict: bool = False,
335
396
  markdown: bool = False,
336
397
  ) -> str | None:
@@ -345,7 +406,7 @@ def class_definitions(
345
406
  class_definition(
346
407
  cls,
347
408
  strict=strict,
348
- include_pg_object_as_base=include_pg_object_as_base,
409
+ allowed_dependencies=allowed_dependencies,
349
410
  )
350
411
  )
351
412
  ret = def_str.getvalue()
@@ -355,15 +416,17 @@ def class_definitions(
355
416
 
356
417
 
357
418
  def class_definition(
358
- cls, strict: bool = False, include_pg_object_as_base: bool = False
419
+ cls,
420
+ strict: bool = False,
421
+ allowed_dependencies: set[Type[Any]] | None = None,
359
422
  ) -> str:
360
423
  """Returns the Python class definition."""
361
424
  out = io.StringIO()
362
- schema = _pg_schema(cls)
425
+ schema = pg.schema(cls)
363
426
  eligible_bases = []
364
427
  for base_cls in cls.__bases__:
365
428
  if base_cls is not object:
366
- if include_pg_object_as_base or base_cls is not pg.Object:
429
+ if allowed_dependencies is None or base_cls in allowed_dependencies:
367
430
  eligible_bases.append(base_cls.__name__)
368
431
 
369
432
  if eligible_bases:
@@ -383,13 +446,16 @@ def class_definition(
383
446
  out.write('\n')
384
447
  out.write(' """\n')
385
448
 
449
+ empty_class = True
386
450
  if schema.fields:
387
451
  for key, field in schema.items():
388
452
  if not isinstance(key, pg.typing.ConstStrKey):
389
- raise TypeError(
453
+ pg.logging.warning(
390
454
  'Variable-length keyword arguments is not supported in '
391
- f'structured parsing or query. Encountered: {field}'
455
+ f'structured parsing or query. Encountered: {cls}, Schema: {schema}'
392
456
  )
457
+ continue
458
+
393
459
  # Write field doc string as comments before the field definition.
394
460
  if field.description:
395
461
  for line in field.description.split('\n'):
@@ -397,19 +463,54 @@ def class_definition(
397
463
  out.write(' # ')
398
464
  out.write(line)
399
465
  out.write('\n')
400
- out.write(f' {field.key}: {annotation(field.value, strict=strict)}')
466
+
467
+ annotation_str = annotation(
468
+ field.value, strict=strict, allowed_dependencies=allowed_dependencies
469
+ )
470
+ out.write(f' {field.key}: {annotation_str}')
401
471
  out.write('\n')
402
- else:
472
+ empty_class = False
473
+
474
+ for method in _iter_newly_defined_methods(cls, allowed_dependencies):
475
+ source = inspect.getsource(method)
476
+ # Remove decorators from the method definition.
477
+ source = re.sub(r'\s*@.*\.include_method_in_prompt.*\n', '', source)
478
+ out.write('\n')
479
+ out.write(
480
+ textwrap.indent(
481
+ inspect.cleandoc('\n' + source), ' ' * 2)
482
+ )
483
+ out.write('\n')
484
+ empty_class = False
485
+
486
+ if empty_class:
403
487
  out.write(' pass\n')
404
488
  return out.getvalue()
405
489
 
406
490
 
491
+ def _iter_newly_defined_methods(
492
+ cls, allowed_dependencies: set[Type[Any]] | None):
493
+ names = {attr_name: True for attr_name in dir(cls)}
494
+ for base in cls.__bases__:
495
+ if allowed_dependencies is None or base in allowed_dependencies:
496
+ for name in dir(base):
497
+ names.pop(name, None)
498
+ for name in names.keys():
499
+ attr = getattr(cls, name)
500
+ if callable(attr) and should_include_method_in_prompt(attr):
501
+ yield attr
502
+
503
+
407
504
  def annotation(
408
505
  vs: pg.typing.ValueSpec,
409
506
  annotate_optional: bool = True,
410
507
  strict: bool = False,
508
+ allowed_dependencies: set[Type[Any]] | None = None,
411
509
  ) -> str:
412
510
  """Returns the annotation string for a value spec."""
511
+ child_annotation_kwargs = dict(
512
+ strict=strict, allowed_dependencies=allowed_dependencies
513
+ )
413
514
  if isinstance(vs, pg.typing.Any):
414
515
  return 'Any'
415
516
  elif isinstance(vs, pg.typing.Enum):
@@ -418,7 +519,7 @@ def annotation(
418
519
  elif isinstance(vs, pg.typing.Union):
419
520
  candidate_str = ', '.join(
420
521
  [
421
- annotation(c, annotate_optional=False, strict=strict)
522
+ annotation(c, annotate_optional=False, **child_annotation_kwargs)
422
523
  for c in vs.candidates
423
524
  ]
424
525
  )
@@ -454,20 +555,23 @@ def annotation(
454
555
  )
455
556
  x += '(' + ', '.join(constraints) + ')'
456
557
  elif isinstance(vs, pg.typing.Object):
457
- x = vs.cls.__name__
558
+ if allowed_dependencies is None or vs.cls in allowed_dependencies:
559
+ x = vs.cls.__name__
560
+ else:
561
+ x = 'Any'
458
562
  elif isinstance(vs, pg.typing.List):
459
- item_str = annotation(vs.element.value, strict=strict)
563
+ item_str = annotation(vs.element.value, **child_annotation_kwargs)
460
564
  x = f'list[{item_str}]'
461
565
  elif isinstance(vs, pg.typing.Tuple):
462
566
  elem_str = ', '.join(
463
- [annotation(el.value, strict=strict) for el in vs.elements]
567
+ [annotation(el.value, **child_annotation_kwargs) for el in vs.elements]
464
568
  )
465
569
  x = f'tuple[{elem_str}]'
466
570
  elif isinstance(vs, pg.typing.Dict):
467
571
  kv_pairs = None
468
572
  if vs.schema is not None:
469
573
  kv_pairs = [
470
- (k, annotation(f.value, strict=strict))
574
+ (k, annotation(f.value, **child_annotation_kwargs))
471
575
  for k, f in vs.schema.items()
472
576
  if isinstance(k, pg.typing.ConstStrKey)
473
577
  ]
@@ -477,6 +581,9 @@ def annotation(
477
581
  x = '{' + kv_str + '}'
478
582
  if strict:
479
583
  x = f'pg.typing.Dict({x})'
584
+ elif vs.schema and vs.schema.dynamic_field:
585
+ v = annotation(vs.schema.dynamic_field.value, **child_annotation_kwargs)
586
+ x = f'dict[str, {v}]'
480
587
  else:
481
588
  x = 'dict[str, Any]'
482
589
 
@@ -491,7 +598,8 @@ def annotation(
491
598
  class SchemaJsonRepr(SchemaRepr):
492
599
  """JSON-representation for a schema."""
493
600
 
494
- def repr(self, schema: Schema) -> str:
601
+ def repr(self, schema: Schema, **kwargs) -> str:
602
+ del kwargs
495
603
  out = io.StringIO()
496
604
  def _visit(node: Any) -> None:
497
605
  if isinstance(node, str):
@@ -569,12 +677,19 @@ class ValuePythonRepr(ValueRepr):
569
677
  cls_schema = Schema.from_value(value)
570
678
  if isinstance(cls_schema.spec, pg.typing.Object):
571
679
  object_code = SchemaPythonRepr().class_definitions(
572
- cls_schema, markdown=markdown, include_pg_object_as_base=True
680
+ cls_schema,
681
+ markdown=markdown,
682
+ # We add `pg.Object` as additional dependencies to the class
683
+ # definition so exemplars for class generation could show
684
+ # pg.Object as their bases.
685
+ additional_dependencies=[pg.Object]
573
686
  )
574
687
  assert object_code is not None
575
688
  return object_code
576
689
  else:
577
690
  object_code = SchemaPythonRepr().result_definition(cls_schema)
691
+ elif isinstance(value, lf.Template):
692
+ return str(value)
578
693
  else:
579
694
  object_code = pg.format(
580
695
  value, compact=compact, verbose=verbose, python_format=True
@@ -649,12 +764,15 @@ class JsonError(Exception):
649
764
  def __str__(self) -> str:
650
765
  r = io.StringIO()
651
766
  r.write(
652
- lf.colored(f'{self.cause.__class__.__name__}: {self.cause}', 'magenta'))
767
+ pg.colored(
768
+ f'{self.cause.__class__.__name__}: {self.cause}', 'magenta'
769
+ )
770
+ )
653
771
 
654
772
  r.write('\n\n')
655
- r.write(lf.colored('JSON text:', 'red'))
773
+ r.write(pg.colored('JSON text:', 'red'))
656
774
  r.write('\n\n')
657
- r.write(textwrap.indent(lf.colored(self.json, 'magenta'), ' ' * 2))
775
+ r.write(textwrap.indent(pg.colored(self.json, 'magenta'), ' ' * 2))
658
776
  return r.getvalue()
659
777
 
660
778
 
@@ -669,7 +787,7 @@ class ValueJsonRepr(ValueRepr):
669
787
  """Parse a JSON string into a structured object."""
670
788
  del schema
671
789
  try:
672
- text = self.cleanup_json(text)
790
+ text = cleanup_json(text)
673
791
  v = pg.from_json_str(text, **kwargs)
674
792
  except Exception as e:
675
793
  raise JsonError(text, e) # pylint: disable=raise-missing-from
@@ -681,55 +799,56 @@ class ValueJsonRepr(ValueRepr):
681
799
  ))
682
800
  return v['result']
683
801
 
684
- def cleanup_json(self, json_str: str) -> str:
685
- """Clean up the LM responded JSON string."""
686
- # Treatments:
687
- # 1. Extract the JSON string with a top-level dict from the response.
688
- # This prevents the leading and trailing texts in the response to
689
- # be counted as part of the JSON.
690
- # 2. Escape new lines in JSON values.
691
-
692
- curly_brackets = 0
693
- under_json = False
694
- under_str = False
695
- str_begin = -1
696
-
697
- cleaned = io.StringIO()
698
- for i, c in enumerate(json_str):
699
- if c == '{' and not under_str:
700
- cleaned.write(c)
701
- curly_brackets += 1
702
- under_json = True
703
- continue
704
- elif not under_json:
705
- continue
706
802
 
707
- if c == '}' and not under_str:
708
- cleaned.write(c)
709
- curly_brackets -= 1
710
- if curly_brackets == 0:
711
- break
712
- elif c == '"' and json_str[i - 1] != '\\':
713
- under_str = not under_str
714
- if under_str:
715
- str_begin = i
716
- else:
717
- assert str_begin > 0
718
- str_value = json_str[str_begin : i + 1].replace('\n', '\\n')
719
- cleaned.write(str_value)
720
- str_begin = -1
721
- elif not under_str:
722
- cleaned.write(c)
723
-
724
- if not under_json:
725
- raise ValueError(f'No JSON dict in the output: {json_str}')
726
-
727
- if curly_brackets > 0:
728
- raise ValueError(
729
- f'Malformated JSON: missing {curly_brackets} closing curly braces.'
730
- )
803
+ def cleanup_json(json_str: str) -> str:
804
+ """Clean up the LM responded JSON string."""
805
+ # Treatments:
806
+ # 1. Extract the JSON string with a top-level dict from the response.
807
+ # This prevents the leading and trailing texts in the response to
808
+ # be counted as part of the JSON.
809
+ # 2. Escape new lines in JSON values.
810
+
811
+ curly_brackets = 0
812
+ under_json = False
813
+ under_str = False
814
+ str_begin = -1
815
+
816
+ cleaned = io.StringIO()
817
+ for i, c in enumerate(json_str):
818
+ if c == '{' and not under_str:
819
+ cleaned.write(c)
820
+ curly_brackets += 1
821
+ under_json = True
822
+ continue
823
+ elif not under_json:
824
+ continue
825
+
826
+ if c == '}' and not under_str:
827
+ cleaned.write(c)
828
+ curly_brackets -= 1
829
+ if curly_brackets == 0:
830
+ break
831
+ elif c == '"' and json_str[i - 1] != '\\':
832
+ under_str = not under_str
833
+ if under_str:
834
+ str_begin = i
835
+ else:
836
+ assert str_begin > 0
837
+ str_value = json_str[str_begin : i + 1].replace('\n', '\\n')
838
+ cleaned.write(str_value)
839
+ str_begin = -1
840
+ elif not under_str:
841
+ cleaned.write(c)
842
+
843
+ if not under_json:
844
+ raise ValueError(f'No JSON dict in the output: {json_str}')
845
+
846
+ if curly_brackets > 0:
847
+ raise ValueError(
848
+ f'Malformated JSON: missing {curly_brackets} closing curly braces.'
849
+ )
731
850
 
732
- return cleaned.getvalue()
851
+ return cleaned.getvalue()
733
852
 
734
853
 
735
854
  def schema_repr(protocol: SchemaProtocol) -> SchemaRepr:
@@ -830,13 +949,3 @@ class Unknown(pg.Object, pg.typing.CustomTyping):
830
949
 
831
950
 
832
951
  UNKNOWN = Unknown()
833
-
834
-
835
- def _pg_schema(cls: Type[Any]) -> pg.Schema:
836
- """Returns PyGlove schema for the constructor of a class."""
837
- schema = getattr(cls, '__schema__', None)
838
- if schema is None:
839
- schema = pg.symbolic.callable_schema(
840
- cls.__init__, auto_typing=True, auto_doc=True, remove_self=True
841
- )
842
- return schema
@@ -58,7 +58,7 @@ class GenerateClass(mapping.Mapping):
58
58
  class_name = self.context
59
59
  cls = output_vars.get(class_name, None)
60
60
  if cls is None:
61
- raise correction.errors.CodeError(
61
+ raise pg.coding.CodeError(
62
62
  final_code,
63
63
  TypeError(f'Class {class_name} is absent from LLM output.'),
64
64
  )
@@ -14,8 +14,8 @@
14
14
  import inspect
15
15
  import unittest
16
16
 
17
- import langfun.core.coding as lf_coding
18
17
  from langfun.core.llms import fake
18
+ from langfun.core.structured import mapping
19
19
  from langfun.core.structured import schema_generation
20
20
 
21
21
 
@@ -92,7 +92,7 @@ class GenerateClassTest(unittest.TestCase):
92
92
  )
93
93
  self.assertIs(cls.__name__, 'B')
94
94
 
95
- with self.assertRaises(lf_coding.CodeError):
95
+ with self.assertRaises(mapping.MappingError):
96
96
  schema_generation.generate_class(
97
97
  'Foo',
98
98
  'Generate a Foo class with a field pointing to another class A',