langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512040805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +1 -1
- langfun/core/__init__.py +7 -1
- langfun/core/agentic/__init__.py +8 -1
- langfun/core/agentic/action.py +740 -112
- langfun/core/agentic/action_eval.py +9 -2
- langfun/core/agentic/action_test.py +189 -24
- langfun/core/async_support.py +104 -5
- langfun/core/async_support_test.py +23 -0
- langfun/core/coding/python/correction.py +19 -9
- langfun/core/coding/python/execution.py +14 -12
- langfun/core/coding/python/generation.py +21 -16
- langfun/core/coding/python/sandboxing.py +23 -3
- langfun/core/component.py +42 -3
- langfun/core/concurrent.py +70 -6
- langfun/core/concurrent_test.py +9 -2
- langfun/core/console.py +1 -1
- langfun/core/data/conversion/anthropic.py +12 -3
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini.py +11 -2
- langfun/core/data/conversion/gemini_test.py +48 -9
- langfun/core/data/conversion/openai.py +145 -31
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base.py +48 -44
- langfun/core/eval/base_test.py +5 -5
- langfun/core/eval/matching.py +5 -2
- langfun/core/eval/patching.py +3 -3
- langfun/core/eval/scoring.py +4 -3
- langfun/core/eval/v2/__init__.py +2 -0
- langfun/core/eval/v2/checkpointing.py +76 -7
- langfun/core/eval/v2/checkpointing_test.py +9 -2
- langfun/core/eval/v2/config_saver.py +37 -0
- langfun/core/eval/v2/config_saver_test.py +36 -0
- langfun/core/eval/v2/eval_test_helper.py +104 -3
- langfun/core/eval/v2/evaluation.py +92 -17
- langfun/core/eval/v2/evaluation_test.py +9 -3
- langfun/core/eval/v2/example.py +50 -40
- langfun/core/eval/v2/example_test.py +16 -8
- langfun/core/eval/v2/experiment.py +84 -15
- langfun/core/eval/v2/experiment_test.py +19 -0
- langfun/core/eval/v2/metric_values.py +31 -3
- langfun/core/eval/v2/metric_values_test.py +32 -0
- langfun/core/eval/v2/metrics.py +157 -44
- langfun/core/eval/v2/metrics_test.py +39 -18
- langfun/core/eval/v2/progress.py +31 -1
- langfun/core/eval/v2/progress_test.py +27 -0
- langfun/core/eval/v2/progress_tracking.py +13 -5
- langfun/core/eval/v2/progress_tracking_test.py +9 -1
- langfun/core/eval/v2/reporting.py +90 -71
- langfun/core/eval/v2/reporting_test.py +24 -6
- langfun/core/eval/v2/runners/__init__.py +30 -0
- langfun/core/eval/v2/{runners.py → runners/base.py} +72 -180
- langfun/core/eval/v2/runners/beam.py +354 -0
- langfun/core/eval/v2/runners/beam_test.py +153 -0
- langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
- langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
- langfun/core/eval/v2/runners/debug.py +40 -0
- langfun/core/eval/v2/runners/debug_test.py +76 -0
- langfun/core/eval/v2/runners/parallel.py +243 -0
- langfun/core/eval/v2/runners/parallel_test.py +182 -0
- langfun/core/eval/v2/runners/sequential.py +47 -0
- langfun/core/eval/v2/runners/sequential_test.py +169 -0
- langfun/core/langfunc.py +45 -130
- langfun/core/langfunc_test.py +7 -5
- langfun/core/language_model.py +189 -36
- langfun/core/language_model_test.py +54 -3
- langfun/core/llms/__init__.py +12 -1
- langfun/core/llms/anthropic.py +157 -2
- langfun/core/llms/azure_openai.py +29 -17
- langfun/core/llms/cache/base.py +25 -3
- langfun/core/llms/cache/in_memory.py +48 -7
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/compositional.py +25 -1
- langfun/core/llms/deepseek.py +30 -2
- langfun/core/llms/fake.py +32 -1
- langfun/core/llms/gemini.py +64 -12
- langfun/core/llms/gemini_test.py +110 -0
- langfun/core/llms/google_genai.py +34 -1
- langfun/core/llms/groq.py +28 -3
- langfun/core/llms/llama_cpp.py +23 -4
- langfun/core/llms/openai.py +120 -3
- langfun/core/llms/openai_compatible.py +148 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/rest.py +16 -1
- langfun/core/llms/vertexai.py +58 -8
- langfun/core/logging.py +1 -1
- langfun/core/mcp/__init__.py +10 -0
- langfun/core/mcp/client.py +177 -0
- langfun/core/mcp/client_test.py +71 -0
- langfun/core/mcp/session.py +241 -0
- langfun/core/mcp/session_test.py +54 -0
- langfun/core/mcp/testing/simple_mcp_client.py +33 -0
- langfun/core/mcp/testing/simple_mcp_server.py +33 -0
- langfun/core/mcp/tool.py +254 -0
- langfun/core/mcp/tool_test.py +197 -0
- langfun/core/memory.py +1 -0
- langfun/core/message.py +160 -55
- langfun/core/message_test.py +65 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/audio.py +21 -1
- langfun/core/modalities/image.py +73 -3
- langfun/core/modalities/image_test.py +116 -0
- langfun/core/modalities/mime.py +64 -3
- langfun/core/modalities/mime_test.py +11 -0
- langfun/core/modalities/pdf.py +19 -1
- langfun/core/modalities/video.py +21 -1
- langfun/core/modality.py +167 -29
- langfun/core/modality_test.py +42 -12
- langfun/core/natural_language.py +1 -1
- langfun/core/sampling.py +4 -4
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/__init__.py +2 -24
- langfun/core/structured/completion.py +34 -44
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/description.py +54 -50
- langfun/core/structured/function_generation.py +29 -12
- langfun/core/structured/mapping.py +81 -37
- langfun/core/structured/parsing.py +95 -79
- langfun/core/structured/parsing_test.py +0 -3
- langfun/core/structured/querying.py +230 -154
- langfun/core/structured/querying_test.py +69 -33
- langfun/core/structured/schema/__init__.py +49 -0
- langfun/core/structured/schema/base.py +664 -0
- langfun/core/structured/schema/base_test.py +531 -0
- langfun/core/structured/schema/json.py +174 -0
- langfun/core/structured/schema/json_test.py +121 -0
- langfun/core/structured/schema/python.py +316 -0
- langfun/core/structured/schema/python_test.py +410 -0
- langfun/core/structured/schema_generation.py +33 -14
- langfun/core/structured/scoring.py +47 -36
- langfun/core/structured/tokenization.py +26 -11
- langfun/core/subscription.py +2 -2
- langfun/core/template.py +175 -50
- langfun/core/template_test.py +123 -17
- langfun/env/__init__.py +43 -0
- langfun/env/base_environment.py +827 -0
- langfun/env/base_environment_test.py +473 -0
- langfun/env/base_feature.py +304 -0
- langfun/env/base_feature_test.py +228 -0
- langfun/env/base_sandbox.py +842 -0
- langfun/env/base_sandbox_test.py +1235 -0
- langfun/env/event_handlers/__init__.py +14 -0
- langfun/env/event_handlers/chain.py +233 -0
- langfun/env/event_handlers/chain_test.py +253 -0
- langfun/env/event_handlers/event_logger.py +472 -0
- langfun/env/event_handlers/event_logger_test.py +304 -0
- langfun/env/event_handlers/metric_writer.py +726 -0
- langfun/env/event_handlers/metric_writer_test.py +214 -0
- langfun/env/interface.py +1640 -0
- langfun/env/interface_test.py +153 -0
- langfun/env/load_balancers.py +59 -0
- langfun/env/load_balancers_test.py +141 -0
- langfun/env/test_utils.py +507 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/METADATA +7 -3
- langfun-0.1.2.dev202512040805.dist-info/RECORD +217 -0
- langfun/core/eval/v2/runners_test.py +0 -343
- langfun/core/structured/schema.py +0 -987
- langfun/core/structured/schema_test.py +0 -982
- langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import unittest
|
|
15
|
+
from langfun.core.structured.schema import base
|
|
16
|
+
from langfun.core.structured.schema import json
|
|
17
|
+
import pyglove as pg
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Activity(pg.Object):
|
|
21
|
+
description: str
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Itinerary(pg.Object):
|
|
25
|
+
"""A travel itinerary for a day."""
|
|
26
|
+
|
|
27
|
+
day: pg.typing.Int[1, None]
|
|
28
|
+
type: pg.typing.Enum['daytime', 'nighttime']
|
|
29
|
+
activities: list[Activity]
|
|
30
|
+
hotel: pg.typing.Annotated[
|
|
31
|
+
pg.typing.Str['.*Hotel'] | None,
|
|
32
|
+
'Hotel to stay if applicable.'
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
Itinerary.__serialization_key__ = 'Itinerary'
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class Node(pg.Object):
|
|
40
|
+
children: list['Node']
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class SchemaReprTest(unittest.TestCase):
|
|
44
|
+
|
|
45
|
+
def test_repr(self):
|
|
46
|
+
schema = base.Schema([{'x': Itinerary}])
|
|
47
|
+
self.assertEqual(
|
|
48
|
+
base.schema_repr(schema, protocol='json'),
|
|
49
|
+
(
|
|
50
|
+
'{"result": [{"x": {"_type": "Itinerary", "day":'
|
|
51
|
+
' int(min=1), "type": "daytime" | "nighttime", "activities":'
|
|
52
|
+
' [{"_type": "%s", "description": str}], "hotel":'
|
|
53
|
+
' str(regex=.*Hotel) | None}}]}' % (
|
|
54
|
+
Activity.__type_name__,
|
|
55
|
+
)
|
|
56
|
+
),
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class ValueReprTest(unittest.TestCase):
|
|
61
|
+
|
|
62
|
+
def test_value_repr(self):
|
|
63
|
+
self.assertEqual(base.value_repr(1, protocol='json'), '{"result": 1}')
|
|
64
|
+
|
|
65
|
+
def assert_parse_value(self, inputs, output) -> None:
|
|
66
|
+
self.assertEqual(base.parse_value(inputs, protocol='json'), output)
|
|
67
|
+
|
|
68
|
+
def test_parse_basics(self):
|
|
69
|
+
self.assert_parse_value('{"result": 1}', 1)
|
|
70
|
+
self.assert_parse_value('{"result": "\\"}ab{"}', '"}ab{')
|
|
71
|
+
self.assert_parse_value(
|
|
72
|
+
'{"result": {"x": true, "y": null}}',
|
|
73
|
+
{'x': True, 'y': None},
|
|
74
|
+
)
|
|
75
|
+
self.assert_parse_value(
|
|
76
|
+
(
|
|
77
|
+
'{"result": {"_type": "%s", "description": "play"}}'
|
|
78
|
+
% Activity.__type_name__
|
|
79
|
+
),
|
|
80
|
+
Activity('play'),
|
|
81
|
+
)
|
|
82
|
+
with self.assertRaisesRegex(
|
|
83
|
+
json.JsonError, 'JSONDecodeError'
|
|
84
|
+
):
|
|
85
|
+
base.parse_value('{"abc", 1}', protocol='json')
|
|
86
|
+
|
|
87
|
+
with self.assertRaisesRegex(
|
|
88
|
+
json.JsonError,
|
|
89
|
+
'The root node of the JSON must be a dict with key `result`'
|
|
90
|
+
):
|
|
91
|
+
base.parse_value('{"abc": 1}', protocol='json')
|
|
92
|
+
|
|
93
|
+
def test_parse_with_surrounding_texts(self):
|
|
94
|
+
self.assert_parse_value('The answer is {"result": 1}.', 1)
|
|
95
|
+
|
|
96
|
+
def test_parse_with_new_lines(self):
|
|
97
|
+
self.assert_parse_value(
|
|
98
|
+
"""
|
|
99
|
+
{
|
|
100
|
+
"result": [
|
|
101
|
+
"foo
|
|
102
|
+
bar"]
|
|
103
|
+
}
|
|
104
|
+
""",
|
|
105
|
+
['foo\nbar'])
|
|
106
|
+
|
|
107
|
+
def test_parse_with_malformated_json(self):
|
|
108
|
+
with self.assertRaisesRegex(
|
|
109
|
+
json.JsonError, 'No JSON dict in the output'
|
|
110
|
+
):
|
|
111
|
+
base.parse_value('The answer is 1.', protocol='json')
|
|
112
|
+
|
|
113
|
+
with self.assertRaisesRegex(
|
|
114
|
+
json.JsonError,
|
|
115
|
+
'Malformated JSON: missing .* closing curly braces'
|
|
116
|
+
):
|
|
117
|
+
base.parse_value('{"result": 1', protocol='json')
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
if __name__ == '__main__':
|
|
121
|
+
unittest.main()
|
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Python-based prompting protocol."""
|
|
15
|
+
|
|
16
|
+
import inspect
|
|
17
|
+
import io
|
|
18
|
+
import re
|
|
19
|
+
import sys
|
|
20
|
+
import textwrap
|
|
21
|
+
import typing
|
|
22
|
+
from typing import Any, Sequence, Type
|
|
23
|
+
import langfun.core as lf
|
|
24
|
+
from langfun.core.coding.python import correction
|
|
25
|
+
from langfun.core.structured.schema import base
|
|
26
|
+
import pyglove as pg
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class PythonPromptingProtocol(base.PromptingProtocol):
|
|
30
|
+
"""Python-based prompting protocol."""
|
|
31
|
+
|
|
32
|
+
NAME = 'python'
|
|
33
|
+
|
|
34
|
+
def schema_repr(
|
|
35
|
+
self,
|
|
36
|
+
schema: base.Schema,
|
|
37
|
+
*,
|
|
38
|
+
include_result_definition: bool = True,
|
|
39
|
+
markdown: bool = True,
|
|
40
|
+
**kwargs,
|
|
41
|
+
) -> str:
|
|
42
|
+
ret = ''
|
|
43
|
+
if include_result_definition:
|
|
44
|
+
ret += self.result_definition(schema)
|
|
45
|
+
class_definition_str = self.class_definitions(
|
|
46
|
+
schema, markdown=markdown, **kwargs
|
|
47
|
+
)
|
|
48
|
+
if class_definition_str:
|
|
49
|
+
ret += f'\n\n{class_definition_str}'
|
|
50
|
+
return ret.strip()
|
|
51
|
+
|
|
52
|
+
def class_definitions(
|
|
53
|
+
self,
|
|
54
|
+
schema: base.Schema,
|
|
55
|
+
additional_dependencies: list[Type[Any]] | None = None,
|
|
56
|
+
**kwargs
|
|
57
|
+
) -> str | None:
|
|
58
|
+
"""Returns a string containing of class definitions from a schema."""
|
|
59
|
+
deps = schema.class_dependencies(
|
|
60
|
+
include_base_classes=False, include_subclasses=True
|
|
61
|
+
)
|
|
62
|
+
allowed_dependencies = set(deps)
|
|
63
|
+
if additional_dependencies:
|
|
64
|
+
allowed_dependencies.update(additional_dependencies)
|
|
65
|
+
return class_definitions(
|
|
66
|
+
deps, allowed_dependencies=allowed_dependencies, **kwargs
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def result_definition(self, schema: base.Schema) -> str:
|
|
70
|
+
return base.annotation(schema.spec)
|
|
71
|
+
|
|
72
|
+
def value_repr(
|
|
73
|
+
self,
|
|
74
|
+
value: Any,
|
|
75
|
+
schema: base.Schema | None = None,
|
|
76
|
+
*,
|
|
77
|
+
compact: bool = True,
|
|
78
|
+
verbose: bool = False,
|
|
79
|
+
markdown: bool = True,
|
|
80
|
+
assign_to_var: str | None = None,
|
|
81
|
+
**kwargs) -> str:
|
|
82
|
+
del schema
|
|
83
|
+
if inspect.isclass(value):
|
|
84
|
+
cls_schema = base.Schema.from_value(value)
|
|
85
|
+
if isinstance(cls_schema.spec, pg.typing.Object):
|
|
86
|
+
object_code = self.class_definitions(
|
|
87
|
+
cls_schema,
|
|
88
|
+
markdown=markdown,
|
|
89
|
+
# We add `pg.Object` as additional dependencies to the class
|
|
90
|
+
# definition so exemplars for class generation could show
|
|
91
|
+
# pg.Object as their bases.
|
|
92
|
+
additional_dependencies=[pg.Object]
|
|
93
|
+
)
|
|
94
|
+
assert object_code is not None
|
|
95
|
+
return object_code
|
|
96
|
+
else:
|
|
97
|
+
object_code = self.result_definition(cls_schema)
|
|
98
|
+
elif isinstance(value, lf.Template):
|
|
99
|
+
return str(value)
|
|
100
|
+
else:
|
|
101
|
+
object_code = pg.format(
|
|
102
|
+
value, compact=compact, verbose=verbose, python_format=True
|
|
103
|
+
)
|
|
104
|
+
if assign_to_var is not None:
|
|
105
|
+
object_code = f'{assign_to_var} = {object_code}'
|
|
106
|
+
if markdown:
|
|
107
|
+
return f'```python\n{object_code}\n```'
|
|
108
|
+
return object_code
|
|
109
|
+
|
|
110
|
+
def parse_value(
|
|
111
|
+
self,
|
|
112
|
+
text: str,
|
|
113
|
+
schema: base.Schema | None = None,
|
|
114
|
+
*,
|
|
115
|
+
additional_context: dict[str, Type[Any]] | None = None,
|
|
116
|
+
permission: pg.coding.CodePermission = (
|
|
117
|
+
pg.coding.CodePermission.ASSIGN | pg.coding.CodePermission.CALL
|
|
118
|
+
),
|
|
119
|
+
autofix=0,
|
|
120
|
+
autofix_lm: lf.LanguageModel = lf.contextual(),
|
|
121
|
+
**kwargs,
|
|
122
|
+
) -> Any:
|
|
123
|
+
"""Parses a Python string into a structured object."""
|
|
124
|
+
del kwargs
|
|
125
|
+
global_vars = additional_context or {}
|
|
126
|
+
if schema is not None:
|
|
127
|
+
dependencies = schema.class_dependencies()
|
|
128
|
+
global_vars.update({d.__name__: d for d in dependencies})
|
|
129
|
+
return structure_from_python(
|
|
130
|
+
text,
|
|
131
|
+
global_vars=global_vars,
|
|
132
|
+
autofix=autofix,
|
|
133
|
+
autofix_lm=autofix_lm,
|
|
134
|
+
permission=permission,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def structure_from_python(
|
|
139
|
+
code: str,
|
|
140
|
+
*,
|
|
141
|
+
global_vars: dict[str, Any] | None = None,
|
|
142
|
+
permission: pg.coding.CodePermission = (
|
|
143
|
+
pg.coding.CodePermission.ASSIGN | pg.coding.CodePermission.CALL
|
|
144
|
+
),
|
|
145
|
+
autofix=0,
|
|
146
|
+
autofix_lm: lf.LanguageModel = lf.contextual(),
|
|
147
|
+
) -> Any:
|
|
148
|
+
"""Evaluates structure from Python code with access to symbols."""
|
|
149
|
+
global_vars = global_vars or {}
|
|
150
|
+
global_vars.update({
|
|
151
|
+
'pg': pg,
|
|
152
|
+
'Object': pg.Object,
|
|
153
|
+
'Any': typing.Any,
|
|
154
|
+
'List': typing.List,
|
|
155
|
+
'Tuple': typing.Tuple,
|
|
156
|
+
'Dict': typing.Dict,
|
|
157
|
+
'Sequence': typing.Sequence,
|
|
158
|
+
'Optional': typing.Optional,
|
|
159
|
+
'Union': typing.Union,
|
|
160
|
+
# Special value markers.
|
|
161
|
+
'UNKNOWN': base.UNKNOWN,
|
|
162
|
+
})
|
|
163
|
+
# We are creating objects here, so we execute the code without a sandbox.
|
|
164
|
+
return correction.run_with_correction(
|
|
165
|
+
code,
|
|
166
|
+
global_vars=global_vars,
|
|
167
|
+
sandbox=False,
|
|
168
|
+
max_attempts=autofix,
|
|
169
|
+
lm=autofix_lm,
|
|
170
|
+
permission=permission,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def source_form(value, compact: bool = True, markdown: bool = False) -> str:
|
|
175
|
+
"""Returns the source code form of an object."""
|
|
176
|
+
return PythonPromptingProtocol().value_repr(
|
|
177
|
+
value, compact=compact, markdown=markdown
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def include_method_in_prompt(method):
|
|
182
|
+
"""Decorator to include a method in the class definition of the prompt."""
|
|
183
|
+
setattr(method, '__show_in_prompt__', True)
|
|
184
|
+
return method
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def should_include_method_in_prompt(method):
|
|
188
|
+
"""Returns True if the method should be shown in the prompt."""
|
|
189
|
+
return getattr(method, '__show_in_prompt__', False)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def class_definition(
|
|
193
|
+
cls,
|
|
194
|
+
strict: bool = False,
|
|
195
|
+
allowed_dependencies: set[Type[Any]] | None = None,
|
|
196
|
+
) -> str:
|
|
197
|
+
"""Returns the Python class definition."""
|
|
198
|
+
out = io.StringIO()
|
|
199
|
+
schema = pg.schema(cls)
|
|
200
|
+
eligible_bases = []
|
|
201
|
+
for base_cls in cls.__bases__:
|
|
202
|
+
if base_cls is not object:
|
|
203
|
+
if allowed_dependencies is None or base_cls in allowed_dependencies:
|
|
204
|
+
eligible_bases.append(base_cls.__name__)
|
|
205
|
+
|
|
206
|
+
if eligible_bases:
|
|
207
|
+
base_cls_str = ', '.join(eligible_bases)
|
|
208
|
+
out.write(f'class {cls.__name__}({base_cls_str}):\n')
|
|
209
|
+
else:
|
|
210
|
+
out.write(f'class {cls.__name__}:\n')
|
|
211
|
+
|
|
212
|
+
if cls.__doc__:
|
|
213
|
+
doc_lines = cls.__doc__.strip().split('\n')
|
|
214
|
+
if len(doc_lines) == 1:
|
|
215
|
+
out.write(f' """{cls.__doc__}"""\n')
|
|
216
|
+
else:
|
|
217
|
+
out.write(' """')
|
|
218
|
+
|
|
219
|
+
# Since Python 3.13, the indentation of docstring lines is removed.
|
|
220
|
+
# Therefore, we add two spaces to each non-empty line to keep the
|
|
221
|
+
# indentation consistent with the class definition.
|
|
222
|
+
if sys.version_info >= (3, 13):
|
|
223
|
+
for i in range(1, len(doc_lines)):
|
|
224
|
+
if doc_lines[i]:
|
|
225
|
+
doc_lines[i] = ' ' * 2 + doc_lines[i]
|
|
226
|
+
|
|
227
|
+
for line in doc_lines:
|
|
228
|
+
out.write(line)
|
|
229
|
+
out.write('\n')
|
|
230
|
+
out.write(' """\n')
|
|
231
|
+
|
|
232
|
+
empty_class = True
|
|
233
|
+
if schema.fields:
|
|
234
|
+
for key, field in schema.items():
|
|
235
|
+
if not isinstance(key, pg.typing.ConstStrKey):
|
|
236
|
+
pg.logging.warning(
|
|
237
|
+
'Variable-length keyword arguments is not supported in '
|
|
238
|
+
f'structured parsing or query. Encountered: {cls}, Schema: {schema}'
|
|
239
|
+
)
|
|
240
|
+
continue
|
|
241
|
+
|
|
242
|
+
# Skip fields that are marked as excluded from the prompt sent to LLM
|
|
243
|
+
# for OOP.
|
|
244
|
+
if field.metadata.get('exclude_from_prompt', False):
|
|
245
|
+
continue
|
|
246
|
+
|
|
247
|
+
# Write field doc string as comments before the field definition.
|
|
248
|
+
if field.description:
|
|
249
|
+
for line in field.description.split('\n'):
|
|
250
|
+
if line:
|
|
251
|
+
out.write(' # ')
|
|
252
|
+
out.write(line)
|
|
253
|
+
out.write('\n')
|
|
254
|
+
|
|
255
|
+
annotation_str = base.annotation(
|
|
256
|
+
field.value, strict=strict, allowed_dependencies=allowed_dependencies
|
|
257
|
+
)
|
|
258
|
+
out.write(f' {field.key}: {annotation_str}')
|
|
259
|
+
out.write('\n')
|
|
260
|
+
empty_class = False
|
|
261
|
+
|
|
262
|
+
for method in _iter_newly_defined_methods(cls, allowed_dependencies):
|
|
263
|
+
source = inspect.getsource(method)
|
|
264
|
+
# Remove decorators from the method definition.
|
|
265
|
+
source = re.sub(r'\s*@.*\.include_method_in_prompt.*\n', '', source)
|
|
266
|
+
out.write('\n')
|
|
267
|
+
out.write(
|
|
268
|
+
textwrap.indent(
|
|
269
|
+
inspect.cleandoc('\n' + source), ' ' * 2)
|
|
270
|
+
)
|
|
271
|
+
out.write('\n')
|
|
272
|
+
empty_class = False
|
|
273
|
+
|
|
274
|
+
if empty_class:
|
|
275
|
+
out.write(' pass\n')
|
|
276
|
+
return out.getvalue()
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def _iter_newly_defined_methods(
|
|
280
|
+
cls, allowed_dependencies: set[Type[Any]] | None):
|
|
281
|
+
names = {attr_name: True for attr_name in dir(cls)}
|
|
282
|
+
for base_cls in cls.__bases__:
|
|
283
|
+
if allowed_dependencies is None or base_cls in allowed_dependencies:
|
|
284
|
+
for name in dir(base_cls):
|
|
285
|
+
names.pop(name, None)
|
|
286
|
+
for name in names.keys():
|
|
287
|
+
attr = getattr(cls, name)
|
|
288
|
+
if callable(attr) and should_include_method_in_prompt(attr):
|
|
289
|
+
yield attr
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def class_definitions(
|
|
293
|
+
classes: Sequence[Type[Any]],
|
|
294
|
+
*,
|
|
295
|
+
allowed_dependencies: set[Type[Any]] | None = None,
|
|
296
|
+
strict: bool = False,
|
|
297
|
+
markdown: bool = False,
|
|
298
|
+
) -> str | None:
|
|
299
|
+
"""Returns a string for class definitions."""
|
|
300
|
+
if not classes:
|
|
301
|
+
return None
|
|
302
|
+
def_str = io.StringIO()
|
|
303
|
+
for i, cls in enumerate(classes):
|
|
304
|
+
if i > 0:
|
|
305
|
+
def_str.write('\n')
|
|
306
|
+
def_str.write(
|
|
307
|
+
class_definition(
|
|
308
|
+
cls,
|
|
309
|
+
strict=strict,
|
|
310
|
+
allowed_dependencies=allowed_dependencies,
|
|
311
|
+
)
|
|
312
|
+
)
|
|
313
|
+
ret = def_str.getvalue()
|
|
314
|
+
if markdown and ret:
|
|
315
|
+
ret = f'```python\n{ret}```'
|
|
316
|
+
return ret
|