langfun 0.1.2.dev202501080804__py3-none-any.whl → 0.1.2.dev202501240804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. langfun/core/__init__.py +1 -6
  2. langfun/core/coding/python/__init__.py +5 -11
  3. langfun/core/coding/python/correction.py +4 -7
  4. langfun/core/coding/python/correction_test.py +2 -3
  5. langfun/core/coding/python/execution.py +22 -211
  6. langfun/core/coding/python/execution_test.py +11 -90
  7. langfun/core/coding/python/generation.py +3 -2
  8. langfun/core/coding/python/generation_test.py +2 -2
  9. langfun/core/coding/python/parsing.py +108 -194
  10. langfun/core/coding/python/parsing_test.py +2 -105
  11. langfun/core/component.py +11 -273
  12. langfun/core/component_test.py +2 -29
  13. langfun/core/concurrent.py +187 -82
  14. langfun/core/concurrent_test.py +28 -19
  15. langfun/core/console.py +7 -3
  16. langfun/core/eval/base.py +2 -3
  17. langfun/core/eval/v2/evaluation.py +3 -1
  18. langfun/core/eval/v2/reporting.py +8 -4
  19. langfun/core/language_model.py +84 -8
  20. langfun/core/language_model_test.py +84 -29
  21. langfun/core/llms/__init__.py +46 -11
  22. langfun/core/llms/anthropic.py +1 -123
  23. langfun/core/llms/anthropic_test.py +0 -48
  24. langfun/core/llms/deepseek.py +117 -0
  25. langfun/core/llms/deepseek_test.py +61 -0
  26. langfun/core/llms/gemini.py +1 -1
  27. langfun/core/llms/groq.py +12 -99
  28. langfun/core/llms/groq_test.py +31 -137
  29. langfun/core/llms/llama_cpp.py +17 -54
  30. langfun/core/llms/llama_cpp_test.py +2 -34
  31. langfun/core/llms/openai.py +9 -147
  32. langfun/core/llms/openai_compatible.py +179 -0
  33. langfun/core/llms/openai_compatible_test.py +495 -0
  34. langfun/core/llms/openai_test.py +13 -423
  35. langfun/core/llms/rest_test.py +1 -1
  36. langfun/core/llms/vertexai.py +387 -18
  37. langfun/core/llms/vertexai_test.py +52 -0
  38. langfun/core/message_test.py +3 -3
  39. langfun/core/modalities/mime.py +8 -0
  40. langfun/core/modalities/mime_test.py +19 -4
  41. langfun/core/modality_test.py +0 -1
  42. langfun/core/structured/mapping.py +13 -13
  43. langfun/core/structured/mapping_test.py +2 -2
  44. langfun/core/structured/schema.py +16 -8
  45. langfun/core/structured/schema_generation.py +1 -1
  46. {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/METADATA +13 -2
  47. {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/RECORD +50 -52
  48. {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/WHEEL +1 -1
  49. langfun/core/coding/python/errors.py +0 -108
  50. langfun/core/coding/python/errors_test.py +0 -99
  51. langfun/core/coding/python/permissions.py +0 -90
  52. langfun/core/coding/python/permissions_test.py +0 -86
  53. langfun/core/text_formatting.py +0 -168
  54. langfun/core/text_formatting_test.py +0 -65
  55. {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/LICENSE +0 -0
  56. {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@
15
15
 
16
16
  import abc
17
17
  import collections
18
+ from collections.abc import Mapping
18
19
  import concurrent.futures
19
20
  import dataclasses
20
21
  import io
@@ -22,10 +23,8 @@ import random
22
23
  import sys
23
24
  import threading
24
25
  import time
25
- from typing import Any, Callable, Iterable, Iterator, Literal, Sequence, Tuple, Type, Union
26
+ from typing import Annotated, Any, Callable, Iterable, Iterator, Literal, Sequence, Tuple, Type, Union
26
27
 
27
- from langfun.core import component
28
- from langfun.core import text_formatting
29
28
  import pyglove as pg
30
29
 
31
30
 
@@ -39,18 +38,6 @@ except ImportError:
39
38
  tqdm = None
40
39
 
41
40
 
42
- def with_context_access(func: Callable[..., Any]) -> Callable[..., Any]:
43
- """Derives a user function with the access to the current context."""
44
- with component.context() as current_context:
45
- pass
46
-
47
- def _func(*args, **kwargs) -> Any:
48
- with component.context(**current_context):
49
- return func(*args, **kwargs)
50
-
51
- return _func
52
-
53
-
54
41
  class RetryError(RuntimeError):
55
42
  """Retry error."""
56
43
 
@@ -144,43 +131,33 @@ def with_retry(
144
131
  A function with the same signature of the input function, with the retry
145
132
  capability.
146
133
  """
147
- rand = random if seed is None else random.Random(seed)
148
-
149
- def _func(*args, **kwargs) -> Any:
150
- def base_interval() -> int:
151
- if isinstance(retry_interval, tuple):
152
- return rand.randint(retry_interval[0], retry_interval[1])
153
- else:
154
- assert isinstance(retry_interval, int)
155
- return retry_interval
156
134
 
157
- def next_wait_interval(attempt: int) -> float:
158
- if not exponential_backoff:
159
- attempt = 1
160
- return min(max_retry_interval, base_interval() * (2 ** (attempt - 1)))
161
-
162
- wait_intervals = []
163
- errors = []
164
- while True:
165
- with pg.catch_errors(retry_on_errors) as error_context:
166
- return func(*args, **kwargs)
135
+ def _func(*args, **kwargs):
136
+ job = Job(
137
+ func,
138
+ args,
139
+ kwargs,
140
+ retry_on_errors=retry_on_errors,
141
+ max_attempts=max_attempts,
142
+ retry_interval=retry_interval,
143
+ exponential_backoff=exponential_backoff,
144
+ max_retry_interval=max_retry_interval,
145
+ seed=seed,
146
+ )
147
+ job()
148
+ if job.error:
149
+ raise job.error
150
+ return job.result
167
151
 
168
- # Branch when errors are met for retry.
169
- errors.append(error_context.error)
170
- if len(errors) < max_attempts:
171
- wait_interval = next_wait_interval(len(errors))
172
- wait_intervals.append(wait_interval)
152
+ return _func
173
153
 
174
- pg.logging.warning(
175
- f'Calling {func!r} encountered {error_context.error!r} '
176
- f'(attempts={len(errors)}), retrying in {wait_interval} seconds...'
177
- )
178
154
 
179
- time.sleep(wait_interval)
180
- else:
181
- raise RetryError(func, errors, wait_intervals)
155
+ class RetryEntry(pg.Object):
156
+ """Retry entry."""
182
157
 
183
- return _func
158
+ call_interval: float
159
+ error: BaseException | None = None
160
+ wait_interval: float = 0.
184
161
 
185
162
 
186
163
  def concurrent_execute(
@@ -198,6 +175,7 @@ def concurrent_execute(
198
175
  retry_interval: int | tuple[int, int] = (5, 60),
199
176
  exponential_backoff: bool = True,
200
177
  max_retry_interval: int = 300,
178
+ return_jobs: bool = False,
201
179
  ) -> list[Any]:
202
180
  """Executes a function concurrently under current component context.
203
181
 
@@ -221,32 +199,53 @@ def concurrent_execute(
221
199
  max_retry_interval: The max retry interval in seconds. This is useful when
222
200
  the retry interval is exponential, to avoid the wait time to grow
223
201
  exponentially.
202
+ return_jobs: If True, return a list of `Job` objects. Otherwise, return a
203
+ list of outputs.
224
204
 
225
205
  Returns:
226
206
  A list of ouputs. Each is the return value of `func` based on the input
227
207
  value. Order is preserved.
228
208
  """
229
- if retry_on_errors is not None:
230
- func = with_retry(
231
- func,
232
- retry_on_errors,
233
- max_attempts=max_attempts,
234
- retry_interval=retry_interval,
235
- exponential_backoff=exponential_backoff,
236
- max_retry_interval=max_retry_interval,
209
+ jobs = []
210
+ for inputs in parallel_inputs:
211
+ jobs.append(
212
+ Job(
213
+ func,
214
+ (inputs,),
215
+ retry_on_errors=retry_on_errors,
216
+ max_attempts=max_attempts,
217
+ retry_interval=retry_interval,
218
+ exponential_backoff=exponential_backoff,
219
+ max_retry_interval=max_retry_interval,
220
+ )
237
221
  )
238
222
 
239
223
  # NOTE(daiyip): when executor is not specified and max_worker is 1,
240
224
  # we don't need to create a executor pool. Instead, the inputs will be
241
225
  # processed by the user function in sequence within the current thread.
242
226
  if executor is None and max_workers == 1:
243
- return [func(i) for i in parallel_inputs]
227
+ for job in jobs:
228
+ job()
229
+ if job.error:
230
+ raise job.error
231
+ return jobs if return_jobs else [job.result for job in jobs]
244
232
 
245
233
  shutdown_after_finish = executor is None
246
234
  executor = _executor_pool.executor_from(executor, max_workers=max_workers)
247
235
 
248
236
  try:
249
- return list(executor.map(with_context_access(func), parallel_inputs))
237
+ executed_jobs = list(
238
+ executor.map(
239
+ lambda job: job(),
240
+ [pg.with_contextual_override(job) for job in jobs]
241
+ )
242
+ )
243
+ for job in executed_jobs:
244
+ if job.error:
245
+ raise job.error
246
+ return (
247
+ executed_jobs if return_jobs else [job.result for job in executed_jobs]
248
+ )
250
249
  finally:
251
250
  if shutdown_after_finish:
252
251
  # Do not wait threads to finish if they are timed out.
@@ -258,9 +257,61 @@ class Job:
258
257
  """Thread pool job."""
259
258
 
260
259
  func: Callable[[Any], Any]
261
- arg: Any
260
+ args: Sequence[Any] = ()
261
+ kwargs: Mapping[str, Any] = dataclasses.field(default_factory=dict)
262
+ _: dataclasses.KW_ONLY
263
+
262
264
  result: Any = pg.MISSING_VALUE
263
- error: BaseException | None = None
265
+ error: Annotated[
266
+ BaseException | None,
267
+ 'The non-retryable error encountered during the job execution.',
268
+ ] = None
269
+ retry_entries: Annotated[
270
+ Sequence[RetryEntry], 'Records of retry attempts.'
271
+ ] = dataclasses.field(default_factory=list)
272
+
273
+ retry_on_errors: Annotated[
274
+ Sequence[Type[BaseException] | str],
275
+ (
276
+ 'A sequence of exception types or tuples of exception type and error '
277
+ 'messages (described in regular expression) as the desired exception '
278
+ 'types to retry.'
279
+ ),
280
+ ] = ()
281
+ max_attempts: Annotated[
282
+ int, 'Max number of attempts if an error to retry is encountered.'
283
+ ] = 5
284
+ retry_interval: Annotated[
285
+ int | tuple[int, int],
286
+ (
287
+ 'The (base) retry interval in seconds. If a tuple, the retry '
288
+ 'interval will be randomly chosen between the first and the second '
289
+ 'element of the tuple.'
290
+ ),
291
+ ] = (5, 60)
292
+ exponential_backoff: Annotated[
293
+ bool,
294
+ (
295
+ 'If True, exponential wait time will be applied on top of the base '
296
+ 'retry interval.'
297
+ ),
298
+ ] = True
299
+ max_retry_interval: Annotated[
300
+ int,
301
+ (
302
+ 'The max retry interval in seconds. This is useful when the retry '
303
+ 'interval is exponential, to avoid the wait time to grow '
304
+ 'exponentially.'
305
+ ),
306
+ ] = 300
307
+ seed: Annotated[
308
+ int | None,
309
+ (
310
+ 'Random seed to generate retry interval. If None, the seed will be'
311
+ ' determined based on current time.'
312
+ ),
313
+ ] = None
314
+
264
315
  timeit: pg.object_utils.TimeIt = dataclasses.field(
265
316
  default_factory=lambda: pg.object_utils.TimeIt('job')
266
317
  )
@@ -270,14 +321,70 @@ class Job:
270
321
  """Returns the running time in seconds since the job get started."""
271
322
  return self.timeit.elapse
272
323
 
273
- def __call__(self) -> Any:
324
+ def _retry_call(self) -> 'Job':
325
+ """Retries func call on args."""
326
+ rand = random if self.seed is None else random.Random(self.seed)
327
+
328
+ def base_interval() -> int:
329
+ if isinstance(self.retry_interval, tuple):
330
+ return rand.randint(*self.retry_interval)
331
+ else:
332
+ assert isinstance(self.retry_interval, int)
333
+ return self.retry_interval
334
+
335
+ def next_wait_interval(attempt: int) -> float:
336
+ if not self.exponential_backoff:
337
+ attempt = 1
338
+ return min(
339
+ self.max_retry_interval, base_interval() * (2 ** (attempt - 1))
340
+ )
341
+
342
+ retry_entries = []
343
+ wait_interval = 0
344
+ while True:
345
+ with pg.catch_errors(self.retry_on_errors) as error_context:
346
+ begin_time = time.time()
347
+ self.result = self.func(*self.args, **self.kwargs)
348
+
349
+ end_time = time.time()
350
+ retry_entries.append(RetryEntry(
351
+ call_interval=end_time - begin_time,
352
+ wait_interval=wait_interval,
353
+ error=error_context.error,
354
+ ))
355
+ if error_context.error is None:
356
+ self.retry_entries = retry_entries
357
+ return self
358
+
359
+ # Branch when errors are met for retry.
360
+ if len(retry_entries) < self.max_attempts:
361
+ wait_interval = next_wait_interval(len(retry_entries))
362
+
363
+ pg.logging.warning(
364
+ f'Calling {self.func!r} encountered {error_context.error!r} '
365
+ f'(attempts={len(retry_entries)}), retrying in '
366
+ f'{wait_interval} seconds...'
367
+ )
368
+
369
+ time.sleep(wait_interval)
370
+ else:
371
+ errors = [e.error for e in retry_entries]
372
+ # First wait interval is 0.
373
+ wait_intervals = [e.wait_interval for e in retry_entries[1:]]
374
+ raise RetryError(self.func, errors, wait_intervals)
375
+
376
+ def __call__(self) -> 'Job':
377
+ if getattr(self, '_has_call', False):
378
+ raise ValueError('Job can only be called once.')
379
+ self._has_call = True
274
380
  try:
275
381
  with self.timeit:
276
- self.result = self.func(self.arg)
277
- return self.result
382
+ if self.retry_on_errors:
383
+ return self._retry_call()
384
+ self.result = self.func(*self.args, **self.kwargs)
278
385
  except BaseException as e: # pylint: disable=broad-exception-caught
279
386
  self.error = e
280
- return e
387
+ return self
281
388
 
282
389
  def mark_canceled(self, error: BaseException) -> None:
283
390
  """Marks the job as canceled."""
@@ -538,7 +645,8 @@ def concurrent_map(
538
645
  max_attempts: int = 5,
539
646
  retry_interval: int | tuple[int, int] = (5, 60),
540
647
  exponential_backoff: bool = True,
541
- ) -> Iterator[tuple[Any, Any, BaseException | None]]:
648
+ return_jobs: bool = False,
649
+ ) -> Iterator[Any]:
542
650
  """Maps inputs to outptus via func concurrently under current context.
543
651
 
544
652
  Args:
@@ -581,9 +689,10 @@ def concurrent_map(
581
689
  of the tuple.
582
690
  exponential_backoff: If True, exponential wait time will be applied on top
583
691
  of the base retry interval.
692
+ return_jobs: If True, the returned iterator will emit `Job` objects.
584
693
 
585
694
  Yields:
586
- An iterator of (input, output, error).
695
+ An iterator of (input, output, error) or Job object.
587
696
 
588
697
  Raises:
589
698
  Exception: Errors that are not in `silence_on_errors` or `retry_on_errors`,
@@ -593,15 +702,6 @@ def concurrent_map(
593
702
  """
594
703
  # Internal usage logging.
595
704
 
596
- if retry_on_errors:
597
- func = with_retry(
598
- func,
599
- retry_on_errors,
600
- max_attempts=max_attempts,
601
- retry_interval=retry_interval,
602
- exponential_backoff=exponential_backoff,
603
- )
604
-
605
705
  status_fn = status_fn or (lambda p: { # pylint: disable=g-long-lambda
606
706
  'Succeeded': '%.2f%% (%d/%d)' % (
607
707
  p.success_rate * 100, p.succeeded, p.completed),
@@ -616,10 +716,15 @@ def concurrent_map(
616
716
  pending_futures = []
617
717
  total = 0
618
718
  for inputs in parallel_inputs:
619
- job = Job(func, inputs)
620
- future = executor.submit(
621
- with_context_access(job),
719
+ job = Job(
720
+ func,
721
+ (inputs,),
722
+ retry_on_errors=retry_on_errors,
723
+ max_attempts=max_attempts,
724
+ retry_interval=retry_interval,
725
+ exponential_backoff=exponential_backoff,
622
726
  )
727
+ future = executor.submit(pg.with_contextual_override(job))
623
728
  pending_futures.append(future)
624
729
  future_to_job[future] = job
625
730
  total += 1
@@ -669,7 +774,7 @@ def concurrent_map(
669
774
  silence_on_errors and isinstance(job.error, silence_on_errors)):
670
775
  raise job.error # pylint: disable=g-doc-exception
671
776
 
672
- yield job.arg, job.result, job.error
777
+ yield job if return_jobs else job.args[0], job.result, job.error
673
778
  progress.update(job)
674
779
  update_progress_bar(progress)
675
780
  ProgressBar.refresh()
@@ -690,7 +795,7 @@ def concurrent_map(
690
795
  if job.error is not None and not (
691
796
  silence_on_errors and isinstance(job.error, silence_on_errors)):
692
797
  raise job.error # pylint: disable=g-doc-exception
693
- yield job.arg, job.result, job.error
798
+ yield job if return_jobs else job.args[0], job.result, job.error
694
799
  progress.update(job)
695
800
  update_progress_bar(progress)
696
801
  completed_batch.add(future)
@@ -713,7 +818,7 @@ def concurrent_map(
713
818
  and isinstance(job.error, silence_on_errors)):
714
819
  raise job.error # pylint: disable=g-doc-exception
715
820
 
716
- yield job.arg, job.result, job.error
821
+ yield job.args[0], job.result, job.error
717
822
  progress.update(job)
718
823
  update_progress_bar(progress)
719
824
  else:
@@ -844,10 +949,10 @@ class _ConsoleProgressControl(_ProgressControl):
844
949
  def refresh(self):
845
950
  s = io.StringIO()
846
951
  if self.label is not None:
847
- s.write(text_formatting.colored(self.label, 'red', styles=['bold']))
952
+ s.write(pg.colored(self.label, 'red', styles=['bold']))
848
953
  s.write(': ')
849
954
  s.write(
850
- text_formatting.colored(
955
+ pg.colored(
851
956
  '%d%% (%d/%d)' %
852
957
  (
853
958
  self._progress * 100 // self.total,
@@ -29,22 +29,6 @@ class A(component.Component):
29
29
  y: int = component.contextual()
30
30
 
31
31
 
32
- class WithContextAccessTest(unittest.TestCase):
33
-
34
- def test_context_access(self):
35
- inputs = [A(1), A(2)]
36
- with futures.ThreadPoolExecutor() as executor:
37
- with component.context(y=3):
38
- self.assertEqual(
39
- list(
40
- executor.map(
41
- concurrent.with_context_access(lambda x: x.y), inputs
42
- )
43
- ),
44
- [3, 3],
45
- )
46
-
47
-
48
32
  class RetryErrorTest(unittest.TestCase):
49
33
 
50
34
  def test_basics(self):
@@ -94,7 +78,7 @@ class RetryErrorTest(unittest.TestCase):
94
78
  )
95
79
 
96
80
 
97
- class WithRetryTest(unittest.TestCase):
81
+ class RetryTest(unittest.TestCase):
98
82
 
99
83
  def assert_retry(self, func, expected_attempts, expected_wait_intervals):
100
84
  with pg.catch_errors(concurrent.RetryError) as error_context:
@@ -162,6 +146,31 @@ class WithRetryTest(unittest.TestCase):
162
146
  with self.assertRaises(ValueError):
163
147
  foo_with_retry()
164
148
 
149
+ def test_retry_with_job(self):
150
+ count = 0
151
+
152
+ def foo():
153
+ nonlocal count
154
+ count += 1
155
+ if count < 3:
156
+ raise ValueError('Foo temporary error.')
157
+ return 'Success'
158
+
159
+ job = concurrent.Job(
160
+ foo,
161
+ retry_on_errors=ValueError,
162
+ retry_interval=1,
163
+ )
164
+ job()
165
+ self.assertEqual(job.result, 'Success')
166
+ self.assertEqual(
167
+ [retry_entry.wait_interval for retry_entry in job.retry_entries],
168
+ [0, 1, 2],
169
+ )
170
+ self.assertIsInstance(job.retry_entries[0].error, ValueError)
171
+ self.assertIsInstance(job.retry_entries[1].error, ValueError)
172
+ self.assertIsNone(job.retry_entries[2].error)
173
+
165
174
 
166
175
  class ConcurrentExecuteTest(unittest.TestCase):
167
176
 
@@ -217,8 +226,8 @@ class ProgressTest(unittest.TestCase):
217
226
  def fun2(unused_x):
218
227
  raise ValueError('Intentional error.')
219
228
 
220
- job1 = concurrent.Job(fun, 1)
221
- job2 = concurrent.Job(fun2, 2)
229
+ job1 = concurrent.Job(fun, (1,))
230
+ job2 = concurrent.Job(fun2, (2,))
222
231
  job1()
223
232
  job2()
224
233
 
langfun/core/console.py CHANGED
@@ -15,7 +15,7 @@
15
15
 
16
16
  import sys
17
17
  from typing import Any
18
- from langfun.core.text_formatting import colored
18
+ import pyglove as pg
19
19
 
20
20
 
21
21
  def write(
@@ -42,10 +42,14 @@ def write(
42
42
  """
43
43
  # Print title if present.
44
44
  if title is not None:
45
- print(colored(title, styles=['bold']))
45
+ print(pg.colored(title, styles=['bold']))
46
46
 
47
47
  # Print body.
48
- print(colored(str(value), color=color, background=background, styles=styles))
48
+ print(
49
+ pg.colored(
50
+ str(value), color=color, background=background, styles=styles
51
+ )
52
+ )
49
53
 
50
54
 
51
55
  try:
langfun/core/eval/base.py CHANGED
@@ -1298,7 +1298,7 @@ class Evaluation(Evaluable):
1298
1298
  id=self.id,
1299
1299
  dir=self.dir,
1300
1300
  model=self.lm.model_id,
1301
- prompt_template=lf.text_formatting.decolored(str(self.prompt)),
1301
+ prompt_template=pg.decolor(str(self.prompt)),
1302
1302
  method=self.method,
1303
1303
  schema_fn=str(self.schema_fn),
1304
1304
  ),
@@ -2110,8 +2110,7 @@ class Summary(pg.Object):
2110
2110
 
2111
2111
  def _format_error(error: Exception):
2112
2112
  """Formats an error into a string."""
2113
- return (f'({error.__class__.__name__}) '
2114
- + lf.text_formatting.decolored(str(error)))
2113
+ return (f'({error.__class__.__name__}) ' + pg.decolor(str(error)))
2115
2114
 
2116
2115
 
2117
2116
  def _error_key(error: Exception) -> str:
@@ -516,7 +516,9 @@ class Evaluation(experiment_lib.Experiment):
516
516
  target='example-view',
517
517
  css_classes=['example-link'],
518
518
  )
519
- for dp in metric_value.data_points
519
+ for dp in sorted(
520
+ metric_value.data_points, key=lambda dp: dp.example_id
521
+ )
520
522
  ]
521
523
  )
522
524
  )
@@ -51,6 +51,7 @@ class HtmlReporter(experiment_lib.Plugin):
51
51
  self._update_thread = None
52
52
  self._stop_update = False
53
53
  self._stop_update_experiment_ids = set()
54
+ self._summary_lock = None
54
55
  self._experiment_index_lock = None
55
56
 
56
57
  def on_run_start(
@@ -62,6 +63,7 @@ class HtmlReporter(experiment_lib.Plugin):
62
63
  self._last_experiment_report_time = {leaf.id: 0 for leaf in root.leaf_nodes}
63
64
  self._stop_update = False
64
65
  self._stop_update_experiment_ids = set()
66
+ self._summary_lock = threading.Lock()
65
67
  self._experiment_index_lock = {
66
68
  leaf.id: threading.Lock() for leaf in root.leaf_nodes
67
69
  }
@@ -137,21 +139,23 @@ class HtmlReporter(experiment_lib.Plugin):
137
139
  """Maybe update the summary of current run."""
138
140
  run = runner.current_run
139
141
  def _summary():
140
- run.experiment.to_html(
142
+ html = run.experiment.to_html(
141
143
  collapse_level=None,
142
144
  extra_flags=dict(
143
145
  current_run=run, interactive=False, card_view=True,
144
146
  )
145
- ).save(
146
- run.output_path_for(run.experiment, _SUMMARY_FILE)
147
147
  )
148
+ with self._summary_lock:
149
+ html.save(
150
+ run.output_path_for(run.experiment, _SUMMARY_FILE)
151
+ )
148
152
 
149
153
  if force or (time.time() - self._last_summary_time > self.summary_interval):
154
+ self._last_summary_time = time.time()
150
155
  if background:
151
156
  runner.background_run(_summary)
152
157
  else:
153
158
  _summary()
154
- self._last_summary_time = time.time()
155
159
 
156
160
  def _maybe_update_experiment_html(
157
161
  self,