langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501150804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. langfun/__init__.py +20 -2
  2. langfun/core/__init__.py +16 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -21
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +63 -2
  18. langfun/core/component_test.py +53 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +17 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +16 -1
  24. langfun/core/eval/base.py +622 -174
  25. langfun/core/eval/base_test.py +200 -54
  26. langfun/core/eval/matching.py +63 -76
  27. langfun/core/eval/matching_test.py +17 -8
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +26 -26
  31. langfun/core/eval/scoring_test.py +19 -2
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +4 -17
  55. langfun/core/langfunc_test.py +22 -6
  56. langfun/core/language_model.py +577 -39
  57. langfun/core/language_model_test.py +470 -56
  58. langfun/core/llms/__init__.py +87 -16
  59. langfun/core/llms/anthropic.py +312 -87
  60. langfun/core/llms/anthropic_test.py +71 -3
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +53 -2
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +11 -7
  69. langfun/core/llms/fake_test.py +14 -0
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -202
  74. langfun/core/llms/groq.py +160 -144
  75. langfun/core/llms/groq_test.py +31 -137
  76. langfun/core/llms/llama_cpp.py +15 -42
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +395 -203
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +30 -395
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -26
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +12 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +7 -6
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +60 -27
  112. langfun/core/structured/function_generation_test.py +72 -2
  113. langfun/core/structured/mapping.py +97 -47
  114. langfun/core/structured/mapping_test.py +90 -2
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +53 -9
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
  119. langfun/core/structured/schema.py +204 -97
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_test.py +130 -29
  122. langfun/core/structured/scoring.py +125 -19
  123. langfun/core/structured/scoring_test.py +30 -0
  124. langfun/core/structured/tokenization.py +64 -0
  125. langfun/core/structured/tokenization_test.py +48 -0
  126. langfun/core/template.py +115 -1
  127. langfun/core/template_test.py +71 -1
  128. langfun/core/templates/conversation.py +9 -0
  129. langfun/core/templates/conversation_test.py +4 -3
  130. langfun/core/templates/selfplay_test.py +10 -2
  131. langfun-0.1.2.dev202501150804.dist-info/METADATA +225 -0
  132. langfun-0.1.2.dev202501150804.dist-info/RECORD +153 -0
  133. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/WHEEL +1 -1
  134. langfun/core/coding/python/errors.py +0 -108
  135. langfun/core/coding/python/errors_test.py +0 -99
  136. langfun/core/coding/python/permissions.py +0 -90
  137. langfun/core/coding/python/permissions_test.py +0 -86
  138. langfun/core/structured/prompting.py +0 -238
  139. langfun/core/text_formatting.py +0 -162
  140. langfun/core/text_formatting_test.py +0 -47
  141. langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
  142. langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
  143. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/LICENSE +0 -0
  144. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,433 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import datetime
15
+ import os
16
+ import tempfile
17
+ import unittest
18
+
19
+ from langfun.core.eval.v2 import evaluation as evaluation_lib
20
+ from langfun.core.eval.v2 import experiment as experiment_lib
21
+ from langfun.core.eval.v2 import metrics as metrics_lib
22
+
23
+ import pyglove as pg
24
+
25
+ Experiment = experiment_lib.Experiment
26
+ Suite = experiment_lib.Suite
27
+ Evaluation = evaluation_lib.Evaluation
28
+ Run = experiment_lib.Run
29
+ RunId = experiment_lib.RunId
30
+ Runner = experiment_lib.Runner
31
+
32
+
33
+ @pg.functor()
34
+ def sample_inputs(num_examples: int = 1):
35
+ return [
36
+ pg.Dict(x=1)
37
+ ] * num_examples
38
+
39
+
40
+ class MyEvaluation(Evaluation):
41
+ NAME = 'my_eval'
42
+ RUN_ARGS = dict(
43
+ runner='test'
44
+ )
45
+
46
+ replica_id: int = 0
47
+ inputs = sample_inputs()
48
+ metrics = [metrics_lib.Match()]
49
+
50
+ def process(self, example):
51
+ return 1
52
+
53
+
54
+ class ExperimentTest(unittest.TestCase):
55
+
56
+ def test_hierarchy(self):
57
+ exp = Suite([
58
+ Suite([
59
+ MyEvaluation(replica_id=0)
60
+ ]),
61
+ MyEvaluation(replica_id=pg.oneof(range(5))),
62
+ ])
63
+
64
+ self.assertIsNotNone(exp.id)
65
+ self.assertTrue(exp.id.startswith('Suite@'))
66
+ self.assertEqual(len(exp.children), 2)
67
+ self.assertEqual(len(exp.leaf_nodes), 6)
68
+ self.assertEqual(len(exp.nonleaf_nodes), 3)
69
+ self.assertFalse(exp.is_leaf)
70
+ self.assertFalse(exp.empty())
71
+ self.assertEqual(len(exp.nodes), 9)
72
+
73
+ self.assertTrue(exp.children[0].children[0].id.startswith('MyEvaluation@'))
74
+ self.assertTrue(exp.children[0].children[0].is_leaf)
75
+ self.assertEqual(len(exp.children[0].children[0].leaf_nodes), 1)
76
+ self.assertFalse(exp.children[1].is_leaf)
77
+ self.assertEqual(len(exp.children[1].children), 5)
78
+ self.assertEqual(len(exp.children[1].leaf_nodes), 5)
79
+ self.assertEqual(exp.leaf_nodes[-1].replica_id, 4)
80
+ self.assertNotEqual(exp.leaf_nodes[1].hash, exp.leaf_nodes[2].hash)
81
+
82
+ self.assertIsNone(exp.parent)
83
+ self.assertIs(exp.children[0].parent, exp)
84
+ self.assertIs(exp.children[0].children[0].parent, exp.children[0])
85
+ self.assertIs(exp.children[1].children[0].parent, exp.children[1])
86
+ self.assertIs(exp.get(exp.leaf_nodes[-1].id), exp.leaf_nodes[-1])
87
+
88
+ def test_html_view(self):
89
+ exp = Suite([
90
+ Suite([
91
+ MyEvaluation(replica_id=0)
92
+ ]),
93
+ MyEvaluation(replica_id=pg.oneof(range(5))),
94
+ ])
95
+ self.assertIn(exp.id, exp.to_html().content)
96
+ run = Run('/root', RunId.from_id('20241102_0'), pg.Ref(exp))
97
+ self.assertIn(
98
+ str(run.id),
99
+ run.to_html(
100
+ extra_flags=dict(
101
+ current_run=run
102
+ )
103
+ ).content
104
+ )
105
+
106
+ def test_find(self):
107
+ exp = Experiment.find('my_eval')
108
+ self.assertIsInstance(exp, MyEvaluation)
109
+ exp = Experiment.find('.*_eval')
110
+ self.assertIsInstance(exp, MyEvaluation)
111
+ exp = Experiment.find('foo')
112
+ self.assertTrue(pg.eq(exp, Suite([])))
113
+
114
+
115
+ class RunIdTest(unittest.TestCase):
116
+
117
+ def test_basic(self):
118
+ rid = RunId.from_id('20241102_0')
119
+ self.assertEqual(
120
+ rid.dirname('/root'), os.path.join('/root', 'run_20241102_0')
121
+ )
122
+ self.assertEqual(str(rid), '20241102_0')
123
+ self.assertEqual(rid.date, datetime.date(2024, 11, 2))
124
+ self.assertEqual(rid.number, 0)
125
+
126
+ def test_comparison(self):
127
+ self.assertEqual(
128
+ RunId.from_id('20241102_0'), RunId.from_id('20241102_0')
129
+ )
130
+ self.assertLess(
131
+ RunId.from_id('20241102_0'), RunId.from_id('20241102_1')
132
+ )
133
+ self.assertLess(
134
+ RunId.from_id('20241101_0'), RunId.from_id('20241102_1')
135
+ )
136
+ self.assertGreater(
137
+ RunId.from_id('20241102_0'), RunId.from_id('20241101_0')
138
+ )
139
+ self.assertLessEqual(
140
+ RunId.from_id('20241102_0'), RunId.from_id('20241102_0')
141
+ )
142
+ self.assertEqual(
143
+ RunId.from_id('20241102_0').next(),
144
+ RunId.from_id('20241102_1')
145
+ )
146
+
147
+ def test_get_latest(self):
148
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_eval')
149
+ pg.io.mkdirs(os.path.join(root_dir, 'run_20241102_0'))
150
+ pg.io.mkdirs(os.path.join(root_dir, 'run_20241101_0'))
151
+ self.assertEqual(
152
+ RunId.get_latest(root_dir),
153
+ RunId.from_id('20241102_0')
154
+ )
155
+ self.assertIsNone(RunId.get_latest('/notexist'))
156
+ self.assertIsNone(RunId.get_latest(tempfile.gettempdir()))
157
+
158
+ def test_new(self):
159
+ rid = RunId(date=datetime.date.today(), number=1)
160
+ self.assertEqual(
161
+ RunId.new(root_dir=os.path.join(tempfile.gettempdir(), 'test_new')),
162
+ rid
163
+ )
164
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_eval2')
165
+ pg.io.mkdirs(rid.dirname(root_dir))
166
+ self.assertEqual(RunId.new(root_dir), rid.next())
167
+
168
+ def test_is_valid(self):
169
+ self.assertTrue(RunId.is_valid('latest'))
170
+ self.assertTrue(RunId.is_valid('new'))
171
+ self.assertTrue(RunId.is_valid('20241102_0'))
172
+ self.assertFalse(RunId.is_valid('20241102-0'))
173
+
174
+ def test_from_id(self):
175
+ with self.assertRaisesRegex(
176
+ ValueError, '.* must be one of'
177
+ ):
178
+ RunId.from_id('abc')
179
+
180
+ with self.assertRaisesRegex(
181
+ ValueError, '`root_dir` must be provided'
182
+ ):
183
+ RunId.from_id('latest')
184
+
185
+ with self.assertRaisesRegex(
186
+ ValueError, '.* no previous runs'
187
+ ):
188
+ RunId.from_id('latest', root_dir=tempfile.gettempdir())
189
+
190
+ self.assertEqual(
191
+ RunId.from_id('20241102_1'),
192
+ RunId(date=datetime.date(2024, 11, 2), number=1)
193
+ )
194
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_eval3')
195
+ rid = RunId.from_id('20241102_1')
196
+ pg.io.mkdirs(rid.dirname(root_dir))
197
+ self.assertEqual(
198
+ RunId.from_id('latest', root_dir=root_dir), rid
199
+ )
200
+ self.assertEqual(
201
+ RunId.from_id('new', root_dir=root_dir),
202
+ RunId(datetime.date.today(), 1)
203
+ )
204
+ self.assertEqual(
205
+ RunId.from_id(None, root_dir=root_dir), rid
206
+ )
207
+
208
+
209
+ class RunTest(unittest.TestCase):
210
+
211
+ def test_input_output_paths(self):
212
+ run = Run(
213
+ '/root',
214
+ RunId.from_id('20241102_0'),
215
+ pg.Ref(Suite([
216
+ MyEvaluation(replica_id=0),
217
+ ])),
218
+ )
219
+ self.assertEqual(run.output_root, '/root/run_20241102_0')
220
+ self.assertEqual(run.input_root, '/root/run_20241102_0')
221
+ self.assertEqual(
222
+ run.output_dir(run.experiment.leaf_nodes[0]),
223
+ (
224
+ '/root/run_20241102_0/MyEvaluation/'
225
+ + run.experiment.leaf_nodes[0].hash
226
+ )
227
+ )
228
+ self.assertEqual(
229
+ run.input_path_for(run.experiment, 'a.txt'),
230
+ '/root/run_20241102_0/a.txt'
231
+ )
232
+ self.assertEqual(
233
+ run.input_path_for(run.experiment.leaf_nodes[0], 'a.txt'),
234
+ '/root/run_20241102_0/MyEvaluation/%s/a.txt' % (
235
+ run.experiment.leaf_nodes[0].hash
236
+ )
237
+ )
238
+
239
+ # With warmup_id
240
+ run = Run(
241
+ '/root',
242
+ RunId.from_id('20241102_0'),
243
+ pg.Ref(Suite([MyEvaluation(replica_id=0)])),
244
+ warm_start_from='/root2/run_20241103_1'
245
+ )
246
+ self.assertEqual(run.output_root, '/root/run_20241102_0')
247
+ self.assertEqual(run.input_root, '/root2/run_20241103_1')
248
+ self.assertEqual(
249
+ run.output_dir(run.experiment.leaf_nodes[0]),
250
+ (
251
+ '/root/run_20241102_0/MyEvaluation/'
252
+ + run.experiment.leaf_nodes[0].hash
253
+ )
254
+ )
255
+ self.assertEqual(
256
+ run.input_dir(run.experiment.leaf_nodes[0]),
257
+ (
258
+ '/root2/run_20241103_1/MyEvaluation/'
259
+ + run.experiment.leaf_nodes[0].hash
260
+ )
261
+ )
262
+ self.assertEqual(
263
+ run.input_path_for(run.experiment, 'a.txt'),
264
+ '/root2/run_20241103_1/a.txt'
265
+ )
266
+ self.assertEqual(
267
+ run.input_path_for(run.experiment.leaf_nodes[0], 'a.txt'),
268
+ '/root2/run_20241103_1/MyEvaluation/%s/a.txt' % (
269
+ run.experiment.leaf_nodes[0].hash
270
+ )
271
+ )
272
+
273
+ def test_examples_start_from_scratch(self):
274
+ run = Run(
275
+ '/root',
276
+ RunId.from_id('20241102_0'),
277
+ pg.Ref(Suite([
278
+ MyEvaluation(replica_id=0, inputs=sample_inputs(10)),
279
+ ])),
280
+ )
281
+ root = run.experiment
282
+ self.assertEqual(run.examples_to_evaluate(root), set())
283
+ self.assertEqual(run.examples_to_reprocess(root), set())
284
+ self.assertEqual(run.examples_to_load(root), set())
285
+ self.assertEqual(run.examples_to_load_metadata(root), set())
286
+
287
+ exp = root.leaf_nodes[0]
288
+ self.assertEqual(run.examples_to_evaluate(exp), set(range(1, 11)))
289
+ self.assertEqual(run.examples_to_reprocess(exp), set())
290
+ self.assertEqual(run.examples_to_load(exp), set(range(1, 11)))
291
+ self.assertEqual(run.examples_to_load_metadata(exp), set())
292
+
293
+ def test_examples_with_example_ids(self):
294
+ run = Run(
295
+ '/root',
296
+ RunId.from_id('20241102_0'),
297
+ pg.Ref(Suite([
298
+ MyEvaluation(replica_id=0, inputs=sample_inputs(10)),
299
+ ])),
300
+ example_ids=[1, 3, 5]
301
+ )
302
+ exp = run.experiment.leaf_nodes[0]
303
+ self.assertEqual(run.examples_to_evaluate(exp), set([1, 3, 5]))
304
+ self.assertEqual(run.examples_to_reprocess(exp), set())
305
+ self.assertEqual(run.examples_to_load(exp), set([1, 3, 5]))
306
+ self.assertEqual(run.examples_to_load_metadata(exp), set())
307
+
308
+ def test_examples_with_reprocess_all(self):
309
+ run = Run(
310
+ '/root',
311
+ RunId.from_id('20241102_0'),
312
+ pg.Ref(Suite([
313
+ MyEvaluation(replica_id=0, inputs=sample_inputs(10)),
314
+ ])),
315
+ example_ids=[1, 3, 5],
316
+ reprocess=True
317
+ )
318
+ exp = run.experiment.leaf_nodes[0]
319
+ self.assertEqual(run.examples_to_evaluate(exp), set([1, 3, 5]))
320
+ self.assertEqual(run.examples_to_reprocess(exp), set([1, 3, 5]))
321
+ self.assertEqual(run.examples_to_load(exp), set())
322
+ self.assertEqual(run.examples_to_load_metadata(exp), set())
323
+
324
+ def test_examples_with_reprocess_some(self):
325
+ run = Run(
326
+ '/root',
327
+ RunId.from_id('20241102_0'),
328
+ pg.Ref(Suite([
329
+ MyEvaluation(replica_id=0, inputs=sample_inputs(10)),
330
+ ])),
331
+ example_ids=[1, 3, 5],
332
+ reprocess=[1],
333
+ )
334
+ exp = run.experiment.leaf_nodes[0]
335
+ self.assertEqual(run.examples_to_evaluate(exp), set([1, 3, 5]))
336
+ self.assertEqual(run.examples_to_reprocess(exp), set([1]))
337
+ self.assertEqual(run.examples_to_load(exp), set([3, 5]))
338
+ self.assertEqual(run.examples_to_load_metadata(exp), set())
339
+
340
+ def test_examples_with_generate_example_html_all(self):
341
+ run = Run(
342
+ '/root',
343
+ RunId.from_id('20241102_0'),
344
+ pg.Ref(Suite([
345
+ MyEvaluation(replica_id=0, inputs=sample_inputs(10)),
346
+ ])),
347
+ example_ids=[1, 3, 5],
348
+ reprocess=[1],
349
+ generate_example_html='all',
350
+ )
351
+ exp = run.experiment.leaf_nodes[0]
352
+ self.assertEqual(run.examples_to_evaluate(exp), set([1, 3, 5]))
353
+ self.assertEqual(run.examples_to_reprocess(exp), set([1]))
354
+ self.assertEqual(run.examples_to_load(exp), set([3, 5]))
355
+ self.assertEqual(run.examples_to_load_metadata(exp), set([3, 5]))
356
+
357
+ def test_examples_with_generate_example_html_new(self):
358
+ run = Run(
359
+ '/root',
360
+ RunId.from_id('20241102_0'),
361
+ pg.Ref(Suite([
362
+ MyEvaluation(replica_id=0, inputs=sample_inputs(10)),
363
+ ])),
364
+ example_ids=[1, 3, 5],
365
+ reprocess=[1],
366
+ generate_example_html='new',
367
+ )
368
+ exp = run.experiment.leaf_nodes[0]
369
+ self.assertEqual(run.examples_to_evaluate(exp), set([1, 3, 5]))
370
+ self.assertEqual(run.examples_to_reprocess(exp), set([1]))
371
+ self.assertEqual(run.examples_to_load(exp), set([3, 5]))
372
+ self.assertEqual(run.examples_to_load_metadata(exp), set())
373
+
374
+ def test_examples_with_generate_example_html_some(self):
375
+ run = Run(
376
+ '/root',
377
+ RunId.from_id('20241102_0'),
378
+ pg.Ref(Suite([
379
+ MyEvaluation(replica_id=0, inputs=sample_inputs(10)),
380
+ ])),
381
+ example_ids=[1, 3, 5],
382
+ reprocess=[1],
383
+ generate_example_html=[1, 2, 3],
384
+ )
385
+ exp = run.experiment.leaf_nodes[0]
386
+ self.assertEqual(run.examples_to_evaluate(exp), set([1, 3, 5]))
387
+ self.assertEqual(run.examples_to_reprocess(exp), set([1]))
388
+ self.assertEqual(run.examples_to_load(exp), set([2, 3, 5]))
389
+ self.assertEqual(run.examples_to_load_metadata(exp), set([2, 3]))
390
+
391
+
392
+ class RunnerTest(unittest.TestCase):
393
+
394
+ def test_basic(self):
395
+
396
+ class TestRunner(Runner):
397
+ NAME = 'test'
398
+
399
+ def run(self):
400
+ pass
401
+
402
+ self.assertIsInstance(
403
+ Runner.create(
404
+ 'test',
405
+ current_run=Run(
406
+ '/root',
407
+ RunId.from_id('20241102_0'), pg.Ref(Suite([])),
408
+ )
409
+ ),
410
+ TestRunner
411
+ )
412
+ root_dir = os.path.join(tempfile.gettempdir(), 'my_eval')
413
+
414
+ # Test standard run.
415
+ MyEvaluation(replica_id=0).run(
416
+ root_dir, id='20241101_0', runner='test'
417
+ )
418
+
419
+ # Test run preconfigured.
420
+ MyEvaluation(replica_id=0).run_preconfigured(
421
+ root_dir=root_dir, id='20241101_1'
422
+ )
423
+
424
+ with self.assertRaisesRegex(
425
+ ValueError, 'Runner class must define a NAME constant'
426
+ ):
427
+ class AnotherRunner(Runner): # pylint: disable=unused-variable
428
+ def run(self):
429
+ pass
430
+
431
+
432
+ if __name__ == '__main__':
433
+ unittest.main()
@@ -0,0 +1,156 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Common value types for evaluation metrics and metadata."""
15
+
16
+
17
+ import abc
18
+ from typing import Annotated, Any, Union
19
+ import pyglove as pg
20
+
21
+
22
+ class MetricValue(pg.Object):
23
+ """Base class for metric values."""
24
+
25
+ class DataPoint(pg.Object):
26
+ """A data point for a metric value."""
27
+ example_id: int
28
+ value: float
29
+ weight: float = 1.0
30
+
31
+ # NOTE(daiyip): For evaluations, usually the number of examples is within 10K,
32
+ # therefore it's beneficial to store all accumulated values with their example
33
+ # IDs so we are able to track the individual examples that contributed to this
34
+ # metric value. If this premise changes, we might consider using a more
35
+ # efficient data structure.
36
+ data_points: Annotated[
37
+ list[DataPoint],
38
+ 'Accumulated computed values with example IDs and weights.'
39
+ ] = []
40
+
41
+ total: Annotated[
42
+ int,
43
+ 'The total number of examples being evaluated. Including errors.'
44
+ ] = 0
45
+
46
+ def _on_bound(self):
47
+ super()._on_bound()
48
+ self._weighted_sum = sum(dp.value * dp.weight for dp in self.data_points)
49
+
50
+ def reset(self) -> None:
51
+ """Resets the value to its initial state."""
52
+ self._sync_members(data_points=[], total=0)
53
+ self._weighted_sum = 0.0
54
+
55
+ def _sync_members(self, **kwargs) -> None:
56
+ """Synchronizes the members of this object."""
57
+ self.rebind(**kwargs, skip_notification=True, raise_on_no_change=False)
58
+
59
+ def __float__(self) -> float:
60
+ """Returns the float representation of this object."""
61
+ if self.total == 0:
62
+ return float('nan')
63
+ return self.reduce()
64
+
65
+ @abc.abstractmethod
66
+ def reduce(self) -> float:
67
+ """Reduces the accumulated values into a single value."""
68
+
69
+ def increment_total(self, delta: int = 1) -> 'MetricValue':
70
+ """Increments the total number of examples being evaluated."""
71
+ self._sync_members(total=self.total + delta)
72
+ return self
73
+
74
+ def add(
75
+ self,
76
+ example_id: int,
77
+ value: float,
78
+ weight: float = 1.0,
79
+ increment_total: bool = False,
80
+ ) -> 'MetricValue':
81
+ """Adds a value to the accumulated values."""
82
+ self._weighted_sum += value * weight
83
+ with pg.notify_on_change(False), pg.allow_writable_accessors(True):
84
+ self.data_points.append(
85
+ MetricValue.DataPoint(example_id, value, weight)
86
+ )
87
+ if increment_total:
88
+ self.increment_total()
89
+ return self
90
+
91
+ def __gt__(self, other: Union['MetricValue', float]) -> bool:
92
+ if isinstance(other, self.__class__):
93
+ return float(self) > float(other)
94
+ return float(self) > other
95
+
96
+ def __lt__(self, other: Union['MetricValue', float]) -> bool:
97
+ if isinstance(other, self.__class__):
98
+ return float(self) < float(other)
99
+ return float(self) < other
100
+
101
+ def __eq__(self, other: Union['MetricValue', float]) -> bool:
102
+ if isinstance(other, self.__class__):
103
+ return super().__eq__(other)
104
+ return float(self) == other
105
+
106
+ def __nonzero__(self) -> bool:
107
+ return float(self) != 0
108
+
109
+ def format(
110
+ self,
111
+ compact: bool = False,
112
+ verbose: bool = True,
113
+ *args,
114
+ **kwargs
115
+ ) -> str:
116
+ if compact:
117
+ return super().format(compact, *args, **kwargs)
118
+ if self.total == 0:
119
+ return 'n/a'
120
+ if verbose:
121
+ return (
122
+ f'{self.scalar_repr()} ({len(self.data_points)}/{self.total})'
123
+ )
124
+ return self.scalar_repr()
125
+
126
+ @abc.abstractmethod
127
+ def scalar_repr(self) -> str:
128
+ """Returns the format string for the value."""
129
+
130
+ def _sym_nondefault(self) -> dict[str, Any]:
131
+ """Overrides nondefault valuesso volatile values are not included."""
132
+ return dict()
133
+
134
+
135
+ class Rate(MetricValue):
136
+ """Representing a rate in range [0, 1]."""
137
+
138
+ def reduce(self) -> float:
139
+ return self._weighted_sum / self.total
140
+
141
+ def scalar_repr(self):
142
+ if self.total == 0:
143
+ return 'n/a'
144
+ return f'{self.reduce():.1%}'
145
+
146
+
147
+ class Average(MetricValue):
148
+ """Average of a aggregated values."""
149
+
150
+ def reduce(self) -> float:
151
+ if not self.data_points:
152
+ return float('nan')
153
+ return self._weighted_sum / len(self.data_points)
154
+
155
+ def scalar_repr(self):
156
+ return f'{self.reduce():.3f}'
@@ -0,0 +1,80 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ import unittest
16
+
17
+ from langfun.core.eval.v2 import metric_values
18
+ import pyglove as pg
19
+
20
+
21
+ class RateTest(unittest.TestCase):
22
+
23
+ def test_basic(self):
24
+ rate = metric_values.Rate()
25
+ self.assertEqual(rate.total, 0)
26
+ self.assertTrue(math.isnan(float(rate)))
27
+ self.assertEqual(pg.format(rate), 'n/a')
28
+ rate.increment_total()
29
+ self.assertEqual(rate.total, 1)
30
+ self.assertEqual(float(rate), 0.0)
31
+ rate.add(1, 1.0, 1.0)
32
+ self.assertEqual(float(rate), 1.0)
33
+ self.assertEqual(pg.format(rate, verbose=False), '100.0%')
34
+ self.assertEqual(pg.format(rate, verbose=True), '100.0% (1/1)')
35
+ self.assertEqual(
36
+ rate.data_points, [metric_values.MetricValue.DataPoint(1, 1.0, 1.0)]
37
+ )
38
+ self.assertEqual(rate, 1.0)
39
+ self.assertGreater(rate, 0.5)
40
+ self.assertLess(rate, 1.5)
41
+ self.assertEqual(
42
+ rate,
43
+ metric_values.Rate(
44
+ [metric_values.MetricValue.DataPoint(1, 1.0, 1.0)], 1
45
+ )
46
+ )
47
+ self.assertGreater(rate, metric_values.Rate([], 1))
48
+ self.assertLess(metric_values.Rate([], 1), rate)
49
+
50
+ rate.reset()
51
+ self.assertEqual(rate.total, 0)
52
+ self.assertTrue(math.isnan(float(rate)))
53
+
54
+
55
+ class AverageTest(unittest.TestCase):
56
+
57
+ def test_basic(self):
58
+ average = metric_values.Average()
59
+ self.assertEqual(average.total, 0)
60
+ self.assertTrue(math.isnan(float(average)))
61
+ self.assertEqual(pg.format(average, verbose=False), 'n/a')
62
+ average.add(1, 1.0, 0.5, increment_total=True)
63
+ average.add(1, 0.0, 1.0, increment_total=True)
64
+ self.assertEqual(average.total, 2)
65
+ self.assertEqual(float(average), 0.25)
66
+ self.assertEqual(pg.format(average, verbose=False), '0.250')
67
+ self.assertEqual(pg.format(average, verbose=True), '0.250 (2/2)')
68
+ self.assertEqual(
69
+ average.data_points,
70
+ [
71
+ metric_values.MetricValue.DataPoint(1, 1.0, 0.5),
72
+ metric_values.MetricValue.DataPoint(1, 0.0, 1.0),
73
+ ]
74
+ )
75
+ average.reset()
76
+ self.assertEqual(average.total, 0)
77
+
78
+
79
+ if __name__ == '__main__':
80
+ unittest.main()