langfun 0.1.2.dev202509120804__py3-none-any.whl → 0.1.2.dev202512040805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. langfun/__init__.py +1 -1
  2. langfun/core/__init__.py +7 -1
  3. langfun/core/agentic/__init__.py +8 -1
  4. langfun/core/agentic/action.py +740 -112
  5. langfun/core/agentic/action_eval.py +9 -2
  6. langfun/core/agentic/action_test.py +189 -24
  7. langfun/core/async_support.py +104 -5
  8. langfun/core/async_support_test.py +23 -0
  9. langfun/core/coding/python/correction.py +19 -9
  10. langfun/core/coding/python/execution.py +14 -12
  11. langfun/core/coding/python/generation.py +21 -16
  12. langfun/core/coding/python/sandboxing.py +23 -3
  13. langfun/core/component.py +42 -3
  14. langfun/core/concurrent.py +70 -6
  15. langfun/core/concurrent_test.py +9 -2
  16. langfun/core/console.py +1 -1
  17. langfun/core/data/conversion/anthropic.py +12 -3
  18. langfun/core/data/conversion/anthropic_test.py +8 -6
  19. langfun/core/data/conversion/gemini.py +11 -2
  20. langfun/core/data/conversion/gemini_test.py +48 -9
  21. langfun/core/data/conversion/openai.py +145 -31
  22. langfun/core/data/conversion/openai_test.py +161 -17
  23. langfun/core/eval/base.py +48 -44
  24. langfun/core/eval/base_test.py +5 -5
  25. langfun/core/eval/matching.py +5 -2
  26. langfun/core/eval/patching.py +3 -3
  27. langfun/core/eval/scoring.py +4 -3
  28. langfun/core/eval/v2/__init__.py +2 -0
  29. langfun/core/eval/v2/checkpointing.py +76 -7
  30. langfun/core/eval/v2/checkpointing_test.py +9 -2
  31. langfun/core/eval/v2/config_saver.py +37 -0
  32. langfun/core/eval/v2/config_saver_test.py +36 -0
  33. langfun/core/eval/v2/eval_test_helper.py +104 -3
  34. langfun/core/eval/v2/evaluation.py +92 -17
  35. langfun/core/eval/v2/evaluation_test.py +9 -3
  36. langfun/core/eval/v2/example.py +50 -40
  37. langfun/core/eval/v2/example_test.py +16 -8
  38. langfun/core/eval/v2/experiment.py +84 -15
  39. langfun/core/eval/v2/experiment_test.py +19 -0
  40. langfun/core/eval/v2/metric_values.py +31 -3
  41. langfun/core/eval/v2/metric_values_test.py +32 -0
  42. langfun/core/eval/v2/metrics.py +157 -44
  43. langfun/core/eval/v2/metrics_test.py +39 -18
  44. langfun/core/eval/v2/progress.py +31 -1
  45. langfun/core/eval/v2/progress_test.py +27 -0
  46. langfun/core/eval/v2/progress_tracking.py +13 -5
  47. langfun/core/eval/v2/progress_tracking_test.py +9 -1
  48. langfun/core/eval/v2/reporting.py +90 -71
  49. langfun/core/eval/v2/reporting_test.py +24 -6
  50. langfun/core/eval/v2/runners/__init__.py +30 -0
  51. langfun/core/eval/v2/{runners.py → runners/base.py} +72 -180
  52. langfun/core/eval/v2/runners/beam.py +354 -0
  53. langfun/core/eval/v2/runners/beam_test.py +153 -0
  54. langfun/core/eval/v2/runners/ckpt_monitor.py +294 -0
  55. langfun/core/eval/v2/runners/ckpt_monitor_test.py +162 -0
  56. langfun/core/eval/v2/runners/debug.py +40 -0
  57. langfun/core/eval/v2/runners/debug_test.py +76 -0
  58. langfun/core/eval/v2/runners/parallel.py +243 -0
  59. langfun/core/eval/v2/runners/parallel_test.py +182 -0
  60. langfun/core/eval/v2/runners/sequential.py +47 -0
  61. langfun/core/eval/v2/runners/sequential_test.py +169 -0
  62. langfun/core/langfunc.py +45 -130
  63. langfun/core/langfunc_test.py +7 -5
  64. langfun/core/language_model.py +189 -36
  65. langfun/core/language_model_test.py +54 -3
  66. langfun/core/llms/__init__.py +12 -1
  67. langfun/core/llms/anthropic.py +157 -2
  68. langfun/core/llms/azure_openai.py +29 -17
  69. langfun/core/llms/cache/base.py +25 -3
  70. langfun/core/llms/cache/in_memory.py +48 -7
  71. langfun/core/llms/cache/in_memory_test.py +14 -4
  72. langfun/core/llms/compositional.py +25 -1
  73. langfun/core/llms/deepseek.py +30 -2
  74. langfun/core/llms/fake.py +32 -1
  75. langfun/core/llms/gemini.py +64 -12
  76. langfun/core/llms/gemini_test.py +110 -0
  77. langfun/core/llms/google_genai.py +34 -1
  78. langfun/core/llms/groq.py +28 -3
  79. langfun/core/llms/llama_cpp.py +23 -4
  80. langfun/core/llms/openai.py +120 -3
  81. langfun/core/llms/openai_compatible.py +148 -27
  82. langfun/core/llms/openai_compatible_test.py +207 -20
  83. langfun/core/llms/openai_test.py +0 -2
  84. langfun/core/llms/rest.py +16 -1
  85. langfun/core/llms/vertexai.py +58 -8
  86. langfun/core/logging.py +1 -1
  87. langfun/core/mcp/__init__.py +10 -0
  88. langfun/core/mcp/client.py +177 -0
  89. langfun/core/mcp/client_test.py +71 -0
  90. langfun/core/mcp/session.py +241 -0
  91. langfun/core/mcp/session_test.py +54 -0
  92. langfun/core/mcp/testing/simple_mcp_client.py +33 -0
  93. langfun/core/mcp/testing/simple_mcp_server.py +33 -0
  94. langfun/core/mcp/tool.py +254 -0
  95. langfun/core/mcp/tool_test.py +197 -0
  96. langfun/core/memory.py +1 -0
  97. langfun/core/message.py +160 -55
  98. langfun/core/message_test.py +65 -81
  99. langfun/core/modalities/__init__.py +8 -0
  100. langfun/core/modalities/audio.py +21 -1
  101. langfun/core/modalities/image.py +73 -3
  102. langfun/core/modalities/image_test.py +116 -0
  103. langfun/core/modalities/mime.py +64 -3
  104. langfun/core/modalities/mime_test.py +11 -0
  105. langfun/core/modalities/pdf.py +19 -1
  106. langfun/core/modalities/video.py +21 -1
  107. langfun/core/modality.py +167 -29
  108. langfun/core/modality_test.py +42 -12
  109. langfun/core/natural_language.py +1 -1
  110. langfun/core/sampling.py +4 -4
  111. langfun/core/sampling_test.py +20 -4
  112. langfun/core/structured/__init__.py +2 -24
  113. langfun/core/structured/completion.py +34 -44
  114. langfun/core/structured/completion_test.py +23 -43
  115. langfun/core/structured/description.py +54 -50
  116. langfun/core/structured/function_generation.py +29 -12
  117. langfun/core/structured/mapping.py +81 -37
  118. langfun/core/structured/parsing.py +95 -79
  119. langfun/core/structured/parsing_test.py +0 -3
  120. langfun/core/structured/querying.py +230 -154
  121. langfun/core/structured/querying_test.py +69 -33
  122. langfun/core/structured/schema/__init__.py +49 -0
  123. langfun/core/structured/schema/base.py +664 -0
  124. langfun/core/structured/schema/base_test.py +531 -0
  125. langfun/core/structured/schema/json.py +174 -0
  126. langfun/core/structured/schema/json_test.py +121 -0
  127. langfun/core/structured/schema/python.py +316 -0
  128. langfun/core/structured/schema/python_test.py +410 -0
  129. langfun/core/structured/schema_generation.py +33 -14
  130. langfun/core/structured/scoring.py +47 -36
  131. langfun/core/structured/tokenization.py +26 -11
  132. langfun/core/subscription.py +2 -2
  133. langfun/core/template.py +175 -50
  134. langfun/core/template_test.py +123 -17
  135. langfun/env/__init__.py +43 -0
  136. langfun/env/base_environment.py +827 -0
  137. langfun/env/base_environment_test.py +473 -0
  138. langfun/env/base_feature.py +304 -0
  139. langfun/env/base_feature_test.py +228 -0
  140. langfun/env/base_sandbox.py +842 -0
  141. langfun/env/base_sandbox_test.py +1235 -0
  142. langfun/env/event_handlers/__init__.py +14 -0
  143. langfun/env/event_handlers/chain.py +233 -0
  144. langfun/env/event_handlers/chain_test.py +253 -0
  145. langfun/env/event_handlers/event_logger.py +472 -0
  146. langfun/env/event_handlers/event_logger_test.py +304 -0
  147. langfun/env/event_handlers/metric_writer.py +726 -0
  148. langfun/env/event_handlers/metric_writer_test.py +214 -0
  149. langfun/env/interface.py +1640 -0
  150. langfun/env/interface_test.py +153 -0
  151. langfun/env/load_balancers.py +59 -0
  152. langfun/env/load_balancers_test.py +141 -0
  153. langfun/env/test_utils.py +507 -0
  154. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/METADATA +7 -3
  155. langfun-0.1.2.dev202512040805.dist-info/RECORD +217 -0
  156. langfun/core/eval/v2/runners_test.py +0 -343
  157. langfun/core/structured/schema.py +0 -987
  158. langfun/core/structured/schema_test.py +0 -982
  159. langfun-0.1.2.dev202509120804.dist-info/RECORD +0 -172
  160. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/WHEEL +0 -0
  161. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/licenses/LICENSE +0 -0
  162. {langfun-0.1.2.dev202509120804.dist-info → langfun-0.1.2.dev202512040805.dist-info}/top_level.txt +0 -0
@@ -25,15 +25,22 @@ class MatchTest(unittest.TestCase):
25
25
  def test_basic(self):
26
26
  m = metrics.Match() # pylint: disable=invalid-name
27
27
  self.assertEqual(
28
- m.audit(Example(id=1, input=pg.Dict(groundtruth=1), output=1)),
29
- dict(match=True)
28
+ m.update(Example(id=1, input=pg.Dict(groundtruth=1), output=1)),
29
+ dict(is_correct=True)
30
30
  )
31
31
  self.assertEqual(
32
- m.audit(Example(id=2, input=pg.Dict(groundtruth=1), output=2)),
33
- dict(mismatch=True)
32
+ m.update(
33
+ Example(
34
+ id=2,
35
+ input=pg.Dict(groundtruth=1),
36
+ output=2,
37
+ metric_metadata=dict(match=dict(is_correct=False, x=1))
38
+ )
39
+ ),
40
+ dict(is_correct=False, x=1)
34
41
  )
35
42
  self.assertEqual(
36
- m.audit(
43
+ m.update(
37
44
  Example(
38
45
  id=3,
39
46
  input=pg.Dict(groundtruth=1),
@@ -47,7 +54,7 @@ class MatchTest(unittest.TestCase):
47
54
  dict(error='ValueError')
48
55
  )
49
56
  self.assertEqual(
50
- m.audit(
57
+ m.update(
51
58
  Example(
52
59
  id=3,
53
60
  input=pg.Dict(groundtruth=1),
@@ -80,7 +87,7 @@ class MatchTest(unittest.TestCase):
80
87
  def test_bad_case(self):
81
88
  m = metrics.Match() # pylint: disable=invalid-name
82
89
  with self.assertRaisesRegex(ValueError, '`groundtruth` is not present'):
83
- m.audit(Example(id=1, input=pg.Dict(x=1), output=1))
90
+ m.update(Example(id=1, input=pg.Dict(x=1), output=1))
84
91
 
85
92
  def test_custom_metadata(self):
86
93
 
@@ -90,22 +97,36 @@ class MatchTest(unittest.TestCase):
90
97
 
91
98
  m = MyMatch() # pylint: disable=invalid-name
92
99
  self.assertEqual(
93
- m.audit(Example(id=1, input=pg.Dict(x=1), output=1)),
94
- dict(match=True, x=1)
100
+ m.update(Example(id=1, input=pg.Dict(x=1), output=1)),
101
+ dict(is_correct=True, x=1)
95
102
  )
96
103
  self.assertEqual(m.matches, 1.0)
97
104
 
98
105
  def test_html_view(self):
99
106
  m = metrics.Match() # pylint: disable=invalid-name
100
- m.audit(Example(id=1, input=pg.Dict(groundtruth=1), output=1))
107
+ m.update(Example(id=1, input=pg.Dict(groundtruth=1), output=1))
101
108
  self.assertIn(
102
109
  '100.0%',
103
110
  m.to_html().content,
104
111
  )
105
112
  with pg.views.html.controls.HtmlControl.track_scripts() as scripts:
106
- m.audit(Example(id=2, input=pg.Dict(groundtruth=1), output=2))
113
+ m.update(Example(id=2, input=pg.Dict(groundtruth=1), output=2))
107
114
  self.assertEqual(len(scripts), 12)
108
115
 
116
+ def test_merge_from(self):
117
+ m1 = metrics.Match()
118
+ m1.update(Example(id=1, input=pg.Dict(groundtruth=1), output=1))
119
+ m2 = metrics.Match()
120
+ m2.update(Example(id=2, input=pg.Dict(groundtruth=1), output=2))
121
+ m1.merge_from(m2)
122
+ self.assertEqual(m1.matches, 0.5)
123
+ self.assertEqual(m1.mismatches, 0.5)
124
+ self.assertEqual(m1.oop_errors, 0.0)
125
+ self.assertEqual(m1.non_oop_errors, 0.0)
126
+ self.assertEqual(m1.matches.total, 2)
127
+ self.assertEqual(len(m1.matches.data_points), 1)
128
+ self.assertEqual(len(m1.mismatches.data_points), 1)
129
+
109
130
 
110
131
  class ScoreTest(unittest.TestCase):
111
132
 
@@ -118,15 +139,15 @@ class ScoreTest(unittest.TestCase):
118
139
 
119
140
  m = MyScore() # pylint: disable=invalid-name
120
141
  self.assertEqual(
121
- m.audit(Example(id=1, input=pg.Dict(x=1), output=1)),
142
+ m.update(Example(id=1, input=pg.Dict(x=1), output=1)),
122
143
  dict(score=1 * 1)
123
144
  )
124
145
  self.assertEqual(
125
- m.audit(Example(id=2, input=pg.Dict(x=2), output=2)),
146
+ m.update(Example(id=2, input=pg.Dict(x=2), output=2)),
126
147
  dict(score=2 * 2)
127
148
  )
128
149
  self.assertEqual(
129
- m.audit(
150
+ m.update(
130
151
  Example(
131
152
  id=3,
132
153
  input=pg.Dict(x=1),
@@ -140,7 +161,7 @@ class ScoreTest(unittest.TestCase):
140
161
  dict(error='ValueError')
141
162
  )
142
163
  self.assertEqual(
143
- m.audit(
164
+ m.update(
144
165
  Example(
145
166
  id=3,
146
167
  input=pg.Dict(x=1),
@@ -176,7 +197,7 @@ class ScoreTest(unittest.TestCase):
176
197
 
177
198
  m = MyScore() # pylint: disable=invalid-name
178
199
  self.assertEqual(
179
- m.audit(Example(id=1, input=pg.Dict(x=1), output=1)),
200
+ m.update(Example(id=1, input=pg.Dict(x=1), output=1)),
180
201
  dict(score=1 * 1, x=1)
181
202
  )
182
203
  self.assertEqual(m.average_score, 1.0)
@@ -189,13 +210,13 @@ class ScoreTest(unittest.TestCase):
189
210
  return example_input.x * output
190
211
 
191
212
  m = MyScore() # pylint: disable=invalid-name
192
- m.audit(Example(id=1, input=pg.Dict(x=1), output=2))
213
+ m.update(Example(id=1, input=pg.Dict(x=1), output=2))
193
214
  self.assertIn(
194
215
  '2.000',
195
216
  m.to_html().content,
196
217
  )
197
218
  with pg.views.html.controls.HtmlControl.track_scripts() as scripts:
198
- m.audit(Example(id=2, input=pg.Dict(x=1), output=2))
219
+ m.update(Example(id=2, input=pg.Dict(x=1), output=2))
199
220
  self.assertEqual(len(scripts), 9)
200
221
 
201
222
 
@@ -21,7 +21,15 @@ import pyglove as pg
21
21
 
22
22
 
23
23
  class Progress(pg.Object, pg.views.HtmlTreeView.Extension):
24
- """Evaluation progress."""
24
+ """Represents and tracks the progress of an evaluation.
25
+
26
+ The `Progress` class maintains counts of processed, failed, and skipped
27
+ items in an evaluation, along with timing information (start time, stop time,
28
+ duration) and an execution summary. It provides properties to check the
29
+ status of the evaluation (e.g., `is_started`, `is_completed`) and methods
30
+ to update progress as items are evaluated.
31
+ It also supports HTML rendering as a progress bar for visualization.
32
+ """
25
33
 
26
34
  num_total: Annotated[
27
35
  int | None,
@@ -84,6 +92,7 @@ class Progress(pg.Object, pg.views.HtmlTreeView.Extension):
84
92
  stop_time=None,
85
93
  execution_summary=pg.object_utils.TimeIt.StatusSummary(),
86
94
  )
95
+ self._progress_bar = None
87
96
 
88
97
  @property
89
98
  def num_completed(self) -> int:
@@ -216,6 +225,27 @@ class Progress(pg.Object, pg.views.HtmlTreeView.Extension):
216
225
  """Overrides nondefault values so volatile values are not included."""
217
226
  return dict()
218
227
 
228
+ def merge_from(self, other: 'Progress') -> None:
229
+ """Merges the progress from another progress."""
230
+ with pg.notify_on_change(False), pg.allow_writable_accessors(True):
231
+ if other.start_time is not None and (
232
+ self.start_time is None or self.start_time > other.start_time):
233
+ self.start_time = other.start_time
234
+
235
+ if other.stop_time is not None and (
236
+ self.stop_time is None or self.stop_time < other.stop_time):
237
+ self.stop_time = other.stop_time
238
+
239
+ if other.num_total is not None:
240
+ if self.num_total is None:
241
+ self.num_total = other.num_total
242
+ else:
243
+ assert self.num_total == other.num_total, (self, other)
244
+ self.num_processed += other.num_processed
245
+ self.num_failed += other.num_failed
246
+ self.num_skipped += other.num_skipped
247
+ self.execution_summary.aggregate(other.execution_summary.breakdown)
248
+
219
249
  #
220
250
  # HTML view.
221
251
  #
@@ -77,6 +77,33 @@ class ProgressTest(unittest.TestCase):
77
77
  self.assertTrue(p.is_stopped)
78
78
  self.assertIsNotNone(p.stop_time_str)
79
79
 
80
+ def test_merge_from(self):
81
+ p1 = Progress()
82
+ p1.start(10)
83
+ p1.increment_processed()
84
+ p1.increment_failed()
85
+ p1.stop()
86
+
87
+ p2 = Progress()
88
+ p2.start(10)
89
+ p2.increment_skipped()
90
+ p2.stop()
91
+
92
+ with pg.allow_writable_accessors(True):
93
+ p1.start_time = 2.0
94
+ p1.stop_time = 4.0
95
+ p2.start_time = 1.0
96
+ p2.stop_time = 5.0
97
+
98
+ p1.merge_from(p2)
99
+ self.assertEqual(p1.num_total, 10)
100
+ self.assertEqual(p1.num_processed, 1)
101
+ self.assertEqual(p1.num_failed, 1)
102
+ self.assertEqual(p1.num_skipped, 1)
103
+ self.assertEqual(p1.num_completed, 3)
104
+ self.assertEqual(p1.start_time, 1.0)
105
+ self.assertEqual(p1.stop_time, 5.0)
106
+
80
107
 
81
108
  if __name__ == '__main__':
82
109
  unittest.main()
@@ -14,6 +14,7 @@
14
14
  """Tracking evaluation run progress."""
15
15
 
16
16
  import os
17
+ from typing import Literal
17
18
  import langfun.core as lf
18
19
  from langfun.core.eval.v2 import example as example_lib
19
20
  from langfun.core.eval.v2 import experiment as experiment_lib
@@ -24,16 +25,24 @@ Experiment = experiment_lib.Experiment
24
25
  Example = example_lib.Example
25
26
 
26
27
 
27
- def progress_tracker(tqdm: bool = False) -> experiment_lib.Plugin:
28
+ def progress_tracker(
29
+ tracker_type: Literal['tqdm', 'html', 'auto'] = 'auto'
30
+ ) -> experiment_lib.Plugin:
28
31
  """Creates a progress tracker as a plugin.
29
32
 
30
33
  Args:
31
- tqdm: If True, force using tqdm for progress update.
34
+ tracker_type: The type of progress tracker to use.
35
+ If `tqdm`, force using tqdm for progress update.
36
+ If `html`, force using html for progress update.
37
+ If `auto`, determine it automatically based on the running
38
+ environment (console vs. notebook)
32
39
 
33
40
  Returns:
34
41
  The progress tracker plugin.
35
42
  """
36
- if tqdm or not lf.console.under_notebook():
43
+ if tracker_type == 'tqdm' or (
44
+ tracker_type == 'auto' and not lf.console.under_notebook()
45
+ ):
37
46
  return _TqdmProgressTracker()
38
47
  else:
39
48
  return _HtmlProgressTracker()
@@ -88,8 +97,7 @@ class _TqdmProgressTracker(experiment_lib.Plugin):
88
97
  self._leaf_progresses = {
89
98
  leaf.id: lf.concurrent.ProgressBar.install(
90
99
  label=f'[#{i + 1} - {leaf.id}]',
91
- total=(len(runner.current_run.example_ids)
92
- if runner.current_run.example_ids else leaf.num_examples),
100
+ total=len(runner.current_run.examples_to_evaluate(leaf)),
93
101
  color='cyan',
94
102
  status=None
95
103
  )
@@ -14,12 +14,14 @@
14
14
  import contextlib
15
15
  import io
16
16
  import os
17
+ import sys
17
18
  import tempfile
18
19
  import unittest
19
20
 
21
+ from langfun.core import concurrent as lf_concurrent
20
22
  from langfun.core import console as lf_console
21
23
  from langfun.core.eval.v2 import eval_test_helper
22
- from langfun.core.eval.v2 import progress_tracking # pylint: disable=unused-import
24
+ from langfun.core.eval.v2 import progress_tracking
23
25
  from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
24
26
  import pyglove as pg
25
27
 
@@ -31,6 +33,7 @@ class HtmlProgressTrackerTest(unittest.TestCase):
31
33
  def display(x):
32
34
  result['view'] = x.to_html()
33
35
 
36
+ self.assertFalse(progress_tracking._HtmlProgressTracker.is_per_example())
34
37
  lf_console._notebook = pg.Dict(
35
38
  display=display
36
39
  )
@@ -44,11 +47,14 @@ class HtmlProgressTrackerTest(unittest.TestCase):
44
47
  class TqdmProgressTrackerTest(unittest.TestCase):
45
48
 
46
49
  def test_basic(self):
50
+ self.assertFalse(progress_tracking._TqdmProgressTracker.is_per_example())
47
51
  root_dir = os.path.join(tempfile.mkdtemp(), 'test_tqdm_progress_tracker')
48
52
  experiment = eval_test_helper.test_experiment()
49
53
  string_io = io.StringIO()
50
54
  with contextlib.redirect_stderr(string_io):
51
55
  _ = experiment.run(root_dir, 'new', plugins=[])
56
+ sys.stderr.flush()
57
+ lf_concurrent.ProgressBar.refresh()
52
58
  self.assertIn('All: 100%', string_io.getvalue())
53
59
 
54
60
  def test_with_example_ids(self):
@@ -59,6 +65,8 @@ class TqdmProgressTrackerTest(unittest.TestCase):
59
65
  string_io = io.StringIO()
60
66
  with contextlib.redirect_stderr(string_io):
61
67
  _ = experiment.run(root_dir, 'new', example_ids=[1], plugins=[])
68
+ sys.stderr.flush()
69
+ lf_concurrent.ProgressBar.refresh()
62
70
  self.assertIn('All: 100%', string_io.getvalue())
63
71
 
64
72
 
@@ -32,8 +32,97 @@ _SUMMARY_FILE = 'summary.html'
32
32
  _EVALULATION_DETAIL_FILE = 'index.html'
33
33
 
34
34
 
35
+ class ExampleHtmlGenerator(experiment_lib.Plugin):
36
+ """Plugin for generating HTML views for each evaluation example."""
37
+
38
+ def on_example_complete(
39
+ self, runner: Runner, experiment: Experiment, example: Example
40
+ ):
41
+ self._save_example_html(runner, experiment, example)
42
+
43
+ def _save_example_html(
44
+ self, runner: Runner, experiment: Experiment, example: Example
45
+ ) -> None:
46
+ """Saves the example in HTML format."""
47
+ current_run = runner.current_run
48
+ def _generate():
49
+ try:
50
+ with pg.timeit() as t:
51
+ html = example.to_html(
52
+ collapse_level=None,
53
+ enable_summary_tooltip=False,
54
+ extra_flags=dict(
55
+ # For properly rendering the next link.
56
+ num_examples=getattr(experiment, 'num_examples', None)
57
+ ),
58
+ )
59
+ html.save(
60
+ runner.current_run.output_path_for(
61
+ experiment, f'{example.id}.html'
62
+ )
63
+ )
64
+ experiment.info(
65
+ f'\'{example.id}.html\' generated in {t.elapse:.2f} seconds. '
66
+ )
67
+ except BaseException as e: # pylint: disable=broad-except
68
+ experiment.error(
69
+ f'Failed to generate \'{example.id}.html\'. '
70
+ f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.',
71
+ )
72
+ raise e
73
+
74
+ def _copy():
75
+ src_file = current_run.input_path_for(experiment, f'{example.id}.html')
76
+ dest_file = current_run.output_path_for(experiment, f'{example.id}.html')
77
+
78
+ if src_file == dest_file:
79
+ return
80
+
81
+ if not pg.io.path_exists(src_file):
82
+ experiment.warning(
83
+ f'Skip copying \'{example.id}.html\' as '
84
+ f'{src_file!r} does not exist.'
85
+ )
86
+ return
87
+
88
+ try:
89
+ with pg.timeit() as t, pg.io.open(src_file, 'r') as src:
90
+ content = src.read()
91
+ with pg.io.open(dest_file, 'w') as dest:
92
+ dest.write(content)
93
+ experiment.info(
94
+ f'\'{example.id}.html\' copied in {t.elapse:.2f} seconds.'
95
+ )
96
+ except BaseException as e: # pylint: disable=broad-except
97
+ experiment.error(
98
+ f'Failed to copy {src_file!r} to {dest_file!r}. Error: {e}.'
99
+ )
100
+ raise e
101
+
102
+ generate_example_html = current_run.generate_example_html
103
+ if (generate_example_html == 'all'
104
+ or (generate_example_html == 'new' and example.newly_processed)
105
+ or (isinstance(generate_example_html, list)
106
+ and example.id in generate_example_html)):
107
+ op = _generate
108
+ else:
109
+ op = _copy
110
+ runner.background_run(op)
111
+
112
+
35
113
  class HtmlReporter(experiment_lib.Plugin):
36
- """Plugin for periodically generating HTML reports for the experiment."""
114
+ """Plugin for periodically generating HTML reports for the experiment.
115
+
116
+ The `HtmlReporter` plugin generates several HTML files during an experiment
117
+ run:
118
+ - A `summary.html` at the root of the run directory, summarizing all
119
+ evaluations in the experiment.
120
+ - An `index.html` for each leaf evaluation, detailing the evaluation
121
+ definition, metrics, and logs.
122
+
123
+ These reports are updated periodically in the background during the run,
124
+ allowing users to monitor progress in near real-time.
125
+ """
37
126
 
38
127
  summary_interval: Annotated[
39
128
  int,
@@ -127,7 +216,6 @@ class HtmlReporter(experiment_lib.Plugin):
127
216
  def on_example_complete(
128
217
  self, runner: Runner, experiment: Experiment, example: Example
129
218
  ):
130
- self._save_example_html(runner, experiment, example)
131
219
  self._maybe_update_experiment_html(runner, experiment)
132
220
  self._maybe_update_summary(runner)
133
221
 
@@ -197,72 +285,3 @@ class HtmlReporter(experiment_lib.Plugin):
197
285
  runner.background_run(_save)
198
286
  else:
199
287
  _save()
200
-
201
- def _save_example_html(
202
- self, runner: Runner, experiment: Experiment, example: Example
203
- ) -> None:
204
- """Saves the example in HTML format."""
205
- current_run = runner.current_run
206
- def _generate():
207
- try:
208
- with pg.timeit() as t:
209
- html = example.to_html(
210
- collapse_level=None,
211
- enable_summary_tooltip=False,
212
- extra_flags=dict(
213
- # For properly rendering the next link.
214
- num_examples=getattr(experiment, 'num_examples', None)
215
- ),
216
- )
217
- html.save(
218
- runner.current_run.output_path_for(
219
- experiment, f'{example.id}.html'
220
- )
221
- )
222
- experiment.info(
223
- f'\'{example.id}.html\' generated in {t.elapse:.2f} seconds. '
224
- )
225
- except BaseException as e: # pylint: disable=broad-except
226
- experiment.error(
227
- f'Failed to generate \'{example.id}.html\'. '
228
- f'Error: {e}, Stacktrace: \n{traceback.format_exc()}.',
229
- )
230
- raise e
231
-
232
- def _copy():
233
- src_file = current_run.input_path_for(experiment, f'{example.id}.html')
234
- dest_file = current_run.output_path_for(experiment, f'{example.id}.html')
235
-
236
- if src_file == dest_file:
237
- return
238
-
239
- if not pg.io.path_exists(src_file):
240
- experiment.warning(
241
- f'Skip copying \'{example.id}.html\' as '
242
- f'{src_file!r} does not exist.'
243
- )
244
- return
245
-
246
- try:
247
- with pg.timeit() as t, pg.io.open(src_file, 'r') as src:
248
- content = src.read()
249
- with pg.io.open(dest_file, 'w') as dest:
250
- dest.write(content)
251
- experiment.info(
252
- f'\'{example.id}.html\' copied in {t.elapse:.2f} seconds.'
253
- )
254
- except BaseException as e: # pylint: disable=broad-except
255
- experiment.error(
256
- f'Failed to copy {src_file!r} to {dest_file!r}. Error: {e}.'
257
- )
258
- raise e
259
-
260
- generate_example_html = current_run.generate_example_html
261
- if (generate_example_html == 'all'
262
- or (generate_example_html == 'new' and example.newly_processed)
263
- or (isinstance(generate_example_html, list)
264
- and example.id in generate_example_html)):
265
- op = _generate
266
- else:
267
- op = _copy
268
- runner.background_run(op)
@@ -29,7 +29,16 @@ class ReportingTest(unittest.TestCase):
29
29
  experiment = eval_test_helper.test_experiment()
30
30
  checkpointer = checkpointing.BulkCheckpointer('checkpoint.jsonl')
31
31
  reporter = reporting.HtmlReporter()
32
- run = experiment.run(root_dir, 'new', plugins=[checkpointer, reporter])
32
+ self.assertFalse(reporter.is_per_example())
33
+
34
+ example_html_generator = reporting.ExampleHtmlGenerator()
35
+ self.assertTrue(example_html_generator.is_per_example())
36
+
37
+ run = experiment.run(
38
+ root_dir,
39
+ 'new',
40
+ plugins=[checkpointer, reporter, example_html_generator]
41
+ )
33
42
  self.assertTrue(
34
43
  pg.io.path_exists(os.path.join(run.output_root, 'summary.html'))
35
44
  )
@@ -52,8 +61,10 @@ class ReportingTest(unittest.TestCase):
52
61
  root_dir = os.path.join(tempfile.mkdtemp(), 'test_reporting2')
53
62
  experiment = eval_test_helper.test_experiment()
54
63
  run = experiment.run(
55
- root_dir, 'new', plugins=[checkpointer, reporter],
56
- warm_start_from=run.output_root
64
+ root_dir,
65
+ 'new',
66
+ plugins=[checkpointer, reporter, example_html_generator],
67
+ warm_start_from=run.output_root,
57
68
  )
58
69
  self.assertTrue(
59
70
  pg.io.path_exists(os.path.join(run.output_root, 'summary.html'))
@@ -105,7 +116,12 @@ class ReportingTest(unittest.TestCase):
105
116
  .test_experiment_with_example_html_generation_error())
106
117
  checkpointer = checkpointing.BulkCheckpointer('checkpoint.jsonl')
107
118
  reporter = reporting.HtmlReporter()
108
- run = experiment.run(root_dir, 'new', plugins=[checkpointer, reporter])
119
+ example_html_generator = reporting.ExampleHtmlGenerator()
120
+ run = experiment.run(
121
+ root_dir,
122
+ 'new',
123
+ plugins=[checkpointer, reporter, example_html_generator]
124
+ )
109
125
  self.assertTrue(
110
126
  pg.io.path_exists(os.path.join(run.output_root, 'summary.html'))
111
127
  )
@@ -132,8 +148,10 @@ class ReportingTest(unittest.TestCase):
132
148
  experiment = (eval_test_helper
133
149
  .test_experiment_with_example_html_generation_error())
134
150
  run = experiment.run(
135
- root_dir, 'new', plugins=[checkpointer, reporter],
136
- warm_start_from=run.output_root
151
+ root_dir,
152
+ 'new',
153
+ plugins=[checkpointer, reporter, example_html_generator],
154
+ warm_start_from=run.output_root,
137
155
  )
138
156
  self.assertTrue(
139
157
  pg.io.path_exists(os.path.join(run.output_root, 'summary.html'))
@@ -0,0 +1,30 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Langfun evaluation runners."""
15
+
16
+ # pylint: disable=g-importing-member
17
+ from langfun.core.eval.v2.runners.base import RunnerBase
18
+ from langfun.core.eval.v2.runners.beam import BeamRunner
19
+ from langfun.core.eval.v2.runners.debug import DebugRunner
20
+ from langfun.core.eval.v2.runners.parallel import ParallelRunner
21
+ from langfun.core.eval.v2.runners.sequential import SequentialRunner
22
+ # pylint: enable=g-importing-member
23
+
24
+ __all__ = [
25
+ 'RunnerBase',
26
+ 'BeamRunner',
27
+ 'DebugRunner',
28
+ 'ParallelRunner',
29
+ 'SequentialRunner',
30
+ ]