langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512150805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +1 -1
- langfun/core/__init__.py +7 -1
- langfun/core/agentic/__init__.py +8 -1
- langfun/core/agentic/action.py +740 -112
- langfun/core/agentic/action_eval.py +9 -2
- langfun/core/agentic/action_test.py +189 -24
- langfun/core/async_support.py +104 -5
- langfun/core/async_support_test.py +23 -0
- langfun/core/coding/python/correction.py +19 -9
- langfun/core/coding/python/execution.py +14 -12
- langfun/core/coding/python/generation.py +21 -16
- langfun/core/coding/python/sandboxing.py +23 -3
- langfun/core/component.py +42 -3
- langfun/core/concurrent.py +70 -6
- langfun/core/concurrent_test.py +9 -2
- langfun/core/console.py +1 -1
- langfun/core/data/conversion/anthropic.py +12 -3
- langfun/core/data/conversion/anthropic_test.py +8 -6
- langfun/core/data/conversion/gemini.py +11 -2
- langfun/core/data/conversion/gemini_test.py +48 -9
- langfun/core/data/conversion/openai.py +145 -31
- langfun/core/data/conversion/openai_test.py +161 -17
- langfun/core/eval/base.py +48 -44
- langfun/core/eval/base_test.py +5 -5
- langfun/core/eval/matching.py +5 -2
- langfun/core/eval/patching.py +3 -3
- langfun/core/eval/scoring.py +4 -3
- langfun/core/eval/v2/__init__.py +3 -0
- langfun/core/eval/v2/checkpointing.py +148 -46
- langfun/core/eval/v2/checkpointing_test.py +9 -2
- langfun/core/eval/v2/config_saver.py +37 -0
- langfun/core/eval/v2/config_saver_test.py +36 -0
- langfun/core/eval/v2/eval_test_helper.py +104 -3
- langfun/core/eval/v2/evaluation.py +102 -19
- langfun/core/eval/v2/evaluation_test.py +9 -3
- langfun/core/eval/v2/example.py +50 -40
- langfun/core/eval/v2/example_test.py +16 -8
- langfun/core/eval/v2/experiment.py +95 -20
- langfun/core/eval/v2/experiment_test.py +19 -0
- langfun/core/eval/v2/metric_values.py +31 -3
- langfun/core/eval/v2/metric_values_test.py +32 -0
- langfun/core/eval/v2/metrics.py +157 -44
- langfun/core/eval/v2/metrics_test.py +39 -18
- langfun/core/eval/v2/progress.py +31 -1
- langfun/core/eval/v2/progress_test.py +27 -0
- langfun/core/eval/v2/progress_tracking.py +13 -5
- langfun/core/eval/v2/progress_tracking_test.py +9 -1
- langfun/core/eval/v2/reporting.py +88 -71
- langfun/core/eval/v2/reporting_test.py +24 -6
- langfun/core/eval/v2/runners/__init__.py +30 -0
- langfun/core/eval/v2/{runners.py → runners/base.py} +73 -180
- langfun/core/eval/v2/runners/beam.py +354 -0
- langfun/core/eval/v2/runners/beam_test.py +153 -0
- langfun/core/eval/v2/runners/ckpt_monitor.py +350 -0
- langfun/core/eval/v2/runners/ckpt_monitor_test.py +213 -0
- langfun/core/eval/v2/runners/debug.py +40 -0
- langfun/core/eval/v2/runners/debug_test.py +76 -0
- langfun/core/eval/v2/runners/parallel.py +243 -0
- langfun/core/eval/v2/runners/parallel_test.py +182 -0
- langfun/core/eval/v2/runners/sequential.py +47 -0
- langfun/core/eval/v2/runners/sequential_test.py +169 -0
- langfun/core/langfunc.py +45 -130
- langfun/core/langfunc_test.py +7 -5
- langfun/core/language_model.py +189 -36
- langfun/core/language_model_test.py +54 -3
- langfun/core/llms/__init__.py +14 -1
- langfun/core/llms/anthropic.py +157 -2
- langfun/core/llms/azure_openai.py +29 -17
- langfun/core/llms/cache/base.py +25 -3
- langfun/core/llms/cache/in_memory.py +48 -7
- langfun/core/llms/cache/in_memory_test.py +14 -4
- langfun/core/llms/compositional.py +25 -1
- langfun/core/llms/deepseek.py +30 -2
- langfun/core/llms/fake.py +32 -1
- langfun/core/llms/gemini.py +90 -12
- langfun/core/llms/gemini_test.py +110 -0
- langfun/core/llms/google_genai.py +52 -1
- langfun/core/llms/groq.py +28 -3
- langfun/core/llms/llama_cpp.py +23 -4
- langfun/core/llms/openai.py +120 -3
- langfun/core/llms/openai_compatible.py +148 -27
- langfun/core/llms/openai_compatible_test.py +207 -20
- langfun/core/llms/openai_test.py +0 -2
- langfun/core/llms/rest.py +16 -1
- langfun/core/llms/vertexai.py +78 -8
- langfun/core/logging.py +1 -1
- langfun/core/mcp/__init__.py +10 -0
- langfun/core/mcp/client.py +177 -0
- langfun/core/mcp/client_test.py +71 -0
- langfun/core/mcp/session.py +241 -0
- langfun/core/mcp/session_test.py +54 -0
- langfun/core/mcp/testing/simple_mcp_client.py +33 -0
- langfun/core/mcp/testing/simple_mcp_server.py +33 -0
- langfun/core/mcp/tool.py +254 -0
- langfun/core/mcp/tool_test.py +197 -0
- langfun/core/memory.py +1 -0
- langfun/core/message.py +160 -55
- langfun/core/message_test.py +65 -81
- langfun/core/modalities/__init__.py +8 -0
- langfun/core/modalities/audio.py +21 -1
- langfun/core/modalities/image.py +73 -3
- langfun/core/modalities/image_test.py +116 -0
- langfun/core/modalities/mime.py +78 -4
- langfun/core/modalities/mime_test.py +59 -0
- langfun/core/modalities/pdf.py +19 -1
- langfun/core/modalities/video.py +21 -1
- langfun/core/modality.py +167 -29
- langfun/core/modality_test.py +42 -12
- langfun/core/natural_language.py +1 -1
- langfun/core/sampling.py +4 -4
- langfun/core/sampling_test.py +20 -4
- langfun/core/structured/__init__.py +2 -24
- langfun/core/structured/completion.py +34 -44
- langfun/core/structured/completion_test.py +23 -43
- langfun/core/structured/description.py +54 -50
- langfun/core/structured/function_generation.py +29 -12
- langfun/core/structured/mapping.py +81 -37
- langfun/core/structured/parsing.py +95 -79
- langfun/core/structured/parsing_test.py +0 -3
- langfun/core/structured/querying.py +230 -154
- langfun/core/structured/querying_test.py +69 -33
- langfun/core/structured/schema/__init__.py +49 -0
- langfun/core/structured/schema/base.py +664 -0
- langfun/core/structured/schema/base_test.py +531 -0
- langfun/core/structured/schema/json.py +174 -0
- langfun/core/structured/schema/json_test.py +121 -0
- langfun/core/structured/schema/python.py +316 -0
- langfun/core/structured/schema/python_test.py +410 -0
- langfun/core/structured/schema_generation.py +33 -14
- langfun/core/structured/scoring.py +47 -36
- langfun/core/structured/tokenization.py +26 -11
- langfun/core/subscription.py +2 -2
- langfun/core/template.py +175 -50
- langfun/core/template_test.py +123 -17
- langfun/env/__init__.py +43 -0
- langfun/env/base_environment.py +827 -0
- langfun/env/base_environment_test.py +473 -0
- langfun/env/base_feature.py +304 -0
- langfun/env/base_feature_test.py +228 -0
- langfun/env/base_sandbox.py +842 -0
- langfun/env/base_sandbox_test.py +1235 -0
- langfun/env/event_handlers/__init__.py +14 -0
- langfun/env/event_handlers/chain.py +233 -0
- langfun/env/event_handlers/chain_test.py +253 -0
- langfun/env/event_handlers/event_logger.py +472 -0
- langfun/env/event_handlers/event_logger_test.py +304 -0
- langfun/env/event_handlers/metric_writer.py +726 -0
- langfun/env/event_handlers/metric_writer_test.py +214 -0
- langfun/env/interface.py +1640 -0
- langfun/env/interface_test.py +153 -0
- langfun/env/load_balancers.py +59 -0
- langfun/env/load_balancers_test.py +141 -0
- langfun/env/test_utils.py +507 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/METADATA +7 -3
- langfun-0.1.2.dev202512150805.dist-info/RECORD +217 -0
- langfun/core/eval/v2/runners_test.py +0 -343
- langfun/core/structured/schema.py +0 -987
- langfun/core/structured/schema_test.py +0 -982
- langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512150805.dist-info}/top_level.txt +0 -0
|
@@ -24,7 +24,14 @@ import pyglove as pg
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class ActionEval(lf.eval.v2.Evaluation):
|
|
27
|
-
"""
|
|
27
|
+
"""Evaluation for agentic actions.
|
|
28
|
+
|
|
29
|
+
`ActionEval` is a specialized evaluation class for executing and evaluating
|
|
30
|
+
agentic actions based on provided inputs. Each input example is expected to
|
|
31
|
+
contain an `action` attribute. The `process` method executes the action
|
|
32
|
+
within a dedicated `Session`, captures the final result, and returns it
|
|
33
|
+
along with the session details in the metadata.
|
|
34
|
+
"""
|
|
28
35
|
|
|
29
36
|
action_args: Annotated[
|
|
30
37
|
dict[str, Any],
|
|
@@ -68,7 +75,7 @@ class ExampleView(pg.Object):
|
|
|
68
75
|
class ActionEvalV1(lf_eval.Matching):
|
|
69
76
|
"""Base class for action evaluations.
|
|
70
77
|
|
|
71
|
-
The input function should
|
|
78
|
+
The input function should return a list of pg.Dict, with `action` and
|
|
72
79
|
`groundtruth` fields.
|
|
73
80
|
"""
|
|
74
81
|
# We override the schema and prompt to dummy values since they are not used.
|
|
@@ -34,6 +34,7 @@ class Bar(action_lib.Action):
|
|
|
34
34
|
time.sleep(self.simulate_execution_time)
|
|
35
35
|
session.query('bar', lm=lm)
|
|
36
36
|
session.add_metadata(note='bar')
|
|
37
|
+
session.update_progress('Query completed')
|
|
37
38
|
if self.simulate_action_error:
|
|
38
39
|
raise ValueError('Bar error')
|
|
39
40
|
return 2 + pg.contextual_value('baz', 0)
|
|
@@ -51,6 +52,7 @@ class Foo(action_lib.Action):
|
|
|
51
52
|
with session.track_phase('prepare'):
|
|
52
53
|
session.info('Begin Foo', x=1)
|
|
53
54
|
time.sleep(self.simulate_execution_time[0])
|
|
55
|
+
Bar()(session, lm=lm)
|
|
54
56
|
session.query(
|
|
55
57
|
'foo',
|
|
56
58
|
schema=int if self.simulate_query_error else None,
|
|
@@ -64,14 +66,21 @@ class Foo(action_lib.Action):
|
|
|
64
66
|
def _sub_task(i):
|
|
65
67
|
session.add_metadata(**{f'subtask_{i}': i})
|
|
66
68
|
time.sleep(self.simulate_execution_time[2])
|
|
69
|
+
Bar()(session, lm=lm)
|
|
67
70
|
return lf_structured.query(f'subtask_{i}', lm=lm)
|
|
68
71
|
|
|
72
|
+
self._state = []
|
|
69
73
|
for i, output, error in session.concurrent_map(
|
|
70
|
-
_sub_task,
|
|
74
|
+
_sub_task,
|
|
75
|
+
range(3),
|
|
76
|
+
max_workers=2,
|
|
77
|
+
ordered=True,
|
|
78
|
+
silence_on_errors=None,
|
|
71
79
|
):
|
|
72
80
|
assert isinstance(i, int), i
|
|
73
81
|
assert isinstance(output, str), output
|
|
74
82
|
assert error is None, error
|
|
83
|
+
self._state.append(i)
|
|
75
84
|
return self.x + Bar(
|
|
76
85
|
simulate_action_error=self.simulate_action_error,
|
|
77
86
|
simulate_execution_time=self.simulate_execution_time[3]
|
|
@@ -81,6 +90,50 @@ class Foo(action_lib.Action):
|
|
|
81
90
|
lf_structured.query('additional query', lm=lm)
|
|
82
91
|
|
|
83
92
|
|
|
93
|
+
class ExecutionUnitPositionTest(unittest.TestCase):
|
|
94
|
+
|
|
95
|
+
def test_basics(self):
|
|
96
|
+
pos1 = action_lib.ExecutionUnit.Position(None, 0)
|
|
97
|
+
self.assertEqual(repr(pos1), 'Position(0)')
|
|
98
|
+
self.assertEqual(str(pos1), '')
|
|
99
|
+
self.assertIsNone(pos1.parent)
|
|
100
|
+
self.assertEqual(pos1.index, 0)
|
|
101
|
+
self.assertEqual(pos1.indices(), (0,))
|
|
102
|
+
self.assertEqual(pos1, (0,))
|
|
103
|
+
self.assertEqual(pos1, '')
|
|
104
|
+
self.assertEqual(pos1, action_lib.ExecutionUnit.Position(None, 0))
|
|
105
|
+
self.assertNotEqual(pos1, 1)
|
|
106
|
+
self.assertNotEqual(pos1, (1,))
|
|
107
|
+
self.assertNotEqual(pos1, action_lib.ExecutionUnit.Position(None, 1))
|
|
108
|
+
|
|
109
|
+
pos2 = action_lib.ExecutionUnit.Position(pos1, 0)
|
|
110
|
+
self.assertEqual(repr(pos2), 'Position(0, 0)')
|
|
111
|
+
self.assertEqual(str(pos2), '1')
|
|
112
|
+
self.assertEqual(pos2, '1')
|
|
113
|
+
self.assertEqual(pos2.parent, pos1)
|
|
114
|
+
self.assertEqual(pos2.index, 0)
|
|
115
|
+
self.assertEqual(pos2.indices(), (0, 0))
|
|
116
|
+
self.assertNotEqual(pos1, pos2)
|
|
117
|
+
self.assertLess(pos1, pos2)
|
|
118
|
+
self.assertGreater(pos2, pos1)
|
|
119
|
+
self.assertEqual(
|
|
120
|
+
hash(pos2),
|
|
121
|
+
hash(
|
|
122
|
+
action_lib.ExecutionUnit.Position(
|
|
123
|
+
action_lib.ExecutionUnit.Position(None, 0), 0
|
|
124
|
+
)
|
|
125
|
+
)
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
pos3 = action_lib.ExecutionUnit.Position(pos2, 0)
|
|
129
|
+
self.assertEqual(str(pos3), '1.1')
|
|
130
|
+
self.assertEqual(pos3, '1.1')
|
|
131
|
+
self.assertEqual(pos3.parent, pos2)
|
|
132
|
+
self.assertEqual(pos3.index, 0)
|
|
133
|
+
self.assertEqual(pos3.indices(), (0, 0, 0))
|
|
134
|
+
self.assertEqual(pos3.to_str(separator='>'), '1>1')
|
|
135
|
+
|
|
136
|
+
|
|
84
137
|
class ActionInvocationTest(unittest.TestCase):
|
|
85
138
|
|
|
86
139
|
def test_basics(self):
|
|
@@ -101,9 +154,7 @@ class ExecutionTraceTest(unittest.TestCase):
|
|
|
101
154
|
self.assertEqual(execution.id, '')
|
|
102
155
|
|
|
103
156
|
root = action_lib.ActionInvocation(action=action_lib.RootAction())
|
|
104
|
-
action_invocation = action_lib.ActionInvocation(
|
|
105
|
-
action=Foo(1)
|
|
106
|
-
)
|
|
157
|
+
action_invocation = action_lib.ActionInvocation(action=Foo(1))
|
|
107
158
|
root.execution.append(action_invocation)
|
|
108
159
|
self.assertEqual(action_invocation.execution.id, '/a1')
|
|
109
160
|
|
|
@@ -118,10 +169,11 @@ class SessionTest(unittest.TestCase):
|
|
|
118
169
|
foo = Foo(1)
|
|
119
170
|
self.assertIsNone(foo.session)
|
|
120
171
|
self.assertIsNone(foo.invocation)
|
|
172
|
+
self.assertIsNone(foo.state)
|
|
121
173
|
self.assertIsNone(foo.result)
|
|
122
174
|
self.assertIsNone(foo.metadata)
|
|
123
175
|
|
|
124
|
-
session = action_lib.Session(id='agent@1')
|
|
176
|
+
session = action_lib.Session(id='agent@1', verbose=True)
|
|
125
177
|
self.assertEqual(session.id, 'agent@1')
|
|
126
178
|
self.assertFalse(session.has_started)
|
|
127
179
|
self.assertFalse(session.has_stopped)
|
|
@@ -130,12 +182,14 @@ class SessionTest(unittest.TestCase):
|
|
|
130
182
|
_ = session.to_html()
|
|
131
183
|
|
|
132
184
|
with session:
|
|
133
|
-
result = foo(session, lm=lm
|
|
185
|
+
result = foo(session, lm=lm)
|
|
134
186
|
|
|
135
187
|
self.assertTrue(session.has_started)
|
|
136
188
|
self.assertTrue(session.has_stopped)
|
|
137
189
|
self.assertEqual(result, 3)
|
|
138
190
|
self.assertIsNone(foo.session)
|
|
191
|
+
self.assertEqual(foo.state, [0, 1, 2])
|
|
192
|
+
self.assertIs(foo.invocation.state, foo.state)
|
|
139
193
|
self.assertEqual(foo.result, 3)
|
|
140
194
|
self.assertEqual(
|
|
141
195
|
foo.metadata, dict(note='foo', subtask_0=0, subtask_1=1, subtask_2=2)
|
|
@@ -143,6 +197,7 @@ class SessionTest(unittest.TestCase):
|
|
|
143
197
|
|
|
144
198
|
self.assertIsInstance(session.root.action, action_lib.RootAction)
|
|
145
199
|
self.assertIs(session.current_action, session.root)
|
|
200
|
+
self.assertIs(session.metadata, session.root.metadata)
|
|
146
201
|
|
|
147
202
|
#
|
|
148
203
|
# Inspecting the root invocation.
|
|
@@ -165,20 +220,25 @@ class SessionTest(unittest.TestCase):
|
|
|
165
220
|
)
|
|
166
221
|
|
|
167
222
|
# The root space should have one action (foo), no queries, and no logs.
|
|
223
|
+
self.assertEqual(len(root.execution_units), 1)
|
|
168
224
|
self.assertEqual(len(root.actions), 1)
|
|
169
225
|
self.assertEqual(len(root.queries), 0)
|
|
170
226
|
self.assertEqual(len(root.logs), 0)
|
|
171
|
-
#
|
|
172
|
-
self.assertEqual(len(session.all_queries),
|
|
173
|
-
self.assertEqual(len(root.all_queries),
|
|
174
|
-
#
|
|
175
|
-
self.assertEqual(len(session.all_actions),
|
|
176
|
-
self.assertEqual(
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
227
|
+
# 2 query from Bar, 2 from Foo and 2 * 3 from parallel executions.
|
|
228
|
+
self.assertEqual(len(session.all_queries), 10)
|
|
229
|
+
self.assertEqual(len(root.all_queries), 10)
|
|
230
|
+
# 6 actions: Foo and 2 Bar, and 3 Bar from parallel executions.
|
|
231
|
+
self.assertEqual(len(session.all_actions), 6)
|
|
232
|
+
self.assertEqual(
|
|
233
|
+
[str(a.position) for a in session.all_actions],
|
|
234
|
+
['1', '1.1', '1.2.1.1', '1.2.2.1', '1.2.3.1', '1.3']
|
|
235
|
+
)
|
|
236
|
+
self.assertEqual(len(root.all_actions), 6)
|
|
237
|
+
# 1 log from Bar and 1 from Foo and 3 from Bar in parallel executions.
|
|
238
|
+
self.assertEqual(len(session.all_logs), 6)
|
|
239
|
+
self.assertEqual(len(root.all_logs), 6)
|
|
180
240
|
self.assertIs(session.usage_summary, root.usage_summary)
|
|
181
|
-
self.assertEqual(root.usage_summary.total.num_requests,
|
|
241
|
+
self.assertEqual(root.usage_summary.total.num_requests, 10)
|
|
182
242
|
|
|
183
243
|
# Inspecting the top-level action (Foo)
|
|
184
244
|
foo_invocation = root.execution[0]
|
|
@@ -190,15 +250,19 @@ class SessionTest(unittest.TestCase):
|
|
|
190
250
|
|
|
191
251
|
# Prepare phase.
|
|
192
252
|
prepare_phase = foo_invocation.execution[0]
|
|
253
|
+
self.assertIsNone(prepare_phase.position)
|
|
193
254
|
self.assertIsInstance(prepare_phase, action_lib.ExecutionTrace)
|
|
194
255
|
self.assertEqual(prepare_phase.id, 'agent@1:/a1/prepare')
|
|
195
|
-
self.assertEqual(len(prepare_phase.items),
|
|
256
|
+
self.assertEqual(len(prepare_phase.items), 3)
|
|
196
257
|
self.assertTrue(prepare_phase.has_started)
|
|
197
258
|
self.assertTrue(prepare_phase.has_stopped)
|
|
198
|
-
self.assertEqual(prepare_phase.usage_summary.total.num_requests,
|
|
259
|
+
self.assertEqual(prepare_phase.usage_summary.total.num_requests, 2)
|
|
199
260
|
self.assertIsInstance(prepare_phase.items[0], lf.logging.LogEntry)
|
|
200
|
-
self.assertIsInstance(prepare_phase.items[1],
|
|
201
|
-
self.
|
|
261
|
+
self.assertIsInstance(prepare_phase.items[1], action_lib.ActionInvocation)
|
|
262
|
+
self.assertIs(prepare_phase.items[1].parent_execution_unit, foo_invocation)
|
|
263
|
+
self.assertEqual(prepare_phase.items[1].id, 'agent@1:/a1/prepare/a1')
|
|
264
|
+
self.assertIsInstance(prepare_phase.items[2], lf_structured.QueryInvocation)
|
|
265
|
+
self.assertEqual(prepare_phase.items[2].id, 'agent@1:/a1/prepare/q1')
|
|
202
266
|
|
|
203
267
|
# Tracked queries.
|
|
204
268
|
query_invocation = foo_invocation.execution[1]
|
|
@@ -220,20 +284,44 @@ class SessionTest(unittest.TestCase):
|
|
|
220
284
|
|
|
221
285
|
# Tracked parallel executions.
|
|
222
286
|
parallel_executions = foo_invocation.execution[2]
|
|
287
|
+
# root (0) > foo (0) > parallel executions (1)
|
|
288
|
+
self.assertEqual(parallel_executions.position, (0, 0, 1))
|
|
223
289
|
self.assertEqual(parallel_executions.id, 'agent@1:/a1/p1')
|
|
224
290
|
self.assertIsInstance(parallel_executions, action_lib.ParallelExecutions)
|
|
291
|
+
self.assertIs(
|
|
292
|
+
parallel_executions.all_actions[0].parent_execution_unit,
|
|
293
|
+
parallel_executions
|
|
294
|
+
)
|
|
295
|
+
self.assertIs(
|
|
296
|
+
parallel_executions.all_actions[0].parent_action,
|
|
297
|
+
foo_invocation
|
|
298
|
+
)
|
|
225
299
|
self.assertEqual(len(parallel_executions), 3)
|
|
226
300
|
self.assertEqual(parallel_executions[0].id, 'agent@1:/a1/p1/b1')
|
|
227
301
|
self.assertEqual(parallel_executions[1].id, 'agent@1:/a1/p1/b2')
|
|
228
302
|
self.assertEqual(parallel_executions[2].id, 'agent@1:/a1/p1/b3')
|
|
303
|
+
self.assertEqual(len(parallel_executions[0].execution_units), 1)
|
|
304
|
+
self.assertEqual(len(parallel_executions[1].execution_units), 1)
|
|
305
|
+
self.assertEqual(len(parallel_executions[2].execution_units), 1)
|
|
229
306
|
self.assertEqual(len(parallel_executions[0].queries), 1)
|
|
307
|
+
self.assertEqual(len(parallel_executions[0].all_queries), 2)
|
|
230
308
|
self.assertEqual(len(parallel_executions[1].queries), 1)
|
|
309
|
+
self.assertEqual(len(parallel_executions[1].all_queries), 2)
|
|
231
310
|
self.assertEqual(len(parallel_executions[2].queries), 1)
|
|
311
|
+
self.assertEqual(len(parallel_executions[2].all_queries), 2)
|
|
312
|
+
self.assertEqual(len(parallel_executions.execution_units), 0)
|
|
313
|
+
self.assertEqual(len(parallel_executions.actions), 0)
|
|
314
|
+
self.assertEqual(len(parallel_executions.queries), 0)
|
|
315
|
+
self.assertEqual(len(parallel_executions.logs), 0)
|
|
316
|
+
self.assertEqual(len(parallel_executions.all_actions), 3)
|
|
317
|
+
self.assertEqual(len(parallel_executions.all_queries), 6)
|
|
318
|
+
self.assertEqual(len(parallel_executions.all_logs), 3)
|
|
232
319
|
|
|
233
320
|
# Invocation to Bar.
|
|
234
321
|
bar_invocation = foo_invocation.execution[3]
|
|
235
322
|
self.assertIs(bar_invocation.parent_action, foo_invocation)
|
|
236
|
-
self.
|
|
323
|
+
self.assertIs(bar_invocation.parent_execution_unit, foo_invocation)
|
|
324
|
+
self.assertEqual(bar_invocation.id, 'agent@1:/a1/a5')
|
|
237
325
|
self.assertIsInstance(bar_invocation, action_lib.ActionInvocation)
|
|
238
326
|
self.assertIsInstance(bar_invocation.action, Bar)
|
|
239
327
|
self.assertEqual(bar_invocation.result, 2)
|
|
@@ -366,7 +454,7 @@ class SessionTest(unittest.TestCase):
|
|
|
366
454
|
self.assertFalse(session.has_stopped)
|
|
367
455
|
|
|
368
456
|
session.start()
|
|
369
|
-
result = foo(session, lm=lm
|
|
457
|
+
result = foo(session, lm=lm)
|
|
370
458
|
session.end(result)
|
|
371
459
|
|
|
372
460
|
self.assertTrue(session.has_started)
|
|
@@ -386,7 +474,7 @@ class SessionTest(unittest.TestCase):
|
|
|
386
474
|
session = action_lib.Session(id='agent@1')
|
|
387
475
|
with self.assertRaisesRegex(ValueError, 'Bar error'):
|
|
388
476
|
with session:
|
|
389
|
-
foo(session, lm=lm
|
|
477
|
+
foo(session, lm=lm)
|
|
390
478
|
self.assertTrue(session.has_started)
|
|
391
479
|
self.assertTrue(session.has_stopped)
|
|
392
480
|
self.assertTrue(session.has_error)
|
|
@@ -399,7 +487,7 @@ class SessionTest(unittest.TestCase):
|
|
|
399
487
|
foo = Foo(1, simulate_action_error=True)
|
|
400
488
|
session = action_lib.Session(id='agent@1')
|
|
401
489
|
with self.assertRaisesRegex(ValueError, 'Please call `Session.start'):
|
|
402
|
-
foo(session, lm=lm
|
|
490
|
+
foo(session, lm=lm)
|
|
403
491
|
|
|
404
492
|
def test_succeed_with_multiple_actions(self):
|
|
405
493
|
lm = fake.StaticResponse('lm response')
|
|
@@ -480,6 +568,58 @@ class SessionTest(unittest.TestCase):
|
|
|
480
568
|
):
|
|
481
569
|
foo(lm=lm, max_execution_time=1.0)
|
|
482
570
|
|
|
571
|
+
def test_event_handler(self):
|
|
572
|
+
|
|
573
|
+
class MyActionHandler(pg.Object, action_lib.SessionEventHandler):
|
|
574
|
+
def _on_bound(self):
|
|
575
|
+
super()._on_bound()
|
|
576
|
+
self.progresses = []
|
|
577
|
+
|
|
578
|
+
def on_session_start(self, session):
|
|
579
|
+
session.add_metadata(progresses=pg.Ref(self.progresses))
|
|
580
|
+
|
|
581
|
+
def on_action_progress(self, session, action, title, **kwargs):
|
|
582
|
+
self.progresses.append((action.id, title))
|
|
583
|
+
|
|
584
|
+
handler = MyActionHandler()
|
|
585
|
+
self.assertIs(handler.get(MyActionHandler), handler)
|
|
586
|
+
self.assertIsNone(handler.get(action_lib.SessionLogging))
|
|
587
|
+
|
|
588
|
+
handler_chain = action_lib.SessionEventHandlerChain(
|
|
589
|
+
handlers=[handler, action_lib.SessionLogging()]
|
|
590
|
+
)
|
|
591
|
+
self.assertIs(handler_chain.get(MyActionHandler), handler)
|
|
592
|
+
self.assertIs(
|
|
593
|
+
handler_chain.get(action_lib.SessionLogging),
|
|
594
|
+
handler_chain.handlers[1]
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
session = action_lib.Session(
|
|
598
|
+
id='agent@1',
|
|
599
|
+
event_handler=handler_chain
|
|
600
|
+
)
|
|
601
|
+
bar = Bar()
|
|
602
|
+
with session:
|
|
603
|
+
bar(session, lm=fake.StaticResponse('lm response'))
|
|
604
|
+
session.update_progress('Trajectory completed')
|
|
605
|
+
|
|
606
|
+
self.assertIs(session.metadata['progresses'], handler.progresses)
|
|
607
|
+
self.assertEqual(handler.progresses, [
|
|
608
|
+
('agent@1:/a1', 'Query completed'),
|
|
609
|
+
('agent@1:', 'Trajectory completed'),
|
|
610
|
+
])
|
|
611
|
+
|
|
612
|
+
def test_clone(self):
|
|
613
|
+
event_handler = action_lib.SessionLogging()
|
|
614
|
+
session = action_lib.Session(event_handler=event_handler)
|
|
615
|
+
other = session.clone()
|
|
616
|
+
self.assertIsNot(session, other)
|
|
617
|
+
self.assertIs(other.event_handler, event_handler)
|
|
618
|
+
|
|
619
|
+
other = session.clone(deep=True)
|
|
620
|
+
self.assertIsNot(session, other)
|
|
621
|
+
self.assertIsNot(other.event_handler, session.event_handler)
|
|
622
|
+
|
|
483
623
|
def test_log(self):
|
|
484
624
|
session = action_lib.Session()
|
|
485
625
|
session.debug('hi', x=1, y=2)
|
|
@@ -493,6 +633,31 @@ class SessionTest(unittest.TestCase):
|
|
|
493
633
|
self.assertIn('agent@', session.id)
|
|
494
634
|
self.assertIsInstance(session.as_message(), lf.AIMessage)
|
|
495
635
|
|
|
636
|
+
def test_query_with_track_if(self):
|
|
637
|
+
lm = fake.StaticResponse('lm response')
|
|
638
|
+
session = action_lib.Session()
|
|
639
|
+
|
|
640
|
+
# Render session to trigger javascript updates to the HTML when
|
|
641
|
+
# operating on the session.
|
|
642
|
+
_ = session.to_html()
|
|
643
|
+
with session:
|
|
644
|
+
# This query will succeed.
|
|
645
|
+
session.query(
|
|
646
|
+
'prompt1',
|
|
647
|
+
schema=None,
|
|
648
|
+
lm=lm,
|
|
649
|
+
track_if=lambda q: not q.has_error,
|
|
650
|
+
default=None)
|
|
651
|
+
# This query will fail during parsing.
|
|
652
|
+
session.query(
|
|
653
|
+
'prompt2',
|
|
654
|
+
schema=int,
|
|
655
|
+
lm=lm,
|
|
656
|
+
track_if=lambda q: not q.has_error,
|
|
657
|
+
default=None)
|
|
658
|
+
self.assertEqual(len(session.root.queries), 1)
|
|
659
|
+
self.assertIsNone(session.root.queries[0].error)
|
|
660
|
+
|
|
496
661
|
|
|
497
662
|
if __name__ == '__main__':
|
|
498
663
|
unittest.main()
|
langfun/core/async_support.py
CHANGED
|
@@ -11,18 +11,117 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
"""
|
|
14
|
+
"""Utilities for asynchronous programming in Langfun."""
|
|
15
15
|
|
|
16
16
|
import asyncio
|
|
17
|
-
|
|
17
|
+
import contextlib
|
|
18
|
+
from typing import Any, Awaitable, Callable, Iterator
|
|
19
|
+
import anyio
|
|
18
20
|
import pyglove as pg
|
|
19
21
|
|
|
20
22
|
|
|
21
23
|
async def invoke_async(
|
|
22
|
-
|
|
24
|
+
sync_callable: Callable[..., Any], *args, **kwargs
|
|
23
25
|
) -> Any:
|
|
24
|
-
"""Invokes a callable asynchronously
|
|
26
|
+
"""Invokes a sync callable asynchronously in a separate thread.
|
|
27
|
+
|
|
28
|
+
This is useful for wrapping a sync function into an async function,
|
|
29
|
+
allowing multiple calls of the sync function to run concurrently.
|
|
30
|
+
`lf.context` will be propagated to the thread that runs the sync callable.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
sync_callable: The sync callable to invoke.
|
|
34
|
+
*args: Positional arguments to pass to the callable.
|
|
35
|
+
**kwargs: Keyword arguments to pass to the callable.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
An awaitable that resolves to the return value of the sync_callable.
|
|
39
|
+
"""
|
|
25
40
|
return await asyncio.to_thread(
|
|
26
41
|
# Enable `lf.context` manager for async calls.
|
|
27
|
-
pg.with_contextual_override(
|
|
42
|
+
pg.with_contextual_override(sync_callable), *args, **kwargs
|
|
28
43
|
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def invoke_sync(
|
|
47
|
+
async_callable: Callable[..., Awaitable[Any]],
|
|
48
|
+
*args,
|
|
49
|
+
**kwargs
|
|
50
|
+
) -> Any:
|
|
51
|
+
"""Invokes an async callable synchronously.
|
|
52
|
+
|
|
53
|
+
This is useful for calling an async function from a sync context.
|
|
54
|
+
If there is an existing async event loop in current thread managed by
|
|
55
|
+
`lf.sync_context_manager`, it will be used for running the async callable.
|
|
56
|
+
Otherwise, `anyio.run` will be used to run the async callable in a new
|
|
57
|
+
event loop.
|
|
58
|
+
`lf.context` will be propagated to the async callable.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
async_callable: The async callable to invoke.
|
|
62
|
+
*args: Positional arguments to pass to the callable.
|
|
63
|
+
**kwargs: Keyword arguments to pass to the callable.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
The return value of the async_callable.
|
|
67
|
+
"""
|
|
68
|
+
async def _invoke():
|
|
69
|
+
return await async_callable(*args, **kwargs)
|
|
70
|
+
invoke_fn = pg.with_contextual_override(_invoke)
|
|
71
|
+
blocking_portal = pg.utils.thread_local_get('__blocking_portal__', None)
|
|
72
|
+
if blocking_portal is None:
|
|
73
|
+
return anyio.run(invoke_fn)
|
|
74
|
+
return blocking_portal.call(invoke_fn)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@contextlib.contextmanager
|
|
78
|
+
def sync_context_manager(
|
|
79
|
+
async_context_manager: contextlib.AbstractAsyncContextManager[Any]
|
|
80
|
+
) -> Iterator[Any]:
|
|
81
|
+
"""Adapts an async context manager to a sync context manager.
|
|
82
|
+
|
|
83
|
+
sync_context_manager installs a blocking portal in current thread to run the
|
|
84
|
+
async context manager in a blocking way. It's useful for running async code in
|
|
85
|
+
sync context managers, e.g. `sync_context_manager` can be nested and share the
|
|
86
|
+
same event loop.
|
|
87
|
+
|
|
88
|
+
Example:
|
|
89
|
+
|
|
90
|
+
```python
|
|
91
|
+
@contextlib.asynccontextmanager
|
|
92
|
+
async def foo(x):
|
|
93
|
+
try:
|
|
94
|
+
yield x
|
|
95
|
+
finally:
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
with lf.sync_context_manager(foo(x)) as x
|
|
99
|
+
with lf.sync_context_manager(foo(y)) as y:
|
|
100
|
+
...
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
async_context_manager: The async context manager to adapt.
|
|
105
|
+
|
|
106
|
+
Yields:
|
|
107
|
+
The value yielded by the async context manager.
|
|
108
|
+
"""
|
|
109
|
+
blocking_portal = pg.utils.thread_local_get('__blocking_portal__', None)
|
|
110
|
+
portal_exit_stack = None
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
if blocking_portal is None:
|
|
114
|
+
portal_exit_stack = contextlib.ExitStack()
|
|
115
|
+
blocking_portal = portal_exit_stack.enter_context(
|
|
116
|
+
anyio.from_thread.start_blocking_portal()
|
|
117
|
+
)
|
|
118
|
+
pg.utils.thread_local_set('__blocking_portal__', blocking_portal)
|
|
119
|
+
context_manager = blocking_portal.wrap_async_context_manager(
|
|
120
|
+
async_context_manager
|
|
121
|
+
)
|
|
122
|
+
with context_manager as value:
|
|
123
|
+
yield value
|
|
124
|
+
finally:
|
|
125
|
+
if portal_exit_stack is not None:
|
|
126
|
+
portal_exit_stack.close()
|
|
127
|
+
pg.utils.thread_local_del('__blocking_portal__')
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import asyncio
|
|
16
|
+
import contextlib
|
|
16
17
|
import time
|
|
17
18
|
import unittest
|
|
18
19
|
|
|
@@ -34,6 +35,28 @@ class AsyncSupportTest(unittest.TestCase):
|
|
|
34
35
|
with pg.contextual_override(z=3):
|
|
35
36
|
self.assertEqual(asyncio.run(r), 6)
|
|
36
37
|
|
|
38
|
+
def test_invoke_sync(self):
|
|
39
|
+
@contextlib.asynccontextmanager
|
|
40
|
+
async def bar(x):
|
|
41
|
+
try:
|
|
42
|
+
yield x
|
|
43
|
+
finally:
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
async def foo(x, *, y):
|
|
47
|
+
time.sleep(2)
|
|
48
|
+
return x + y + pg.contextual_value('z', 0)
|
|
49
|
+
|
|
50
|
+
with pg.contextual_override(z=3):
|
|
51
|
+
with async_support.sync_context_manager(bar(1)) as x:
|
|
52
|
+
self.assertEqual(x, 1)
|
|
53
|
+
with async_support.sync_context_manager(bar(2)) as y:
|
|
54
|
+
self.assertEqual(y, 2)
|
|
55
|
+
self.assertEqual(async_support.invoke_sync(foo, 1, y=2), 6)
|
|
56
|
+
|
|
57
|
+
with pg.contextual_override(z=2):
|
|
58
|
+
self.assertEqual(async_support.invoke_sync(foo, 1, y=2), 5)
|
|
59
|
+
|
|
37
60
|
|
|
38
61
|
if __name__ == '__main__':
|
|
39
62
|
unittest.main()
|
|
@@ -19,13 +19,23 @@ import pyglove as pg
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class CodeWithError(pg.Object):
|
|
22
|
-
"""Python code with error.
|
|
22
|
+
"""A structure representing Python code along with an execution error.
|
|
23
|
+
|
|
24
|
+
This is used as input to a language model for error correction, providing
|
|
25
|
+
the model with the code that failed and the error message it produced.
|
|
26
|
+
"""
|
|
23
27
|
|
|
24
28
|
code: str
|
|
25
29
|
error: str
|
|
26
30
|
|
|
27
31
|
|
|
28
32
|
class CorrectedCode(pg.Object):
|
|
33
|
+
"""A structure containing corrected Python code.
|
|
34
|
+
|
|
35
|
+
This is used as the output schema when asking a language model to correct
|
|
36
|
+
code, expecting the model to return the fixed code in the `corrected_code`
|
|
37
|
+
field.
|
|
38
|
+
"""
|
|
29
39
|
corrected_code: str
|
|
30
40
|
|
|
31
41
|
|
|
@@ -49,7 +59,7 @@ def run_with_correction(
|
|
|
49
59
|
code: The source code that may or may not be problematic.
|
|
50
60
|
error: An optional initial error for `code` when it's problematic, usually
|
|
51
61
|
caught from elsewhere when it ran. If None, code will be executed once to
|
|
52
|
-
verify if
|
|
62
|
+
verify if it's good and obtain a feedback error message.
|
|
53
63
|
global_vars: A dict of str to value as the global variables that could be
|
|
54
64
|
accessed within the corrected code.
|
|
55
65
|
lm: Language model to be used. If not specified, it will try to use the `lm`
|
|
@@ -57,15 +67,15 @@ def run_with_correction(
|
|
|
57
67
|
max_attempts: Max number of attempts for the correction.
|
|
58
68
|
sandbox: If True, run code in sandbox; If False, run code in current
|
|
59
69
|
process. If None, run in sandbox first, if the output could not be
|
|
60
|
-
serialized and
|
|
70
|
+
serialized and passed to current process, run the code again in current
|
|
61
71
|
process.
|
|
62
72
|
permission: The permission to run the code.
|
|
63
73
|
timeout: The timeout for running the corrected code. If None, there is no
|
|
64
74
|
timeout. Applicable only when sandbox is set to True.
|
|
65
75
|
returns_code: If True, the return value is a tuple of (result, final code).
|
|
66
76
|
Otherwise the return value is the result only.
|
|
67
|
-
returns_stdout: If True, the stdout (a
|
|
68
|
-
outputs_intermediate: If True, intermediate output will be
|
|
77
|
+
returns_stdout: If True, the stdout (a string) will be returned.
|
|
78
|
+
outputs_intermediate: If True, intermediate output will be output as a
|
|
69
79
|
dict, with the last line's value accessible by key '__result__'. Otherwise
|
|
70
80
|
the value of the last line will be returned.
|
|
71
81
|
|
|
@@ -161,7 +171,7 @@ def correct(
|
|
|
161
171
|
code: The source code that may or may not be problematic.
|
|
162
172
|
error: An optional initial error for `code` when it's problematic, usually
|
|
163
173
|
caught from elsewhere when it ran. If None, code will be executed once to
|
|
164
|
-
verify if
|
|
174
|
+
verify if it's good and obtain a feedback error message.
|
|
165
175
|
global_vars: A dict of str to value as the global variables that could be
|
|
166
176
|
accessed within the corrected code.
|
|
167
177
|
lm: Language model to be used. If not specified, it will try to use the `lm`
|
|
@@ -169,7 +179,7 @@ def correct(
|
|
|
169
179
|
max_attempts: Max number of attempts for the correction.
|
|
170
180
|
sandbox: If True, run code in sandbox; If False, run code in current
|
|
171
181
|
process. If None, run in sandbox first, if the output could not be
|
|
172
|
-
serialized and
|
|
182
|
+
serialized and passed to current process, run the code again in current
|
|
173
183
|
process.
|
|
174
184
|
timeout: The timeout for running the corrected code. If None, there is no
|
|
175
185
|
timeout. Applicable only when sandbox is set to True.
|
|
@@ -193,7 +203,7 @@ def correct(
|
|
|
193
203
|
|
|
194
204
|
|
|
195
205
|
def _error_feedback_str(error: Exception) -> str:
|
|
196
|
-
"""Returns the error
|
|
206
|
+
"""Returns the error string for feedback."""
|
|
197
207
|
if isinstance(error, pg.coding.CodeError):
|
|
198
208
|
return pg.decolor(error.format(include_complete_code=False))
|
|
199
209
|
else:
|
|
@@ -201,7 +211,7 @@ def _error_feedback_str(error: Exception) -> str:
|
|
|
201
211
|
|
|
202
212
|
|
|
203
213
|
def _maybe_custom_validate(result: Any) -> Any:
|
|
204
|
-
"""
|
|
214
|
+
"""Applies custom validation through __validate__ method."""
|
|
205
215
|
if isinstance(result, dict) and "__result__" in result:
|
|
206
216
|
r = result["__result__"]
|
|
207
217
|
else:
|