langfun 0.1.2.dev202511160804__py3-none-any.whl → 0.1.2.dev202511270805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/core/__init__.py +1 -0
- langfun/core/agentic/__init__.py +4 -1
- langfun/core/agentic/action.py +340 -17
- langfun/core/agentic/action_test.py +124 -21
- langfun/core/eval/base_test.py +5 -5
- langfun/core/eval/v2/checkpointing.py +25 -1
- langfun/core/eval/v2/checkpointing_test.py +8 -1
- langfun/core/eval/v2/eval_test_helper.py +7 -2
- langfun/core/eval/v2/evaluation.py +4 -1
- langfun/core/eval/v2/example.py +5 -1
- langfun/core/eval/v2/example_test.py +13 -5
- langfun/core/eval/v2/experiment.py +23 -0
- langfun/core/eval/v2/experiment_test.py +19 -0
- langfun/core/eval/v2/progress_tracking.py +12 -3
- langfun/core/eval/v2/progress_tracking_test.py +3 -1
- langfun/core/eval/v2/reporting_test.py +4 -0
- langfun/core/eval/v2/runners/__init__.py +4 -0
- langfun/core/eval/v2/runners/base.py +40 -21
- langfun/core/eval/v2/runners/beam.py +341 -0
- langfun/core/eval/v2/runners/beam_test.py +131 -0
- langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
- langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
- langfun/core/eval/v2/runners/debug_test.py +1 -4
- langfun/core/eval/v2/runners/parallel_test.py +1 -4
- langfun/core/eval/v2/runners/sequential_test.py +1 -4
- langfun/core/langfunc_test.py +3 -3
- langfun/core/language_model.py +38 -5
- langfun/core/language_model_test.py +45 -0
- langfun/core/llms/__init__.py +2 -0
- langfun/core/llms/gemini.py +41 -8
- langfun/core/llms/gemini_test.py +84 -0
- langfun/core/llms/google_genai.py +5 -0
- langfun/core/llms/vertexai.py +7 -0
- langfun/core/modalities/mime.py +2 -0
- langfun/core/modalities/mime_test.py +11 -0
- langfun/core/structured/schema/__init__.py +1 -0
- {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/RECORD +41 -37
- {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202511160804.dist-info → langfun-0.1.2.dev202511270805.dist-info}/top_level.txt +0 -0
|
@@ -52,6 +52,7 @@ class Foo(action_lib.Action):
|
|
|
52
52
|
with session.track_phase('prepare'):
|
|
53
53
|
session.info('Begin Foo', x=1)
|
|
54
54
|
time.sleep(self.simulate_execution_time[0])
|
|
55
|
+
Bar()(session, lm=lm)
|
|
55
56
|
session.query(
|
|
56
57
|
'foo',
|
|
57
58
|
schema=int if self.simulate_query_error else None,
|
|
@@ -65,6 +66,7 @@ class Foo(action_lib.Action):
|
|
|
65
66
|
def _sub_task(i):
|
|
66
67
|
session.add_metadata(**{f'subtask_{i}': i})
|
|
67
68
|
time.sleep(self.simulate_execution_time[2])
|
|
69
|
+
Bar()(session, lm=lm)
|
|
68
70
|
return lf_structured.query(f'subtask_{i}', lm=lm)
|
|
69
71
|
|
|
70
72
|
self._state = []
|
|
@@ -88,6 +90,50 @@ class Foo(action_lib.Action):
|
|
|
88
90
|
lf_structured.query('additional query', lm=lm)
|
|
89
91
|
|
|
90
92
|
|
|
93
|
+
class ExecutionUnitPositionTest(unittest.TestCase):
|
|
94
|
+
|
|
95
|
+
def test_basics(self):
|
|
96
|
+
pos1 = action_lib.ExecutionUnit.Position(None, 0)
|
|
97
|
+
self.assertEqual(repr(pos1), 'Position(0)')
|
|
98
|
+
self.assertEqual(str(pos1), '')
|
|
99
|
+
self.assertIsNone(pos1.parent)
|
|
100
|
+
self.assertEqual(pos1.index, 0)
|
|
101
|
+
self.assertEqual(pos1.indices(), (0,))
|
|
102
|
+
self.assertEqual(pos1, (0,))
|
|
103
|
+
self.assertEqual(pos1, '')
|
|
104
|
+
self.assertEqual(pos1, action_lib.ExecutionUnit.Position(None, 0))
|
|
105
|
+
self.assertNotEqual(pos1, 1)
|
|
106
|
+
self.assertNotEqual(pos1, (1,))
|
|
107
|
+
self.assertNotEqual(pos1, action_lib.ExecutionUnit.Position(None, 1))
|
|
108
|
+
|
|
109
|
+
pos2 = action_lib.ExecutionUnit.Position(pos1, 0)
|
|
110
|
+
self.assertEqual(repr(pos2), 'Position(0, 0)')
|
|
111
|
+
self.assertEqual(str(pos2), '1')
|
|
112
|
+
self.assertEqual(pos2, '1')
|
|
113
|
+
self.assertEqual(pos2.parent, pos1)
|
|
114
|
+
self.assertEqual(pos2.index, 0)
|
|
115
|
+
self.assertEqual(pos2.indices(), (0, 0))
|
|
116
|
+
self.assertNotEqual(pos1, pos2)
|
|
117
|
+
self.assertLess(pos1, pos2)
|
|
118
|
+
self.assertGreater(pos2, pos1)
|
|
119
|
+
self.assertEqual(
|
|
120
|
+
hash(pos2),
|
|
121
|
+
hash(
|
|
122
|
+
action_lib.ExecutionUnit.Position(
|
|
123
|
+
action_lib.ExecutionUnit.Position(None, 0), 0
|
|
124
|
+
)
|
|
125
|
+
)
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
pos3 = action_lib.ExecutionUnit.Position(pos2, 0)
|
|
129
|
+
self.assertEqual(str(pos3), '1.1')
|
|
130
|
+
self.assertEqual(pos3, '1.1')
|
|
131
|
+
self.assertEqual(pos3.parent, pos2)
|
|
132
|
+
self.assertEqual(pos3.index, 0)
|
|
133
|
+
self.assertEqual(pos3.indices(), (0, 0, 0))
|
|
134
|
+
self.assertEqual(pos3.to_str(separator='>'), '1>1')
|
|
135
|
+
|
|
136
|
+
|
|
91
137
|
class ActionInvocationTest(unittest.TestCase):
|
|
92
138
|
|
|
93
139
|
def test_basics(self):
|
|
@@ -108,9 +154,7 @@ class ExecutionTraceTest(unittest.TestCase):
|
|
|
108
154
|
self.assertEqual(execution.id, '')
|
|
109
155
|
|
|
110
156
|
root = action_lib.ActionInvocation(action=action_lib.RootAction())
|
|
111
|
-
action_invocation = action_lib.ActionInvocation(
|
|
112
|
-
action=Foo(1)
|
|
113
|
-
)
|
|
157
|
+
action_invocation = action_lib.ActionInvocation(action=Foo(1))
|
|
114
158
|
root.execution.append(action_invocation)
|
|
115
159
|
self.assertEqual(action_invocation.execution.id, '/a1')
|
|
116
160
|
|
|
@@ -153,6 +197,7 @@ class SessionTest(unittest.TestCase):
|
|
|
153
197
|
|
|
154
198
|
self.assertIsInstance(session.root.action, action_lib.RootAction)
|
|
155
199
|
self.assertIs(session.current_action, session.root)
|
|
200
|
+
self.assertIs(session.metadata, session.root.metadata)
|
|
156
201
|
|
|
157
202
|
#
|
|
158
203
|
# Inspecting the root invocation.
|
|
@@ -175,20 +220,25 @@ class SessionTest(unittest.TestCase):
|
|
|
175
220
|
)
|
|
176
221
|
|
|
177
222
|
# The root space should have one action (foo), no queries, and no logs.
|
|
223
|
+
self.assertEqual(len(root.execution_units), 1)
|
|
178
224
|
self.assertEqual(len(root.actions), 1)
|
|
179
225
|
self.assertEqual(len(root.queries), 0)
|
|
180
226
|
self.assertEqual(len(root.logs), 0)
|
|
181
|
-
#
|
|
182
|
-
self.assertEqual(len(session.all_queries),
|
|
183
|
-
self.assertEqual(len(root.all_queries),
|
|
184
|
-
#
|
|
185
|
-
self.assertEqual(len(session.all_actions),
|
|
186
|
-
self.assertEqual(
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
227
|
+
# 2 query from Bar, 2 from Foo and 2 * 3 from parallel executions.
|
|
228
|
+
self.assertEqual(len(session.all_queries), 10)
|
|
229
|
+
self.assertEqual(len(root.all_queries), 10)
|
|
230
|
+
# 6 actions: Foo and 2 Bar, and 3 Bar from parallel executions.
|
|
231
|
+
self.assertEqual(len(session.all_actions), 6)
|
|
232
|
+
self.assertEqual(
|
|
233
|
+
[str(a.position) for a in session.all_actions],
|
|
234
|
+
['1', '1.1', '1.2.1.1', '1.2.2.1', '1.2.3.1', '1.3']
|
|
235
|
+
)
|
|
236
|
+
self.assertEqual(len(root.all_actions), 6)
|
|
237
|
+
# 1 log from Bar and 1 from Foo and 3 from Bar in parallel executions.
|
|
238
|
+
self.assertEqual(len(session.all_logs), 6)
|
|
239
|
+
self.assertEqual(len(root.all_logs), 6)
|
|
190
240
|
self.assertIs(session.usage_summary, root.usage_summary)
|
|
191
|
-
self.assertEqual(root.usage_summary.total.num_requests,
|
|
241
|
+
self.assertEqual(root.usage_summary.total.num_requests, 10)
|
|
192
242
|
|
|
193
243
|
# Inspecting the top-level action (Foo)
|
|
194
244
|
foo_invocation = root.execution[0]
|
|
@@ -200,15 +250,19 @@ class SessionTest(unittest.TestCase):
|
|
|
200
250
|
|
|
201
251
|
# Prepare phase.
|
|
202
252
|
prepare_phase = foo_invocation.execution[0]
|
|
253
|
+
self.assertIsNone(prepare_phase.position)
|
|
203
254
|
self.assertIsInstance(prepare_phase, action_lib.ExecutionTrace)
|
|
204
255
|
self.assertEqual(prepare_phase.id, 'agent@1:/a1/prepare')
|
|
205
|
-
self.assertEqual(len(prepare_phase.items),
|
|
256
|
+
self.assertEqual(len(prepare_phase.items), 3)
|
|
206
257
|
self.assertTrue(prepare_phase.has_started)
|
|
207
258
|
self.assertTrue(prepare_phase.has_stopped)
|
|
208
|
-
self.assertEqual(prepare_phase.usage_summary.total.num_requests,
|
|
259
|
+
self.assertEqual(prepare_phase.usage_summary.total.num_requests, 2)
|
|
209
260
|
self.assertIsInstance(prepare_phase.items[0], lf.logging.LogEntry)
|
|
210
|
-
self.assertIsInstance(prepare_phase.items[1],
|
|
211
|
-
self.
|
|
261
|
+
self.assertIsInstance(prepare_phase.items[1], action_lib.ActionInvocation)
|
|
262
|
+
self.assertIs(prepare_phase.items[1].parent_execution_unit, foo_invocation)
|
|
263
|
+
self.assertEqual(prepare_phase.items[1].id, 'agent@1:/a1/prepare/a1')
|
|
264
|
+
self.assertIsInstance(prepare_phase.items[2], lf_structured.QueryInvocation)
|
|
265
|
+
self.assertEqual(prepare_phase.items[2].id, 'agent@1:/a1/prepare/q1')
|
|
212
266
|
|
|
213
267
|
# Tracked queries.
|
|
214
268
|
query_invocation = foo_invocation.execution[1]
|
|
@@ -230,20 +284,44 @@ class SessionTest(unittest.TestCase):
|
|
|
230
284
|
|
|
231
285
|
# Tracked parallel executions.
|
|
232
286
|
parallel_executions = foo_invocation.execution[2]
|
|
287
|
+
# root (0) > foo (0) > parallel executions (1)
|
|
288
|
+
self.assertEqual(parallel_executions.position, (0, 0, 1))
|
|
233
289
|
self.assertEqual(parallel_executions.id, 'agent@1:/a1/p1')
|
|
234
290
|
self.assertIsInstance(parallel_executions, action_lib.ParallelExecutions)
|
|
291
|
+
self.assertIs(
|
|
292
|
+
parallel_executions.all_actions[0].parent_execution_unit,
|
|
293
|
+
parallel_executions
|
|
294
|
+
)
|
|
295
|
+
self.assertIs(
|
|
296
|
+
parallel_executions.all_actions[0].parent_action,
|
|
297
|
+
foo_invocation
|
|
298
|
+
)
|
|
235
299
|
self.assertEqual(len(parallel_executions), 3)
|
|
236
300
|
self.assertEqual(parallel_executions[0].id, 'agent@1:/a1/p1/b1')
|
|
237
301
|
self.assertEqual(parallel_executions[1].id, 'agent@1:/a1/p1/b2')
|
|
238
302
|
self.assertEqual(parallel_executions[2].id, 'agent@1:/a1/p1/b3')
|
|
303
|
+
self.assertEqual(len(parallel_executions[0].execution_units), 1)
|
|
304
|
+
self.assertEqual(len(parallel_executions[1].execution_units), 1)
|
|
305
|
+
self.assertEqual(len(parallel_executions[2].execution_units), 1)
|
|
239
306
|
self.assertEqual(len(parallel_executions[0].queries), 1)
|
|
307
|
+
self.assertEqual(len(parallel_executions[0].all_queries), 2)
|
|
240
308
|
self.assertEqual(len(parallel_executions[1].queries), 1)
|
|
309
|
+
self.assertEqual(len(parallel_executions[1].all_queries), 2)
|
|
241
310
|
self.assertEqual(len(parallel_executions[2].queries), 1)
|
|
311
|
+
self.assertEqual(len(parallel_executions[2].all_queries), 2)
|
|
312
|
+
self.assertEqual(len(parallel_executions.execution_units), 0)
|
|
313
|
+
self.assertEqual(len(parallel_executions.actions), 0)
|
|
314
|
+
self.assertEqual(len(parallel_executions.queries), 0)
|
|
315
|
+
self.assertEqual(len(parallel_executions.logs), 0)
|
|
316
|
+
self.assertEqual(len(parallel_executions.all_actions), 3)
|
|
317
|
+
self.assertEqual(len(parallel_executions.all_queries), 6)
|
|
318
|
+
self.assertEqual(len(parallel_executions.all_logs), 3)
|
|
242
319
|
|
|
243
320
|
# Invocation to Bar.
|
|
244
321
|
bar_invocation = foo_invocation.execution[3]
|
|
245
322
|
self.assertIs(bar_invocation.parent_action, foo_invocation)
|
|
246
|
-
self.
|
|
323
|
+
self.assertIs(bar_invocation.parent_execution_unit, foo_invocation)
|
|
324
|
+
self.assertEqual(bar_invocation.id, 'agent@1:/a1/a5')
|
|
247
325
|
self.assertIsInstance(bar_invocation, action_lib.ActionInvocation)
|
|
248
326
|
self.assertIsInstance(bar_invocation.action, Bar)
|
|
249
327
|
self.assertEqual(bar_invocation.result, 2)
|
|
@@ -497,26 +575,51 @@ class SessionTest(unittest.TestCase):
|
|
|
497
575
|
super()._on_bound()
|
|
498
576
|
self.progresses = []
|
|
499
577
|
|
|
578
|
+
def on_session_start(self, session):
|
|
579
|
+
session.add_metadata(progresses=pg.Ref(self.progresses))
|
|
580
|
+
|
|
500
581
|
def on_action_progress(self, session, action, title, **kwargs):
|
|
501
582
|
self.progresses.append((action.id, title))
|
|
502
583
|
|
|
503
584
|
handler = MyActionHandler()
|
|
585
|
+
self.assertIs(handler.get(MyActionHandler), handler)
|
|
586
|
+
self.assertIsNone(handler.get(action_lib.SessionLogging))
|
|
587
|
+
|
|
588
|
+
handler_chain = action_lib.SessionEventHandlerChain(
|
|
589
|
+
handlers=[handler, action_lib.SessionLogging()]
|
|
590
|
+
)
|
|
591
|
+
self.assertIs(handler_chain.get(MyActionHandler), handler)
|
|
592
|
+
self.assertIs(
|
|
593
|
+
handler_chain.get(action_lib.SessionLogging),
|
|
594
|
+
handler_chain.handlers[1]
|
|
595
|
+
)
|
|
596
|
+
|
|
504
597
|
session = action_lib.Session(
|
|
505
598
|
id='agent@1',
|
|
506
|
-
event_handler=
|
|
507
|
-
handlers=[handler, action_lib.SessionLogging()]
|
|
508
|
-
)
|
|
599
|
+
event_handler=handler_chain
|
|
509
600
|
)
|
|
510
601
|
bar = Bar()
|
|
511
602
|
with session:
|
|
512
603
|
bar(session, lm=fake.StaticResponse('lm response'))
|
|
513
604
|
session.update_progress('Trajectory completed')
|
|
514
605
|
|
|
606
|
+
self.assertIs(session.metadata['progresses'], handler.progresses)
|
|
515
607
|
self.assertEqual(handler.progresses, [
|
|
516
608
|
('agent@1:/a1', 'Query completed'),
|
|
517
609
|
('agent@1:', 'Trajectory completed'),
|
|
518
610
|
])
|
|
519
611
|
|
|
612
|
+
def test_clone(self):
|
|
613
|
+
event_handler = action_lib.SessionLogging()
|
|
614
|
+
session = action_lib.Session(event_handler=event_handler)
|
|
615
|
+
other = session.clone()
|
|
616
|
+
self.assertIsNot(session, other)
|
|
617
|
+
self.assertIs(other.event_handler, event_handler)
|
|
618
|
+
|
|
619
|
+
other = session.clone(deep=True)
|
|
620
|
+
self.assertIsNot(session, other)
|
|
621
|
+
self.assertIsNot(other.event_handler, session.event_handler)
|
|
622
|
+
|
|
520
623
|
def test_log(self):
|
|
521
624
|
session = action_lib.Session()
|
|
522
625
|
session.debug('hi', x=1, y=2)
|
langfun/core/eval/base_test.py
CHANGED
|
@@ -101,7 +101,7 @@ class EvaluationTest(unittest.TestCase):
|
|
|
101
101
|
self.assertEqual(s.dir, os.path.join(s.root_dir, s.id))
|
|
102
102
|
self.assertEqual(s.hash, s.clone().hash)
|
|
103
103
|
# Test persistent hash.
|
|
104
|
-
self.assertEqual(s.hash, '
|
|
104
|
+
self.assertEqual(s.hash, '4dfe486a')
|
|
105
105
|
self.assertEqual(
|
|
106
106
|
s.hash, s.clone(override={'max_workers': 2, 'lm.timeout': 20}).hash
|
|
107
107
|
)
|
|
@@ -211,7 +211,7 @@ class EvaluationTest(unittest.TestCase):
|
|
|
211
211
|
s.result,
|
|
212
212
|
dict(
|
|
213
213
|
experiment_setup=dict(
|
|
214
|
-
id='Evaluation@
|
|
214
|
+
id='Evaluation@e028b6e6',
|
|
215
215
|
dir=s.dir,
|
|
216
216
|
model='StaticSequence',
|
|
217
217
|
prompt_template='{{example.question}}',
|
|
@@ -269,7 +269,7 @@ class EvaluationTest(unittest.TestCase):
|
|
|
269
269
|
s.root_dir, base.Evaluation.SUMMARY_HTML.replace('.html', '.json')
|
|
270
270
|
)
|
|
271
271
|
self.assertTrue(os.path.exists(summary_json))
|
|
272
|
-
summary = pg.load(summary_json,
|
|
272
|
+
summary = pg.load(summary_json, convert_unknown=True)
|
|
273
273
|
self.assertIn('Evaluation', summary)
|
|
274
274
|
self.assertEqual(len(summary['Evaluation']), 1)
|
|
275
275
|
self.assertIsNotNone(summary['Evaluation'][0].experiment)
|
|
@@ -376,7 +376,7 @@ class EvaluationTest(unittest.TestCase):
|
|
|
376
376
|
s.children[0].dir, os.path.join(s.root_dir, s.children[0].id)
|
|
377
377
|
)
|
|
378
378
|
# Test persistent hash.
|
|
379
|
-
self.assertEqual(s.hash, '
|
|
379
|
+
self.assertEqual(s.hash, 'fa8f5419')
|
|
380
380
|
|
|
381
381
|
summary = s.run(verbose=True)
|
|
382
382
|
self.assertEqual(len(summary.evaluations), 2)
|
|
@@ -526,7 +526,7 @@ class SuiteTest(unittest.TestCase):
|
|
|
526
526
|
lm=lm
|
|
527
527
|
)
|
|
528
528
|
# Test for persistent hash.
|
|
529
|
-
self.assertEqual(s.hash, '
|
|
529
|
+
self.assertEqual(s.hash, 'ec3901b8')
|
|
530
530
|
s.run()
|
|
531
531
|
expected = {
|
|
532
532
|
s.children[0].id: dict(
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
"""Checkpointing evaluation runs."""
|
|
15
15
|
import abc
|
|
16
|
+
import datetime
|
|
16
17
|
import re
|
|
17
18
|
import threading
|
|
18
19
|
import traceback
|
|
@@ -44,7 +45,12 @@ class Checkpointer(experiment_lib.Plugin):
|
|
|
44
45
|
checkpoint_filename: Annotated[
|
|
45
46
|
str,
|
|
46
47
|
'Checkpoint file pattern.'
|
|
47
|
-
] = 'checkpoint.
|
|
48
|
+
] = 'checkpoint.jsonl'
|
|
49
|
+
|
|
50
|
+
enable_inprogress_file: Annotated[
|
|
51
|
+
bool,
|
|
52
|
+
'If True, write file "<example_id>.inprogress" when example gets started.'
|
|
53
|
+
] = True
|
|
48
54
|
|
|
49
55
|
max_ckpt_loading_threads: Annotated[
|
|
50
56
|
int,
|
|
@@ -90,6 +96,24 @@ class Checkpointer(experiment_lib.Plugin):
|
|
|
90
96
|
f'scratch. Example IDs: {example_ids_to_evaluate}.'
|
|
91
97
|
)
|
|
92
98
|
|
|
99
|
+
def on_example_start(
|
|
100
|
+
self,
|
|
101
|
+
runner: Runner,
|
|
102
|
+
experiment: Experiment,
|
|
103
|
+
example: Example,
|
|
104
|
+
) -> None:
|
|
105
|
+
"""Saves the example to the checkpoint file."""
|
|
106
|
+
if self.enable_inprogress_file:
|
|
107
|
+
def _save_inprogress_file(example: Example):
|
|
108
|
+
inprogress_file = runner.current_run.output_path_for(
|
|
109
|
+
experiment, f'{example.id}.inprogress'
|
|
110
|
+
)
|
|
111
|
+
pg.io.writefile(
|
|
112
|
+
inprogress_file,
|
|
113
|
+
datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
|
114
|
+
)
|
|
115
|
+
runner.background_run(_save_inprogress_file, example)
|
|
116
|
+
|
|
93
117
|
def on_example_complete(
|
|
94
118
|
self,
|
|
95
119
|
runner: Runner,
|
|
@@ -90,7 +90,10 @@ class PerExampleCheckpointerTest(CheckpointerTest):
|
|
|
90
90
|
root_dir = os.path.join(tempfile.mkdtemp(), 'per_example_checkpointer')
|
|
91
91
|
experiment = eval_test_helper.test_experiment()
|
|
92
92
|
checkpoint_filename = 'checkpoint.jsonl'
|
|
93
|
-
checkpointer = checkpointing.PerExampleCheckpointer(
|
|
93
|
+
checkpointer = checkpointing.PerExampleCheckpointer(
|
|
94
|
+
checkpoint_filename,
|
|
95
|
+
enable_inprogress_file=True
|
|
96
|
+
)
|
|
94
97
|
collector = ExampleCollector()
|
|
95
98
|
run = experiment.run(
|
|
96
99
|
root_dir, 'new', runner='sequential', plugins=[checkpointer, collector]
|
|
@@ -102,6 +105,10 @@ class PerExampleCheckpointerTest(CheckpointerTest):
|
|
|
102
105
|
example = collector.examples[i + 1]
|
|
103
106
|
ckpt = run.output_path_for(leaf, f'checkpoint_{example.id}.jsonl')
|
|
104
107
|
self.assertTrue(pg.io.path_exists(ckpt))
|
|
108
|
+
inprogress_file = run.output_path_for(
|
|
109
|
+
leaf, f'{example.id}.inprogress'
|
|
110
|
+
)
|
|
111
|
+
self.assertTrue(pg.io.path_exists(inprogress_file))
|
|
105
112
|
with pg.io.open_sequence(ckpt) as f:
|
|
106
113
|
examples_from_ckpt = list(iter(f))
|
|
107
114
|
# `eval_test_helper.test_experiment` has two TestEvaluation with
|
|
@@ -127,11 +127,16 @@ class TestEvaluationWithIndexHtmlGenerationError(TestEvaluation):
|
|
|
127
127
|
raise ValueError('Cannot render HTML.')
|
|
128
128
|
|
|
129
129
|
|
|
130
|
+
def test_evaluation(offset: int | pg.hyper.OneOf = 0):
|
|
131
|
+
"""Returns a test evaluation."""
|
|
132
|
+
return TestEvaluation(lm=TestLLM(offset=offset))
|
|
133
|
+
|
|
134
|
+
|
|
130
135
|
def test_experiment():
|
|
131
136
|
"""Returns a test experiment."""
|
|
132
137
|
return Suite([
|
|
133
|
-
|
|
134
|
-
|
|
138
|
+
test_evaluation(),
|
|
139
|
+
test_evaluation(pg.oneof(range(5))),
|
|
135
140
|
])
|
|
136
141
|
|
|
137
142
|
|
|
@@ -880,8 +880,9 @@ class EvaluationState:
|
|
|
880
880
|
load_example_metadata: bool | Callable[
|
|
881
881
|
[example_lib.Example], bool] = True,
|
|
882
882
|
filter: Callable[[example_lib.Example], bool] | None = None, # pylint: disable=redefined-builtin
|
|
883
|
-
) ->
|
|
883
|
+
) -> list[example_lib.Example]:
|
|
884
884
|
"""Loads the state from the example sequence file."""
|
|
885
|
+
examples = []
|
|
885
886
|
for example in example_lib.Example.iter_ckpts(
|
|
886
887
|
state_file,
|
|
887
888
|
example_input_by_id=example_input_by_id,
|
|
@@ -891,6 +892,8 @@ class EvaluationState:
|
|
|
891
892
|
continue
|
|
892
893
|
example.newly_processed = False
|
|
893
894
|
self._ckpt_examples[example.id] = example
|
|
895
|
+
examples.append(example)
|
|
896
|
+
return examples
|
|
894
897
|
|
|
895
898
|
@property
|
|
896
899
|
def evaluation_status(self) -> dict[int, ExampleStatus]:
|
langfun/core/eval/v2/example.py
CHANGED
|
@@ -155,6 +155,8 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
|
|
|
155
155
|
ckpt_file: str | list[str],
|
|
156
156
|
example_input_by_id: Callable[[int], Any] | None = None,
|
|
157
157
|
load_example_metadata: bool = True,
|
|
158
|
+
convert_unknown: bool = True,
|
|
159
|
+
**kwargs
|
|
158
160
|
) -> Iterator['Example']:
|
|
159
161
|
"""Iterates Examples from the checkpoint files."""
|
|
160
162
|
ckpt_files = [ckpt_file] if isinstance(ckpt_file, str) else ckpt_file
|
|
@@ -164,7 +166,9 @@ class Example(pg.JSONConvertible, pg.views.HtmlTreeView.Extension):
|
|
|
164
166
|
example = pg.from_json_str(
|
|
165
167
|
record,
|
|
166
168
|
example_input_by_id=example_input_by_id,
|
|
167
|
-
load_example_metadata=load_example_metadata
|
|
169
|
+
load_example_metadata=load_example_metadata,
|
|
170
|
+
convert_unknown=convert_unknown,
|
|
171
|
+
**kwargs
|
|
168
172
|
)
|
|
169
173
|
assert isinstance(example, cls), example
|
|
170
174
|
yield example
|
|
@@ -94,15 +94,23 @@ class ExampleTest(unittest.TestCase):
|
|
|
94
94
|
pg.JSONConvertible._TYPE_REGISTRY._type_to_cls_map.pop(
|
|
95
95
|
inputs[0].b.__type_name__
|
|
96
96
|
)
|
|
97
|
-
v = pg.from_json_str(
|
|
98
|
-
|
|
99
|
-
|
|
97
|
+
v = pg.from_json_str(
|
|
98
|
+
json_str,
|
|
99
|
+
convert_unknown=True,
|
|
100
|
+
load_example_metadata=True
|
|
101
|
+
)
|
|
100
102
|
self.assertEqual(
|
|
101
103
|
v,
|
|
102
104
|
Example(
|
|
103
105
|
id=1,
|
|
104
|
-
output=pg.
|
|
105
|
-
|
|
106
|
+
output=pg.symbolic.UnknownTypedObject(
|
|
107
|
+
inputs[0].a.__type_name__, x=1
|
|
108
|
+
),
|
|
109
|
+
metadata=dict(
|
|
110
|
+
b=pg.symbolic.UnknownTypedObject(
|
|
111
|
+
inputs[0].b.__type_name__, x=1, y=2
|
|
112
|
+
)
|
|
113
|
+
),
|
|
106
114
|
)
|
|
107
115
|
)
|
|
108
116
|
# Serialize with input.
|
|
@@ -1055,6 +1055,29 @@ class Plugin(lf.Component):
|
|
|
1055
1055
|
or result processing.
|
|
1056
1056
|
"""
|
|
1057
1057
|
|
|
1058
|
+
@classmethod
|
|
1059
|
+
def is_per_example(cls) -> bool:
|
|
1060
|
+
"""Returns whether the plugin is per example only.
|
|
1061
|
+
|
|
1062
|
+
Per-example plugins can be installed on individual workers when examples
|
|
1063
|
+
are evaluated by multiple processes in parallel.
|
|
1064
|
+
"""
|
|
1065
|
+
|
|
1066
|
+
def same_code(method1, method2):
|
|
1067
|
+
return method1.__code__ == method2.__code__
|
|
1068
|
+
return all(
|
|
1069
|
+
same_code(method1, method2)
|
|
1070
|
+
for method1, method2 in [
|
|
1071
|
+
(Plugin.on_run_start, cls.on_run_start),
|
|
1072
|
+
(Plugin.on_run_complete, cls.on_run_complete),
|
|
1073
|
+
(Plugin.on_run_abort, cls.on_run_abort),
|
|
1074
|
+
(Plugin.on_experiment_start, cls.on_experiment_start),
|
|
1075
|
+
(Plugin.on_experiment_skipped, cls.on_experiment_skipped),
|
|
1076
|
+
(Plugin.on_experiment_complete, cls.on_experiment_complete),
|
|
1077
|
+
(Plugin.on_experiment_abort, cls.on_experiment_abort),
|
|
1078
|
+
]
|
|
1079
|
+
)
|
|
1080
|
+
|
|
1058
1081
|
def on_run_start(
|
|
1059
1082
|
self,
|
|
1060
1083
|
runner: Runner,
|
|
@@ -433,5 +433,24 @@ class RunnerTest(unittest.TestCase):
|
|
|
433
433
|
pass
|
|
434
434
|
|
|
435
435
|
|
|
436
|
+
class PluginTest(unittest.TestCase):
|
|
437
|
+
|
|
438
|
+
def test_per_example_only(self):
|
|
439
|
+
|
|
440
|
+
class PerExamplePlugin(experiment_lib.Plugin):
|
|
441
|
+
|
|
442
|
+
def on_example_complete(self, runner, experiment, example):
|
|
443
|
+
print('on_example_complete')
|
|
444
|
+
|
|
445
|
+
self.assertTrue(PerExamplePlugin.is_per_example())
|
|
446
|
+
|
|
447
|
+
class NonPerExamplePlugin(experiment_lib.Plugin):
|
|
448
|
+
|
|
449
|
+
def on_experiment_complete(self, runner, experiment):
|
|
450
|
+
print('on_example_complete')
|
|
451
|
+
|
|
452
|
+
self.assertFalse(NonPerExamplePlugin.is_per_example())
|
|
453
|
+
|
|
454
|
+
|
|
436
455
|
if __name__ == '__main__':
|
|
437
456
|
unittest.main()
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
"""Tracking evaluation run progress."""
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
|
+
from typing import Literal
|
|
17
18
|
import langfun.core as lf
|
|
18
19
|
from langfun.core.eval.v2 import example as example_lib
|
|
19
20
|
from langfun.core.eval.v2 import experiment as experiment_lib
|
|
@@ -24,16 +25,24 @@ Experiment = experiment_lib.Experiment
|
|
|
24
25
|
Example = example_lib.Example
|
|
25
26
|
|
|
26
27
|
|
|
27
|
-
def progress_tracker(
|
|
28
|
+
def progress_tracker(
|
|
29
|
+
tracker_type: Literal['tqdm', 'html', 'auto'] = 'auto'
|
|
30
|
+
) -> experiment_lib.Plugin:
|
|
28
31
|
"""Creates a progress tracker as a plugin.
|
|
29
32
|
|
|
30
33
|
Args:
|
|
31
|
-
|
|
34
|
+
tracker_type: The type of progress tracker to use.
|
|
35
|
+
If `tqdm`, force using tqdm for progress update.
|
|
36
|
+
If `html`, force using html for progress update.
|
|
37
|
+
If `auto`, determine it automatically based on the running
|
|
38
|
+
environment (console vs. notebook)
|
|
32
39
|
|
|
33
40
|
Returns:
|
|
34
41
|
The progress tracker plugin.
|
|
35
42
|
"""
|
|
36
|
-
if tqdm or
|
|
43
|
+
if tracker_type == 'tqdm' or (
|
|
44
|
+
tracker_type == 'auto' and not lf.console.under_notebook()
|
|
45
|
+
):
|
|
37
46
|
return _TqdmProgressTracker()
|
|
38
47
|
else:
|
|
39
48
|
return _HtmlProgressTracker()
|
|
@@ -21,7 +21,7 @@ import unittest
|
|
|
21
21
|
from langfun.core import concurrent as lf_concurrent
|
|
22
22
|
from langfun.core import console as lf_console
|
|
23
23
|
from langfun.core.eval.v2 import eval_test_helper
|
|
24
|
-
from langfun.core.eval.v2 import progress_tracking
|
|
24
|
+
from langfun.core.eval.v2 import progress_tracking
|
|
25
25
|
from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
|
|
26
26
|
import pyglove as pg
|
|
27
27
|
|
|
@@ -33,6 +33,7 @@ class HtmlProgressTrackerTest(unittest.TestCase):
|
|
|
33
33
|
def display(x):
|
|
34
34
|
result['view'] = x.to_html()
|
|
35
35
|
|
|
36
|
+
self.assertFalse(progress_tracking._HtmlProgressTracker.is_per_example())
|
|
36
37
|
lf_console._notebook = pg.Dict(
|
|
37
38
|
display=display
|
|
38
39
|
)
|
|
@@ -46,6 +47,7 @@ class HtmlProgressTrackerTest(unittest.TestCase):
|
|
|
46
47
|
class TqdmProgressTrackerTest(unittest.TestCase):
|
|
47
48
|
|
|
48
49
|
def test_basic(self):
|
|
50
|
+
self.assertFalse(progress_tracking._TqdmProgressTracker.is_per_example())
|
|
49
51
|
root_dir = os.path.join(tempfile.mkdtemp(), 'test_tqdm_progress_tracker')
|
|
50
52
|
experiment = eval_test_helper.test_experiment()
|
|
51
53
|
string_io = io.StringIO()
|
|
@@ -29,7 +29,11 @@ class ReportingTest(unittest.TestCase):
|
|
|
29
29
|
experiment = eval_test_helper.test_experiment()
|
|
30
30
|
checkpointer = checkpointing.BulkCheckpointer('checkpoint.jsonl')
|
|
31
31
|
reporter = reporting.HtmlReporter()
|
|
32
|
+
self.assertFalse(reporter.is_per_example())
|
|
33
|
+
|
|
32
34
|
example_html_generator = reporting.ExampleHtmlGenerator()
|
|
35
|
+
self.assertTrue(example_html_generator.is_per_example())
|
|
36
|
+
|
|
33
37
|
run = experiment.run(
|
|
34
38
|
root_dir,
|
|
35
39
|
'new',
|
|
@@ -13,13 +13,17 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
"""Langfun evaluation runners."""
|
|
15
15
|
|
|
16
|
+
# pylint: disable=g-importing-member
|
|
16
17
|
from langfun.core.eval.v2.runners.base import RunnerBase
|
|
18
|
+
from langfun.core.eval.v2.runners.beam import BeamRunner
|
|
17
19
|
from langfun.core.eval.v2.runners.debug import DebugRunner
|
|
18
20
|
from langfun.core.eval.v2.runners.parallel import ParallelRunner
|
|
19
21
|
from langfun.core.eval.v2.runners.sequential import SequentialRunner
|
|
22
|
+
# pylint: enable=g-importing-member
|
|
20
23
|
|
|
21
24
|
__all__ = [
|
|
22
25
|
'RunnerBase',
|
|
26
|
+
'BeamRunner',
|
|
23
27
|
'DebugRunner',
|
|
24
28
|
'ParallelRunner',
|
|
25
29
|
'SequentialRunner',
|