langfun 0.0.2.dev20240314__tar.gz → 0.0.2.dev20240316__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/PKG-INFO +2 -2
  2. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/__init__.py +3 -0
  3. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/langfunc_test.py +2 -2
  4. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/language_model.py +40 -6
  5. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/language_model_test.py +1 -1
  6. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/llms/gemini.py +1 -6
  7. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/llms/llama_cpp.py +2 -9
  8. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/llms/openai.py +2 -12
  9. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/structured/__init__.py +6 -2
  10. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/structured/description.py +53 -50
  11. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/structured/mapping.py +14 -12
  12. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/structured/parsing.py +18 -16
  13. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/structured/schema.py +56 -16
  14. langfun-0.0.2.dev20240316/langfun/core/structured/schema_generation.py +175 -0
  15. langfun-0.0.2.dev20240316/langfun/core/structured/schema_generation_test.py +104 -0
  16. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/structured/schema_test.py +44 -0
  17. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/template.py +6 -3
  18. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/template_test.py +4 -0
  19. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun.egg-info/PKG-INFO +2 -2
  20. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun.egg-info/SOURCES.txt +2 -0
  21. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun.egg-info/requires.txt +1 -1
  22. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/LICENSE +0 -0
  23. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/README.md +0 -0
  24. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/__init__.py +0 -0
  25. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/coding/__init__.py +0 -0
  26. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/coding/python/__init__.py +0 -0
  27. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/coding/python/correction.py +0 -0
  28. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/coding/python/correction_test.py +0 -0
  29. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/coding/python/errors.py +0 -0
  30. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/coding/python/errors_test.py +0 -0
  31. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/coding/python/execution.py +0 -0
  32. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/coding/python/execution_test.py +0 -0
  33. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/coding/python/generation.py +0 -0
  34. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/coding/python/generation_test.py +0 -0
  35. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/coding/python/parsing.py +0 -0
  36. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/coding/python/parsing_test.py +0 -0
  37. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/coding/python/permissions.py +0 -0
  38. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/coding/python/permissions_test.py +0 -0
  39. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/component.py +0 -0
  40. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/component_test.py +0 -0
  41. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/concurrent.py +0 -0
  42. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/concurrent_test.py +0 -0
  43. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/console.py +0 -0
  44. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/console_test.py +0 -0
  45. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/eval/__init__.py +0 -0
  46. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/eval/base.py +0 -0
  47. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/eval/base_test.py +0 -0
  48. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/eval/matching.py +0 -0
  49. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/eval/matching_test.py +0 -0
  50. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/eval/scoring.py +0 -0
  51. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/eval/scoring_test.py +0 -0
  52. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/langfunc.py +0 -0
  53. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/llms/__init__.py +0 -0
  54. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/llms/cache/__init__.py +0 -0
  55. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/llms/cache/base.py +0 -0
  56. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/llms/cache/in_memory.py +0 -0
  57. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/llms/cache/in_memory_test.py +0 -0
  58. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/llms/fake.py +0 -0
  59. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/llms/fake_test.py +0 -0
  60. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/llms/gemini_test.py +0 -0
  61. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/llms/llama_cpp_test.py +0 -0
  62. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/llms/openai_test.py +0 -0
  63. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/memories/__init__.py +0 -0
  64. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/memories/conversation_history.py +0 -0
  65. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/memories/conversation_history_test.py +0 -0
  66. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/memory.py +0 -0
  67. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/message.py +0 -0
  68. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/message_test.py +0 -0
  69. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/modalities/__init__.py +0 -0
  70. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/modalities/image.py +0 -0
  71. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/modalities/image_test.py +0 -0
  72. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/modalities/mime.py +0 -0
  73. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/modalities/mime_test.py +0 -0
  74. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/modalities/video.py +0 -0
  75. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/modalities/video_test.py +0 -0
  76. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/modality.py +0 -0
  77. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/modality_test.py +0 -0
  78. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/natural_language.py +0 -0
  79. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/natural_language_test.py +0 -0
  80. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/sampling.py +0 -0
  81. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/sampling_test.py +0 -0
  82. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/structured/completion.py +0 -0
  83. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/structured/completion_test.py +0 -0
  84. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/structured/description_test.py +0 -0
  85. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/structured/mapping_test.py +0 -0
  86. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/structured/parsing_test.py +0 -0
  87. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/structured/prompting.py +0 -0
  88. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/structured/prompting_test.py +0 -0
  89. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/structured/scoring.py +0 -0
  90. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/structured/scoring_test.py +0 -0
  91. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/subscription.py +0 -0
  92. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/subscription_test.py +0 -0
  93. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/templates/__init__.py +0 -0
  94. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/templates/completion.py +0 -0
  95. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/templates/completion_test.py +0 -0
  96. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/templates/conversation.py +0 -0
  97. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/templates/conversation_test.py +0 -0
  98. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/templates/demonstration.py +0 -0
  99. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/templates/demonstration_test.py +0 -0
  100. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/templates/selfplay.py +0 -0
  101. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/templates/selfplay_test.py +0 -0
  102. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/text_formatting.py +0 -0
  103. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun/core/text_formatting_test.py +0 -0
  104. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun.egg-info/dependency_links.txt +0 -0
  105. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/langfun.egg-info/top_level.txt +0 -0
  106. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/setup.cfg +0 -0
  107. {langfun-0.0.2.dev20240314 → langfun-0.0.2.dev20240316}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.0.2.dev20240314
3
+ Version: 0.0.2.dev20240316
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -24,7 +24,7 @@ License-File: LICENSE
24
24
  Requires-Dist: google-generativeai>=0.3.2
25
25
  Requires-Dist: jinja2>=3.1.2
26
26
  Requires-Dist: openai==0.27.2
27
- Requires-Dist: pyglove>=0.4.5.dev20240201
27
+ Requires-Dist: pyglove>=0.4.5.dev20240314
28
28
  Requires-Dist: python-magic>=0.4.27
29
29
  Requires-Dist: requests>=2.31.0
30
30
  Requires-Dist: termcolor==1.1.0
@@ -31,6 +31,9 @@ query = structured.query
31
31
  describe = structured.describe
32
32
  complete = structured.complete
33
33
  score = structured.score
34
+ generate_class = structured.generate_class
35
+
36
+ source_form = structured.source_form
34
37
 
35
38
  from langfun.core import eval # pylint: disable=redefined-builtin
36
39
  from langfun.core import templates
@@ -95,8 +95,8 @@ class LangFuncCallTest(unittest.TestCase):
95
95
  ' lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=0.0,'
96
96
  ' max_tokens=1024, n=1, top_k=40, top_p=None, stop=None,'
97
97
  ' random_seed=None, logprobs=False, top_logprobs=None), cache=None,'
98
- ' timeout=120.0, max_attempts=5, retry_interval=(5, 60),'
99
- ' exponential_backoff=True, debug=False))',
98
+ ' max_concurrency=None, timeout=120.0, max_attempts=5,'
99
+ ' retry_interval=(5, 60), exponential_backoff=True, debug=False))',
100
100
  )
101
101
 
102
102
  l = LangFunc('Hello')
@@ -17,8 +17,9 @@ import abc
17
17
  import dataclasses
18
18
  import enum
19
19
  import time
20
- from typing import Annotated, Any
20
+ from typing import Annotated, Any, Callable, Sequence, Tuple, Type, Union
21
21
  from langfun.core import component
22
+ from langfun.core import concurrent
22
23
  from langfun.core import console
23
24
  from langfun.core import message as message_lib
24
25
  import pyglove as pg
@@ -209,6 +210,22 @@ class LanguageModel(component.Component):
209
210
  )
210
211
  ] = component.contextual(default=None)
211
212
 
213
+ max_concurrency: Annotated[
214
+ int | None,
215
+ (
216
+ 'Max concurrent requests being sent to the server. '
217
+ 'If None, there is no limit. '
218
+ 'Please note that the concurrency control is based on the '
219
+ '`resource_id` property, meaning that model instances shared '
220
+ 'the same resource ID will be accounted under the same concurrency '
221
+ 'control key. This allows a process-level concurrency control '
222
+ 'for specific models regardless the number of LM (client) instances '
223
+ 'created by the program. Subclasses could override this number or '
224
+ 'replace it with a `max_concurrency` property to allow dynamic '
225
+ 'concurrency control.'
226
+ ),
227
+ ] = None
228
+
212
229
  timeout: Annotated[
213
230
  float | None, 'Timeout in seconds. If None, there is no timeout.'
214
231
  ] = 120.0
@@ -284,11 +301,6 @@ class LanguageModel(component.Component):
284
301
  """Resource ID for performing request parallism control."""
285
302
  return self.model_id
286
303
 
287
- @property
288
- def max_concurrency(self) -> int:
289
- """Max concurrent requests."""
290
- return 32
291
-
292
304
  def sample(
293
305
  self,
294
306
  prompts: list[str | message_lib.Message],
@@ -355,6 +367,28 @@ class LanguageModel(component.Component):
355
367
  ) -> list[LMSamplingResult]:
356
368
  """Subclass should override."""
357
369
 
370
+ def _parallel_execute_with_currency_control(
371
+ self,
372
+ action: Callable[..., Any],
373
+ inputs: Sequence[Any],
374
+ retry_on_errors: Union[
375
+ None,
376
+ Union[Type[Exception], Tuple[Type[Exception], str]],
377
+ Sequence[Union[Type[Exception], Tuple[Type[Exception], str]]],
378
+ ] = None,
379
+ ) -> Any:
380
+ """Helper method for subclasses for implementing _sample."""
381
+ return concurrent.concurrent_execute(
382
+ action,
383
+ inputs,
384
+ executor=self.resource_id if self.max_concurrency else None,
385
+ max_workers=self.max_concurrency or len(inputs),
386
+ retry_on_errors=retry_on_errors,
387
+ max_attempts=self.max_attempts,
388
+ retry_interval=self.retry_interval,
389
+ exponential_backoff=self.exponential_backoff,
390
+ )
391
+
358
392
  def __call__(
359
393
  self, prompt: message_lib.Message, *, cache_seed: int = 0, **kwargs
360
394
  ) -> message_lib.Message:
@@ -89,7 +89,7 @@ class LanguageModelTest(unittest.TestCase):
89
89
  lm = MockModel(1, temperature=0.5, top_k=2, max_attempts=2)
90
90
  self.assertEqual(lm.model_id, 'MockModel')
91
91
  self.assertEqual(lm.resource_id, 'MockModel')
92
- self.assertEqual(lm.max_concurrency, 32)
92
+ self.assertIsNone(lm.max_concurrency)
93
93
  self.assertEqual(lm.failures_before_attempt, 1)
94
94
  self.assertEqual(lm.sampling_options.temperature, 0.5)
95
95
  self.assertEqual(lm.sampling_options.top_k, 2)
@@ -133,14 +133,9 @@ class Gemini(lf.LanguageModel):
133
133
 
134
134
  def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
135
135
  assert self._api_initialized, 'Vertex AI API is not initialized.'
136
- return lf.concurrent_execute(
136
+ return self._parallel_execute_with_currency_control(
137
137
  self._sample_single,
138
138
  prompts,
139
- executor=self.resource_id,
140
- max_workers=self.max_concurrency,
141
- # NOTE(daiyip): Vertex has its own policy on handling
142
- # with rate limit, so we do not retry on errors.
143
- retry_on_errors=None,
144
139
  )
145
140
 
146
141
  def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
@@ -67,13 +67,6 @@ class LlamaCppRemote(lf.LanguageModel):
67
67
  results.append(result)
68
68
  return results
69
69
 
70
- return lf.concurrent_execute(
71
- _complete_fn,
72
- [prompts],
73
- executor=self.resource_id,
74
- max_workers=self.max_concurrency,
75
- retry_on_errors=(),
76
- max_attempts=self.max_attempts,
77
- retry_interval=self.retry_interval,
78
- exponential_backoff=self.exponential_backoff,
70
+ return self._parallel_execute_with_currency_control(
71
+ _complete_fn, [prompts]
79
72
  )[0]
@@ -214,18 +214,13 @@ class OpenAI(lf.LanguageModel):
214
214
  for index in sorted(samples_by_index.keys())
215
215
  ]
216
216
 
217
- return lf.concurrent_execute(
217
+ return self._parallel_execute_with_currency_control(
218
218
  _open_ai_completion,
219
219
  [prompts],
220
- executor=self.resource_id,
221
- max_workers=self.max_concurrency,
222
220
  retry_on_errors=(
223
221
  openai_error.ServiceUnavailableError,
224
222
  openai_error.RateLimitError,
225
223
  ),
226
- max_attempts=self.max_attempts,
227
- retry_interval=self.retry_interval,
228
- exponential_backoff=self.exponential_backoff,
229
224
  )[0]
230
225
 
231
226
  def _chat_complete_batch(
@@ -280,18 +275,13 @@ class OpenAI(lf.LanguageModel):
280
275
  ),
281
276
  )
282
277
 
283
- return lf.concurrent_execute(
278
+ return self._parallel_execute_with_currency_control(
284
279
  _open_ai_chat_completion,
285
280
  prompts,
286
- executor=self.resource_id,
287
- max_workers=self.max_concurrency,
288
281
  retry_on_errors=(
289
282
  openai_error.ServiceUnavailableError,
290
283
  openai_error.RateLimitError,
291
284
  ),
292
- max_attempts=self.max_attempts,
293
- retry_interval=self.retry_interval,
294
- exponential_backoff=self.exponential_backoff,
295
285
  )
296
286
 
297
287
 
@@ -41,8 +41,12 @@ from langfun.core.structured.schema import ValueRepr
41
41
  from langfun.core.structured.schema import ValueJsonRepr
42
42
  from langfun.core.structured.schema import ValuePythonRepr
43
43
  from langfun.core.structured.schema import schema_repr
44
+ from langfun.core.structured.schema import source_form
44
45
  from langfun.core.structured.schema import value_repr
45
46
 
47
+ from langfun.core.structured.schema_generation import generate_class
48
+ from langfun.core.structured.schema_generation import classgen_example
49
+ from langfun.core.structured.schema_generation import default_classgen_examples
46
50
 
47
51
  from langfun.core.structured.mapping import Mapping
48
52
  from langfun.core.structured.mapping import MappingExample
@@ -68,8 +72,8 @@ from langfun.core.structured.scoring import score
68
72
 
69
73
  # Expose default examples for structured operations so users could refer to
70
74
  # them.
71
- from langfun.core.structured.parsing import DEFAULT_PARSE_EXAMPLES
72
- from langfun.core.structured.description import DEFAULT_DESCRIBE_EXAMPLES
75
+ from langfun.core.structured.parsing import default_parse_examples
76
+ from langfun.core.structured.description import default_describe_examples
73
77
 
74
78
  # Default examples.
75
79
 
@@ -106,58 +106,61 @@ def describe(
106
106
  Returns:
107
107
  The parsed result based on the schema.
108
108
  """
109
- if examples is None:
110
- examples = DEFAULT_DESCRIBE_EXAMPLES
111
109
  return DescribeStructure(
112
- input=value, context=context, examples=examples, **kwargs
110
+ input=value,
111
+ context=context,
112
+ examples=examples or default_describe_examples(),
113
+ **kwargs,
113
114
  )(lm=lm, cache_seed=cache_seed).text
114
115
 
115
116
 
116
- class _Country(pg.Object):
117
- """A example dataclass for structured mapping."""
118
-
119
- name: str
120
- continents: list[
121
- Literal[
122
- 'Africa',
123
- 'Asia',
124
- 'Europe',
125
- 'Oceania',
126
- 'North America',
127
- 'South America',
128
- ]
117
+ def default_describe_examples() -> list[mapping.MappingExample]:
118
+ """Default describe examples."""
119
+
120
+ class Country(pg.Object):
121
+ """A example dataclass for structured mapping."""
122
+
123
+ name: str
124
+ continents: list[
125
+ Literal[
126
+ 'Africa',
127
+ 'Asia',
128
+ 'Europe',
129
+ 'Oceania',
130
+ 'North America',
131
+ 'South America',
132
+ ]
133
+ ]
134
+ num_states: int
135
+ neighbor_countries: list[str]
136
+ population: int
137
+ capital: str | None
138
+ president: str | None
139
+
140
+ return [
141
+ mapping.MappingExample(
142
+ context='Brief intro to United States',
143
+ input=Country(
144
+ name='The United States of America',
145
+ continents=['North America'],
146
+ num_states=50,
147
+ neighbor_countries=[
148
+ 'Canada',
149
+ 'Mexico',
150
+ 'Bahamas',
151
+ 'Cuba',
152
+ 'Russia',
153
+ ],
154
+ population=333000000,
155
+ capital='Washington, D.C',
156
+ president=None,
157
+ ),
158
+ output=inspect.cleandoc("""
159
+ The United States of America is a country primarily located in North America
160
+ consisting of fifty states. It shares land borders with Canada to its north
161
+ and with Mexico to its south and has maritime borders with the Bahamas, Cuba,
162
+ Russia, and other nations. With a population of over 333 million. The national
163
+ capital of the United States is Washington, D.C.
164
+ """),
165
+ ),
129
166
  ]
130
- num_states: int
131
- neighbor_countries: list[str]
132
- population: int
133
- capital: str | None
134
- president: str | None
135
-
136
-
137
- DEFAULT_DESCRIBE_EXAMPLES: list[mapping.MappingExample] = [
138
- mapping.MappingExample(
139
- context='Brief intro to United States',
140
- input=_Country(
141
- name='The United States of America',
142
- continents=['North America'],
143
- num_states=50,
144
- neighbor_countries=[
145
- 'Canada',
146
- 'Mexico',
147
- 'Bahamas',
148
- 'Cuba',
149
- 'Russia',
150
- ],
151
- population=333000000,
152
- capital='Washington, D.C',
153
- president=None,
154
- ),
155
- output=inspect.cleandoc("""
156
- The United States of America is a country primarily located in North America
157
- consisting of fifty states. It shares land borders with Canada to its north
158
- and with Mexico to its south and has maritime borders with the Bahamas, Cuba,
159
- Russia, and other nations. With a population of over 333 million. The national
160
- capital of the United States is Washington, D.C.
161
- """),
162
- ),
163
- ]
@@ -293,25 +293,27 @@ class Mapping(lf.LangFunc):
293
293
 
294
294
  def transform_output(self, lm_output: lf.Message) -> lf.Message:
295
295
  """Transforms LM response into structure if schema is present."""
296
- schema = self.mapping_request.schema
297
- if schema is None:
298
- return lm_output
299
-
300
296
  try:
301
- result = schema.parse(
302
- lm_output.text,
303
- protocol=self.protocol,
304
- additional_context=self.globals(),
305
- autofix=self.autofix,
306
- autofix_lm=self.autofix_lm or self.lm,
307
- )
308
- lm_output.result = self.postprocess_result(result)
297
+ lm_output.result = self.postprocess_result(self.parse_result(lm_output))
309
298
  except Exception as e: # pylint: disable=broad-exception-caught
310
299
  if self.default == lf.RAISE_IF_HAS_ERROR:
311
300
  raise e
312
301
  lm_output.result = self.default
313
302
  return lm_output
314
303
 
304
+ def parse_result(self, lm_output: lf.Message) -> Any:
305
+ """Parse result from LLM response."""
306
+ schema = self.mapping_request.schema
307
+ if schema is None:
308
+ return None
309
+ return schema.parse(
310
+ lm_output.text,
311
+ protocol=self.protocol,
312
+ additional_context=self.globals(),
313
+ autofix=self.autofix,
314
+ autofix_lm=self.autofix_lm or self.lm,
315
+ )
316
+
315
317
  def postprocess_result(self, result: Any) -> Any:
316
318
  """Post process structured output."""
317
319
  return result
@@ -162,11 +162,11 @@ def parse(
162
162
  message.source = lf.UserMessage(user_prompt, tags=['lm-input'])
163
163
  context = getattr(message.lm_input, 'text', None) if include_context else None
164
164
 
165
- if examples is None:
166
- examples = DEFAULT_PARSE_EXAMPLES
167
-
168
165
  t = _parse_structure_cls(protocol)(
169
- schema=schema, context=context, default=default, examples=examples
166
+ schema=schema,
167
+ context=context,
168
+ default=default,
169
+ examples=examples or default_parse_examples(),
170
170
  )
171
171
 
172
172
  # Setting up context.
@@ -296,17 +296,19 @@ def _parse_structure_cls(
296
296
  raise ValueError(f'Unknown protocol: {protocol!r}.')
297
297
 
298
298
 
299
- class _AdditionResults(pg.Object):
300
- one_plus_one_equals: int | None
301
- two_plus_two_equals: int | None
299
+ def default_parse_examples() -> list[mapping.MappingExample]:
300
+ """Default parsing examples."""
302
301
 
302
+ class AdditionResults(pg.Object):
303
+ one_plus_one_equals: int | None
304
+ two_plus_two_equals: int | None
303
305
 
304
- DEFAULT_PARSE_EXAMPLES: list[mapping.MappingExample] = [
305
- mapping.MappingExample(
306
- input='Two plus two equals four. Three plus three equals six.',
307
- schema=_AdditionResults,
308
- output=_AdditionResults(
309
- one_plus_one_equals=None, two_plus_two_equals=4
310
- ),
311
- ),
312
- ]
306
+ return [
307
+ mapping.MappingExample(
308
+ input='Two plus two equals four. Three plus three equals six.',
309
+ schema=AdditionResults,
310
+ output=AdditionResults(
311
+ one_plus_one_equals=None, two_plus_two_equals=4
312
+ ),
313
+ ),
314
+ ]
@@ -301,23 +301,43 @@ class SchemaRepr(metaclass=abc.ABCMeta):
301
301
  class SchemaPythonRepr(SchemaRepr):
302
302
  """Python-representation for a schema."""
303
303
 
304
- def repr(self, schema: Schema) -> str:
305
- ret = self.result_definition(schema)
306
- class_definition_str = self.class_definitions(schema)
304
+ def repr(
305
+ self,
306
+ schema: Schema,
307
+ *,
308
+ include_result_definition: bool = True,
309
+ markdown: bool = True,
310
+ **kwargs,
311
+ ) -> str:
312
+ ret = ''
313
+ if include_result_definition:
314
+ ret += self.result_definition(schema)
315
+ class_definition_str = self.class_definitions(
316
+ schema, markdown=markdown, **kwargs
317
+ )
307
318
  if class_definition_str:
308
- ret += f'\n\n```python\n{class_definition_str}```'
309
- return ret
319
+ ret += f'\n\n{class_definition_str}'
320
+ return ret.strip()
310
321
 
311
- def class_definitions(self, schema: Schema) -> str | None:
322
+ def class_definitions(self, schema: Schema, **kwargs) -> str | None:
312
323
  deps = schema.class_dependencies(include_subclasses=True)
313
- return class_definitions(deps)
324
+ return class_definitions(deps, **kwargs)
314
325
 
315
326
  def result_definition(self, schema: Schema) -> str:
316
327
  return annotation(schema.spec)
317
328
 
318
329
 
330
+ def source_form(value, markdown: bool = False) -> str:
331
+ """Returns the source code form of an object."""
332
+ return ValuePythonRepr().repr(value, markdown=markdown)
333
+
334
+
319
335
  def class_definitions(
320
- classes: Sequence[Type[Any]], strict: bool = False, markdown: bool = False
336
+ classes: Sequence[Type[Any]],
337
+ *,
338
+ include_pg_object_as_base: bool = False,
339
+ strict: bool = False,
340
+ markdown: bool = False,
321
341
  ) -> str | None:
322
342
  """Returns a str for class definitions."""
323
343
  if not classes:
@@ -326,14 +346,22 @@ def class_definitions(
326
346
  for i, cls in enumerate(classes):
327
347
  if i > 0:
328
348
  def_str.write('\n')
329
- def_str.write(class_definition(cls, strict))
349
+ def_str.write(
350
+ class_definition(
351
+ cls,
352
+ strict=strict,
353
+ include_pg_object_as_base=include_pg_object_as_base,
354
+ )
355
+ )
330
356
  ret = def_str.getvalue()
331
357
  if markdown and ret:
332
358
  ret = f'```python\n{ret}```'
333
359
  return ret
334
360
 
335
361
 
336
- def class_definition(cls, strict: bool = False) -> str:
362
+ def class_definition(
363
+ cls, strict: bool = False, include_pg_object_as_base: bool = False
364
+ ) -> str:
337
365
  """Returns the Python class definition."""
338
366
  out = io.StringIO()
339
367
  if not issubclass(cls, pg.Object):
@@ -344,10 +372,9 @@ def class_definition(cls, strict: bool = False) -> str:
344
372
  schema = cls.__schema__
345
373
  eligible_bases = []
346
374
  for base_cls in cls.__bases__:
347
- if issubclass(base_cls, pg.Symbolic) and not base_cls.__module__.startswith(
348
- 'pyglove'
349
- ):
350
- eligible_bases.append(base_cls.__name__)
375
+ if issubclass(base_cls, pg.Object):
376
+ if include_pg_object_as_base or base_cls is not pg.Object:
377
+ eligible_bases.append(base_cls.__name__)
351
378
  if eligible_bases:
352
379
  base_cls_str = ', '.join(eligible_bases)
353
380
  out.write(f'class {cls.__name__}({base_cls_str}):\n')
@@ -547,8 +574,20 @@ class ValuePythonRepr(ValueRepr):
547
574
  markdown: bool = True,
548
575
  **kwargs) -> str:
549
576
  del schema
550
- object_code = pg.format(
551
- value, compact=compact, verbose=verbose, python_format=True)
577
+ if inspect.isclass(value):
578
+ cls_schema = Schema.from_value(value)
579
+ if isinstance(cls_schema.spec, pg.typing.Object):
580
+ object_code = SchemaPythonRepr().class_definitions(
581
+ cls_schema, markdown=markdown, include_pg_object_as_base=True
582
+ )
583
+ assert object_code is not None
584
+ return object_code
585
+ else:
586
+ object_code = SchemaPythonRepr().result_definition(cls_schema)
587
+ else:
588
+ object_code = pg.format(
589
+ value, compact=compact, verbose=verbose, python_format=True
590
+ )
552
591
  if markdown:
553
592
  return f'```python\n{ object_code }\n```'
554
593
  return object_code
@@ -588,6 +627,7 @@ def structure_from_python(
588
627
  global_vars = global_vars or {}
589
628
  global_vars.update({
590
629
  'pg': pg,
630
+ 'Object': pg.Object,
591
631
  'Any': typing.Any,
592
632
  'List': typing.List,
593
633
  'Tuple': typing.Tuple,