langfun 0.1.2.dev202504280818__py3-none-any.whl → 0.1.2.dev202504300804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/__init__.py +3 -0
- langfun/core/agentic/action.py +253 -105
- langfun/core/agentic/action_eval.py +10 -3
- langfun/core/agentic/action_test.py +173 -47
- langfun/core/eval/base_test.py +4 -4
- langfun/core/eval/v2/evaluation.py +78 -12
- langfun/core/eval/v2/evaluation_test.py +2 -0
- langfun/core/structured/__init__.py +3 -0
- langfun/core/structured/mapping.py +5 -2
- langfun/core/structured/parsing_test.py +1 -1
- langfun/core/structured/querying.py +205 -18
- langfun/core/structured/querying_test.py +286 -47
- {langfun-0.1.2.dev202504280818.dist-info → langfun-0.1.2.dev202504300804.dist-info}/METADATA +29 -6
- {langfun-0.1.2.dev202504280818.dist-info → langfun-0.1.2.dev202504300804.dist-info}/RECORD +17 -17
- {langfun-0.1.2.dev202504280818.dist-info → langfun-0.1.2.dev202504300804.dist-info}/WHEEL +1 -1
- {langfun-0.1.2.dev202504280818.dist-info → langfun-0.1.2.dev202504300804.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202504280818.dist-info → langfun-0.1.2.dev202504300804.dist-info}/top_level.txt +0 -0
@@ -22,54 +22,108 @@ import langfun.core.structured as lf_structured
|
|
22
22
|
import pyglove as pg
|
23
23
|
|
24
24
|
|
25
|
-
class
|
25
|
+
class Bar(action_lib.Action):
|
26
|
+
simulate_action_error: bool = False
|
27
|
+
|
28
|
+
def call(self, session, *, lm, **kwargs):
|
29
|
+
assert session.current_action.action is self
|
30
|
+
session.info('Begin Bar')
|
31
|
+
session.query('bar', lm=lm)
|
32
|
+
session.add_metadata(note='bar')
|
33
|
+
if self.simulate_action_error:
|
34
|
+
raise ValueError('Bar error')
|
35
|
+
return 2
|
36
|
+
|
37
|
+
|
38
|
+
class Foo(action_lib.Action):
|
39
|
+
x: int
|
40
|
+
simulate_action_error: bool = False
|
41
|
+
simulate_query_error: bool = False
|
42
|
+
|
43
|
+
def call(self, session, *, lm, **kwargs):
|
44
|
+
assert session.current_action.action is self
|
45
|
+
with session.track_phase('prepare'):
|
46
|
+
session.info('Begin Foo', x=1)
|
47
|
+
session.query(
|
48
|
+
'foo',
|
49
|
+
schema=int if self.simulate_query_error else None,
|
50
|
+
lm=lm
|
51
|
+
)
|
52
|
+
with session.track_queries():
|
53
|
+
self.make_additional_query(lm)
|
54
|
+
session.add_metadata(note='foo')
|
55
|
+
|
56
|
+
def _sub_task(i):
|
57
|
+
session.add_metadata(**{f'subtask_{i}': i})
|
58
|
+
return lf_structured.query(f'subtask_{i}', lm=lm)
|
59
|
+
|
60
|
+
for i, output, error in session.concurrent_map(
|
61
|
+
_sub_task, range(3), max_workers=2, silence_on_errors=None,
|
62
|
+
):
|
63
|
+
assert isinstance(i, int), i
|
64
|
+
assert isinstance(output, str), output
|
65
|
+
assert error is None, error
|
66
|
+
return self.x + Bar(
|
67
|
+
simulate_action_error=self.simulate_action_error
|
68
|
+
)(session, lm=lm)
|
69
|
+
|
70
|
+
def make_additional_query(self, lm):
|
71
|
+
lf_structured.query('additional query', lm=lm)
|
72
|
+
|
73
|
+
|
74
|
+
class ActionInvocationTest(unittest.TestCase):
|
75
|
+
|
76
|
+
def test_basics(self):
|
77
|
+
action_invocation = action_lib.ActionInvocation(
|
78
|
+
action=Foo(1)
|
79
|
+
)
|
80
|
+
self.assertEqual(action_invocation.id, '')
|
81
|
+
root = action_lib.ActionInvocation(action=action_lib.RootAction())
|
82
|
+
root.execution.append(action_invocation)
|
83
|
+
self.assertIs(action_invocation.parent_action, root)
|
84
|
+
self.assertEqual(action_invocation.id, '/a1')
|
85
|
+
|
86
|
+
|
87
|
+
class ExecutionTraceTest(unittest.TestCase):
|
26
88
|
|
27
89
|
def test_basics(self):
|
28
|
-
|
29
|
-
|
30
|
-
class Bar(action_lib.Action):
|
31
|
-
|
32
|
-
def call(self, session, *, lm, **kwargs):
|
33
|
-
test.assertIs(session.current_action.action, self)
|
34
|
-
session.info('Begin Bar')
|
35
|
-
session.query('bar', lm=lm)
|
36
|
-
session.add_metadata(note='bar')
|
37
|
-
return 2
|
38
|
-
|
39
|
-
class Foo(action_lib.Action):
|
40
|
-
x: int
|
41
|
-
|
42
|
-
def call(self, session, *, lm, **kwargs):
|
43
|
-
test.assertIs(session.current_action.action, self)
|
44
|
-
with session.track_phase('prepare'):
|
45
|
-
session.info('Begin Foo', x=1)
|
46
|
-
session.query('foo', lm=lm)
|
47
|
-
with session.track_queries():
|
48
|
-
self.make_additional_query(lm)
|
49
|
-
session.add_metadata(note='foo')
|
50
|
-
|
51
|
-
def _sub_task(i):
|
52
|
-
session.add_metadata(**{f'subtask_{i}': i})
|
53
|
-
return lf_structured.query(f'subtask_{i}', lm=lm)
|
54
|
-
|
55
|
-
for i, output, error in session.concurrent_map(
|
56
|
-
_sub_task, range(3), max_workers=2, silence_on_errors=None,
|
57
|
-
):
|
58
|
-
assert isinstance(i, int), i
|
59
|
-
assert isinstance(output, str), output
|
60
|
-
assert error is None, error
|
61
|
-
return self.x + Bar()(session, lm=lm)
|
62
|
-
|
63
|
-
def make_additional_query(self, lm):
|
64
|
-
lf_structured.query('additional query', lm=lm)
|
90
|
+
execution = action_lib.ExecutionTrace()
|
91
|
+
self.assertEqual(execution.id, '')
|
65
92
|
|
93
|
+
root = action_lib.ActionInvocation(action=action_lib.RootAction())
|
94
|
+
action_invocation = action_lib.ActionInvocation(
|
95
|
+
action=Foo(1)
|
96
|
+
)
|
97
|
+
root.execution.append(action_invocation)
|
98
|
+
self.assertEqual(action_invocation.execution.id, '/a1')
|
99
|
+
|
100
|
+
root.execution.reset()
|
101
|
+
self.assertEqual(len(root.execution), 0)
|
102
|
+
|
103
|
+
|
104
|
+
class SessionTest(unittest.TestCase):
|
105
|
+
|
106
|
+
def test_succeeded_trajectory(self):
|
66
107
|
lm = fake.StaticResponse('lm response')
|
67
108
|
foo = Foo(1)
|
68
|
-
self.
|
109
|
+
self.assertIsNone(foo.session)
|
110
|
+
self.assertIsNone(foo.result)
|
111
|
+
self.assertIsNone(foo.metadata)
|
112
|
+
|
113
|
+
session = action_lib.Session(id='agent@1')
|
114
|
+
self.assertEqual(session.id, 'agent@1')
|
115
|
+
|
116
|
+
# Render HTML view to trigger dynamic update during execution.
|
117
|
+
_ = session.to_html()
|
118
|
+
|
119
|
+
self.assertEqual(foo(session, lm=lm, verbose=True), 3)
|
120
|
+
|
121
|
+
self.assertIsNone(foo.session)
|
122
|
+
self.assertEqual(foo.result, 3)
|
123
|
+
self.assertEqual(
|
124
|
+
foo.metadata, dict(note='foo', subtask_0=0, subtask_1=1, subtask_2=2)
|
125
|
+
)
|
69
126
|
|
70
|
-
session = foo.session
|
71
|
-
self.assertIn('session@', session.id)
|
72
|
-
self.assertIsNotNone(session)
|
73
127
|
self.assertIsInstance(session.root.action, action_lib.RootAction)
|
74
128
|
self.assertIs(session.current_action, session.root)
|
75
129
|
|
@@ -78,6 +132,9 @@ class SessionTest(unittest.TestCase):
|
|
78
132
|
#
|
79
133
|
|
80
134
|
root = session.root
|
135
|
+
self.assertIsNone(root.parent_action)
|
136
|
+
self.assertEqual(root.id, 'agent@1:')
|
137
|
+
self.assertEqual(root.execution.id, 'agent@1:')
|
81
138
|
self.assertEqual(len(root.execution.items), 1)
|
82
139
|
self.assertIs(root.execution.items[0].action, foo)
|
83
140
|
|
@@ -104,33 +161,57 @@ class SessionTest(unittest.TestCase):
|
|
104
161
|
|
105
162
|
# Inspecting the top-level action (Foo)
|
106
163
|
foo_invocation = root.execution.items[0]
|
164
|
+
self.assertIs(foo_invocation.parent_action, root)
|
165
|
+
self.assertEqual(foo_invocation.id, 'agent@1:/a1')
|
166
|
+
self.assertEqual(foo_invocation.execution.id, 'agent@1:/a1')
|
107
167
|
self.assertEqual(len(foo_invocation.execution.items), 4)
|
108
168
|
|
109
169
|
# Prepare phase.
|
110
170
|
prepare_phase = foo_invocation.execution.items[0]
|
111
|
-
self.assertIsInstance(
|
112
|
-
|
113
|
-
)
|
171
|
+
self.assertIsInstance(prepare_phase, action_lib.ExecutionTrace)
|
172
|
+
self.assertEqual(prepare_phase.id, 'agent@1:/a1/prepare')
|
114
173
|
self.assertEqual(len(prepare_phase.items), 2)
|
115
174
|
self.assertTrue(prepare_phase.has_started)
|
116
175
|
self.assertTrue(prepare_phase.has_stopped)
|
117
176
|
self.assertEqual(prepare_phase.usage_summary.total.num_requests, 1)
|
177
|
+
self.assertIsInstance(prepare_phase.items[0], lf.logging.LogEntry)
|
178
|
+
self.assertIsInstance(prepare_phase.items[1], lf_structured.QueryInvocation)
|
179
|
+
self.assertEqual(prepare_phase.items[1].id, 'agent@1:/a1/prepare/q1')
|
118
180
|
|
119
181
|
# Tracked queries.
|
120
182
|
query_invocation = foo_invocation.execution.items[1]
|
121
183
|
self.assertIsInstance(query_invocation, lf_structured.QueryInvocation)
|
184
|
+
self.assertEqual(query_invocation.id, 'agent@1:/a1/q2')
|
122
185
|
self.assertIs(query_invocation.lm, lm)
|
186
|
+
self.assertEqual(
|
187
|
+
foo_invocation.execution.indexof(
|
188
|
+
query_invocation, lf_structured.QueryInvocation
|
189
|
+
),
|
190
|
+
1
|
191
|
+
)
|
192
|
+
self.assertEqual(
|
193
|
+
root.execution.indexof(
|
194
|
+
query_invocation, lf_structured.QueryInvocation
|
195
|
+
),
|
196
|
+
-1
|
197
|
+
)
|
123
198
|
|
124
199
|
# Tracked parallel executions.
|
125
200
|
parallel_executions = foo_invocation.execution.items[2]
|
201
|
+
self.assertEqual(parallel_executions.id, 'agent@1:/a1/p1')
|
126
202
|
self.assertIsInstance(parallel_executions, action_lib.ParallelExecutions)
|
127
203
|
self.assertEqual(len(parallel_executions), 3)
|
204
|
+
self.assertEqual(parallel_executions[0].id, 'agent@1:/a1/p1/b1')
|
205
|
+
self.assertEqual(parallel_executions[1].id, 'agent@1:/a1/p1/b2')
|
206
|
+
self.assertEqual(parallel_executions[2].id, 'agent@1:/a1/p1/b3')
|
128
207
|
self.assertEqual(len(parallel_executions[0].queries), 1)
|
129
208
|
self.assertEqual(len(parallel_executions[1].queries), 1)
|
130
209
|
self.assertEqual(len(parallel_executions[2].queries), 1)
|
131
210
|
|
132
211
|
# Invocation to Bar.
|
133
212
|
bar_invocation = foo_invocation.execution.items[3]
|
213
|
+
self.assertIs(bar_invocation.parent_action, foo_invocation)
|
214
|
+
self.assertEqual(bar_invocation.id, 'agent@1:/a1/a1')
|
134
215
|
self.assertIsInstance(bar_invocation, action_lib.ActionInvocation)
|
135
216
|
self.assertIsInstance(bar_invocation.action, Bar)
|
136
217
|
self.assertEqual(bar_invocation.result, 2)
|
@@ -144,6 +225,51 @@ class SessionTest(unittest.TestCase):
|
|
144
225
|
json_str = session.to_json_str(save_ref_value=True)
|
145
226
|
self.assertIsInstance(pg.from_json_str(json_str), action_lib.Session)
|
146
227
|
|
228
|
+
def test_failed_action(self):
|
229
|
+
lm = fake.StaticResponse('lm response')
|
230
|
+
foo = Foo(1, simulate_action_error=True)
|
231
|
+
with self.assertRaisesRegex(ValueError, 'Bar error'):
|
232
|
+
foo(lm=lm)
|
233
|
+
|
234
|
+
session = foo.session
|
235
|
+
self.assertIsNotNone(session)
|
236
|
+
self.assertIsInstance(session.root.action, action_lib.RootAction)
|
237
|
+
self.assertIs(session.current_action, session.root)
|
238
|
+
|
239
|
+
# Inspecting the root invocation.
|
240
|
+
root = session.root
|
241
|
+
self.assertRegex(root.id, 'agent@.*:')
|
242
|
+
self.assertTrue(root.has_error)
|
243
|
+
foo_invocation = root.execution.items[0]
|
244
|
+
self.assertIsInstance(foo_invocation, action_lib.ActionInvocation)
|
245
|
+
self.assertTrue(foo_invocation.has_error)
|
246
|
+
bar_invocation = foo_invocation.execution.items[3]
|
247
|
+
self.assertIsInstance(bar_invocation, action_lib.ActionInvocation)
|
248
|
+
self.assertTrue(bar_invocation.has_error)
|
249
|
+
|
250
|
+
# Save to HTML
|
251
|
+
self.assertIn('error', session.to_html().content)
|
252
|
+
|
253
|
+
def test_failed_query(self):
|
254
|
+
lm = fake.StaticResponse('lm response')
|
255
|
+
foo = Foo(1, simulate_query_error=True)
|
256
|
+
with self.assertRaisesRegex(lf_structured.MappingError, 'SyntaxError'):
|
257
|
+
foo(lm=lm)
|
258
|
+
|
259
|
+
session = foo.session
|
260
|
+
self.assertIsNotNone(session)
|
261
|
+
self.assertIsInstance(session.root.action, action_lib.RootAction)
|
262
|
+
self.assertIs(session.current_action, session.root)
|
263
|
+
|
264
|
+
# Inspecting the root invocation.
|
265
|
+
root = session.root
|
266
|
+
self.assertRegex(root.id, 'agent@.*:')
|
267
|
+
self.assertTrue(root.has_error)
|
268
|
+
foo_invocation = root.execution.items[0]
|
269
|
+
self.assertIsInstance(foo_invocation, action_lib.ActionInvocation)
|
270
|
+
self.assertTrue(foo_invocation.has_error)
|
271
|
+
self.assertEqual(len(foo_invocation.execution.items), 2)
|
272
|
+
|
147
273
|
def test_log(self):
|
148
274
|
session = action_lib.Session()
|
149
275
|
session.debug('hi', x=1, y=2)
|
@@ -153,8 +279,8 @@ class SessionTest(unittest.TestCase):
|
|
153
279
|
session.fatal('hi', x=1, y=2)
|
154
280
|
|
155
281
|
def test_as_message(self):
|
156
|
-
session = action_lib.Session(
|
157
|
-
self.
|
282
|
+
session = action_lib.Session()
|
283
|
+
self.assertIn('agent@', session.id)
|
158
284
|
self.assertIsInstance(session.as_message(), lf.AIMessage)
|
159
285
|
|
160
286
|
|
langfun/core/eval/base_test.py
CHANGED
@@ -195,7 +195,7 @@ class EvaluationTest(unittest.TestCase):
|
|
195
195
|
score=1.0,
|
196
196
|
logprobs=None,
|
197
197
|
is_cached=False,
|
198
|
-
usage=lf.LMSamplingUsage(
|
198
|
+
usage=lf.LMSamplingUsage(428, 24, 452),
|
199
199
|
tags=['lm-response', 'lm-output', 'transformed'],
|
200
200
|
),
|
201
201
|
)
|
@@ -234,12 +234,12 @@ class EvaluationTest(unittest.TestCase):
|
|
234
234
|
}
|
235
235
|
),
|
236
236
|
usage=dict(
|
237
|
-
total_prompt_tokens=
|
237
|
+
total_prompt_tokens=856,
|
238
238
|
total_completion_tokens=25,
|
239
239
|
num_usages=2,
|
240
|
-
average_prompt_tokens=
|
240
|
+
average_prompt_tokens=428,
|
241
241
|
average_completion_tokens=12,
|
242
|
-
average_total_tokens=
|
242
|
+
average_total_tokens=440,
|
243
243
|
),
|
244
244
|
),
|
245
245
|
)
|
@@ -167,6 +167,8 @@ class Evaluation(experiment_lib.Experiment):
|
|
167
167
|
example.input = self.example_input_by_id(example.id)
|
168
168
|
|
169
169
|
checkpointed = self._state.ckpt_example(example.id)
|
170
|
+
self._state.update(example, in_progress=True)
|
171
|
+
|
170
172
|
with pg.timeit('evaluate') as timeit, lf.track_usages() as usage_summary:
|
171
173
|
if checkpointed is None or checkpointed.has_error:
|
172
174
|
if checkpointed is None:
|
@@ -221,7 +223,7 @@ class Evaluation(experiment_lib.Experiment):
|
|
221
223
|
if example.newly_processed:
|
222
224
|
example.end_time = time.time()
|
223
225
|
|
224
|
-
self._state.update(example)
|
226
|
+
self._state.update(example, in_progress=False)
|
225
227
|
return example
|
226
228
|
|
227
229
|
def _process(
|
@@ -501,6 +503,21 @@ class Evaluation(experiment_lib.Experiment):
|
|
501
503
|
)
|
502
504
|
)
|
503
505
|
|
506
|
+
def _in_progress_tab() -> pg.views.html.controls.Tab | None:
|
507
|
+
"""Renders a tab for the in progress examples."""
|
508
|
+
if not self.state.in_progress_examples:
|
509
|
+
return None
|
510
|
+
return pg.views.html.controls.Tab(
|
511
|
+
label='In Progress',
|
512
|
+
content=pg.Html.element(
|
513
|
+
'div', [
|
514
|
+
self._in_progress_view(
|
515
|
+
list(self.state.in_progress_examples.values())
|
516
|
+
)
|
517
|
+
]
|
518
|
+
)
|
519
|
+
)
|
520
|
+
|
504
521
|
def _metric_tab(metric: metrics_lib.Metric) -> pg.views.html.controls.Tab:
|
505
522
|
"""Renders a tab for a metric (group)."""
|
506
523
|
return pg.views.html.controls.Tab(
|
@@ -571,10 +588,9 @@ class Evaluation(experiment_lib.Experiment):
|
|
571
588
|
pg.views.html.controls.TabControl(
|
572
589
|
[
|
573
590
|
_definition_tab(),
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
_logs_tab()
|
591
|
+
[_metric_tab(m) for m in self.metrics],
|
592
|
+
_in_progress_tab(),
|
593
|
+
_logs_tab(),
|
578
594
|
],
|
579
595
|
selected=1,
|
580
596
|
)
|
@@ -598,6 +614,27 @@ class Evaluation(experiment_lib.Experiment):
|
|
598
614
|
css_classes=['eval-details'],
|
599
615
|
)
|
600
616
|
|
617
|
+
def _in_progress_view(
|
618
|
+
self, in_progress_examples: list[example_lib.Example]
|
619
|
+
) -> pg.Html:
|
620
|
+
"""Renders a HTML view for the in-progress examples."""
|
621
|
+
current_time = time.time()
|
622
|
+
logs = [f'(Total {len(in_progress_examples)} examples in progress)']
|
623
|
+
for example in in_progress_examples:
|
624
|
+
if example.newly_processed:
|
625
|
+
logs.append(
|
626
|
+
f'Example {example.id}: In progress for '
|
627
|
+
f'{current_time - example.start_time:.2f} seconds.'
|
628
|
+
)
|
629
|
+
else:
|
630
|
+
logs.append(f'Example {example.id}: Recomputing metrics...')
|
631
|
+
return pg.Html.element(
|
632
|
+
'textarea',
|
633
|
+
[pg.Html.escape('\n'.join(logs))],
|
634
|
+
readonly=True,
|
635
|
+
css_classes=['logs-textarea'],
|
636
|
+
)
|
637
|
+
|
601
638
|
def _html_tree_view_config(self) -> dict[str, Any]:
|
602
639
|
return dict(
|
603
640
|
css_classes=['eval-card'] if self.is_leaf else None
|
@@ -716,14 +753,27 @@ class EvaluationState:
|
|
716
753
|
'Whether the example is evaluated.'
|
717
754
|
] = False
|
718
755
|
|
756
|
+
in_progress: Annotated[
|
757
|
+
bool,
|
758
|
+
(
|
759
|
+
'Whether the example is in progress. '
|
760
|
+
)
|
761
|
+
] = False
|
762
|
+
|
719
763
|
newly_processed: Annotated[
|
720
764
|
bool,
|
721
|
-
|
765
|
+
(
|
766
|
+
'Whether the example is newly processed. '
|
767
|
+
'Applicable only when evaluated is True.'
|
768
|
+
)
|
722
769
|
] = False
|
723
770
|
|
724
771
|
has_error: Annotated[
|
725
772
|
bool,
|
726
|
-
|
773
|
+
(
|
774
|
+
'Whether the example has error. '
|
775
|
+
'Applicable only when evaluated is True.'
|
776
|
+
)
|
727
777
|
] = False
|
728
778
|
|
729
779
|
def __init__(self):
|
@@ -732,6 +782,7 @@ class EvaluationState:
|
|
732
782
|
self._evaluation_status: dict[
|
733
783
|
int, EvaluationState.ExampleStatus
|
734
784
|
] = {}
|
785
|
+
self._in_progress_examples: dict[int, example_lib.Example] = {}
|
735
786
|
|
736
787
|
def load(
|
737
788
|
self,
|
@@ -758,6 +809,11 @@ class EvaluationState:
|
|
758
809
|
"""Returns the evaluation status of the examples."""
|
759
810
|
return self._evaluation_status
|
760
811
|
|
812
|
+
@property
|
813
|
+
def in_progress_examples(self) -> dict[int, example_lib.Example]:
|
814
|
+
"""Returns the in-progress examples."""
|
815
|
+
return self._in_progress_examples
|
816
|
+
|
761
817
|
@property
|
762
818
|
def ckpt_examples(self) -> dict[int, example_lib.Example]:
|
763
819
|
"""Returns the unevaluated examples from checkpoints."""
|
@@ -773,17 +829,27 @@ class EvaluationState:
|
|
773
829
|
example_id, EvaluationState.ExampleStatus()
|
774
830
|
)
|
775
831
|
|
776
|
-
def update(self, example: example_lib.Example) -> None:
|
832
|
+
def update(self, example: example_lib.Example, in_progress: bool) -> None:
|
777
833
|
"""Updates the state with the given example."""
|
778
|
-
self._update_status(example)
|
779
|
-
|
780
|
-
|
834
|
+
self._update_status(example, in_progress)
|
835
|
+
|
836
|
+
if in_progress:
|
837
|
+
self._in_progress_examples[example.id] = example
|
838
|
+
else:
|
839
|
+
self._in_progress_examples.pop(example.id, None)
|
840
|
+
# Processed examples will be removed once it's done.
|
841
|
+
self._ckpt_examples.pop(example.id, None)
|
781
842
|
|
782
|
-
def _update_status(
|
843
|
+
def _update_status(
|
844
|
+
self,
|
845
|
+
example: example_lib.Example,
|
846
|
+
in_progress: bool
|
847
|
+
) -> None:
|
783
848
|
"""Updates the evaluation status of the example."""
|
784
849
|
self._evaluation_status[example.id] = (
|
785
850
|
EvaluationState.ExampleStatus(
|
786
851
|
evaluated=example.output != pg.MISSING_VALUE,
|
852
|
+
in_progress=in_progress,
|
787
853
|
newly_processed=example.newly_processed,
|
788
854
|
has_error=example.has_error,
|
789
855
|
)
|
@@ -79,8 +79,10 @@ class EvaluationTest(unittest.TestCase):
|
|
79
79
|
exp = eval_test_helper.TestEvaluation()
|
80
80
|
example = exp.evaluate(Example(id=3))
|
81
81
|
self.assertTrue(exp.state.get_status(3).evaluated)
|
82
|
+
self.assertFalse(exp.state.get_status(3).in_progress)
|
82
83
|
self.assertTrue(exp.state.get_status(3).newly_processed)
|
83
84
|
self.assertFalse(exp.state.get_status(3).has_error)
|
85
|
+
self.assertEqual(exp.state.in_progress_examples, {})
|
84
86
|
self.assertTrue(example.newly_processed)
|
85
87
|
self.assertEqual(example.input, pg.Dict(x=2, y=4, groundtruth=6))
|
86
88
|
self.assertEqual(example.output, 6)
|
@@ -55,8 +55,11 @@ from langfun.core.structured.parsing import call
|
|
55
55
|
|
56
56
|
from langfun.core.structured.querying import track_queries
|
57
57
|
from langfun.core.structured.querying import QueryInvocation
|
58
|
+
|
59
|
+
from langfun.core.structured.querying import LfQuery
|
58
60
|
from langfun.core.structured.querying import query
|
59
61
|
from langfun.core.structured.querying import query_and_reduce
|
62
|
+
from langfun.core.structured.querying import query_protocol
|
60
63
|
|
61
64
|
from langfun.core.structured.querying import query_prompt
|
62
65
|
from langfun.core.structured.querying import query_output
|
@@ -340,8 +340,11 @@ class Mapping(lf.LangFunc):
|
|
340
340
|
schema_title: Annotated[str, 'The section title for schema.'] = 'SCHEMA'
|
341
341
|
|
342
342
|
protocol: Annotated[
|
343
|
-
|
344
|
-
|
343
|
+
str,
|
344
|
+
(
|
345
|
+
'A string representing the protocol for formatting the prompt. '
|
346
|
+
'Built-in Langfun protocols are: `python` and `json`.'
|
347
|
+
),
|
345
348
|
] = 'python'
|
346
349
|
|
347
350
|
#
|