langfun 0.0.2.dev20240319__py3-none-any.whl → 0.0.2.dev20240429__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. langfun/__init__.py +2 -0
  2. langfun/core/__init__.py +1 -0
  3. langfun/core/coding/python/correction.py +0 -7
  4. langfun/core/component.py +6 -0
  5. langfun/core/component_test.py +1 -0
  6. langfun/core/eval/__init__.py +2 -0
  7. langfun/core/eval/base.py +240 -37
  8. langfun/core/eval/base_test.py +52 -18
  9. langfun/core/eval/matching.py +26 -9
  10. langfun/core/eval/matching_test.py +3 -4
  11. langfun/core/eval/scoring.py +15 -6
  12. langfun/core/eval/scoring_test.py +2 -2
  13. langfun/core/langfunc.py +0 -5
  14. langfun/core/langfunc_test.py +6 -4
  15. langfun/core/language_model.py +124 -24
  16. langfun/core/language_model_test.py +249 -26
  17. langfun/core/llms/__init__.py +24 -5
  18. langfun/core/llms/anthropic.py +263 -0
  19. langfun/core/llms/anthropic_test.py +167 -0
  20. langfun/core/llms/cache/in_memory_test.py +37 -28
  21. langfun/core/llms/fake.py +31 -22
  22. langfun/core/llms/fake_test.py +122 -11
  23. langfun/core/llms/{gemini.py → google_genai.py} +117 -15
  24. langfun/core/llms/{gemini_test.py → google_genai_test.py} +83 -15
  25. langfun/core/llms/groq.py +260 -0
  26. langfun/core/llms/groq_test.py +170 -0
  27. langfun/core/llms/llama_cpp.py +3 -1
  28. langfun/core/llms/openai.py +97 -79
  29. langfun/core/llms/openai_test.py +285 -59
  30. langfun/core/modalities/video.py +5 -2
  31. langfun/core/structured/__init__.py +3 -0
  32. langfun/core/structured/completion_test.py +2 -2
  33. langfun/core/structured/function_generation.py +245 -0
  34. langfun/core/structured/function_generation_test.py +329 -0
  35. langfun/core/structured/mapping.py +59 -3
  36. langfun/core/structured/mapping_test.py +17 -0
  37. langfun/core/structured/parsing.py +2 -1
  38. langfun/core/structured/parsing_test.py +18 -13
  39. langfun/core/structured/prompting.py +27 -6
  40. langfun/core/structured/prompting_test.py +79 -12
  41. langfun/core/structured/schema.py +25 -22
  42. langfun/core/structured/schema_generation.py +2 -3
  43. langfun/core/structured/schema_generation_test.py +2 -2
  44. langfun/core/structured/schema_test.py +42 -27
  45. langfun/core/template.py +125 -10
  46. langfun/core/template_test.py +75 -0
  47. langfun/core/templates/selfplay_test.py +6 -2
  48. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/METADATA +3 -2
  49. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/RECORD +52 -46
  50. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/LICENSE +0 -0
  51. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/WHEEL +0 -0
  52. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ from typing import Any, Callable, Type, Union
16
16
 
17
17
  import langfun.core as lf
18
18
  from langfun.core.structured import mapping
19
+ from langfun.core.structured import prompting
19
20
  from langfun.core.structured import schema as schema_lib
20
21
  import pyglove as pg
21
22
 
@@ -270,7 +271,7 @@ def call(
270
271
  return lm_output if returns_message else lm_output.text
271
272
 
272
273
  # Call `parsing_lm` for structured parsing.
273
- return parse(
274
+ return prompting.query(
274
275
  lm_output,
275
276
  schema,
276
277
  examples=parsing_examples,
@@ -17,11 +17,9 @@ import inspect
17
17
  import unittest
18
18
 
19
19
  import langfun.core as lf
20
- from langfun.core import coding
21
20
  from langfun.core.llms import fake
22
21
  from langfun.core.structured import mapping
23
22
  from langfun.core.structured import parsing
24
- from langfun.core.structured import schema as schema_lib
25
23
  import pyglove as pg
26
24
 
27
25
 
@@ -255,7 +253,7 @@ class ParseStructurePythonTest(unittest.TestCase):
255
253
  override_attrs=True,
256
254
  ):
257
255
  with self.assertRaisesRegex(
258
- coding.CodeError,
256
+ mapping.MappingError,
259
257
  'name .* is not defined',
260
258
  ):
261
259
  parsing.parse('three', int)
@@ -280,13 +278,15 @@ class ParseStructurePythonTest(unittest.TestCase):
280
278
  ),
281
279
  1,
282
280
  )
281
+ r = parsing.parse(
282
+ 'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm,
283
+ returns_message=True
284
+ )
283
285
  self.assertEqual(
284
- parsing.parse(
285
- 'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm,
286
- returns_message=True
287
- ),
286
+ r,
288
287
  lf.AIMessage(
289
288
  '1', score=1.0, result=1, logprobs=None,
289
+ usage=lf.LMSamplingUsage(652, 1, 653),
290
290
  tags=['lm-response', 'lm-output', 'transformed']
291
291
  ),
292
292
  )
@@ -544,7 +544,7 @@ class ParseStructureJsonTest(unittest.TestCase):
544
544
  override_attrs=True,
545
545
  ):
546
546
  with self.assertRaisesRegex(
547
- schema_lib.JsonError,
547
+ mapping.MappingError,
548
548
  'No JSON dict in the output',
549
549
  ):
550
550
  parsing.parse('three', int, protocol='json')
@@ -634,13 +634,18 @@ class CallTest(unittest.TestCase):
634
634
  )
635
635
 
636
636
  def test_call_with_returning_message(self):
637
+ r = parsing.call(
638
+ 'Compute 1 + 2', int, lm=fake.StaticSequence(['three', '3']),
639
+ returns_message=True
640
+ )
637
641
  self.assertEqual(
638
- parsing.call(
639
- 'Compute 1 + 2', int, lm=fake.StaticSequence(['three', '3']),
640
- returns_message=True
641
- ),
642
+ r,
642
643
  lf.AIMessage(
643
- '3', result=3, score=1.0, logprobs=None,
644
+ '3',
645
+ result=3,
646
+ score=1.0,
647
+ logprobs=None,
648
+ usage=lf.LMSamplingUsage(315, 1, 316),
644
649
  tags=['lm-response', 'lm-output', 'transformed']
645
650
  ),
646
651
  )
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  """Symbolic query."""
15
15
 
16
- from typing import Any, Type, Union
16
+ from typing import Any, Callable, Type, Union
17
17
 
18
18
  import langfun.core as lf
19
19
  from langfun.core.structured import mapping
@@ -78,7 +78,9 @@ class QueryStructurePython(QueryStructure):
78
78
 
79
79
  {{ output_title }}:
80
80
  ```python
81
- Answer(final_answer=2)
81
+ Answer(
82
+ final_answer=2
83
+ )
82
84
  ```
83
85
  """
84
86
  protocol = 'python'
@@ -107,6 +109,7 @@ def query(
107
109
  lm: lf.LanguageModel | None = None,
108
110
  examples: list[mapping.MappingExample] | None = None,
109
111
  cache_seed: int | None = 0,
112
+ response_postprocess: Callable[[str], str] | None = None,
110
113
  autofix: int = 0,
111
114
  autofix_lm: lf.LanguageModel | None = None,
112
115
  protocol: schema_lib.SchemaProtocol = 'python',
@@ -159,6 +162,9 @@ def query(
159
162
  cache_seed: Seed for computing cache key. The cache key is determined by a
160
163
  tuple of (lm, prompt, cache seed). If None, cache will be disabled for
161
164
  the query even cache is configured by the LM.
165
+ response_postprocess: An optional callable object to process the raw LM
166
+ response before parsing it into the final output object. If None, the
167
+ raw LM response will not be processed.
162
168
  autofix: Number of attempts to auto fix the generated code. If 0, autofix is
163
169
  disabled. Auto-fix is not supported for 'json' protocol.
164
170
  autofix_lm: The language model to use for autofix. If not specified, the
@@ -170,8 +176,11 @@ def query(
170
176
  returning the structured `message.result`.
171
177
  skip_lm: If True, returns the rendered prompt as a UserMessage object.
172
178
  otherwise return the LLM response based on the rendered prompt.
173
- **kwargs: Keyword arguments passed to the
174
- `lf.structured.NaturalLanguageToStructureed` transform.
179
+ **kwargs: Keyword arguments passed to render the prompt or configure the
180
+ `lf.structured.Mapping` class. Notable kwargs are:
181
+ - template_str: Change the root template for query.
182
+ - preamble: Change the preamble for query.
183
+ - mapping_template: Change the template for each mapping examle.
175
184
 
176
185
  Returns:
177
186
  The result based on the schema.
@@ -188,13 +197,24 @@ def query(
188
197
  output = lf.LangFunc.from_value(prompt, **kwargs)(
189
198
  lm=lm, cache_seed=cache_seed, skip_lm=skip_lm
190
199
  )
200
+ if response_postprocess:
201
+ processed_text = response_postprocess(output.text)
202
+ if processed_text != output.text:
203
+ output = lf.AIMessage(processed_text, source=output)
191
204
  return output if returns_message else output.text
192
205
 
193
206
  # Query with structured output.
207
+ prompt_kwargs = kwargs.copy()
208
+
209
+ # NOTE(daiyip): when `template_str` is passed in, it's intended to modify the
210
+ # QueryStructure template string. Therefore, we pop out the argument for
211
+ # prompt rendering.
212
+ prompt_kwargs.pop('template_str', None)
213
+
194
214
  if isinstance(prompt, str):
195
- prompt = lf.Template(prompt, **kwargs)
215
+ prompt = lf.Template(prompt, **prompt_kwargs)
196
216
  elif isinstance(prompt, lf.Template):
197
- prompt = prompt.rebind(**kwargs)
217
+ prompt = prompt.rebind(**prompt_kwargs, raise_on_no_change=False)
198
218
 
199
219
  if isinstance(prompt, lf.Template):
200
220
  prompt = prompt.render(lm=lm)
@@ -206,6 +226,7 @@ def query(
206
226
  schema=schema,
207
227
  default=default,
208
228
  examples=examples,
229
+ response_postprocess=response_postprocess,
209
230
  autofix=autofix if protocol == 'python' else 0,
210
231
  **kwargs,
211
232
  )(
@@ -17,12 +17,10 @@ import inspect
17
17
  import unittest
18
18
 
19
19
  import langfun.core as lf
20
- from langfun.core import coding
21
20
  from langfun.core import modalities
22
21
  from langfun.core.llms import fake
23
22
  from langfun.core.structured import mapping
24
23
  from langfun.core.structured import prompting
25
- from langfun.core.structured import schema as schema_lib
26
24
  import pyglove as pg
27
25
 
28
26
 
@@ -77,6 +75,7 @@ class QueryTest(unittest.TestCase):
77
75
  result=1,
78
76
  score=1.0,
79
77
  logprobs=None,
78
+ usage=lf.LMSamplingUsage(323, 1, 324),
80
79
  tags=['lm-response', 'lm-output', 'transformed'],
81
80
  ),
82
81
  )
@@ -116,12 +115,59 @@ class QueryTest(unittest.TestCase):
116
115
  y=2,
117
116
  lm=lm.clone(),
118
117
  expected_snippet=(
119
- 'Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT'
120
- ' according to OUTPUT_TYPE.\n\nINPUT_OBJECT:\n 1 + 1'
121
- ' =\n\nOUTPUT_TYPE:\n Answer\n\n ```python\n class Answer:\n '
122
- ' final_answer: int\n ```\n\nOUTPUT_OBJECT:\n ```python\n '
123
- ' Answer(final_answer=2)\n ```\n\nINPUT_OBJECT:\n What is 1 +'
124
- ' 2?\n\nOUTPUT_TYPE:\n int\n\nOUTPUT_OBJECT:'
118
+ 'Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT '
119
+ 'according to OUTPUT_TYPE.\n\n'
120
+ 'INPUT_OBJECT:\n 1 + 1 =\n\n'
121
+ 'OUTPUT_TYPE:\n'
122
+ ' Answer\n\n'
123
+ ' ```python\n'
124
+ ' class Answer:\n'
125
+ ' final_answer: int\n'
126
+ ' ```\n\n'
127
+ 'OUTPUT_OBJECT:\n'
128
+ ' ```python\n'
129
+ ' Answer(\n'
130
+ ' final_answer=2\n'
131
+ ' )\n'
132
+ ' ```\n\n'
133
+ 'INPUT_OBJECT:\n'
134
+ ' What is 1 + 2?\n\n'
135
+ 'OUTPUT_TYPE:\n'
136
+ ' int\n\n'
137
+ 'OUTPUT_OBJECT:'
138
+ ),
139
+ )
140
+
141
+ def test_str_to_structure_render_custom_template(self):
142
+ lm = fake.StaticResponse('1')
143
+ self.assert_render(
144
+ 'What is {{x}} + {{y}}?',
145
+ int,
146
+ x=1,
147
+ y=2,
148
+ lm=lm.clone(),
149
+ template_str='!!{{ DEFAULT }}!!',
150
+ expected_snippet=(
151
+ '!!Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT '
152
+ 'according to OUTPUT_TYPE.\n\n'
153
+ 'INPUT_OBJECT:\n 1 + 1 =\n\n'
154
+ 'OUTPUT_TYPE:\n'
155
+ ' Answer\n\n'
156
+ ' ```python\n'
157
+ ' class Answer:\n'
158
+ ' final_answer: int\n'
159
+ ' ```\n\n'
160
+ 'OUTPUT_OBJECT:\n'
161
+ ' ```python\n'
162
+ ' Answer(\n'
163
+ ' final_answer=2\n'
164
+ ' )\n'
165
+ ' ```\n\n'
166
+ 'INPUT_OBJECT:\n'
167
+ ' What is 1 + 2?\n\n'
168
+ 'OUTPUT_TYPE:\n'
169
+ ' int\n\n'
170
+ 'OUTPUT_OBJECT:!!'
125
171
  ),
126
172
  )
127
173
 
@@ -264,7 +310,9 @@ class QueryStructurePythonTest(unittest.TestCase):
264
310
 
265
311
  OUTPUT_OBJECT:
266
312
  ```python
267
- Answer(final_answer=2)
313
+ Answer(
314
+ final_answer=2
315
+ )
268
316
  ```
269
317
 
270
318
  INPUT_OBJECT:
@@ -308,7 +356,9 @@ class QueryStructurePythonTest(unittest.TestCase):
308
356
 
309
357
  OUTPUT_OBJECT:
310
358
  ```python
311
- Answer(final_answer=2)
359
+ Answer(
360
+ final_answer=2
361
+ )
312
362
  ```
313
363
 
314
364
  INPUT_OBJECT:
@@ -420,7 +470,7 @@ class QueryStructurePythonTest(unittest.TestCase):
420
470
  override_attrs=True,
421
471
  ):
422
472
  with self.assertRaisesRegex(
423
- coding.CodeError,
473
+ mapping.MappingError,
424
474
  'name .* is not defined',
425
475
  ):
426
476
  prompting.query('Compute 1 + 2', int)
@@ -436,6 +486,23 @@ class QueryStructurePythonTest(unittest.TestCase):
436
486
  ])
437
487
  self.assertEqual(prompting.query('what is 1 + 0', int, lm=lm, autofix=3), 1)
438
488
 
489
+ def test_response_postprocess(self):
490
+ with lf.context(
491
+ lm=fake.StaticResponse('<!-- some comment-->\n3'),
492
+ override_attrs=True,
493
+ ):
494
+ self.assertEqual(
495
+ prompting.query(
496
+ 'Compute 1 + 2', response_postprocess=lambda x: x.split('\n')[1]),
497
+ '3'
498
+ )
499
+ self.assertEqual(
500
+ prompting.query(
501
+ 'Compute 1 + 2', int,
502
+ response_postprocess=lambda x: x.split('\n')[1]),
503
+ 3
504
+ )
505
+
439
506
 
440
507
  class QueryStructureJsonTest(unittest.TestCase):
441
508
 
@@ -641,7 +708,7 @@ class QueryStructureJsonTest(unittest.TestCase):
641
708
  override_attrs=True,
642
709
  ):
643
710
  with self.assertRaisesRegex(
644
- schema_lib.JsonError,
711
+ mapping.MappingError,
645
712
  'No JSON dict in the output',
646
713
  ):
647
714
  prompting.query('Compute 1 + 2', int, protocol='json')
@@ -55,10 +55,6 @@ def parse_value_spec(value) -> pg.typing.ValueSpec:
55
55
  ),
56
56
  ):
57
57
  raise ValueError(f'Unsupported schema specification: {v}')
58
- if isinstance(spec, pg.typing.Object) and not issubclass(
59
- spec.cls, pg.Symbolic
60
- ):
61
- raise ValueError(f'{v} must be a symbolic class to be parsable.')
62
58
  return spec
63
59
 
64
60
  return _parse_node(value)
@@ -208,7 +204,9 @@ def class_dependencies(
208
204
  if isinstance(value_or_spec, Schema):
209
205
  return value_or_spec.class_dependencies(include_subclasses)
210
206
 
211
- if isinstance(value_or_spec, (pg.typing.ValueSpec, pg.symbolic.ObjectMeta)):
207
+ if inspect.isclass(value_or_spec) or isinstance(
208
+ value_or_spec, pg.typing.ValueSpec
209
+ ):
212
210
  value_or_spec = (value_or_spec,)
213
211
 
214
212
  if isinstance(value_or_spec, tuple):
@@ -216,7 +214,7 @@ def class_dependencies(
216
214
  for v in value_or_spec:
217
215
  if isinstance(v, pg.typing.ValueSpec):
218
216
  value_specs.append(v)
219
- elif inspect.isclass(v) and issubclass(v, pg.Object):
217
+ elif inspect.isclass(v):
220
218
  value_specs.append(pg.typing.Object(v))
221
219
  else:
222
220
  raise TypeError(f'Unsupported spec type: {v!r}')
@@ -235,23 +233,20 @@ def class_dependencies(
235
233
 
236
234
  def _fill_dependencies(vs: pg.typing.ValueSpec, include_subclasses: bool):
237
235
  if isinstance(vs, pg.typing.Object):
238
- if issubclass(vs.cls, pg.Object) and vs.cls not in seen:
236
+ if vs.cls not in seen:
239
237
  seen.add(vs.cls)
240
238
 
241
239
  # Add base classes as dependencies.
242
240
  for base_cls in vs.cls.__bases__:
243
241
  # We only keep track of user-defined symbolic classes.
244
- if issubclass(
245
- base_cls, pg.Object
246
- ) and not base_cls.__module__.startswith('pyglove'):
242
+ if base_cls is not object and base_cls is not pg.Object:
247
243
  _fill_dependencies(
248
244
  pg.typing.Object(base_cls), include_subclasses=False
249
245
  )
250
246
 
251
247
  # Add members as dependencies.
252
- if hasattr(vs.cls, '__schema__'):
253
- for field in vs.cls.__schema__.values():
254
- _fill_dependencies(field.value, include_subclasses)
248
+ for field in _pg_schema(vs.cls).values():
249
+ _fill_dependencies(field.value, include_subclasses)
255
250
  _add_dependency(vs.cls)
256
251
 
257
252
  # Check subclasses if available.
@@ -364,17 +359,13 @@ def class_definition(
364
359
  ) -> str:
365
360
  """Returns the Python class definition."""
366
361
  out = io.StringIO()
367
- if not issubclass(cls, pg.Object):
368
- raise TypeError(
369
- 'Classes must be `pg.Object` subclasses to be used as schema. '
370
- f'Encountered: {cls}.'
371
- )
372
- schema = cls.__schema__
362
+ schema = _pg_schema(cls)
373
363
  eligible_bases = []
374
364
  for base_cls in cls.__bases__:
375
- if issubclass(base_cls, pg.Object):
365
+ if base_cls is not object:
376
366
  if include_pg_object_as_base or base_cls is not pg.Object:
377
367
  eligible_bases.append(base_cls.__name__)
368
+
378
369
  if eligible_bases:
379
370
  base_cls_str = ', '.join(eligible_bases)
380
371
  out.write(f'class {cls.__name__}({base_cls_str}):\n')
@@ -395,10 +386,12 @@ def class_definition(
395
386
  if schema.fields:
396
387
  for key, field in schema.items():
397
388
  if not isinstance(key, pg.typing.ConstStrKey):
398
- raise TypeError(
389
+ pg.logging.warning(
399
390
  'Variable-length keyword arguments is not supported in '
400
- f'structured parsing or query. Encountered: {field}'
391
+ f'structured parsing or query. Encountered: {cls}, Schema: {schema}'
401
392
  )
393
+ continue
394
+
402
395
  # Write field doc string as comments before the field definition.
403
396
  if field.description:
404
397
  for line in field.description.split('\n'):
@@ -839,3 +832,13 @@ class Unknown(pg.Object, pg.typing.CustomTyping):
839
832
 
840
833
 
841
834
  UNKNOWN = Unknown()
835
+
836
+
837
+ def _pg_schema(cls: Type[Any]) -> pg.Schema:
838
+ """Returns PyGlove schema for the constructor of a class."""
839
+ schema = getattr(cls, '__schema__', None)
840
+ if schema is None:
841
+ schema = pg.symbolic.callable_schema(
842
+ cls.__init__, auto_typing=True, auto_doc=True, remove_self=True
843
+ )
844
+ return schema
@@ -143,14 +143,14 @@ def generate_class(
143
143
 
144
144
 
145
145
  def classgen_example(
146
- class_name: str, prompt: str | pg.Symbolic, cls: Type[Any]
146
+ prompt: str | pg.Symbolic, cls: Type[Any]
147
147
  ) -> mapping.MappingExample:
148
148
  """Creates a class generation example."""
149
149
  if isinstance(prompt, lf.Template):
150
150
  prompt = prompt.render()
151
151
  return mapping.MappingExample(
152
152
  input=prompt,
153
- context=class_name,
153
+ context=cls.__name__,
154
154
  output=cls,
155
155
  )
156
156
 
@@ -168,7 +168,6 @@ def default_classgen_examples() -> list[mapping.MappingExample]:
168
168
 
169
169
  return [
170
170
  classgen_example(
171
- 'Solution',
172
171
  'How to evaluate an arithmetic expression?',
173
172
  Solution,
174
173
  )
@@ -14,8 +14,8 @@
14
14
  import inspect
15
15
  import unittest
16
16
 
17
- import langfun.core.coding as lf_coding
18
17
  from langfun.core.llms import fake
18
+ from langfun.core.structured import mapping
19
19
  from langfun.core.structured import schema_generation
20
20
 
21
21
 
@@ -92,7 +92,7 @@ class GenerateClassTest(unittest.TestCase):
92
92
  )
93
93
  self.assertIs(cls.__name__, 'B')
94
94
 
95
- with self.assertRaises(lf_coding.CodeError):
95
+ with self.assertRaises(mapping.MappingError):
96
96
  schema_generation.generate_class(
97
97
  'Foo',
98
98
  'Generate a Foo class with a field pointing to another class A',
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  """Tests for structured parsing."""
15
15
 
16
+ import dataclasses
16
17
  import inspect
17
18
  import typing
18
19
  import unittest
@@ -101,12 +102,7 @@ class SchemaTest(unittest.TestCase):
101
102
 
102
103
  self.assert_unsupported_annotation(typing.Type[int])
103
104
  self.assert_unsupported_annotation(typing.Union[int, str, bool])
104
-
105
- class X:
106
- pass
107
-
108
- # X must be a symbolic type to be parsable.
109
- self.assert_unsupported_annotation(X)
105
+ self.assert_unsupported_annotation(typing.Any)
110
106
 
111
107
  def test_schema_dict(self):
112
108
  schema = schema_lib.Schema([{'x': Itinerary}])
@@ -150,6 +146,25 @@ class SchemaTest(unittest.TestCase):
150
146
  schema = schema_lib.Schema([B])
151
147
  self.assertEqual(schema.class_dependencies(), [Foo, A, Bar, X, B])
152
148
 
149
+ def test_class_dependencies_non_pyglove(self):
150
+ class Baz:
151
+ def __init__(self, x: int):
152
+ pass
153
+
154
+ @dataclasses.dataclass(frozen=True)
155
+ class AA:
156
+ foo: tuple[Baz, int]
157
+
158
+ class XX(pg.Object):
159
+ pass
160
+
161
+ @dataclasses.dataclass(frozen=True)
162
+ class BB(AA):
163
+ foo2: Baz | XX
164
+
165
+ schema = schema_lib.Schema([AA])
166
+ self.assertEqual(schema.class_dependencies(), [Baz, AA, XX, BB])
167
+
153
168
  def test_schema_repr(self):
154
169
  schema = schema_lib.Schema([{'x': Itinerary}])
155
170
  self.assertEqual(
@@ -177,9 +192,9 @@ class SchemaTest(unittest.TestCase):
177
192
  self.assertEqual(schema.parse('{"result": 1}'), 1)
178
193
  schema = schema_lib.Schema(dict[str, int])
179
194
  self.assertEqual(
180
- schema.parse(
181
- '{"result": {"_type": "Unknown", "x": 1}}}', force_dict=True),
182
- dict(x=1))
195
+ schema.parse('{"result": {"x": 1}}}'),
196
+ dict(x=1)
197
+ )
183
198
  with self.assertRaisesRegex(
184
199
  schema_lib.SchemaError, 'Expect .* but encountered .*'):
185
200
  schema.parse('{"result": "def"}')
@@ -440,28 +455,22 @@ class SchemaPythonReprTest(unittest.TestCase):
440
455
  'class A(Object):\n pass\n',
441
456
  )
442
457
 
443
- class B:
444
- pass
445
-
446
- with self.assertRaisesRegex(
447
- TypeError, 'Classes must be `pg.Object` subclasses.*'):
448
- schema_lib.class_definition(B)
449
-
450
458
  class C(pg.Object):
451
459
  x: str
452
460
  __kwargs__: typing.Any
453
461
 
454
- with self.assertRaisesRegex(
455
- TypeError, 'Variable-length keyword arguments is not supported'):
456
- schema_lib.class_definition(C)
462
+ self.assertEqual(schema_lib.class_definition(C), 'class C:\n x: str\n')
457
463
 
458
464
  def test_repr(self):
459
465
  class Foo(pg.Object):
460
466
  x: int
461
467
 
462
- class Bar(pg.Object):
468
+ @dataclasses.dataclass(frozen=True)
469
+ class Bar:
470
+ """Class Bar."""
463
471
  y: str
464
472
 
473
+ @dataclasses.dataclass(frozen=True)
465
474
  class Baz(Bar): # pylint: disable=unused-variable
466
475
  pass
467
476
 
@@ -475,7 +484,7 @@ class SchemaPythonReprTest(unittest.TestCase):
475
484
  schema = schema_lib.Schema([B])
476
485
  self.assertEqual(
477
486
  schema_lib.SchemaPythonRepr().class_definitions(schema),
478
- inspect.cleandoc("""
487
+ inspect.cleandoc('''
479
488
  class Foo:
480
489
  x: int
481
490
 
@@ -483,16 +492,18 @@ class SchemaPythonReprTest(unittest.TestCase):
483
492
  foo: Foo
484
493
 
485
494
  class Bar:
495
+ """Class Bar."""
486
496
  y: str
487
497
 
488
498
  class Baz(Bar):
499
+ """Baz(y: str)"""
489
500
  y: str
490
501
 
491
502
  class B(A):
492
503
  foo: Foo
493
504
  bar: Bar
494
505
  foo2: Foo
495
- """) + '\n',
506
+ ''') + '\n',
496
507
  )
497
508
 
498
509
  self.assertEqual(
@@ -501,7 +512,7 @@ class SchemaPythonReprTest(unittest.TestCase):
501
512
 
502
513
  self.assertEqual(
503
514
  schema_lib.SchemaPythonRepr().repr(schema),
504
- inspect.cleandoc("""
515
+ inspect.cleandoc('''
505
516
  list[B]
506
517
 
507
518
  ```python
@@ -512,9 +523,11 @@ class SchemaPythonReprTest(unittest.TestCase):
512
523
  foo: Foo
513
524
 
514
525
  class Bar:
526
+ """Class Bar."""
515
527
  y: str
516
528
 
517
529
  class Baz(Bar):
530
+ """Baz(y: str)"""
518
531
  y: str
519
532
 
520
533
  class B(A):
@@ -522,7 +535,7 @@ class SchemaPythonReprTest(unittest.TestCase):
522
535
  bar: Bar
523
536
  foo2: Foo
524
537
  ```
525
- """),
538
+ '''),
526
539
  )
527
540
  self.assertEqual(
528
541
  schema_lib.SchemaPythonRepr().repr(
@@ -531,24 +544,26 @@ class SchemaPythonReprTest(unittest.TestCase):
531
544
  include_pg_object_as_base=True,
532
545
  markdown=False,
533
546
  ),
534
- inspect.cleandoc("""
547
+ inspect.cleandoc('''
535
548
  class Foo(Object):
536
549
  x: int
537
550
 
538
551
  class A(Object):
539
552
  foo: Foo
540
553
 
541
- class Bar(Object):
554
+ class Bar:
555
+ """Class Bar."""
542
556
  y: str
543
557
 
544
558
  class Baz(Bar):
559
+ """Baz(y: str)"""
545
560
  y: str
546
561
 
547
562
  class B(A):
548
563
  foo: Foo
549
564
  bar: Bar
550
565
  foo2: Foo
551
- """),
566
+ '''),
552
567
  )
553
568
 
554
569