langfun 0.0.2.dev20240319__py3-none-any.whl → 0.0.2.dev20240429__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +2 -0
- langfun/core/__init__.py +1 -0
- langfun/core/coding/python/correction.py +0 -7
- langfun/core/component.py +6 -0
- langfun/core/component_test.py +1 -0
- langfun/core/eval/__init__.py +2 -0
- langfun/core/eval/base.py +240 -37
- langfun/core/eval/base_test.py +52 -18
- langfun/core/eval/matching.py +26 -9
- langfun/core/eval/matching_test.py +3 -4
- langfun/core/eval/scoring.py +15 -6
- langfun/core/eval/scoring_test.py +2 -2
- langfun/core/langfunc.py +0 -5
- langfun/core/langfunc_test.py +6 -4
- langfun/core/language_model.py +124 -24
- langfun/core/language_model_test.py +249 -26
- langfun/core/llms/__init__.py +24 -5
- langfun/core/llms/anthropic.py +263 -0
- langfun/core/llms/anthropic_test.py +167 -0
- langfun/core/llms/cache/in_memory_test.py +37 -28
- langfun/core/llms/fake.py +31 -22
- langfun/core/llms/fake_test.py +122 -11
- langfun/core/llms/{gemini.py → google_genai.py} +117 -15
- langfun/core/llms/{gemini_test.py → google_genai_test.py} +83 -15
- langfun/core/llms/groq.py +260 -0
- langfun/core/llms/groq_test.py +170 -0
- langfun/core/llms/llama_cpp.py +3 -1
- langfun/core/llms/openai.py +97 -79
- langfun/core/llms/openai_test.py +285 -59
- langfun/core/modalities/video.py +5 -2
- langfun/core/structured/__init__.py +3 -0
- langfun/core/structured/completion_test.py +2 -2
- langfun/core/structured/function_generation.py +245 -0
- langfun/core/structured/function_generation_test.py +329 -0
- langfun/core/structured/mapping.py +59 -3
- langfun/core/structured/mapping_test.py +17 -0
- langfun/core/structured/parsing.py +2 -1
- langfun/core/structured/parsing_test.py +18 -13
- langfun/core/structured/prompting.py +27 -6
- langfun/core/structured/prompting_test.py +79 -12
- langfun/core/structured/schema.py +25 -22
- langfun/core/structured/schema_generation.py +2 -3
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +42 -27
- langfun/core/template.py +125 -10
- langfun/core/template_test.py +75 -0
- langfun/core/templates/selfplay_test.py +6 -2
- {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/METADATA +3 -2
- {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/RECORD +52 -46
- {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ from typing import Any, Callable, Type, Union
|
|
16
16
|
|
17
17
|
import langfun.core as lf
|
18
18
|
from langfun.core.structured import mapping
|
19
|
+
from langfun.core.structured import prompting
|
19
20
|
from langfun.core.structured import schema as schema_lib
|
20
21
|
import pyglove as pg
|
21
22
|
|
@@ -270,7 +271,7 @@ def call(
|
|
270
271
|
return lm_output if returns_message else lm_output.text
|
271
272
|
|
272
273
|
# Call `parsing_lm` for structured parsing.
|
273
|
-
return
|
274
|
+
return prompting.query(
|
274
275
|
lm_output,
|
275
276
|
schema,
|
276
277
|
examples=parsing_examples,
|
@@ -17,11 +17,9 @@ import inspect
|
|
17
17
|
import unittest
|
18
18
|
|
19
19
|
import langfun.core as lf
|
20
|
-
from langfun.core import coding
|
21
20
|
from langfun.core.llms import fake
|
22
21
|
from langfun.core.structured import mapping
|
23
22
|
from langfun.core.structured import parsing
|
24
|
-
from langfun.core.structured import schema as schema_lib
|
25
23
|
import pyglove as pg
|
26
24
|
|
27
25
|
|
@@ -255,7 +253,7 @@ class ParseStructurePythonTest(unittest.TestCase):
|
|
255
253
|
override_attrs=True,
|
256
254
|
):
|
257
255
|
with self.assertRaisesRegex(
|
258
|
-
|
256
|
+
mapping.MappingError,
|
259
257
|
'name .* is not defined',
|
260
258
|
):
|
261
259
|
parsing.parse('three', int)
|
@@ -280,13 +278,15 @@ class ParseStructurePythonTest(unittest.TestCase):
|
|
280
278
|
),
|
281
279
|
1,
|
282
280
|
)
|
281
|
+
r = parsing.parse(
|
282
|
+
'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm,
|
283
|
+
returns_message=True
|
284
|
+
)
|
283
285
|
self.assertEqual(
|
284
|
-
|
285
|
-
'the answer is 1', int, user_prompt='what is 0 + 1?', lm=lm,
|
286
|
-
returns_message=True
|
287
|
-
),
|
286
|
+
r,
|
288
287
|
lf.AIMessage(
|
289
288
|
'1', score=1.0, result=1, logprobs=None,
|
289
|
+
usage=lf.LMSamplingUsage(652, 1, 653),
|
290
290
|
tags=['lm-response', 'lm-output', 'transformed']
|
291
291
|
),
|
292
292
|
)
|
@@ -544,7 +544,7 @@ class ParseStructureJsonTest(unittest.TestCase):
|
|
544
544
|
override_attrs=True,
|
545
545
|
):
|
546
546
|
with self.assertRaisesRegex(
|
547
|
-
|
547
|
+
mapping.MappingError,
|
548
548
|
'No JSON dict in the output',
|
549
549
|
):
|
550
550
|
parsing.parse('three', int, protocol='json')
|
@@ -634,13 +634,18 @@ class CallTest(unittest.TestCase):
|
|
634
634
|
)
|
635
635
|
|
636
636
|
def test_call_with_returning_message(self):
|
637
|
+
r = parsing.call(
|
638
|
+
'Compute 1 + 2', int, lm=fake.StaticSequence(['three', '3']),
|
639
|
+
returns_message=True
|
640
|
+
)
|
637
641
|
self.assertEqual(
|
638
|
-
|
639
|
-
'Compute 1 + 2', int, lm=fake.StaticSequence(['three', '3']),
|
640
|
-
returns_message=True
|
641
|
-
),
|
642
|
+
r,
|
642
643
|
lf.AIMessage(
|
643
|
-
'3',
|
644
|
+
'3',
|
645
|
+
result=3,
|
646
|
+
score=1.0,
|
647
|
+
logprobs=None,
|
648
|
+
usage=lf.LMSamplingUsage(315, 1, 316),
|
644
649
|
tags=['lm-response', 'lm-output', 'transformed']
|
645
650
|
),
|
646
651
|
)
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Symbolic query."""
|
15
15
|
|
16
|
-
from typing import Any, Type, Union
|
16
|
+
from typing import Any, Callable, Type, Union
|
17
17
|
|
18
18
|
import langfun.core as lf
|
19
19
|
from langfun.core.structured import mapping
|
@@ -78,7 +78,9 @@ class QueryStructurePython(QueryStructure):
|
|
78
78
|
|
79
79
|
{{ output_title }}:
|
80
80
|
```python
|
81
|
-
Answer(
|
81
|
+
Answer(
|
82
|
+
final_answer=2
|
83
|
+
)
|
82
84
|
```
|
83
85
|
"""
|
84
86
|
protocol = 'python'
|
@@ -107,6 +109,7 @@ def query(
|
|
107
109
|
lm: lf.LanguageModel | None = None,
|
108
110
|
examples: list[mapping.MappingExample] | None = None,
|
109
111
|
cache_seed: int | None = 0,
|
112
|
+
response_postprocess: Callable[[str], str] | None = None,
|
110
113
|
autofix: int = 0,
|
111
114
|
autofix_lm: lf.LanguageModel | None = None,
|
112
115
|
protocol: schema_lib.SchemaProtocol = 'python',
|
@@ -159,6 +162,9 @@ def query(
|
|
159
162
|
cache_seed: Seed for computing cache key. The cache key is determined by a
|
160
163
|
tuple of (lm, prompt, cache seed). If None, cache will be disabled for
|
161
164
|
the query even cache is configured by the LM.
|
165
|
+
response_postprocess: An optional callable object to process the raw LM
|
166
|
+
response before parsing it into the final output object. If None, the
|
167
|
+
raw LM response will not be processed.
|
162
168
|
autofix: Number of attempts to auto fix the generated code. If 0, autofix is
|
163
169
|
disabled. Auto-fix is not supported for 'json' protocol.
|
164
170
|
autofix_lm: The language model to use for autofix. If not specified, the
|
@@ -170,8 +176,11 @@ def query(
|
|
170
176
|
returning the structured `message.result`.
|
171
177
|
skip_lm: If True, returns the rendered prompt as a UserMessage object.
|
172
178
|
otherwise return the LLM response based on the rendered prompt.
|
173
|
-
**kwargs: Keyword arguments passed to the
|
174
|
-
`lf.structured.
|
179
|
+
**kwargs: Keyword arguments passed to render the prompt or configure the
|
180
|
+
`lf.structured.Mapping` class. Notable kwargs are:
|
181
|
+
- template_str: Change the root template for query.
|
182
|
+
- preamble: Change the preamble for query.
|
183
|
+
- mapping_template: Change the template for each mapping examle.
|
175
184
|
|
176
185
|
Returns:
|
177
186
|
The result based on the schema.
|
@@ -188,13 +197,24 @@ def query(
|
|
188
197
|
output = lf.LangFunc.from_value(prompt, **kwargs)(
|
189
198
|
lm=lm, cache_seed=cache_seed, skip_lm=skip_lm
|
190
199
|
)
|
200
|
+
if response_postprocess:
|
201
|
+
processed_text = response_postprocess(output.text)
|
202
|
+
if processed_text != output.text:
|
203
|
+
output = lf.AIMessage(processed_text, source=output)
|
191
204
|
return output if returns_message else output.text
|
192
205
|
|
193
206
|
# Query with structured output.
|
207
|
+
prompt_kwargs = kwargs.copy()
|
208
|
+
|
209
|
+
# NOTE(daiyip): when `template_str` is passed in, it's intended to modify the
|
210
|
+
# QueryStructure template string. Therefore, we pop out the argument for
|
211
|
+
# prompt rendering.
|
212
|
+
prompt_kwargs.pop('template_str', None)
|
213
|
+
|
194
214
|
if isinstance(prompt, str):
|
195
|
-
prompt = lf.Template(prompt, **
|
215
|
+
prompt = lf.Template(prompt, **prompt_kwargs)
|
196
216
|
elif isinstance(prompt, lf.Template):
|
197
|
-
prompt = prompt.rebind(**
|
217
|
+
prompt = prompt.rebind(**prompt_kwargs, raise_on_no_change=False)
|
198
218
|
|
199
219
|
if isinstance(prompt, lf.Template):
|
200
220
|
prompt = prompt.render(lm=lm)
|
@@ -206,6 +226,7 @@ def query(
|
|
206
226
|
schema=schema,
|
207
227
|
default=default,
|
208
228
|
examples=examples,
|
229
|
+
response_postprocess=response_postprocess,
|
209
230
|
autofix=autofix if protocol == 'python' else 0,
|
210
231
|
**kwargs,
|
211
232
|
)(
|
@@ -17,12 +17,10 @@ import inspect
|
|
17
17
|
import unittest
|
18
18
|
|
19
19
|
import langfun.core as lf
|
20
|
-
from langfun.core import coding
|
21
20
|
from langfun.core import modalities
|
22
21
|
from langfun.core.llms import fake
|
23
22
|
from langfun.core.structured import mapping
|
24
23
|
from langfun.core.structured import prompting
|
25
|
-
from langfun.core.structured import schema as schema_lib
|
26
24
|
import pyglove as pg
|
27
25
|
|
28
26
|
|
@@ -77,6 +75,7 @@ class QueryTest(unittest.TestCase):
|
|
77
75
|
result=1,
|
78
76
|
score=1.0,
|
79
77
|
logprobs=None,
|
78
|
+
usage=lf.LMSamplingUsage(323, 1, 324),
|
80
79
|
tags=['lm-response', 'lm-output', 'transformed'],
|
81
80
|
),
|
82
81
|
)
|
@@ -116,12 +115,59 @@ class QueryTest(unittest.TestCase):
|
|
116
115
|
y=2,
|
117
116
|
lm=lm.clone(),
|
118
117
|
expected_snippet=(
|
119
|
-
'Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT'
|
120
|
-
'
|
121
|
-
'
|
122
|
-
'
|
123
|
-
'
|
124
|
-
'
|
118
|
+
'Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT '
|
119
|
+
'according to OUTPUT_TYPE.\n\n'
|
120
|
+
'INPUT_OBJECT:\n 1 + 1 =\n\n'
|
121
|
+
'OUTPUT_TYPE:\n'
|
122
|
+
' Answer\n\n'
|
123
|
+
' ```python\n'
|
124
|
+
' class Answer:\n'
|
125
|
+
' final_answer: int\n'
|
126
|
+
' ```\n\n'
|
127
|
+
'OUTPUT_OBJECT:\n'
|
128
|
+
' ```python\n'
|
129
|
+
' Answer(\n'
|
130
|
+
' final_answer=2\n'
|
131
|
+
' )\n'
|
132
|
+
' ```\n\n'
|
133
|
+
'INPUT_OBJECT:\n'
|
134
|
+
' What is 1 + 2?\n\n'
|
135
|
+
'OUTPUT_TYPE:\n'
|
136
|
+
' int\n\n'
|
137
|
+
'OUTPUT_OBJECT:'
|
138
|
+
),
|
139
|
+
)
|
140
|
+
|
141
|
+
def test_str_to_structure_render_custom_template(self):
|
142
|
+
lm = fake.StaticResponse('1')
|
143
|
+
self.assert_render(
|
144
|
+
'What is {{x}} + {{y}}?',
|
145
|
+
int,
|
146
|
+
x=1,
|
147
|
+
y=2,
|
148
|
+
lm=lm.clone(),
|
149
|
+
template_str='!!{{ DEFAULT }}!!',
|
150
|
+
expected_snippet=(
|
151
|
+
'!!Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT '
|
152
|
+
'according to OUTPUT_TYPE.\n\n'
|
153
|
+
'INPUT_OBJECT:\n 1 + 1 =\n\n'
|
154
|
+
'OUTPUT_TYPE:\n'
|
155
|
+
' Answer\n\n'
|
156
|
+
' ```python\n'
|
157
|
+
' class Answer:\n'
|
158
|
+
' final_answer: int\n'
|
159
|
+
' ```\n\n'
|
160
|
+
'OUTPUT_OBJECT:\n'
|
161
|
+
' ```python\n'
|
162
|
+
' Answer(\n'
|
163
|
+
' final_answer=2\n'
|
164
|
+
' )\n'
|
165
|
+
' ```\n\n'
|
166
|
+
'INPUT_OBJECT:\n'
|
167
|
+
' What is 1 + 2?\n\n'
|
168
|
+
'OUTPUT_TYPE:\n'
|
169
|
+
' int\n\n'
|
170
|
+
'OUTPUT_OBJECT:!!'
|
125
171
|
),
|
126
172
|
)
|
127
173
|
|
@@ -264,7 +310,9 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
264
310
|
|
265
311
|
OUTPUT_OBJECT:
|
266
312
|
```python
|
267
|
-
Answer(
|
313
|
+
Answer(
|
314
|
+
final_answer=2
|
315
|
+
)
|
268
316
|
```
|
269
317
|
|
270
318
|
INPUT_OBJECT:
|
@@ -308,7 +356,9 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
308
356
|
|
309
357
|
OUTPUT_OBJECT:
|
310
358
|
```python
|
311
|
-
Answer(
|
359
|
+
Answer(
|
360
|
+
final_answer=2
|
361
|
+
)
|
312
362
|
```
|
313
363
|
|
314
364
|
INPUT_OBJECT:
|
@@ -420,7 +470,7 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
420
470
|
override_attrs=True,
|
421
471
|
):
|
422
472
|
with self.assertRaisesRegex(
|
423
|
-
|
473
|
+
mapping.MappingError,
|
424
474
|
'name .* is not defined',
|
425
475
|
):
|
426
476
|
prompting.query('Compute 1 + 2', int)
|
@@ -436,6 +486,23 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
436
486
|
])
|
437
487
|
self.assertEqual(prompting.query('what is 1 + 0', int, lm=lm, autofix=3), 1)
|
438
488
|
|
489
|
+
def test_response_postprocess(self):
|
490
|
+
with lf.context(
|
491
|
+
lm=fake.StaticResponse('<!-- some comment-->\n3'),
|
492
|
+
override_attrs=True,
|
493
|
+
):
|
494
|
+
self.assertEqual(
|
495
|
+
prompting.query(
|
496
|
+
'Compute 1 + 2', response_postprocess=lambda x: x.split('\n')[1]),
|
497
|
+
'3'
|
498
|
+
)
|
499
|
+
self.assertEqual(
|
500
|
+
prompting.query(
|
501
|
+
'Compute 1 + 2', int,
|
502
|
+
response_postprocess=lambda x: x.split('\n')[1]),
|
503
|
+
3
|
504
|
+
)
|
505
|
+
|
439
506
|
|
440
507
|
class QueryStructureJsonTest(unittest.TestCase):
|
441
508
|
|
@@ -641,7 +708,7 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
641
708
|
override_attrs=True,
|
642
709
|
):
|
643
710
|
with self.assertRaisesRegex(
|
644
|
-
|
711
|
+
mapping.MappingError,
|
645
712
|
'No JSON dict in the output',
|
646
713
|
):
|
647
714
|
prompting.query('Compute 1 + 2', int, protocol='json')
|
@@ -55,10 +55,6 @@ def parse_value_spec(value) -> pg.typing.ValueSpec:
|
|
55
55
|
),
|
56
56
|
):
|
57
57
|
raise ValueError(f'Unsupported schema specification: {v}')
|
58
|
-
if isinstance(spec, pg.typing.Object) and not issubclass(
|
59
|
-
spec.cls, pg.Symbolic
|
60
|
-
):
|
61
|
-
raise ValueError(f'{v} must be a symbolic class to be parsable.')
|
62
58
|
return spec
|
63
59
|
|
64
60
|
return _parse_node(value)
|
@@ -208,7 +204,9 @@ def class_dependencies(
|
|
208
204
|
if isinstance(value_or_spec, Schema):
|
209
205
|
return value_or_spec.class_dependencies(include_subclasses)
|
210
206
|
|
211
|
-
if
|
207
|
+
if inspect.isclass(value_or_spec) or isinstance(
|
208
|
+
value_or_spec, pg.typing.ValueSpec
|
209
|
+
):
|
212
210
|
value_or_spec = (value_or_spec,)
|
213
211
|
|
214
212
|
if isinstance(value_or_spec, tuple):
|
@@ -216,7 +214,7 @@ def class_dependencies(
|
|
216
214
|
for v in value_or_spec:
|
217
215
|
if isinstance(v, pg.typing.ValueSpec):
|
218
216
|
value_specs.append(v)
|
219
|
-
elif inspect.isclass(v)
|
217
|
+
elif inspect.isclass(v):
|
220
218
|
value_specs.append(pg.typing.Object(v))
|
221
219
|
else:
|
222
220
|
raise TypeError(f'Unsupported spec type: {v!r}')
|
@@ -235,23 +233,20 @@ def class_dependencies(
|
|
235
233
|
|
236
234
|
def _fill_dependencies(vs: pg.typing.ValueSpec, include_subclasses: bool):
|
237
235
|
if isinstance(vs, pg.typing.Object):
|
238
|
-
if
|
236
|
+
if vs.cls not in seen:
|
239
237
|
seen.add(vs.cls)
|
240
238
|
|
241
239
|
# Add base classes as dependencies.
|
242
240
|
for base_cls in vs.cls.__bases__:
|
243
241
|
# We only keep track of user-defined symbolic classes.
|
244
|
-
if
|
245
|
-
base_cls, pg.Object
|
246
|
-
) and not base_cls.__module__.startswith('pyglove'):
|
242
|
+
if base_cls is not object and base_cls is not pg.Object:
|
247
243
|
_fill_dependencies(
|
248
244
|
pg.typing.Object(base_cls), include_subclasses=False
|
249
245
|
)
|
250
246
|
|
251
247
|
# Add members as dependencies.
|
252
|
-
|
253
|
-
|
254
|
-
_fill_dependencies(field.value, include_subclasses)
|
248
|
+
for field in _pg_schema(vs.cls).values():
|
249
|
+
_fill_dependencies(field.value, include_subclasses)
|
255
250
|
_add_dependency(vs.cls)
|
256
251
|
|
257
252
|
# Check subclasses if available.
|
@@ -364,17 +359,13 @@ def class_definition(
|
|
364
359
|
) -> str:
|
365
360
|
"""Returns the Python class definition."""
|
366
361
|
out = io.StringIO()
|
367
|
-
|
368
|
-
raise TypeError(
|
369
|
-
'Classes must be `pg.Object` subclasses to be used as schema. '
|
370
|
-
f'Encountered: {cls}.'
|
371
|
-
)
|
372
|
-
schema = cls.__schema__
|
362
|
+
schema = _pg_schema(cls)
|
373
363
|
eligible_bases = []
|
374
364
|
for base_cls in cls.__bases__:
|
375
|
-
if
|
365
|
+
if base_cls is not object:
|
376
366
|
if include_pg_object_as_base or base_cls is not pg.Object:
|
377
367
|
eligible_bases.append(base_cls.__name__)
|
368
|
+
|
378
369
|
if eligible_bases:
|
379
370
|
base_cls_str = ', '.join(eligible_bases)
|
380
371
|
out.write(f'class {cls.__name__}({base_cls_str}):\n')
|
@@ -395,10 +386,12 @@ def class_definition(
|
|
395
386
|
if schema.fields:
|
396
387
|
for key, field in schema.items():
|
397
388
|
if not isinstance(key, pg.typing.ConstStrKey):
|
398
|
-
|
389
|
+
pg.logging.warning(
|
399
390
|
'Variable-length keyword arguments is not supported in '
|
400
|
-
f'structured parsing or query. Encountered: {
|
391
|
+
f'structured parsing or query. Encountered: {cls}, Schema: {schema}'
|
401
392
|
)
|
393
|
+
continue
|
394
|
+
|
402
395
|
# Write field doc string as comments before the field definition.
|
403
396
|
if field.description:
|
404
397
|
for line in field.description.split('\n'):
|
@@ -839,3 +832,13 @@ class Unknown(pg.Object, pg.typing.CustomTyping):
|
|
839
832
|
|
840
833
|
|
841
834
|
UNKNOWN = Unknown()
|
835
|
+
|
836
|
+
|
837
|
+
def _pg_schema(cls: Type[Any]) -> pg.Schema:
|
838
|
+
"""Returns PyGlove schema for the constructor of a class."""
|
839
|
+
schema = getattr(cls, '__schema__', None)
|
840
|
+
if schema is None:
|
841
|
+
schema = pg.symbolic.callable_schema(
|
842
|
+
cls.__init__, auto_typing=True, auto_doc=True, remove_self=True
|
843
|
+
)
|
844
|
+
return schema
|
@@ -143,14 +143,14 @@ def generate_class(
|
|
143
143
|
|
144
144
|
|
145
145
|
def classgen_example(
|
146
|
-
|
146
|
+
prompt: str | pg.Symbolic, cls: Type[Any]
|
147
147
|
) -> mapping.MappingExample:
|
148
148
|
"""Creates a class generation example."""
|
149
149
|
if isinstance(prompt, lf.Template):
|
150
150
|
prompt = prompt.render()
|
151
151
|
return mapping.MappingExample(
|
152
152
|
input=prompt,
|
153
|
-
context=
|
153
|
+
context=cls.__name__,
|
154
154
|
output=cls,
|
155
155
|
)
|
156
156
|
|
@@ -168,7 +168,6 @@ def default_classgen_examples() -> list[mapping.MappingExample]:
|
|
168
168
|
|
169
169
|
return [
|
170
170
|
classgen_example(
|
171
|
-
'Solution',
|
172
171
|
'How to evaluate an arithmetic expression?',
|
173
172
|
Solution,
|
174
173
|
)
|
@@ -14,8 +14,8 @@
|
|
14
14
|
import inspect
|
15
15
|
import unittest
|
16
16
|
|
17
|
-
import langfun.core.coding as lf_coding
|
18
17
|
from langfun.core.llms import fake
|
18
|
+
from langfun.core.structured import mapping
|
19
19
|
from langfun.core.structured import schema_generation
|
20
20
|
|
21
21
|
|
@@ -92,7 +92,7 @@ class GenerateClassTest(unittest.TestCase):
|
|
92
92
|
)
|
93
93
|
self.assertIs(cls.__name__, 'B')
|
94
94
|
|
95
|
-
with self.assertRaises(
|
95
|
+
with self.assertRaises(mapping.MappingError):
|
96
96
|
schema_generation.generate_class(
|
97
97
|
'Foo',
|
98
98
|
'Generate a Foo class with a field pointing to another class A',
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Tests for structured parsing."""
|
15
15
|
|
16
|
+
import dataclasses
|
16
17
|
import inspect
|
17
18
|
import typing
|
18
19
|
import unittest
|
@@ -101,12 +102,7 @@ class SchemaTest(unittest.TestCase):
|
|
101
102
|
|
102
103
|
self.assert_unsupported_annotation(typing.Type[int])
|
103
104
|
self.assert_unsupported_annotation(typing.Union[int, str, bool])
|
104
|
-
|
105
|
-
class X:
|
106
|
-
pass
|
107
|
-
|
108
|
-
# X must be a symbolic type to be parsable.
|
109
|
-
self.assert_unsupported_annotation(X)
|
105
|
+
self.assert_unsupported_annotation(typing.Any)
|
110
106
|
|
111
107
|
def test_schema_dict(self):
|
112
108
|
schema = schema_lib.Schema([{'x': Itinerary}])
|
@@ -150,6 +146,25 @@ class SchemaTest(unittest.TestCase):
|
|
150
146
|
schema = schema_lib.Schema([B])
|
151
147
|
self.assertEqual(schema.class_dependencies(), [Foo, A, Bar, X, B])
|
152
148
|
|
149
|
+
def test_class_dependencies_non_pyglove(self):
|
150
|
+
class Baz:
|
151
|
+
def __init__(self, x: int):
|
152
|
+
pass
|
153
|
+
|
154
|
+
@dataclasses.dataclass(frozen=True)
|
155
|
+
class AA:
|
156
|
+
foo: tuple[Baz, int]
|
157
|
+
|
158
|
+
class XX(pg.Object):
|
159
|
+
pass
|
160
|
+
|
161
|
+
@dataclasses.dataclass(frozen=True)
|
162
|
+
class BB(AA):
|
163
|
+
foo2: Baz | XX
|
164
|
+
|
165
|
+
schema = schema_lib.Schema([AA])
|
166
|
+
self.assertEqual(schema.class_dependencies(), [Baz, AA, XX, BB])
|
167
|
+
|
153
168
|
def test_schema_repr(self):
|
154
169
|
schema = schema_lib.Schema([{'x': Itinerary}])
|
155
170
|
self.assertEqual(
|
@@ -177,9 +192,9 @@ class SchemaTest(unittest.TestCase):
|
|
177
192
|
self.assertEqual(schema.parse('{"result": 1}'), 1)
|
178
193
|
schema = schema_lib.Schema(dict[str, int])
|
179
194
|
self.assertEqual(
|
180
|
-
schema.parse(
|
181
|
-
|
182
|
-
|
195
|
+
schema.parse('{"result": {"x": 1}}}'),
|
196
|
+
dict(x=1)
|
197
|
+
)
|
183
198
|
with self.assertRaisesRegex(
|
184
199
|
schema_lib.SchemaError, 'Expect .* but encountered .*'):
|
185
200
|
schema.parse('{"result": "def"}')
|
@@ -440,28 +455,22 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
440
455
|
'class A(Object):\n pass\n',
|
441
456
|
)
|
442
457
|
|
443
|
-
class B:
|
444
|
-
pass
|
445
|
-
|
446
|
-
with self.assertRaisesRegex(
|
447
|
-
TypeError, 'Classes must be `pg.Object` subclasses.*'):
|
448
|
-
schema_lib.class_definition(B)
|
449
|
-
|
450
458
|
class C(pg.Object):
|
451
459
|
x: str
|
452
460
|
__kwargs__: typing.Any
|
453
461
|
|
454
|
-
|
455
|
-
TypeError, 'Variable-length keyword arguments is not supported'):
|
456
|
-
schema_lib.class_definition(C)
|
462
|
+
self.assertEqual(schema_lib.class_definition(C), 'class C:\n x: str\n')
|
457
463
|
|
458
464
|
def test_repr(self):
|
459
465
|
class Foo(pg.Object):
|
460
466
|
x: int
|
461
467
|
|
462
|
-
|
468
|
+
@dataclasses.dataclass(frozen=True)
|
469
|
+
class Bar:
|
470
|
+
"""Class Bar."""
|
463
471
|
y: str
|
464
472
|
|
473
|
+
@dataclasses.dataclass(frozen=True)
|
465
474
|
class Baz(Bar): # pylint: disable=unused-variable
|
466
475
|
pass
|
467
476
|
|
@@ -475,7 +484,7 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
475
484
|
schema = schema_lib.Schema([B])
|
476
485
|
self.assertEqual(
|
477
486
|
schema_lib.SchemaPythonRepr().class_definitions(schema),
|
478
|
-
inspect.cleandoc(
|
487
|
+
inspect.cleandoc('''
|
479
488
|
class Foo:
|
480
489
|
x: int
|
481
490
|
|
@@ -483,16 +492,18 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
483
492
|
foo: Foo
|
484
493
|
|
485
494
|
class Bar:
|
495
|
+
"""Class Bar."""
|
486
496
|
y: str
|
487
497
|
|
488
498
|
class Baz(Bar):
|
499
|
+
"""Baz(y: str)"""
|
489
500
|
y: str
|
490
501
|
|
491
502
|
class B(A):
|
492
503
|
foo: Foo
|
493
504
|
bar: Bar
|
494
505
|
foo2: Foo
|
495
|
-
|
506
|
+
''') + '\n',
|
496
507
|
)
|
497
508
|
|
498
509
|
self.assertEqual(
|
@@ -501,7 +512,7 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
501
512
|
|
502
513
|
self.assertEqual(
|
503
514
|
schema_lib.SchemaPythonRepr().repr(schema),
|
504
|
-
inspect.cleandoc(
|
515
|
+
inspect.cleandoc('''
|
505
516
|
list[B]
|
506
517
|
|
507
518
|
```python
|
@@ -512,9 +523,11 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
512
523
|
foo: Foo
|
513
524
|
|
514
525
|
class Bar:
|
526
|
+
"""Class Bar."""
|
515
527
|
y: str
|
516
528
|
|
517
529
|
class Baz(Bar):
|
530
|
+
"""Baz(y: str)"""
|
518
531
|
y: str
|
519
532
|
|
520
533
|
class B(A):
|
@@ -522,7 +535,7 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
522
535
|
bar: Bar
|
523
536
|
foo2: Foo
|
524
537
|
```
|
525
|
-
|
538
|
+
'''),
|
526
539
|
)
|
527
540
|
self.assertEqual(
|
528
541
|
schema_lib.SchemaPythonRepr().repr(
|
@@ -531,24 +544,26 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
531
544
|
include_pg_object_as_base=True,
|
532
545
|
markdown=False,
|
533
546
|
),
|
534
|
-
inspect.cleandoc(
|
547
|
+
inspect.cleandoc('''
|
535
548
|
class Foo(Object):
|
536
549
|
x: int
|
537
550
|
|
538
551
|
class A(Object):
|
539
552
|
foo: Foo
|
540
553
|
|
541
|
-
class Bar
|
554
|
+
class Bar:
|
555
|
+
"""Class Bar."""
|
542
556
|
y: str
|
543
557
|
|
544
558
|
class Baz(Bar):
|
559
|
+
"""Baz(y: str)"""
|
545
560
|
y: str
|
546
561
|
|
547
562
|
class B(A):
|
548
563
|
foo: Foo
|
549
564
|
bar: Bar
|
550
565
|
foo2: Foo
|
551
|
-
|
566
|
+
'''),
|
552
567
|
)
|
553
568
|
|
554
569
|
|