langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. langfun/__init__.py +22 -2
  2. langfun/core/__init__.py +17 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -28
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +69 -2
  18. langfun/core/component_test.py +54 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +17 -0
  24. langfun/core/eval/base.py +767 -140
  25. langfun/core/eval/base_test.py +238 -53
  26. langfun/core/eval/matching.py +80 -76
  27. langfun/core/eval/matching_test.py +19 -9
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +37 -28
  31. langfun/core/eval/scoring_test.py +21 -3
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +3 -21
  55. langfun/core/langfunc_test.py +26 -8
  56. langfun/core/language_model.py +686 -48
  57. langfun/core/language_model_test.py +681 -44
  58. langfun/core/llms/__init__.py +100 -12
  59. langfun/core/llms/anthropic.py +488 -0
  60. langfun/core/llms/anthropic_test.py +235 -0
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +88 -28
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +39 -26
  69. langfun/core/llms/fake_test.py +136 -11
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -197
  74. langfun/core/llms/groq.py +276 -0
  75. langfun/core/llms/groq_test.py +64 -0
  76. langfun/core/llms/llama_cpp.py +15 -40
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +436 -226
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +35 -174
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -23
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +15 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +9 -8
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +278 -0
  112. langfun/core/structured/function_generation_test.py +399 -0
  113. langfun/core/structured/mapping.py +150 -46
  114. langfun/core/structured/mapping_test.py +105 -0
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +71 -22
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
  119. langfun/core/structured/schema.py +208 -99
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_generation_test.py +2 -2
  122. langfun/core/structured/schema_test.py +133 -34
  123. langfun/core/structured/scoring.py +125 -19
  124. langfun/core/structured/scoring_test.py +30 -0
  125. langfun/core/structured/tokenization.py +64 -0
  126. langfun/core/structured/tokenization_test.py +48 -0
  127. langfun/core/template.py +240 -11
  128. langfun/core/template_test.py +146 -1
  129. langfun/core/templates/conversation.py +9 -0
  130. langfun/core/templates/conversation_test.py +4 -3
  131. langfun/core/templates/selfplay_test.py +14 -2
  132. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  133. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  134. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  135. langfun/core/coding/python/errors.py +0 -108
  136. langfun/core/coding/python/errors_test.py +0 -99
  137. langfun/core/coding/python/permissions.py +0 -90
  138. langfun/core/coding/python/permissions_test.py +0 -86
  139. langfun/core/structured/prompting.py +0 -217
  140. langfun/core/text_formatting.py +0 -162
  141. langfun/core/text_formatting_test.py +0 -47
  142. langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
  143. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  144. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  145. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,348 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Progress reporting for evaluation."""
15
+
16
+ import datetime
17
+ import threading
18
+ import time
19
+ from typing import Annotated, Any
20
+ import pyglove as pg
21
+
22
+
23
+ class Progress(pg.Object, pg.views.HtmlTreeView.Extension):
24
+ """Evaluation progress."""
25
+
26
+ num_total: Annotated[
27
+ int | None,
28
+ (
29
+ 'Total number of items to be processed. '
30
+ 'If None, the progress is not started.'
31
+ )
32
+ ] = None
33
+ num_processed: Annotated[
34
+ int,
35
+ (
36
+ 'Number of items that have been processed without errors.'
37
+ )
38
+ ] = 0
39
+ num_failed: Annotated[
40
+ int,
41
+ (
42
+ 'Number of items that have failed.'
43
+ )
44
+ ] = 0
45
+ num_skipped: Annotated[
46
+ int,
47
+ (
48
+ 'Number of items that have been skipped.'
49
+ )
50
+ ] = 0
51
+ start_time: Annotated[
52
+ float | None,
53
+ (
54
+ 'The start time of the progress. '
55
+ 'If None, the progress is not started.'
56
+ )
57
+ ] = None
58
+ stop_time: Annotated[
59
+ float | None,
60
+ (
61
+ 'The stop time of the progress. '
62
+ 'If None, the progress is not stopped.'
63
+ )
64
+ ] = None
65
+ execution_summary: Annotated[
66
+ pg.object_utils.TimeIt.StatusSummary,
67
+ 'The execution summary of the progress.'
68
+ ] = pg.object_utils.TimeIt.StatusSummary()
69
+
70
+ def _on_bound(self):
71
+ super()._on_bound()
72
+ self._progress_bar = None
73
+ self._time_label = None
74
+ self._lock = threading.Lock()
75
+
76
+ def reset(self) -> None:
77
+ """Resets the progress."""
78
+ self._sync_members(
79
+ num_total=None,
80
+ num_processed=0,
81
+ num_failed=0,
82
+ num_skipped=0,
83
+ start_time=None,
84
+ stop_time=None,
85
+ execution_summary=pg.object_utils.TimeIt.StatusSummary(),
86
+ )
87
+
88
+ @property
89
+ def num_completed(self) -> int:
90
+ """Returns the number of completed examples."""
91
+ return self.num_processed + self.num_failed + self.num_skipped
92
+
93
+ def __float__(self) -> float:
94
+ """Returns the complete rate in range [0, 1]."""
95
+ if self.num_total is None:
96
+ return float('nan')
97
+ return self.num_completed / self.num_total
98
+
99
+ @property
100
+ def is_started(self) -> bool:
101
+ """Returns whether the evaluation is started."""
102
+ return self.start_time is not None
103
+
104
+ @property
105
+ def is_stopped(self) -> bool:
106
+ """Returns whether the evaluation is stopped."""
107
+ return self.stop_time is not None
108
+
109
+ @property
110
+ def is_completed(self) -> bool:
111
+ """Returns whether the evaluation is completed."""
112
+ return (
113
+ self.num_total is not None
114
+ and self.num_completed == self.num_total
115
+ )
116
+
117
+ @property
118
+ def is_skipped(self) -> bool:
119
+ """Returns whether the evaluation is skipped."""
120
+ return (
121
+ self.num_total is not None
122
+ and self.num_skipped == self.num_total
123
+ )
124
+
125
+ @property
126
+ def is_failed(self) -> bool:
127
+ """Returns whether the evaluation is failed."""
128
+ return (
129
+ self.num_failed > 0
130
+ and self.num_failed + self.num_skipped == self.num_total
131
+ )
132
+
133
+ @property
134
+ def elapse(self) -> float | None:
135
+ """Returns the elapse time in seconds."""
136
+ if self.start_time is None:
137
+ return None
138
+ if self.stop_time is None:
139
+ return time.time() - self.start_time
140
+ return self.stop_time - self.start_time
141
+
142
+ @property
143
+ def start_time_str(self) -> str | None:
144
+ """Returns the start time string of the evaluation."""
145
+ if self.start_time is None:
146
+ return None
147
+ return time.strftime('%Y/%m/%d %H:%M:%S', time.localtime(self.start_time))
148
+
149
+ @property
150
+ def stop_time_str(self) -> str | None:
151
+ """Returns the complete time string of the evaluation."""
152
+ if self.stop_time is None:
153
+ return None
154
+ return time.strftime(
155
+ '%Y/%m/%d %H:%M:%S', time.localtime(self.stop_time)
156
+ )
157
+
158
+ def start(self, total: int) -> None:
159
+ """Marks the evaluation as started."""
160
+ assert self.start_time is None, self
161
+ self._sync_members(start_time=time.time(), num_total=total)
162
+ if self._progress_bar is not None:
163
+ self._progress_bar.update(total=total)
164
+ self._update_time_label()
165
+
166
+ def stop(self) -> None:
167
+ """Marks the evaluation as stopped."""
168
+ assert self.stop_time is None, self
169
+ self._sync_members(stop_time=time.time())
170
+ self._update_time_label()
171
+
172
+ def _sync_members(self, **kwargs: Any):
173
+ """Synchronizes the members of the progress."""
174
+ self.rebind(
175
+ **kwargs,
176
+ skip_notification=True,
177
+ raise_on_no_change=False,
178
+ )
179
+
180
+ def increment_processed(self, delta: int = 1) -> None:
181
+ """Updates the number of processed examples."""
182
+ assert self.is_started and not self.is_stopped, self
183
+ with self._lock:
184
+ self._sync_members(num_processed=self.num_processed + delta)
185
+ if self._progress_bar is not None:
186
+ self._progress_bar['Processed'].increment(delta)
187
+ self._update_time_label()
188
+
189
+ def increment_failed(self, delta: int = 1) -> None:
190
+ """Updates the number of failed examples."""
191
+ assert self.is_started and not self.is_stopped, self
192
+ with self._lock:
193
+ self._sync_members(num_failed=self.num_failed + delta)
194
+ if self._progress_bar is not None:
195
+ self._progress_bar['Failed'].increment(delta)
196
+ self._update_time_label()
197
+
198
+ def increment_skipped(self, delta: int = 1) -> None:
199
+ """Updates the number of skipped examples."""
200
+ assert self.is_started and not self.is_stopped, self
201
+ with self._lock:
202
+ self._sync_members(num_skipped=self.num_skipped + delta)
203
+ if self._progress_bar is not None:
204
+ self._progress_bar['Skipped'].increment(delta)
205
+ self._update_time_label()
206
+
207
+ def update_execution_summary(
208
+ self,
209
+ execution_status: dict[str, pg.object_utils.TimeIt.Status]
210
+ ) -> None:
211
+ """Updates the execution summary of the progress."""
212
+ with self._lock:
213
+ self.execution_summary.aggregate(execution_status)
214
+
215
+ def _sym_nondefault(self) -> dict[str, Any]:
216
+ """Overrides nondefault values so volatile values are not included."""
217
+ return dict()
218
+
219
+ #
220
+ # HTML view.
221
+ #
222
+
223
+ def _duration_text(self) -> str:
224
+ if self.start_time is None:
225
+ return '00:00:00'
226
+ return str(datetime.timedelta(seconds=self.elapse)).split('.')[0]
227
+
228
+ def _time_tooltip(self) -> pg.Html.WritableTypes:
229
+ time_info = pg.Dict(
230
+ duration=self._duration_text(),
231
+ last_update=(
232
+ time.strftime( # pylint: disable=g-long-ternary
233
+ '%Y/%m/%d %H:%M:%S',
234
+ time.localtime(time.time())
235
+ ) if not self.is_stopped else self.stop_time_str
236
+ ),
237
+ start_time=self.start_time_str,
238
+ stop_time=self.stop_time_str,
239
+ )
240
+ if self.execution_summary:
241
+ time_info['execution'] = pg.Dict(
242
+ {
243
+ k: pg.Dict(
244
+ num_started=v.num_started,
245
+ num_ended=v.num_ended,
246
+ num_failed=v.num_failed,
247
+ avg_duration=round(v.avg_duration, 2),
248
+ ) for k, v in self.execution_summary.breakdown.items()
249
+ }
250
+ )
251
+ return pg.format(time_info, verbose=False)
252
+
253
+ def _html_tree_view(
254
+ self,
255
+ *,
256
+ view: pg.views.HtmlTreeView,
257
+ extra_flags: dict[str, Any] | None = None,
258
+ **kwargs
259
+ ) -> pg.Html:
260
+ """Renders the content of the progress bar."""
261
+ def _progress_bar():
262
+ return pg.views.html.controls.ProgressBar(
263
+ [
264
+ pg.views.html.controls.SubProgress(
265
+ name='Skipped', value=self.num_skipped,
266
+ ),
267
+ pg.views.html.controls.SubProgress(
268
+ name='Processed', value=self.num_processed,
269
+ ),
270
+ pg.views.html.controls.SubProgress(
271
+ name='Failed', value=self.num_failed,
272
+ ),
273
+ ],
274
+ total=self.num_total,
275
+ interactive=interactive,
276
+ )
277
+
278
+ def _time_label():
279
+ css_class = 'not-started'
280
+ if self.is_started and not self.is_stopped:
281
+ css_class = 'started'
282
+ elif self.is_stopped:
283
+ css_class = 'stopped'
284
+ return pg.views.html.controls.Label(
285
+ self._duration_text(),
286
+ tooltip=self._time_tooltip(),
287
+ css_classes=[
288
+ 'progress-time', css_class
289
+ ],
290
+ interactive=interactive,
291
+ )
292
+
293
+ extra_flags = extra_flags or {}
294
+ interactive = extra_flags.pop('interactive', True)
295
+ if interactive:
296
+ if self._progress_bar is None:
297
+ self._progress_bar = _progress_bar()
298
+ if self._time_label is None:
299
+ self._time_label = _time_label()
300
+ progress_bar = self._progress_bar
301
+ time_label = self._time_label
302
+ else:
303
+ progress_bar = _progress_bar()
304
+ time_label = _time_label()
305
+ return pg.Html.element(
306
+ 'div', [progress_bar, time_label], css_classes=['eval-progress'],
307
+ )
308
+
309
+ def _update_time_label(self):
310
+ """Updates the time label of the progress."""
311
+ if self._time_label is None:
312
+ return
313
+ self._time_label.update(
314
+ text=self._duration_text(),
315
+ tooltip=self._time_tooltip(),
316
+ styles=dict(
317
+ color=(
318
+ 'dodgerblue' if self.is_started
319
+ and not self.is_stopped else '#ccc'
320
+ ),
321
+ ),
322
+ )
323
+
324
+ @classmethod
325
+ def _html_tree_view_css_styles(cls) -> list[str]:
326
+ return super()._html_tree_view_css_styles() + [
327
+ """
328
+ .eval-progress {
329
+ display: inline-block;
330
+ }
331
+ .sub-progress.skipped {
332
+ background-color:yellow;
333
+ }
334
+ .sub-progress.processed {
335
+ background-color:#00B000;
336
+ }
337
+ .sub-progress.failed {
338
+ background-color:red;
339
+ }
340
+ .progress-time {
341
+ font-weight: normal;
342
+ margin-left: 10px;
343
+ border-radius: 5px;
344
+ color: #CCC;
345
+ padding: 5px;
346
+ }
347
+ """
348
+ ]
@@ -0,0 +1,82 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ import time
16
+ import unittest
17
+
18
+ from langfun.core.eval.v2 import progress as progress_lib
19
+ import pyglove as pg
20
+
21
+ Progress = progress_lib.Progress
22
+
23
+
24
+ class ProgressTest(unittest.TestCase):
25
+
26
+ def test_basic(self):
27
+ p = Progress()
28
+ self.assertFalse(p.is_started)
29
+ self.assertFalse(p.is_stopped)
30
+ self.assertFalse(p.is_completed)
31
+ self.assertFalse(p.is_skipped)
32
+ self.assertFalse(p.is_failed)
33
+ self.assertEqual(p.num_completed, 0)
34
+ self.assertIsNone(p.elapse)
35
+ self.assertIsNone(p.start_time_str)
36
+ self.assertIsNone(p.stop_time_str)
37
+ self.assertTrue(math.isnan(float(p)))
38
+
39
+ p.start(10)
40
+ self.assertEqual(p.num_total, 10)
41
+ self.assertTrue(p.is_started)
42
+ self.assertFalse(p.is_stopped)
43
+ self.assertIsNotNone(p.start_time_str)
44
+ self.assertIsNotNone(p.elapse)
45
+ self.assertEqual(float(p), 0.0)
46
+
47
+ with pg.views.html.controls.HtmlControl.track_scripts() as scripts:
48
+ p.increment_processed()
49
+ p.increment_failed()
50
+ p.increment_skipped()
51
+ # Does not triggger scripts as progress is not rendered yet.
52
+ self.assertEqual(len(scripts), 0)
53
+ self.assertEqual(p.num_completed, 3)
54
+ self.assertIn(
55
+ '3/10',
56
+ p.to_html(extra_flags=dict(interactive=True)).content,
57
+ )
58
+ with pg.views.html.controls.HtmlControl.track_scripts() as scripts:
59
+ p.increment_processed()
60
+ p.increment_failed()
61
+ p.increment_skipped()
62
+ self.assertEqual(len(scripts), 24)
63
+ self.assertEqual(p.num_completed, 6)
64
+ self.assertIn(
65
+ '6/10',
66
+ p.to_html(extra_flags=dict(interactive=False)).content,
67
+ )
68
+ p.update_execution_summary(
69
+ dict(
70
+ evaluate=pg.object_utils.TimeIt.Status(name='evaluate', elapse=1.0)
71
+ )
72
+ )
73
+ p.stop()
74
+ elapse1 = p.elapse
75
+ time.sleep(0.1)
76
+ self.assertEqual(p.elapse, elapse1)
77
+ self.assertTrue(p.is_stopped)
78
+ self.assertIsNotNone(p.stop_time_str)
79
+
80
+
81
+ if __name__ == '__main__':
82
+ unittest.main()
@@ -0,0 +1,210 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tracking evaluation run progress."""
15
+
16
+ import langfun.core as lf
17
+ from langfun.core.eval.v2 import example as example_lib
18
+ from langfun.core.eval.v2 import experiment as experiment_lib
19
+ import pyglove as pg
20
+
21
+ Runner = experiment_lib.Runner
22
+ Experiment = experiment_lib.Experiment
23
+ Example = example_lib.Example
24
+
25
+
26
+ def progress_tracker(tqdm: bool = False) -> experiment_lib.Plugin:
27
+ """Creates a progress tracker as a plugin.
28
+
29
+ Args:
30
+ tqdm: If True, force using tqdm for progress update.
31
+
32
+ Returns:
33
+ The progress tracker plugin.
34
+ """
35
+ if tqdm or not lf.console.under_notebook():
36
+ return _TqdmProgressTracker()
37
+ else:
38
+ return _HtmlProgressTracker()
39
+
40
+
41
+ class _HtmlProgressTracker(experiment_lib.Plugin):
42
+ """HTML progress tracker plugin."""
43
+
44
+ def on_run_start(
45
+ self,
46
+ runner: Runner,
47
+ root: Experiment
48
+ ) -> None:
49
+ # Display the experiment if running under notebook and not using tqdm.
50
+ assert lf.console.under_notebook()
51
+ with pg.view_options(
52
+ collapse_level=None,
53
+ extra_flags=dict(
54
+ current_run=runner.current_run,
55
+ interactive=True,
56
+ )
57
+ ):
58
+ lf.console.display(runner.current_run.experiment)
59
+
60
+
61
+ ProgressBarId = int
62
+
63
+
64
+ class _TqdmProgressTracker(experiment_lib.Plugin):
65
+ """Tqdm process updater plugin."""
66
+
67
+ def _on_bound(self):
68
+ super()._on_bound()
69
+ self._overall_progress: ProgressBarId | None = None
70
+ self._leaf_progresses: dict[str, ProgressBarId] = {}
71
+
72
+ def experiment_progress(
73
+ self, experiment: Experiment) -> lf.concurrent.ProgressBar:
74
+ """Returns the progress of the experiment."""
75
+ assert experiment.is_leaf
76
+ return self._leaf_progresses[experiment.id]
77
+
78
+ def on_run_start(
79
+ self,
80
+ runner: Runner,
81
+ root: Experiment
82
+ ) -> None:
83
+ """Called when a runner is started."""
84
+ self._overall_progress = lf.concurrent.ProgressBar.install(
85
+ label='All', total=len(root.leaf_nodes), color='blue'
86
+ )
87
+ self._leaf_progresses = {
88
+ leaf.id: lf.concurrent.ProgressBar.install(
89
+ label=f'[#{i + 1} - {leaf.id}]',
90
+ total=(len(runner.current_run.example_ids)
91
+ if runner.current_run.example_ids else leaf.num_examples),
92
+ color='cyan',
93
+ status=None
94
+ )
95
+ for i, leaf in enumerate(root.leaf_nodes)
96
+ }
97
+ summary_link = Experiment.link(
98
+ runner.current_run.output_path_for(root, 'summary.html')
99
+ )
100
+ lf.console.write(f'Summary: {summary_link}.', color='green')
101
+
102
+ def on_run_complete(
103
+ self,
104
+ runner: Runner,
105
+ root: Experiment
106
+ ) -> None:
107
+ """Called when a runner is complete."""
108
+ lf.concurrent.ProgressBar.update(
109
+ self._overall_progress,
110
+ color='green',
111
+ status='ALL COMPLETED.',
112
+ )
113
+ lf.concurrent.ProgressBar.uninstall(self._overall_progress)
114
+ self._overall_progress = None
115
+ for progress in self._leaf_progresses.values():
116
+ lf.concurrent.ProgressBar.uninstall(progress)
117
+ self._leaf_progresses = {}
118
+
119
+ def on_experiment_start(
120
+ self,
121
+ runner: Runner,
122
+ experiment: Experiment
123
+ ) -> None:
124
+ """Called when an evaluation is started."""
125
+
126
+ def on_experiment_skipped(
127
+ self,
128
+ runner: Runner,
129
+ experiment: Experiment
130
+ ) -> None:
131
+ """Called when an evaluation is skipped."""
132
+ if experiment.is_leaf:
133
+ lf.concurrent.ProgressBar.update(
134
+ self.experiment_progress(experiment),
135
+ delta=experiment.progress.num_total,
136
+ status='Skipped.',
137
+ )
138
+ lf.concurrent.ProgressBar.update(
139
+ self._overall_progress,
140
+ status=f'Skipped {experiment.id}.',
141
+ )
142
+
143
+ def on_experiment_complete(
144
+ self,
145
+ runner: Runner,
146
+ experiment: Experiment
147
+ ) -> None:
148
+ """Called when an evaluation is complete."""
149
+ if experiment.is_leaf:
150
+ lf.concurrent.ProgressBar.update(
151
+ self.experiment_progress(experiment),
152
+ color='green',
153
+ )
154
+ lf.concurrent.ProgressBar.update(
155
+ self._overall_progress,
156
+ delta=1,
157
+ status=f'{experiment.id} COMPLETED.',
158
+ )
159
+
160
+ def on_example_start(
161
+ self,
162
+ runner: Runner,
163
+ experiment: Experiment,
164
+ example: Example
165
+ ) -> None:
166
+ """Called when an evaluation example is started."""
167
+
168
+ def on_example_skipped(
169
+ self,
170
+ runner: Runner,
171
+ experiment: Experiment,
172
+ example: Example
173
+ ) -> None:
174
+ """Called when an evaluation example is skipped."""
175
+ del runner, example
176
+ lf.concurrent.ProgressBar.update(
177
+ self.experiment_progress(experiment),
178
+ delta=1,
179
+ )
180
+
181
+ def on_example_complete(
182
+ self,
183
+ runner: Runner,
184
+ experiment: Experiment,
185
+ example: Example
186
+ ) -> None:
187
+ """Called when an evaluation example is complete."""
188
+ lf.concurrent.ProgressBar.update(
189
+ self.experiment_progress(experiment),
190
+ delta=1,
191
+ status=self.status(experiment),
192
+ )
193
+
194
+ def status(self, experiment: Experiment) -> str:
195
+ """Returns the progress text of the evaluation."""
196
+ items = []
197
+ for metric in experiment.metrics:
198
+ for metric_value in metric.values():
199
+ items.append(
200
+ f'{metric_value.sym_path.key}={metric_value.format(verbose=True)}'
201
+ )
202
+ error_tags = {}
203
+ for entry in experiment.progress.execution_summary.breakdown.values():
204
+ error_tags.update(entry.error_tags)
205
+
206
+ if error_tags:
207
+ items.extend(
208
+ [f'{k}={v}' for k, v in error_tags.items() if v > 0]
209
+ )
210
+ return ', '.join(items)
@@ -0,0 +1,66 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import contextlib
15
+ import io
16
+ import os
17
+ import tempfile
18
+ import unittest
19
+
20
+ from langfun.core import console as lf_console
21
+ from langfun.core.eval.v2 import eval_test_helper
22
+ from langfun.core.eval.v2 import progress_tracking # pylint: disable=unused-import
23
+ from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
24
+ import pyglove as pg
25
+
26
+
27
+ class HtmlProgressTrackerTest(unittest.TestCase):
28
+
29
+ def test_track_progress(self):
30
+ result = pg.Dict()
31
+ def display(x):
32
+ result['view'] = x.to_html()
33
+
34
+ lf_console._notebook = pg.Dict(
35
+ display=display
36
+ )
37
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_html_progress_tracker')
38
+ experiment = eval_test_helper.test_experiment()
39
+ _ = experiment.run(root_dir, 'new', plugins=[])
40
+ self.assertIsInstance(result['view'], pg.Html)
41
+ lf_console._notebook = None
42
+
43
+
44
+ class TqdmProgressTrackerTest(unittest.TestCase):
45
+
46
+ def test_basic(self):
47
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_tqdm_progress_tracker')
48
+ experiment = eval_test_helper.test_experiment()
49
+ string_io = io.StringIO()
50
+ with contextlib.redirect_stderr(string_io):
51
+ _ = experiment.run(root_dir, 'new', plugins=[])
52
+ self.assertIn('All: 100%', string_io.getvalue())
53
+
54
+ def test_with_example_ids(self):
55
+ root_dir = os.path.join(
56
+ tempfile.gettempdir(), 'test_tqdm_progress_tracker_with_example_ids'
57
+ )
58
+ experiment = eval_test_helper.test_experiment()
59
+ string_io = io.StringIO()
60
+ with contextlib.redirect_stderr(string_io):
61
+ _ = experiment.run(root_dir, 'new', example_ids=[1], plugins=[])
62
+ self.assertIn('All: 100%', string_io.getvalue())
63
+
64
+
65
+ if __name__ == '__main__':
66
+ unittest.main()