langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. langfun/__init__.py +22 -2
  2. langfun/core/__init__.py +17 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -28
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +69 -2
  18. langfun/core/component_test.py +54 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +17 -0
  24. langfun/core/eval/base.py +767 -140
  25. langfun/core/eval/base_test.py +238 -53
  26. langfun/core/eval/matching.py +80 -76
  27. langfun/core/eval/matching_test.py +19 -9
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +37 -28
  31. langfun/core/eval/scoring_test.py +21 -3
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +3 -21
  55. langfun/core/langfunc_test.py +26 -8
  56. langfun/core/language_model.py +686 -48
  57. langfun/core/language_model_test.py +681 -44
  58. langfun/core/llms/__init__.py +100 -12
  59. langfun/core/llms/anthropic.py +488 -0
  60. langfun/core/llms/anthropic_test.py +235 -0
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +88 -28
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +39 -26
  69. langfun/core/llms/fake_test.py +136 -11
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -197
  74. langfun/core/llms/groq.py +276 -0
  75. langfun/core/llms/groq_test.py +64 -0
  76. langfun/core/llms/llama_cpp.py +15 -40
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +436 -226
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +35 -174
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -23
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +15 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +9 -8
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +278 -0
  112. langfun/core/structured/function_generation_test.py +399 -0
  113. langfun/core/structured/mapping.py +150 -46
  114. langfun/core/structured/mapping_test.py +105 -0
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +71 -22
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
  119. langfun/core/structured/schema.py +208 -99
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_generation_test.py +2 -2
  122. langfun/core/structured/schema_test.py +133 -34
  123. langfun/core/structured/scoring.py +125 -19
  124. langfun/core/structured/scoring_test.py +30 -0
  125. langfun/core/structured/tokenization.py +64 -0
  126. langfun/core/structured/tokenization_test.py +48 -0
  127. langfun/core/template.py +240 -11
  128. langfun/core/template_test.py +146 -1
  129. langfun/core/templates/conversation.py +9 -0
  130. langfun/core/templates/conversation_test.py +4 -3
  131. langfun/core/templates/selfplay_test.py +14 -2
  132. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  133. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  134. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  135. langfun/core/coding/python/errors.py +0 -108
  136. langfun/core/coding/python/errors_test.py +0 -99
  137. langfun/core/coding/python/permissions.py +0 -90
  138. langfun/core/coding/python/permissions_test.py +0 -86
  139. langfun/core/structured/prompting.py +0 -217
  140. langfun/core/text_formatting.py +0 -162
  141. langfun/core/text_formatting_test.py +0 -47
  142. langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
  143. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  144. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  145. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -13,17 +13,30 @@
13
13
  # limitations under the License.
14
14
  """Utility library for handling concurrency in langfun."""
15
15
 
16
+ import abc
16
17
  import collections
18
+ from collections.abc import Mapping
17
19
  import concurrent.futures
18
20
  import dataclasses
21
+ import io
19
22
  import random
23
+ import sys
20
24
  import threading
21
25
  import time
22
- from typing import Any, Callable, Iterable, Iterator, Literal, Sequence, Tuple, Type, Union
26
+ from typing import Annotated, Any, Callable, Iterable, Iterator, Literal, Sequence, Tuple, Type, Union
23
27
 
24
28
  from langfun.core import component
25
29
  import pyglove as pg
26
- from tqdm import auto as tqdm
30
+
31
+
32
+ progress_bar: Literal['tqdm', 'console', None] = None
33
+
34
+ try:
35
+ from tqdm import auto as tqdm # pylint: disable=g-import-not-at-top
36
+ progress_bar = 'tqdm'
37
+ except ImportError:
38
+ progress_bar = 'console'
39
+ tqdm = None
27
40
 
28
41
 
29
42
  def with_context_access(func: Callable[..., Any]) -> Callable[..., Any]:
@@ -44,7 +57,7 @@ class RetryError(RuntimeError):
44
57
  def __init__(
45
58
  self,
46
59
  func: Callable[..., Any],
47
- errors: list[Exception],
60
+ errors: list[BaseException],
48
61
  wait_intervals: list[int],
49
62
  ):
50
63
  assert len(errors) == len(wait_intervals) + 1
@@ -99,12 +112,13 @@ class RetryError(RuntimeError):
99
112
  def with_retry(
100
113
  func: Callable[[Any], Any],
101
114
  retry_on_errors: Union[
102
- Union[Type[Exception], Tuple[Type[Exception], str]],
103
- Sequence[Union[Type[Exception], Tuple[Type[Exception], str]]],
115
+ Union[Type[BaseException], Tuple[Type[BaseException], str]],
116
+ Sequence[Union[Type[BaseException], Tuple[Type[BaseException], str]]],
104
117
  ],
105
118
  max_attempts: int,
106
119
  retry_interval: int | tuple[int, int] = (5, 60),
107
120
  exponential_backoff: bool = True,
121
+ max_retry_interval: int = 300,
108
122
  seed: int | None = None,
109
123
  ) -> Callable[..., Any]:
110
124
  """Derives a user function with retry on error.
@@ -120,6 +134,9 @@ def with_retry(
120
134
  of the tuple.
121
135
  exponential_backoff: If True, exponential wait time will be applied on top
122
136
  of the base retry interval.
137
+ max_retry_interval: The max retry interval in seconds. This is useful when
138
+ the retry interval is exponential, to avoid the wait time to grow
139
+ exponentially.
123
140
  seed: Random seed to generate retry interval. If None, the seed will be
124
141
  determined based on current time.
125
142
 
@@ -127,44 +144,33 @@ def with_retry(
127
144
  A function with the same signature of the input function, with the retry
128
145
  capability.
129
146
  """
130
- rand = random if seed is None else random.Random(seed)
131
-
132
- def _func(*args, **kwargs) -> Any:
133
- def base_interval() -> int:
134
- if isinstance(retry_interval, tuple):
135
- return rand.randint(retry_interval[0], retry_interval[1])
136
- else:
137
- assert isinstance(retry_interval, int)
138
- return retry_interval
139
147
 
140
- def next_wait_interval(attempt: int) -> float:
141
- if not exponential_backoff:
142
- attempt = 1
143
- return base_interval() * (2 ** (attempt - 1))
144
-
145
- wait_interval = None
146
- wait_intervals = []
147
- errors = []
148
- while True:
149
- with pg.catch_errors(retry_on_errors) as error_context:
150
- return func(*args, **kwargs)
148
+ def _func(*args, **kwargs):
149
+ job = Job(
150
+ func,
151
+ args,
152
+ kwargs,
153
+ retry_on_errors=retry_on_errors,
154
+ max_attempts=max_attempts,
155
+ retry_interval=retry_interval,
156
+ exponential_backoff=exponential_backoff,
157
+ max_retry_interval=max_retry_interval,
158
+ seed=seed,
159
+ )
160
+ job()
161
+ if job.error:
162
+ raise job.error
163
+ return job.result
151
164
 
152
- # Branch when errors are met for retry.
153
- errors.append(error_context.error)
154
- if len(errors) < max_attempts:
155
- wait_interval = next_wait_interval(len(errors))
156
- wait_intervals.append(wait_interval)
165
+ return _func
157
166
 
158
- pg.logging.warning(
159
- f'Calling {func!r} encountered {error_context.error!r} '
160
- f'(attempts={len(errors)}), retrying in {wait_interval} seconds...'
161
- )
162
167
 
163
- time.sleep(wait_interval)
164
- else:
165
- raise RetryError(func, errors, wait_intervals)
168
+ class RetryEntry(pg.Object):
169
+ """Retry entry."""
166
170
 
167
- return _func
171
+ call_interval: float
172
+ error: BaseException | None = None
173
+ wait_interval: float = 0.
168
174
 
169
175
 
170
176
  def concurrent_execute(
@@ -174,13 +180,15 @@ def concurrent_execute(
174
180
  executor: Union[concurrent.futures.ThreadPoolExecutor, str, None] = None,
175
181
  max_workers: int = 32,
176
182
  retry_on_errors: Union[
177
- Union[Type[Exception], Tuple[Type[Exception], str]],
178
- Sequence[Union[Type[Exception], Tuple[Type[Exception], str]]],
183
+ Union[Type[BaseException], Tuple[Type[BaseException], str]],
184
+ Sequence[Union[Type[BaseException], Tuple[Type[BaseException], str]]],
179
185
  None,
180
186
  ] = None,
181
187
  max_attempts: int = 5,
182
188
  retry_interval: int | tuple[int, int] = (5, 60),
183
189
  exponential_backoff: bool = True,
190
+ max_retry_interval: int = 300,
191
+ return_jobs: bool = False,
184
192
  ) -> list[Any]:
185
193
  """Executes a function concurrently under current component context.
186
194
 
@@ -201,31 +209,55 @@ def concurrent_execute(
201
209
  of the tuple.
202
210
  exponential_backoff: If True, exponential wait time will be applied on top
203
211
  of the base retry interval.
212
+ max_retry_interval: The max retry interval in seconds. This is useful when
213
+ the retry interval is exponential, to avoid the wait time to grow
214
+ exponentially.
215
+ return_jobs: If True, return a list of `Job` objects. Otherwise, return a
216
+ list of outputs.
204
217
 
205
218
  Returns:
206
219
  A list of ouputs. Each is the return value of `func` based on the input
207
220
  value. Order is preserved.
208
221
  """
209
- if retry_on_errors is not None:
210
- func = with_retry(
211
- func,
212
- retry_on_errors,
213
- max_attempts=max_attempts,
214
- retry_interval=retry_interval,
215
- exponential_backoff=exponential_backoff,
222
+ jobs = []
223
+ for inputs in parallel_inputs:
224
+ jobs.append(
225
+ Job(
226
+ func,
227
+ (inputs,),
228
+ retry_on_errors=retry_on_errors,
229
+ max_attempts=max_attempts,
230
+ retry_interval=retry_interval,
231
+ exponential_backoff=exponential_backoff,
232
+ max_retry_interval=max_retry_interval,
233
+ )
216
234
  )
217
235
 
218
236
  # NOTE(daiyip): when executor is not specified and max_worker is 1,
219
237
  # we don't need to create a executor pool. Instead, the inputs will be
220
238
  # processed by the user function in sequence within the current thread.
221
239
  if executor is None and max_workers == 1:
222
- return [func(i) for i in parallel_inputs]
240
+ for job in jobs:
241
+ job()
242
+ if job.error:
243
+ raise job.error
244
+ return jobs if return_jobs else [job.result for job in jobs]
223
245
 
224
246
  shutdown_after_finish = executor is None
225
247
  executor = _executor_pool.executor_from(executor, max_workers=max_workers)
226
248
 
227
249
  try:
228
- return list(executor.map(with_context_access(func), parallel_inputs))
250
+ executed_jobs = list(
251
+ executor.map(
252
+ lambda job: job(), [with_context_access(job) for job in jobs]
253
+ )
254
+ )
255
+ for job in executed_jobs:
256
+ if job.error:
257
+ raise job.error
258
+ return (
259
+ executed_jobs if return_jobs else [job.result for job in executed_jobs]
260
+ )
229
261
  finally:
230
262
  if shutdown_after_finish:
231
263
  # Do not wait threads to finish if they are timed out.
@@ -237,36 +269,139 @@ class Job:
237
269
  """Thread pool job."""
238
270
 
239
271
  func: Callable[[Any], Any]
240
- arg: Any
272
+ args: Sequence[Any] = ()
273
+ kwargs: Mapping[str, Any] = dataclasses.field(default_factory=dict)
274
+ _: dataclasses.KW_ONLY
275
+
241
276
  result: Any = pg.MISSING_VALUE
242
- error: Exception | None = None
243
- start_time: float | None = None
244
- end_time: float | None = None
277
+ error: Annotated[
278
+ BaseException | None,
279
+ 'The non-retryable error encountered during the job execution.',
280
+ ] = None
281
+ retry_entries: Annotated[
282
+ Sequence[RetryEntry], 'Records of retry attempts.'
283
+ ] = dataclasses.field(default_factory=list)
284
+
285
+ retry_on_errors: Annotated[
286
+ Sequence[Type[BaseException] | str],
287
+ (
288
+ 'A sequence of exception types or tuples of exception type and error '
289
+ 'messages (described in regular expression) as the desired exception '
290
+ 'types to retry.'
291
+ ),
292
+ ] = ()
293
+ max_attempts: Annotated[
294
+ int, 'Max number of attempts if an error to retry is encountered.'
295
+ ] = 5
296
+ retry_interval: Annotated[
297
+ int | tuple[int, int],
298
+ (
299
+ 'The (base) retry interval in seconds. If a tuple, the retry '
300
+ 'interval will be randomly chosen between the first and the second '
301
+ 'element of the tuple.'
302
+ ),
303
+ ] = (5, 60)
304
+ exponential_backoff: Annotated[
305
+ bool,
306
+ (
307
+ 'If True, exponential wait time will be applied on top of the base '
308
+ 'retry interval.'
309
+ ),
310
+ ] = True
311
+ max_retry_interval: Annotated[
312
+ int,
313
+ (
314
+ 'The max retry interval in seconds. This is useful when the retry '
315
+ 'interval is exponential, to avoid the wait time to grow '
316
+ 'exponentially.'
317
+ ),
318
+ ] = 300
319
+ seed: Annotated[
320
+ int | None,
321
+ (
322
+ 'Random seed to generate retry interval. If None, the seed will be'
323
+ ' determined based on current time.'
324
+ ),
325
+ ] = None
326
+
327
+ timeit: pg.object_utils.TimeIt = dataclasses.field(
328
+ default_factory=lambda: pg.object_utils.TimeIt('job')
329
+ )
330
+
331
+ @property
332
+ def elapse(self) -> float:
333
+ """Returns the running time in seconds since the job get started."""
334
+ return self.timeit.elapse
335
+
336
+ def _retry_call(self) -> 'Job':
337
+ """Retries func call on args."""
338
+ rand = random if self.seed is None else random.Random(self.seed)
339
+
340
+ def base_interval() -> int:
341
+ if isinstance(self.retry_interval, tuple):
342
+ return rand.randint(*self.retry_interval)
343
+ else:
344
+ assert isinstance(self.retry_interval, int)
345
+ return self.retry_interval
346
+
347
+ def next_wait_interval(attempt: int) -> float:
348
+ if not self.exponential_backoff:
349
+ attempt = 1
350
+ return min(
351
+ self.max_retry_interval, base_interval() * (2 ** (attempt - 1))
352
+ )
353
+
354
+ retry_entries = []
355
+ wait_interval = 0
356
+ while True:
357
+ with pg.catch_errors(self.retry_on_errors) as error_context:
358
+ begin_time = time.time()
359
+ self.result = self.func(*self.args, **self.kwargs)
360
+
361
+ end_time = time.time()
362
+ retry_entries.append(RetryEntry(
363
+ call_interval=end_time - begin_time,
364
+ wait_interval=wait_interval,
365
+ error=error_context.error,
366
+ ))
367
+ if error_context.error is None:
368
+ self.retry_entries = retry_entries
369
+ return self
370
+
371
+ # Branch when errors are met for retry.
372
+ if len(retry_entries) < self.max_attempts:
373
+ wait_interval = next_wait_interval(len(retry_entries))
374
+
375
+ pg.logging.warning(
376
+ f'Calling {self.func!r} encountered {error_context.error!r} '
377
+ f'(attempts={len(retry_entries)}), retrying in '
378
+ f'{wait_interval} seconds...'
379
+ )
245
380
 
246
- def __call__(self) -> Any:
247
- self.start_time = time.time()
381
+ time.sleep(wait_interval)
382
+ else:
383
+ errors = [e.error for e in retry_entries]
384
+ # First wait interval is 0.
385
+ wait_intervals = [e.wait_interval for e in retry_entries[1:]]
386
+ raise RetryError(self.func, errors, wait_intervals)
387
+
388
+ def __call__(self) -> 'Job':
389
+ if getattr(self, '_has_call', False):
390
+ raise ValueError('Job can only be called once.')
391
+ self._has_call = True
248
392
  try:
249
- self.result = self.func(self.arg)
250
- return self.result
251
- except Exception as e: # pylint: disable=broad-exception-caught
393
+ with self.timeit:
394
+ if self.retry_on_errors:
395
+ return self._retry_call()
396
+ self.result = self.func(*self.args, **self.kwargs)
397
+ except BaseException as e: # pylint: disable=broad-exception-caught
252
398
  self.error = e
253
- return e
254
- finally:
255
- self.end_time = time.time()
399
+ return self
256
400
 
257
- def mark_canceled(self, error: Exception) -> None:
401
+ def mark_canceled(self, error: BaseException) -> None:
258
402
  """Marks the job as canceled."""
403
+ self.timeit.end(error)
259
404
  self.error = error
260
- self.end_time = time.time()
261
-
262
- @property
263
- def elapse(self) -> float:
264
- """Returns the running time in seconds since the job get started."""
265
- if self.start_time is None:
266
- return 0.0
267
- if self.end_time is None:
268
- return time.time() - self.start_time
269
- return self.end_time - self.start_time
270
405
 
271
406
 
272
407
  @dataclasses.dataclass
@@ -276,9 +411,12 @@ class Progress:
276
411
 
277
412
  _succeeded: int = 0
278
413
  _failed: int = 0
279
- _last_error: Exception | None = None
414
+ _last_error: BaseException | None = None
280
415
  _total_duration: float = 0.0
281
416
  _job: Job | None = None
417
+ _timeit_summary: pg.object_utils.TimeIt.StatusSummary = dataclasses.field(
418
+ default_factory=pg.object_utils.TimeIt.StatusSummary
419
+ )
282
420
 
283
421
  @property
284
422
  def succeeded(self) -> int:
@@ -296,7 +434,7 @@ class Progress:
296
434
  return self.succeeded + self.failed
297
435
 
298
436
  @property
299
- def last_error(self) -> Exception | None:
437
+ def last_error(self) -> BaseException | None:
300
438
  """Returns last error."""
301
439
  return self._last_error
302
440
 
@@ -326,6 +464,28 @@ class Progress:
326
464
  return 0.0
327
465
  return self._total_duration / self.completed
328
466
 
467
+ @property
468
+ def timeit_summary(self) -> pg.object_utils.TimeIt.StatusSummary:
469
+ """Returns the aggregated summary for each `pg.timeit`."""
470
+ return self._timeit_summary
471
+
472
+ def timeit_summary_str(self) -> str | None:
473
+ if not self.timeit_summary:
474
+ return None
475
+ return ', '.join([
476
+ '%s (%.2fs, %d/%d)' % (
477
+ k.lstrip('job.'), v.avg_duration, v.num_ended, v.num_started
478
+ ) for k, v in self.timeit_summary.breakdown.items() if k != 'job'
479
+ ])
480
+
481
+ def last_error_str(self) -> str | None:
482
+ if self.last_error is None:
483
+ return None
484
+ error_text = repr(self.last_error)
485
+ if len(error_text) >= 64:
486
+ error_text = error_text[:64] + '...'
487
+ return error_text
488
+
329
489
  def update(self, job: Job) -> None:
330
490
  """Mark a job as completed."""
331
491
  self._job = job
@@ -335,6 +495,7 @@ class Progress:
335
495
  self._failed += 1
336
496
  self._last_error = job.error
337
497
  self._total_duration += job.elapse
498
+ self._timeit_summary.aggregate(job.timeit.status())
338
499
 
339
500
 
340
501
  class ProgressBar:
@@ -356,17 +517,17 @@ class ProgressBar:
356
517
  label: str | None
357
518
  total: int
358
519
  color: str | None = None
359
- postfix: dict[str, str] | None = None
520
+ status: dict[str, Any] | None = None
360
521
 
361
522
  @dataclasses.dataclass
362
523
  class Update:
363
524
  """Progress bar update."""
364
525
  bar_id: int
365
526
  delta: int
366
- postfix: Union[dict[str, str], str, None] = None
527
+ status: Union[dict[str, Any], str, None] = None
367
528
  color: str | None = None
368
529
 
369
- _progress_bars: dict[int, tqdm.tqdm] = {}
530
+ _progress_bars: dict[int, '_ProgressControl'] = {}
370
531
  _install_requests: list[tuple[int, Settings]] = []
371
532
  _updates: collections.deque[Update] = collections.deque()
372
533
  _uninstall_requests: list[int] = []
@@ -378,11 +539,11 @@ class ProgressBar:
378
539
  label: str | None,
379
540
  total: int,
380
541
  color: str | None = None,
381
- postfix: dict[str, str] | None = None,
542
+ status: dict[str, Any] | None = None,
382
543
  ) -> int:
383
544
  """Installs a progress bar and returns a reference id."""
384
545
  with cls._lock:
385
- settings = ProgressBar.Settings(label, total, color, postfix)
546
+ settings = ProgressBar.Settings(label, total, color, status)
386
547
  bar_id = id(settings)
387
548
  cls._install_requests.append((bar_id, settings))
388
549
  return bar_id
@@ -392,15 +553,17 @@ class ProgressBar:
392
553
  cls,
393
554
  bar_id: int,
394
555
  delta: int = 0,
395
- postfix: Union[dict[str, str], str, None] = None,
556
+ status: Union[dict[str, Any], str, None] = None,
396
557
  color: str | None = None,
397
558
  refresh: bool = True,
398
559
  ) -> None:
399
560
  """Report the progress for a label."""
561
+ if status is not None and not isinstance(status, (str, dict)):
562
+ raise ValueError(f'Unsupported status: {status}')
400
563
  with cls._lock:
401
564
  cls._updates.append(
402
565
  ProgressBar.Update(
403
- bar_id=bar_id, delta=delta, postfix=postfix, color=color,
566
+ bar_id=bar_id, delta=delta, status=status, color=color,
404
567
  )
405
568
  )
406
569
  if refresh:
@@ -422,11 +585,11 @@ class ProgressBar:
422
585
  # Process install requests.
423
586
  if cls._install_requests:
424
587
  for bar_id, settings in cls._install_requests:
425
- cls._progress_bars[bar_id] = tqdm.tqdm(
588
+ cls._progress_bars[bar_id] = _progress_control(
426
589
  total=settings.total,
427
- desc=settings.label,
428
- colour=settings.color,
429
- postfix=settings.postfix)
590
+ label=settings.label,
591
+ color=settings.color,
592
+ status=settings.status)
430
593
  cls._install_requests.clear()
431
594
 
432
595
  # Process updates.
@@ -441,15 +604,11 @@ class ProgressBar:
441
604
  if update.delta > 0:
442
605
  bar.update(update.delta)
443
606
 
444
- if isinstance(update.postfix, str):
445
- bar.set_postfix_str(update.postfix, refresh=False)
446
- elif isinstance(update.postfix, dict):
447
- bar.set_postfix(update.postfix, refresh=False)
448
- elif update.postfix is not None:
449
- raise ValueError(f'Unsupported postfix: {update.postfix}')
607
+ if update.status is not None:
608
+ bar.set_status(update.status)
450
609
 
451
610
  if update.color is not None:
452
- bar.colour = update.color
611
+ bar.set_color(update.color)
453
612
  updated_bars.add(bar)
454
613
 
455
614
  # Refresh each updated bar just once.
@@ -459,7 +618,9 @@ class ProgressBar:
459
618
  # Process uninstall requests.
460
619
  if cls._uninstall_requests:
461
620
  for bar_id in cls._uninstall_requests:
462
- cls._progress_bars.pop(bar_id, None)
621
+ bar = cls._progress_bars.pop(bar_id, None)
622
+ if bar is not None:
623
+ bar.close()
463
624
  cls._uninstall_requests.clear()
464
625
 
465
626
 
@@ -486,17 +647,18 @@ def concurrent_map(
486
647
  status_fn: Callable[[Progress], dict[str, Any]] | None = None,
487
648
  timeout: int | None = None,
488
649
  silence_on_errors: Union[
489
- Type[Exception], Tuple[Type[Exception], ...], None
650
+ Type[BaseException], Tuple[Type[BaseException], ...], None
490
651
  ] = Exception,
491
652
  retry_on_errors: Union[
492
- Type[Exception],
493
- Tuple[Type[Exception], ...],
653
+ Type[BaseException],
654
+ Tuple[Type[BaseException], ...],
494
655
  None,
495
656
  ] = None,
496
657
  max_attempts: int = 5,
497
658
  retry_interval: int | tuple[int, int] = (5, 60),
498
659
  exponential_backoff: bool = True,
499
- ) -> Iterator[tuple[Any, Any, Exception | None]]:
660
+ return_jobs: bool = False,
661
+ ) -> Iterator[Any]:
500
662
  """Maps inputs to outptus via func concurrently under current context.
501
663
 
502
664
  Args:
@@ -539,9 +701,10 @@ def concurrent_map(
539
701
  of the tuple.
540
702
  exponential_backoff: If True, exponential wait time will be applied on top
541
703
  of the base retry interval.
704
+ return_jobs: If True, the returned iterator will emit `Job` objects.
542
705
 
543
706
  Yields:
544
- An iterator of (input, output, error).
707
+ An iterator of (input, output, error) or Job object.
545
708
 
546
709
  Raises:
547
710
  Exception: Errors that are not in `silence_on_errors` or `retry_on_errors`,
@@ -551,15 +714,6 @@ def concurrent_map(
551
714
  """
552
715
  # Internal usage logging.
553
716
 
554
- if retry_on_errors:
555
- func = with_retry(
556
- func,
557
- retry_on_errors,
558
- max_attempts=max_attempts,
559
- retry_interval=retry_interval,
560
- exponential_backoff=exponential_backoff,
561
- )
562
-
563
717
  status_fn = status_fn or (lambda p: { # pylint: disable=g-long-lambda
564
718
  'Succeeded': '%.2f%% (%d/%d)' % (
565
719
  p.success_rate * 100, p.succeeded, p.completed),
@@ -574,7 +728,14 @@ def concurrent_map(
574
728
  pending_futures = []
575
729
  total = 0
576
730
  for inputs in parallel_inputs:
577
- job = Job(func, inputs)
731
+ job = Job(
732
+ func,
733
+ (inputs,),
734
+ retry_on_errors=retry_on_errors,
735
+ max_attempts=max_attempts,
736
+ retry_interval=retry_interval,
737
+ exponential_backoff=exponential_backoff,
738
+ )
578
739
  future = executor.submit(
579
740
  with_context_access(job),
580
741
  )
@@ -596,14 +757,14 @@ def concurrent_map(
596
757
  if show_progress:
597
758
  status = status_fn(progress)
598
759
  status.update({
599
- 'AvgDuration': '%.2f seconds' % progress.avg_duration
760
+ 'AvgDuration': '%.2fs' % progress.avg_duration
600
761
  })
601
762
  if progress.last_error is not None:
602
- error_text = repr(progress.last_error)
603
- if len(error_text) >= 64:
604
- error_text = error_text[:64] + '...'
605
- status['LastError'] = error_text
606
- ProgressBar.update(bar_id, delta=1, postfix=status)
763
+ status['LastError'] = progress.last_error_str()
764
+
765
+ if progress.timeit_summary:
766
+ status['TimeIt'] = progress.timeit_summary_str()
767
+ ProgressBar.update(bar_id, delta=1, status=status)
607
768
 
608
769
  try:
609
770
  if ordered:
@@ -627,7 +788,7 @@ def concurrent_map(
627
788
  silence_on_errors and isinstance(job.error, silence_on_errors)):
628
789
  raise job.error # pylint: disable=g-doc-exception
629
790
 
630
- yield job.arg, job.result, job.error
791
+ yield job if return_jobs else job.args[0], job.result, job.error
631
792
  progress.update(job)
632
793
  update_progress_bar(progress)
633
794
  ProgressBar.refresh()
@@ -648,7 +809,7 @@ def concurrent_map(
648
809
  if job.error is not None and not (
649
810
  silence_on_errors and isinstance(job.error, silence_on_errors)):
650
811
  raise job.error # pylint: disable=g-doc-exception
651
- yield job.arg, job.result, job.error
812
+ yield job if return_jobs else job.args[0], job.result, job.error
652
813
  progress.update(job)
653
814
  update_progress_bar(progress)
654
815
  completed_batch.add(future)
@@ -671,7 +832,7 @@ def concurrent_map(
671
832
  and isinstance(job.error, silence_on_errors)):
672
833
  raise job.error # pylint: disable=g-doc-exception
673
834
 
674
- yield job.arg, job.result, job.error
835
+ yield job.args[0], job.result, job.error
675
836
  progress.update(job)
676
837
  update_progress_bar(progress)
677
838
  else:
@@ -729,5 +890,141 @@ class ExecutorPool:
729
890
  raise ValueError(f'Unsupported value: {maybe_executor}.')
730
891
 
731
892
 
893
+ class _ProgressControl(pg.Object):
894
+ """Abstract progress control."""
895
+ # Disable symbolic comparison so the hash is based on object address.
896
+ use_symbolic_comparison = False
897
+
898
+ total: int
899
+ label: str | None
900
+ color: str | None
901
+ status: str | dict[str, Any] | None
902
+
903
+ def set_color(self, color: str | None):
904
+ with pg.notify_on_change(False):
905
+ self.rebind(color=color)
906
+
907
+ def set_status(self, status: str | dict[str, Any] | None):
908
+ with pg.notify_on_change(False):
909
+ self.rebind(status=status)
910
+
911
+ @abc.abstractmethod
912
+ def update(self, delta):
913
+ """Update progress."""
914
+
915
+ @abc.abstractmethod
916
+ def refresh(self) -> None:
917
+ """Refresh progress bar."""
918
+
919
+ @abc.abstractmethod
920
+ def close(self) -> None:
921
+ """Close progress bar."""
922
+
923
+
924
+ class _TqdmProgressControl(_ProgressControl):
925
+ """Tqdm-based progress control."""
926
+
927
+ def _on_bound(self):
928
+ super()._on_bound()
929
+ assert tqdm is not None
930
+ self._tqdm = tqdm.tqdm(
931
+ total=self.total,
932
+ desc=self.label,
933
+ colour=self.color,
934
+ postfix=self.status,
935
+ )
936
+
937
+ def update(self, delta: int) -> None:
938
+ self._tqdm.update(delta)
939
+
940
+ def refresh(self):
941
+ self._tqdm.set_description(self.label, refresh=False)
942
+ if isinstance(self.status, str):
943
+ self._tqdm.set_postfix_str(self.status, refresh=False)
944
+ else:
945
+ self._tqdm.set_postfix(self.status, refresh=False)
946
+ self._tqdm.colour = self.color
947
+ self._tqdm.refresh()
948
+
949
+ def close(self):
950
+ self._tqdm.close()
951
+
952
+
953
+ class _ConsoleProgressControl(_ProgressControl):
954
+ """Simple progress control by printing the status to the console."""
955
+
956
+ def _on_bound(self):
957
+ super()._on_bound()
958
+ self._progress = 0
959
+
960
+ def update(self, delta: int) -> None:
961
+ self._progress += delta
962
+
963
+ def refresh(self):
964
+ s = io.StringIO()
965
+ if self.label is not None:
966
+ s.write(pg.colored(self.label, 'red', styles=['bold']))
967
+ s.write(': ')
968
+ s.write(
969
+ pg.colored(
970
+ '%d%% (%d/%d)' %
971
+ (
972
+ self._progress * 100 // self.total,
973
+ self._progress,
974
+ self.total,
975
+ ),
976
+ color=self.color or 'green'
977
+ )
978
+ )
979
+ if self.status is not None:
980
+ status = repr(self.status) if isinstance(
981
+ self.status, dict) else self.status
982
+ s.write(f' : {status}')
983
+ sys.stderr.write(s.getvalue() + '\n')
984
+
985
+ def close(self):
986
+ sys.stderr.flush()
987
+
988
+
989
+ class _NoopProgressControl(_ProgressControl):
990
+ """No-op progress control."""
991
+
992
+ def update(self, delta: int) -> None:
993
+ pass
994
+
995
+ def refresh(self) -> None:
996
+ pass
997
+
998
+ def close(self) -> None:
999
+ pass
1000
+
1001
+
1002
+ def _progress_control(
1003
+ total: int,
1004
+ label: str | None,
1005
+ color: str | None,
1006
+ status: str | dict[str, Any] | None,
1007
+ ) -> _ProgressControl:
1008
+ """Creates a process control."""
1009
+ if progress_bar == 'tqdm':
1010
+ if not tqdm:
1011
+ raise RuntimeError(
1012
+ 'Please install package "tqdm" to use `tqdm` progress bar.'
1013
+ )
1014
+ return _TqdmProgressControl(total, label, color, status)
1015
+ elif progress_bar == 'console':
1016
+ return _ConsoleProgressControl(total, label, color, status)
1017
+ elif progress_bar is None:
1018
+ return _NoopProgressControl(total, label, color, status)
1019
+ else:
1020
+ raise ValueError(f'Unsupported progress bar type: {progress_bar}')
1021
+
1022
+
1023
+ def get_executor(
1024
+ resource_id: str,
1025
+ max_workers: int | None = None) -> concurrent.futures.ThreadPoolExecutor:
1026
+ """Gets a thread pool executor associated with a resource id."""
1027
+ return _executor_pool.get(resource_id, max_workers)
1028
+
732
1029
  # The global executor pool based on resource IDs.
733
1030
  _executor_pool = ExecutorPool()