langfun 0.0.2.dev20240315__py3-none-any.whl → 0.0.2.dev20240316__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +3 -0
- langfun/core/structured/__init__.py +6 -2
- langfun/core/structured/description.py +53 -50
- langfun/core/structured/mapping.py +14 -12
- langfun/core/structured/parsing.py +18 -16
- langfun/core/structured/schema.py +56 -16
- langfun/core/structured/schema_generation.py +175 -0
- langfun/core/structured/schema_generation_test.py +104 -0
- langfun/core/structured/schema_test.py +44 -0
- {langfun-0.0.2.dev20240315.dist-info → langfun-0.0.2.dev20240316.dist-info}/METADATA +1 -1
- {langfun-0.0.2.dev20240315.dist-info → langfun-0.0.2.dev20240316.dist-info}/RECORD +14 -12
- {langfun-0.0.2.dev20240315.dist-info → langfun-0.0.2.dev20240316.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240315.dist-info → langfun-0.0.2.dev20240316.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240315.dist-info → langfun-0.0.2.dev20240316.dist-info}/top_level.txt +0 -0
langfun/__init__.py
CHANGED
@@ -31,6 +31,9 @@ query = structured.query
|
|
31
31
|
describe = structured.describe
|
32
32
|
complete = structured.complete
|
33
33
|
score = structured.score
|
34
|
+
generate_class = structured.generate_class
|
35
|
+
|
36
|
+
source_form = structured.source_form
|
34
37
|
|
35
38
|
from langfun.core import eval # pylint: disable=redefined-builtin
|
36
39
|
from langfun.core import templates
|
@@ -41,8 +41,12 @@ from langfun.core.structured.schema import ValueRepr
|
|
41
41
|
from langfun.core.structured.schema import ValueJsonRepr
|
42
42
|
from langfun.core.structured.schema import ValuePythonRepr
|
43
43
|
from langfun.core.structured.schema import schema_repr
|
44
|
+
from langfun.core.structured.schema import source_form
|
44
45
|
from langfun.core.structured.schema import value_repr
|
45
46
|
|
47
|
+
from langfun.core.structured.schema_generation import generate_class
|
48
|
+
from langfun.core.structured.schema_generation import classgen_example
|
49
|
+
from langfun.core.structured.schema_generation import default_classgen_examples
|
46
50
|
|
47
51
|
from langfun.core.structured.mapping import Mapping
|
48
52
|
from langfun.core.structured.mapping import MappingExample
|
@@ -68,8 +72,8 @@ from langfun.core.structured.scoring import score
|
|
68
72
|
|
69
73
|
# Expose default examples for structured operations so users could refer to
|
70
74
|
# them.
|
71
|
-
from langfun.core.structured.parsing import
|
72
|
-
from langfun.core.structured.description import
|
75
|
+
from langfun.core.structured.parsing import default_parse_examples
|
76
|
+
from langfun.core.structured.description import default_describe_examples
|
73
77
|
|
74
78
|
# Default examples.
|
75
79
|
|
@@ -106,58 +106,61 @@ def describe(
|
|
106
106
|
Returns:
|
107
107
|
The parsed result based on the schema.
|
108
108
|
"""
|
109
|
-
if examples is None:
|
110
|
-
examples = DEFAULT_DESCRIBE_EXAMPLES
|
111
109
|
return DescribeStructure(
|
112
|
-
input=value,
|
110
|
+
input=value,
|
111
|
+
context=context,
|
112
|
+
examples=examples or default_describe_examples(),
|
113
|
+
**kwargs,
|
113
114
|
)(lm=lm, cache_seed=cache_seed).text
|
114
115
|
|
115
116
|
|
116
|
-
|
117
|
-
"""
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
117
|
+
def default_describe_examples() -> list[mapping.MappingExample]:
|
118
|
+
"""Default describe examples."""
|
119
|
+
|
120
|
+
class Country(pg.Object):
|
121
|
+
"""A example dataclass for structured mapping."""
|
122
|
+
|
123
|
+
name: str
|
124
|
+
continents: list[
|
125
|
+
Literal[
|
126
|
+
'Africa',
|
127
|
+
'Asia',
|
128
|
+
'Europe',
|
129
|
+
'Oceania',
|
130
|
+
'North America',
|
131
|
+
'South America',
|
132
|
+
]
|
133
|
+
]
|
134
|
+
num_states: int
|
135
|
+
neighbor_countries: list[str]
|
136
|
+
population: int
|
137
|
+
capital: str | None
|
138
|
+
president: str | None
|
139
|
+
|
140
|
+
return [
|
141
|
+
mapping.MappingExample(
|
142
|
+
context='Brief intro to United States',
|
143
|
+
input=Country(
|
144
|
+
name='The United States of America',
|
145
|
+
continents=['North America'],
|
146
|
+
num_states=50,
|
147
|
+
neighbor_countries=[
|
148
|
+
'Canada',
|
149
|
+
'Mexico',
|
150
|
+
'Bahamas',
|
151
|
+
'Cuba',
|
152
|
+
'Russia',
|
153
|
+
],
|
154
|
+
population=333000000,
|
155
|
+
capital='Washington, D.C',
|
156
|
+
president=None,
|
157
|
+
),
|
158
|
+
output=inspect.cleandoc("""
|
159
|
+
The United States of America is a country primarily located in North America
|
160
|
+
consisting of fifty states. It shares land borders with Canada to its north
|
161
|
+
and with Mexico to its south and has maritime borders with the Bahamas, Cuba,
|
162
|
+
Russia, and other nations. With a population of over 333 million. The national
|
163
|
+
capital of the United States is Washington, D.C.
|
164
|
+
"""),
|
165
|
+
),
|
129
166
|
]
|
130
|
-
num_states: int
|
131
|
-
neighbor_countries: list[str]
|
132
|
-
population: int
|
133
|
-
capital: str | None
|
134
|
-
president: str | None
|
135
|
-
|
136
|
-
|
137
|
-
DEFAULT_DESCRIBE_EXAMPLES: list[mapping.MappingExample] = [
|
138
|
-
mapping.MappingExample(
|
139
|
-
context='Brief intro to United States',
|
140
|
-
input=_Country(
|
141
|
-
name='The United States of America',
|
142
|
-
continents=['North America'],
|
143
|
-
num_states=50,
|
144
|
-
neighbor_countries=[
|
145
|
-
'Canada',
|
146
|
-
'Mexico',
|
147
|
-
'Bahamas',
|
148
|
-
'Cuba',
|
149
|
-
'Russia',
|
150
|
-
],
|
151
|
-
population=333000000,
|
152
|
-
capital='Washington, D.C',
|
153
|
-
president=None,
|
154
|
-
),
|
155
|
-
output=inspect.cleandoc("""
|
156
|
-
The United States of America is a country primarily located in North America
|
157
|
-
consisting of fifty states. It shares land borders with Canada to its north
|
158
|
-
and with Mexico to its south and has maritime borders with the Bahamas, Cuba,
|
159
|
-
Russia, and other nations. With a population of over 333 million. The national
|
160
|
-
capital of the United States is Washington, D.C.
|
161
|
-
"""),
|
162
|
-
),
|
163
|
-
]
|
@@ -293,25 +293,27 @@ class Mapping(lf.LangFunc):
|
|
293
293
|
|
294
294
|
def transform_output(self, lm_output: lf.Message) -> lf.Message:
|
295
295
|
"""Transforms LM response into structure if schema is present."""
|
296
|
-
schema = self.mapping_request.schema
|
297
|
-
if schema is None:
|
298
|
-
return lm_output
|
299
|
-
|
300
296
|
try:
|
301
|
-
result =
|
302
|
-
lm_output.text,
|
303
|
-
protocol=self.protocol,
|
304
|
-
additional_context=self.globals(),
|
305
|
-
autofix=self.autofix,
|
306
|
-
autofix_lm=self.autofix_lm or self.lm,
|
307
|
-
)
|
308
|
-
lm_output.result = self.postprocess_result(result)
|
297
|
+
lm_output.result = self.postprocess_result(self.parse_result(lm_output))
|
309
298
|
except Exception as e: # pylint: disable=broad-exception-caught
|
310
299
|
if self.default == lf.RAISE_IF_HAS_ERROR:
|
311
300
|
raise e
|
312
301
|
lm_output.result = self.default
|
313
302
|
return lm_output
|
314
303
|
|
304
|
+
def parse_result(self, lm_output: lf.Message) -> Any:
|
305
|
+
"""Parse result from LLM response."""
|
306
|
+
schema = self.mapping_request.schema
|
307
|
+
if schema is None:
|
308
|
+
return None
|
309
|
+
return schema.parse(
|
310
|
+
lm_output.text,
|
311
|
+
protocol=self.protocol,
|
312
|
+
additional_context=self.globals(),
|
313
|
+
autofix=self.autofix,
|
314
|
+
autofix_lm=self.autofix_lm or self.lm,
|
315
|
+
)
|
316
|
+
|
315
317
|
def postprocess_result(self, result: Any) -> Any:
|
316
318
|
"""Post process structured output."""
|
317
319
|
return result
|
@@ -162,11 +162,11 @@ def parse(
|
|
162
162
|
message.source = lf.UserMessage(user_prompt, tags=['lm-input'])
|
163
163
|
context = getattr(message.lm_input, 'text', None) if include_context else None
|
164
164
|
|
165
|
-
if examples is None:
|
166
|
-
examples = DEFAULT_PARSE_EXAMPLES
|
167
|
-
|
168
165
|
t = _parse_structure_cls(protocol)(
|
169
|
-
schema=schema,
|
166
|
+
schema=schema,
|
167
|
+
context=context,
|
168
|
+
default=default,
|
169
|
+
examples=examples or default_parse_examples(),
|
170
170
|
)
|
171
171
|
|
172
172
|
# Setting up context.
|
@@ -296,17 +296,19 @@ def _parse_structure_cls(
|
|
296
296
|
raise ValueError(f'Unknown protocol: {protocol!r}.')
|
297
297
|
|
298
298
|
|
299
|
-
|
300
|
-
|
301
|
-
two_plus_two_equals: int | None
|
299
|
+
def default_parse_examples() -> list[mapping.MappingExample]:
|
300
|
+
"""Default parsing examples."""
|
302
301
|
|
302
|
+
class AdditionResults(pg.Object):
|
303
|
+
one_plus_one_equals: int | None
|
304
|
+
two_plus_two_equals: int | None
|
303
305
|
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
]
|
306
|
+
return [
|
307
|
+
mapping.MappingExample(
|
308
|
+
input='Two plus two equals four. Three plus three equals six.',
|
309
|
+
schema=AdditionResults,
|
310
|
+
output=AdditionResults(
|
311
|
+
one_plus_one_equals=None, two_plus_two_equals=4
|
312
|
+
),
|
313
|
+
),
|
314
|
+
]
|
@@ -301,23 +301,43 @@ class SchemaRepr(metaclass=abc.ABCMeta):
|
|
301
301
|
class SchemaPythonRepr(SchemaRepr):
|
302
302
|
"""Python-representation for a schema."""
|
303
303
|
|
304
|
-
def repr(
|
305
|
-
|
306
|
-
|
304
|
+
def repr(
|
305
|
+
self,
|
306
|
+
schema: Schema,
|
307
|
+
*,
|
308
|
+
include_result_definition: bool = True,
|
309
|
+
markdown: bool = True,
|
310
|
+
**kwargs,
|
311
|
+
) -> str:
|
312
|
+
ret = ''
|
313
|
+
if include_result_definition:
|
314
|
+
ret += self.result_definition(schema)
|
315
|
+
class_definition_str = self.class_definitions(
|
316
|
+
schema, markdown=markdown, **kwargs
|
317
|
+
)
|
307
318
|
if class_definition_str:
|
308
|
-
ret += f'\n\n
|
309
|
-
return ret
|
319
|
+
ret += f'\n\n{class_definition_str}'
|
320
|
+
return ret.strip()
|
310
321
|
|
311
|
-
def class_definitions(self, schema: Schema) -> str | None:
|
322
|
+
def class_definitions(self, schema: Schema, **kwargs) -> str | None:
|
312
323
|
deps = schema.class_dependencies(include_subclasses=True)
|
313
|
-
return class_definitions(deps)
|
324
|
+
return class_definitions(deps, **kwargs)
|
314
325
|
|
315
326
|
def result_definition(self, schema: Schema) -> str:
|
316
327
|
return annotation(schema.spec)
|
317
328
|
|
318
329
|
|
330
|
+
def source_form(value, markdown: bool = False) -> str:
|
331
|
+
"""Returns the source code form of an object."""
|
332
|
+
return ValuePythonRepr().repr(value, markdown=markdown)
|
333
|
+
|
334
|
+
|
319
335
|
def class_definitions(
|
320
|
-
classes: Sequence[Type[Any]],
|
336
|
+
classes: Sequence[Type[Any]],
|
337
|
+
*,
|
338
|
+
include_pg_object_as_base: bool = False,
|
339
|
+
strict: bool = False,
|
340
|
+
markdown: bool = False,
|
321
341
|
) -> str | None:
|
322
342
|
"""Returns a str for class definitions."""
|
323
343
|
if not classes:
|
@@ -326,14 +346,22 @@ def class_definitions(
|
|
326
346
|
for i, cls in enumerate(classes):
|
327
347
|
if i > 0:
|
328
348
|
def_str.write('\n')
|
329
|
-
def_str.write(
|
349
|
+
def_str.write(
|
350
|
+
class_definition(
|
351
|
+
cls,
|
352
|
+
strict=strict,
|
353
|
+
include_pg_object_as_base=include_pg_object_as_base,
|
354
|
+
)
|
355
|
+
)
|
330
356
|
ret = def_str.getvalue()
|
331
357
|
if markdown and ret:
|
332
358
|
ret = f'```python\n{ret}```'
|
333
359
|
return ret
|
334
360
|
|
335
361
|
|
336
|
-
def class_definition(
|
362
|
+
def class_definition(
|
363
|
+
cls, strict: bool = False, include_pg_object_as_base: bool = False
|
364
|
+
) -> str:
|
337
365
|
"""Returns the Python class definition."""
|
338
366
|
out = io.StringIO()
|
339
367
|
if not issubclass(cls, pg.Object):
|
@@ -344,10 +372,9 @@ def class_definition(cls, strict: bool = False) -> str:
|
|
344
372
|
schema = cls.__schema__
|
345
373
|
eligible_bases = []
|
346
374
|
for base_cls in cls.__bases__:
|
347
|
-
if issubclass(base_cls, pg.
|
348
|
-
|
349
|
-
|
350
|
-
eligible_bases.append(base_cls.__name__)
|
375
|
+
if issubclass(base_cls, pg.Object):
|
376
|
+
if include_pg_object_as_base or base_cls is not pg.Object:
|
377
|
+
eligible_bases.append(base_cls.__name__)
|
351
378
|
if eligible_bases:
|
352
379
|
base_cls_str = ', '.join(eligible_bases)
|
353
380
|
out.write(f'class {cls.__name__}({base_cls_str}):\n')
|
@@ -547,8 +574,20 @@ class ValuePythonRepr(ValueRepr):
|
|
547
574
|
markdown: bool = True,
|
548
575
|
**kwargs) -> str:
|
549
576
|
del schema
|
550
|
-
|
551
|
-
|
577
|
+
if inspect.isclass(value):
|
578
|
+
cls_schema = Schema.from_value(value)
|
579
|
+
if isinstance(cls_schema.spec, pg.typing.Object):
|
580
|
+
object_code = SchemaPythonRepr().class_definitions(
|
581
|
+
cls_schema, markdown=markdown, include_pg_object_as_base=True
|
582
|
+
)
|
583
|
+
assert object_code is not None
|
584
|
+
return object_code
|
585
|
+
else:
|
586
|
+
object_code = SchemaPythonRepr().result_definition(cls_schema)
|
587
|
+
else:
|
588
|
+
object_code = pg.format(
|
589
|
+
value, compact=compact, verbose=verbose, python_format=True
|
590
|
+
)
|
552
591
|
if markdown:
|
553
592
|
return f'```python\n{ object_code }\n```'
|
554
593
|
return object_code
|
@@ -588,6 +627,7 @@ def structure_from_python(
|
|
588
627
|
global_vars = global_vars or {}
|
589
628
|
global_vars.update({
|
590
629
|
'pg': pg,
|
630
|
+
'Object': pg.Object,
|
591
631
|
'Any': typing.Any,
|
592
632
|
'List': typing.List,
|
593
633
|
'Tuple': typing.Tuple,
|
@@ -0,0 +1,175 @@
|
|
1
|
+
# Copyright 2023 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""LLM-based class generation."""
|
15
|
+
|
16
|
+
import typing
|
17
|
+
from typing import Any, Type
|
18
|
+
import langfun.core as lf
|
19
|
+
from langfun.core.coding.python import correction
|
20
|
+
from langfun.core.structured import mapping
|
21
|
+
import pyglove as pg
|
22
|
+
|
23
|
+
|
24
|
+
class GenerateClass(mapping.Mapping):
|
25
|
+
"""Python class generation."""
|
26
|
+
|
27
|
+
input_title = 'GENERATION_CONTEXT'
|
28
|
+
context_title = 'CLASS_NAME'
|
29
|
+
output_title = 'OUTPUT_CLASS'
|
30
|
+
|
31
|
+
preamble: lf.Template = lf.Template("""
|
32
|
+
Help generate a class based on the last {{ context_title }} and {{ input_title }}.
|
33
|
+
|
34
|
+
Instructions:
|
35
|
+
- Use `Object` as the base class for all generated classes
|
36
|
+
- Create auxillary classes for composition if needed.
|
37
|
+
- Use Python type annotation for declaraing fields:
|
38
|
+
(e.g. bool, str, int, float, Optional[str], List[int], Union[str, int])
|
39
|
+
- Do not use types that need import.
|
40
|
+
- Avoid self-referential types. e.g:
|
41
|
+
```
|
42
|
+
class Node(Object):
|
43
|
+
children: list[Node]
|
44
|
+
```
|
45
|
+
- Do not generate methods.
|
46
|
+
""")
|
47
|
+
|
48
|
+
def parse_result(self, lm_output: lf.Message) -> Type[Any]:
|
49
|
+
output_vars, final_code = correction.run_with_correction(
|
50
|
+
lm_output.text,
|
51
|
+
global_vars=self.allowed_annotation_types,
|
52
|
+
sandbox=False,
|
53
|
+
max_attempts=self.autofix,
|
54
|
+
lm=self.autofix_lm,
|
55
|
+
returns_code=True,
|
56
|
+
outputs_intermediate=True,
|
57
|
+
)
|
58
|
+
class_name = self.context
|
59
|
+
cls = output_vars.get(class_name, None)
|
60
|
+
if cls is None:
|
61
|
+
raise correction.errors.CodeError(
|
62
|
+
final_code,
|
63
|
+
TypeError(f'Class {class_name} is absent from LLM output.'),
|
64
|
+
)
|
65
|
+
return cls
|
66
|
+
|
67
|
+
@property
|
68
|
+
def allowed_annotation_types(self):
|
69
|
+
return dict(
|
70
|
+
pg=pg,
|
71
|
+
Any=typing.Any,
|
72
|
+
Object=pg.Object,
|
73
|
+
List=typing.List,
|
74
|
+
Dict=typing.Tuple,
|
75
|
+
Tuple=typing.Tuple,
|
76
|
+
Sequence=typing.Sequence,
|
77
|
+
Optional=typing.Optional,
|
78
|
+
Union=typing.Union,
|
79
|
+
)
|
80
|
+
|
81
|
+
|
82
|
+
def generate_class(
|
83
|
+
name: str,
|
84
|
+
prompt: str | pg.Symbolic,
|
85
|
+
*,
|
86
|
+
lm: lf.LanguageModel | None = None,
|
87
|
+
examples: list[mapping.MappingExample] | None = None,
|
88
|
+
returns_message: bool = False,
|
89
|
+
skip_lm: bool = False,
|
90
|
+
**kwargs,
|
91
|
+
) -> Type[Any] | lf.Message:
|
92
|
+
"""Generate a class with specified name based on the prompt.
|
93
|
+
|
94
|
+
Example:
|
95
|
+
```
|
96
|
+
trip_cls = lf.classgen(
|
97
|
+
'Trip',
|
98
|
+
'A trip plan to visit {{ city }}, city='San Francisco',
|
99
|
+
lm=lf.llms.GeminiPro()
|
100
|
+
)
|
101
|
+
```
|
102
|
+
|
103
|
+
Args:
|
104
|
+
name: Class name to be generated.
|
105
|
+
prompt: A str (may contain {{}} as template) as natural language input, or a
|
106
|
+
`pg.Symbolic` object as structured input as prompt to LLM.
|
107
|
+
lm: The language model to use. If not specified, the language model from
|
108
|
+
`lf.context` context manager will be used.
|
109
|
+
examples: An optional list of fewshot examples for helping class generation.
|
110
|
+
If None, a default single shot example will be used. Use
|
111
|
+
`lf.structured.classgen_example` to generate example.
|
112
|
+
returns_message: If True, returns `lf.Message` as the output, instead of
|
113
|
+
returning the structured `message.result`.
|
114
|
+
skip_lm: If True, returns the rendered prompt as a UserMessage object.
|
115
|
+
otherwise return the LLM response based on the rendered prompt.
|
116
|
+
**kwargs: Template variables passed to `prompt` and keyword arguments passed
|
117
|
+
to `lf.structured.GenerateClass`.
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
Generated class.
|
121
|
+
|
122
|
+
Raises:
|
123
|
+
CodeError: if generation failed.
|
124
|
+
"""
|
125
|
+
if isinstance(prompt, str):
|
126
|
+
prompt = lf.Template(prompt, **kwargs)
|
127
|
+
elif isinstance(prompt, lf.Template):
|
128
|
+
prompt = prompt.rebind(**kwargs, raise_on_no_change=False)
|
129
|
+
|
130
|
+
if isinstance(prompt, lf.Template):
|
131
|
+
prompt = prompt.render(lm=lm)
|
132
|
+
|
133
|
+
call_kwargs = dict(skip_lm=skip_lm)
|
134
|
+
if lm is not None:
|
135
|
+
call_kwargs['lm'] = lm
|
136
|
+
message = GenerateClass(
|
137
|
+
input=prompt,
|
138
|
+
context=name,
|
139
|
+
examples=examples or default_classgen_examples(),
|
140
|
+
**kwargs,
|
141
|
+
)(**call_kwargs)
|
142
|
+
return message if returns_message else message.result
|
143
|
+
|
144
|
+
|
145
|
+
def classgen_example(
|
146
|
+
class_name: str, prompt: str | pg.Symbolic, cls: Type[Any]
|
147
|
+
) -> mapping.MappingExample:
|
148
|
+
"""Creates a class generation example."""
|
149
|
+
if isinstance(prompt, lf.Template):
|
150
|
+
prompt = prompt.render()
|
151
|
+
return mapping.MappingExample(
|
152
|
+
input=prompt,
|
153
|
+
context=class_name,
|
154
|
+
output=cls,
|
155
|
+
)
|
156
|
+
|
157
|
+
|
158
|
+
def default_classgen_examples() -> list[mapping.MappingExample]:
|
159
|
+
"""Default examples for class generation."""
|
160
|
+
|
161
|
+
class Step(pg.Object):
|
162
|
+
description: str
|
163
|
+
output: float
|
164
|
+
|
165
|
+
class Solution(pg.Object):
|
166
|
+
steps: list[Step] # pytype: disable=invalid-annotation
|
167
|
+
result: float
|
168
|
+
|
169
|
+
return [
|
170
|
+
classgen_example(
|
171
|
+
'Solution',
|
172
|
+
'How to evaluate an arithmetic expression?',
|
173
|
+
Solution,
|
174
|
+
)
|
175
|
+
]
|
@@ -0,0 +1,104 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import inspect
|
15
|
+
import unittest
|
16
|
+
|
17
|
+
import langfun.core.coding as lf_coding
|
18
|
+
from langfun.core.llms import fake
|
19
|
+
from langfun.core.structured import schema_generation
|
20
|
+
|
21
|
+
|
22
|
+
class GenerateClassTest(unittest.TestCase):
|
23
|
+
|
24
|
+
def test_generate_class_prompt(self):
|
25
|
+
input_message = schema_generation.generate_class(
|
26
|
+
'Trip',
|
27
|
+
'Generate a trip class',
|
28
|
+
skip_lm=True,
|
29
|
+
returns_message=True,
|
30
|
+
)
|
31
|
+
self.maxDiff = None
|
32
|
+
self.assertEqual(
|
33
|
+
input_message.text,
|
34
|
+
inspect.cleandoc("""
|
35
|
+
Help generate a class based on the last CLASS_NAME and GENERATION_CONTEXT.
|
36
|
+
|
37
|
+
Instructions:
|
38
|
+
- Use `Object` as the base class for all generated classes
|
39
|
+
- Create auxillary classes for composition if needed.
|
40
|
+
- Use Python type annotation for declaraing fields:
|
41
|
+
(e.g. bool, str, int, float, Optional[str], List[int], Union[str, int])
|
42
|
+
- Do not use types that need import.
|
43
|
+
- Avoid self-referential types. e.g:
|
44
|
+
```
|
45
|
+
class Node(Object):
|
46
|
+
children: list[Node]
|
47
|
+
```
|
48
|
+
- Do not generate methods.
|
49
|
+
|
50
|
+
CLASS_NAME:
|
51
|
+
Solution
|
52
|
+
|
53
|
+
GENERATION_CONTEXT:
|
54
|
+
How to evaluate an arithmetic expression?
|
55
|
+
|
56
|
+
OUTPUT_CLASS:
|
57
|
+
```python
|
58
|
+
class Step(Object):
|
59
|
+
description: str
|
60
|
+
output: float
|
61
|
+
|
62
|
+
class Solution(Object):
|
63
|
+
steps: list[Step]
|
64
|
+
result: float
|
65
|
+
```
|
66
|
+
|
67
|
+
|
68
|
+
CLASS_NAME:
|
69
|
+
Trip
|
70
|
+
|
71
|
+
GENERATION_CONTEXT:
|
72
|
+
Generate a trip class
|
73
|
+
|
74
|
+
OUTPUT_CLASS:
|
75
|
+
"""),
|
76
|
+
)
|
77
|
+
|
78
|
+
def test_generate_class(self):
|
79
|
+
lm = fake.StaticResponse("""
|
80
|
+
```python
|
81
|
+
class A(Object):
|
82
|
+
x: int
|
83
|
+
|
84
|
+
class B(Object):
|
85
|
+
a: A
|
86
|
+
```
|
87
|
+
""")
|
88
|
+
cls = schema_generation.generate_class(
|
89
|
+
'B',
|
90
|
+
'Generate a B class with a field pointing to another class A',
|
91
|
+
lm=lm,
|
92
|
+
)
|
93
|
+
self.assertIs(cls.__name__, 'B')
|
94
|
+
|
95
|
+
with self.assertRaises(lf_coding.CodeError):
|
96
|
+
schema_generation.generate_class(
|
97
|
+
'Foo',
|
98
|
+
'Generate a Foo class with a field pointing to another class A',
|
99
|
+
lm=lm,
|
100
|
+
)
|
101
|
+
|
102
|
+
|
103
|
+
if __name__ == '__main__':
|
104
|
+
unittest.main()
|
@@ -435,6 +435,10 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
435
435
|
schema_lib.class_definition(A),
|
436
436
|
'class A:\n pass\n',
|
437
437
|
)
|
438
|
+
self.assertEqual(
|
439
|
+
schema_lib.class_definition(A, include_pg_object_as_base=True),
|
440
|
+
'class A(Object):\n pass\n',
|
441
|
+
)
|
438
442
|
|
439
443
|
class B:
|
440
444
|
pass
|
@@ -520,6 +524,32 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
520
524
|
```
|
521
525
|
"""),
|
522
526
|
)
|
527
|
+
self.assertEqual(
|
528
|
+
schema_lib.SchemaPythonRepr().repr(
|
529
|
+
schema,
|
530
|
+
include_result_definition=False,
|
531
|
+
include_pg_object_as_base=True,
|
532
|
+
markdown=False,
|
533
|
+
),
|
534
|
+
inspect.cleandoc("""
|
535
|
+
class Foo(Object):
|
536
|
+
x: int
|
537
|
+
|
538
|
+
class A(Object):
|
539
|
+
foo: Foo
|
540
|
+
|
541
|
+
class Bar(Object):
|
542
|
+
y: str
|
543
|
+
|
544
|
+
class Baz(Bar):
|
545
|
+
y: str
|
546
|
+
|
547
|
+
class B(A):
|
548
|
+
foo: Foo
|
549
|
+
bar: Bar
|
550
|
+
foo2: Foo
|
551
|
+
"""),
|
552
|
+
)
|
523
553
|
|
524
554
|
|
525
555
|
class SchemaJsonReprTest(unittest.TestCase):
|
@@ -559,6 +589,20 @@ class ValuePythonReprTest(unittest.TestCase):
|
|
559
589
|
),
|
560
590
|
"A(foo=[Foo(x=1), Foo(x=2)], y='bar')",
|
561
591
|
)
|
592
|
+
self.assertEqual(
|
593
|
+
schema_lib.ValuePythonRepr().repr(A),
|
594
|
+
inspect.cleandoc("""
|
595
|
+
```python
|
596
|
+
class Foo(Object):
|
597
|
+
x: int
|
598
|
+
|
599
|
+
class A(Object):
|
600
|
+
foo: list[Foo]
|
601
|
+
y: str | None
|
602
|
+
```
|
603
|
+
"""),
|
604
|
+
)
|
605
|
+
self.assertEqual(schema_lib.source_form(int), 'int')
|
562
606
|
|
563
607
|
def test_parse(self):
|
564
608
|
class Foo(pg.Object):
|
@@ -1,4 +1,4 @@
|
|
1
|
-
langfun/__init__.py,sha256=
|
1
|
+
langfun/__init__.py,sha256=PqX3u18BC0szYIMu00j-RKxvwkNPwXtAFZ-96oxrQ0M,1841
|
2
2
|
langfun/core/__init__.py,sha256=sVcPl89lWYHQ1cUoaLaM8dErCovugJo5e2F3A_94Q3Y,4192
|
3
3
|
langfun/core/component.py,sha256=VRPfDB_2jEnxcB3-HoiVjG4ID-SMenNPIsytb0uXMPg,9674
|
4
4
|
langfun/core/component_test.py,sha256=VAPd6V_-odAe8rBvesW3ogYDd6OSqRq4FaPhfgOM4Zg,7949
|
@@ -69,19 +69,21 @@ langfun/core/modalities/mime.py,sha256=wVfaYflhGz1W4v3m972rAplW3OGOFtjFpHDYIaUD5
|
|
69
69
|
langfun/core/modalities/mime_test.py,sha256=cVHxRvJ1QXC1SVhBmWkJdWGpL9Xl0UNfTQq6j0OGGL4,1881
|
70
70
|
langfun/core/modalities/video.py,sha256=5-sIlzXb_ZY84RMFcpVD9ysP9GbcwbdKaZOEm3jECtc,1469
|
71
71
|
langfun/core/modalities/video_test.py,sha256=jYuI2m8S8zDCAVBPEUbbpP205dXAht90A2_PHWo4-r8,2039
|
72
|
-
langfun/core/structured/__init__.py,sha256=
|
72
|
+
langfun/core/structured/__init__.py,sha256=SpObW-HKpyKvkLlX8FV5ixz7CRm098j2aGfOguM3AUI,3462
|
73
73
|
langfun/core/structured/completion.py,sha256=skBxt6V_fv2TBUKnzFgnPMbVY8HSYn8sY04MLok2yvs,7299
|
74
74
|
langfun/core/structured/completion_test.py,sha256=98UCgA4gzfp6H6HgP2s2kcKs25YH3k4Nxj1rgAvmVBw,19249
|
75
|
-
langfun/core/structured/description.py,sha256=
|
75
|
+
langfun/core/structured/description.py,sha256=SXW4MJvshFjbR-0gw6rE21o6WXq12UlRXawvDBXMZFA,5211
|
76
76
|
langfun/core/structured/description_test.py,sha256=UtZGjSFUaQ6130t1E5tcL7ODu0xIefkapb53TbnqsK8,7362
|
77
|
-
langfun/core/structured/mapping.py,sha256=
|
77
|
+
langfun/core/structured/mapping.py,sha256=tahkaAB-L6yKbYb7qjVI301-FfIARdw4w8nP3wqS2-k,10291
|
78
78
|
langfun/core/structured/mapping_test.py,sha256=07DDCGbwytQHSMm7fCi5-Ly-JNgdV4ubHZq0wthX4A4,3338
|
79
|
-
langfun/core/structured/parsing.py,sha256=
|
79
|
+
langfun/core/structured/parsing.py,sha256=yTKuezai5i-X9W-jU0DeEZzqHHbCFom0plj-D0bhp98,11436
|
80
80
|
langfun/core/structured/parsing_test.py,sha256=2_Uf3LYNRON1-5ysEr75xiG_cAxR3ZiixSfvUQu6mOQ,20846
|
81
81
|
langfun/core/structured/prompting.py,sha256=0xRPC0K_RaFRv-j52x8_-1n1eRFSomJEpdZApVXsCV0,6902
|
82
82
|
langfun/core/structured/prompting_test.py,sha256=SwoYbPyKhUT1H2QbqHvl93biCiE9Ttn1aWixoHH-v9Y,19129
|
83
|
-
langfun/core/structured/schema.py,sha256=
|
84
|
-
langfun/core/structured/
|
83
|
+
langfun/core/structured/schema.py,sha256=60griJ-yC1SExX6g-aOcAOo8yFh53CdwMV4EVK3ivug,25207
|
84
|
+
langfun/core/structured/schema_generation.py,sha256=Yv9flJ4GTtLw-bDB8S7A93G-z4gXsFMkMASkbiduT3E,5353
|
85
|
+
langfun/core/structured/schema_generation_test.py,sha256=cfZyP0gHno2fXy_c9vsVdvHmqKQSfuyUsCtfO3JFmYQ,2945
|
86
|
+
langfun/core/structured/schema_test.py,sha256=kMIgnAzm3f2O5ofn0pPKjT6H8hny4cWVaUVDOZuyjOQ,21987
|
85
87
|
langfun/core/structured/scoring.py,sha256=a3vfGnqf-DOWjD07MF54GCZTO_R1RTxTDVPzerXnU0s,2325
|
86
88
|
langfun/core/structured/scoring_test.py,sha256=TznLMl0x9QxzmhHz_3Vr44VOXuvFnUSeRQVhu33W5cA,1437
|
87
89
|
langfun/core/templates/__init__.py,sha256=bO0eMsVJbi7sxEB2YlInKRQ2EVP-RyyKUwcD-8msuN4,927
|
@@ -93,8 +95,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
|
|
93
95
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
94
96
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
95
97
|
langfun/core/templates/selfplay_test.py,sha256=IB5rWbjK_9CTkqEo1BclQPzFAKcIiusJckH8J19HFgI,2096
|
96
|
-
langfun-0.0.2.
|
97
|
-
langfun-0.0.2.
|
98
|
-
langfun-0.0.2.
|
99
|
-
langfun-0.0.2.
|
100
|
-
langfun-0.0.2.
|
98
|
+
langfun-0.0.2.dev20240316.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
99
|
+
langfun-0.0.2.dev20240316.dist-info/METADATA,sha256=rvpQMtNiFs55Okrd1TNlJOS8szUWshlCt5NFB_2vPfs,3405
|
100
|
+
langfun-0.0.2.dev20240316.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
101
|
+
langfun-0.0.2.dev20240316.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
102
|
+
langfun-0.0.2.dev20240316.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|