langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +22 -2
- langfun/core/__init__.py +17 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -28
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +69 -2
- langfun/core/component_test.py +54 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +17 -0
- langfun/core/eval/base.py +767 -140
- langfun/core/eval/base_test.py +238 -53
- langfun/core/eval/matching.py +80 -76
- langfun/core/eval/matching_test.py +19 -9
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +37 -28
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +3 -21
- langfun/core/langfunc_test.py +26 -8
- langfun/core/language_model.py +686 -48
- langfun/core/language_model_test.py +681 -44
- langfun/core/llms/__init__.py +100 -12
- langfun/core/llms/anthropic.py +488 -0
- langfun/core/llms/anthropic_test.py +235 -0
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +88 -28
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +39 -26
- langfun/core/llms/fake_test.py +136 -11
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -197
- langfun/core/llms/groq.py +276 -0
- langfun/core/llms/groq_test.py +64 -0
- langfun/core/llms/llama_cpp.py +15 -40
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +436 -226
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +35 -174
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -23
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +15 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +9 -8
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +278 -0
- langfun/core/structured/function_generation_test.py +399 -0
- langfun/core/structured/mapping.py +150 -46
- langfun/core/structured/mapping_test.py +105 -0
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +71 -22
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
- langfun/core/structured/schema.py +208 -99
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +133 -34
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +240 -11
- langfun/core/template_test.py +146 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +14 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -217
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
langfun/core/concurrent.py
CHANGED
@@ -13,17 +13,30 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Utility library for handling concurrency in langfun."""
|
15
15
|
|
16
|
+
import abc
|
16
17
|
import collections
|
18
|
+
from collections.abc import Mapping
|
17
19
|
import concurrent.futures
|
18
20
|
import dataclasses
|
21
|
+
import io
|
19
22
|
import random
|
23
|
+
import sys
|
20
24
|
import threading
|
21
25
|
import time
|
22
|
-
from typing import Any, Callable, Iterable, Iterator, Literal, Sequence, Tuple, Type, Union
|
26
|
+
from typing import Annotated, Any, Callable, Iterable, Iterator, Literal, Sequence, Tuple, Type, Union
|
23
27
|
|
24
28
|
from langfun.core import component
|
25
29
|
import pyglove as pg
|
26
|
-
|
30
|
+
|
31
|
+
|
32
|
+
progress_bar: Literal['tqdm', 'console', None] = None
|
33
|
+
|
34
|
+
try:
|
35
|
+
from tqdm import auto as tqdm # pylint: disable=g-import-not-at-top
|
36
|
+
progress_bar = 'tqdm'
|
37
|
+
except ImportError:
|
38
|
+
progress_bar = 'console'
|
39
|
+
tqdm = None
|
27
40
|
|
28
41
|
|
29
42
|
def with_context_access(func: Callable[..., Any]) -> Callable[..., Any]:
|
@@ -44,7 +57,7 @@ class RetryError(RuntimeError):
|
|
44
57
|
def __init__(
|
45
58
|
self,
|
46
59
|
func: Callable[..., Any],
|
47
|
-
errors: list[
|
60
|
+
errors: list[BaseException],
|
48
61
|
wait_intervals: list[int],
|
49
62
|
):
|
50
63
|
assert len(errors) == len(wait_intervals) + 1
|
@@ -99,12 +112,13 @@ class RetryError(RuntimeError):
|
|
99
112
|
def with_retry(
|
100
113
|
func: Callable[[Any], Any],
|
101
114
|
retry_on_errors: Union[
|
102
|
-
Union[Type[
|
103
|
-
Sequence[Union[Type[
|
115
|
+
Union[Type[BaseException], Tuple[Type[BaseException], str]],
|
116
|
+
Sequence[Union[Type[BaseException], Tuple[Type[BaseException], str]]],
|
104
117
|
],
|
105
118
|
max_attempts: int,
|
106
119
|
retry_interval: int | tuple[int, int] = (5, 60),
|
107
120
|
exponential_backoff: bool = True,
|
121
|
+
max_retry_interval: int = 300,
|
108
122
|
seed: int | None = None,
|
109
123
|
) -> Callable[..., Any]:
|
110
124
|
"""Derives a user function with retry on error.
|
@@ -120,6 +134,9 @@ def with_retry(
|
|
120
134
|
of the tuple.
|
121
135
|
exponential_backoff: If True, exponential wait time will be applied on top
|
122
136
|
of the base retry interval.
|
137
|
+
max_retry_interval: The max retry interval in seconds. This is useful when
|
138
|
+
the retry interval is exponential, to avoid the wait time to grow
|
139
|
+
exponentially.
|
123
140
|
seed: Random seed to generate retry interval. If None, the seed will be
|
124
141
|
determined based on current time.
|
125
142
|
|
@@ -127,44 +144,33 @@ def with_retry(
|
|
127
144
|
A function with the same signature of the input function, with the retry
|
128
145
|
capability.
|
129
146
|
"""
|
130
|
-
rand = random if seed is None else random.Random(seed)
|
131
|
-
|
132
|
-
def _func(*args, **kwargs) -> Any:
|
133
|
-
def base_interval() -> int:
|
134
|
-
if isinstance(retry_interval, tuple):
|
135
|
-
return rand.randint(retry_interval[0], retry_interval[1])
|
136
|
-
else:
|
137
|
-
assert isinstance(retry_interval, int)
|
138
|
-
return retry_interval
|
139
147
|
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
148
|
+
def _func(*args, **kwargs):
|
149
|
+
job = Job(
|
150
|
+
func,
|
151
|
+
args,
|
152
|
+
kwargs,
|
153
|
+
retry_on_errors=retry_on_errors,
|
154
|
+
max_attempts=max_attempts,
|
155
|
+
retry_interval=retry_interval,
|
156
|
+
exponential_backoff=exponential_backoff,
|
157
|
+
max_retry_interval=max_retry_interval,
|
158
|
+
seed=seed,
|
159
|
+
)
|
160
|
+
job()
|
161
|
+
if job.error:
|
162
|
+
raise job.error
|
163
|
+
return job.result
|
151
164
|
|
152
|
-
|
153
|
-
errors.append(error_context.error)
|
154
|
-
if len(errors) < max_attempts:
|
155
|
-
wait_interval = next_wait_interval(len(errors))
|
156
|
-
wait_intervals.append(wait_interval)
|
165
|
+
return _func
|
157
166
|
|
158
|
-
pg.logging.warning(
|
159
|
-
f'Calling {func!r} encountered {error_context.error!r} '
|
160
|
-
f'(attempts={len(errors)}), retrying in {wait_interval} seconds...'
|
161
|
-
)
|
162
167
|
|
163
|
-
|
164
|
-
|
165
|
-
raise RetryError(func, errors, wait_intervals)
|
168
|
+
class RetryEntry(pg.Object):
|
169
|
+
"""Retry entry."""
|
166
170
|
|
167
|
-
|
171
|
+
call_interval: float
|
172
|
+
error: BaseException | None = None
|
173
|
+
wait_interval: float = 0.
|
168
174
|
|
169
175
|
|
170
176
|
def concurrent_execute(
|
@@ -174,13 +180,15 @@ def concurrent_execute(
|
|
174
180
|
executor: Union[concurrent.futures.ThreadPoolExecutor, str, None] = None,
|
175
181
|
max_workers: int = 32,
|
176
182
|
retry_on_errors: Union[
|
177
|
-
Union[Type[
|
178
|
-
Sequence[Union[Type[
|
183
|
+
Union[Type[BaseException], Tuple[Type[BaseException], str]],
|
184
|
+
Sequence[Union[Type[BaseException], Tuple[Type[BaseException], str]]],
|
179
185
|
None,
|
180
186
|
] = None,
|
181
187
|
max_attempts: int = 5,
|
182
188
|
retry_interval: int | tuple[int, int] = (5, 60),
|
183
189
|
exponential_backoff: bool = True,
|
190
|
+
max_retry_interval: int = 300,
|
191
|
+
return_jobs: bool = False,
|
184
192
|
) -> list[Any]:
|
185
193
|
"""Executes a function concurrently under current component context.
|
186
194
|
|
@@ -201,31 +209,55 @@ def concurrent_execute(
|
|
201
209
|
of the tuple.
|
202
210
|
exponential_backoff: If True, exponential wait time will be applied on top
|
203
211
|
of the base retry interval.
|
212
|
+
max_retry_interval: The max retry interval in seconds. This is useful when
|
213
|
+
the retry interval is exponential, to avoid the wait time to grow
|
214
|
+
exponentially.
|
215
|
+
return_jobs: If True, return a list of `Job` objects. Otherwise, return a
|
216
|
+
list of outputs.
|
204
217
|
|
205
218
|
Returns:
|
206
219
|
A list of ouputs. Each is the return value of `func` based on the input
|
207
220
|
value. Order is preserved.
|
208
221
|
"""
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
222
|
+
jobs = []
|
223
|
+
for inputs in parallel_inputs:
|
224
|
+
jobs.append(
|
225
|
+
Job(
|
226
|
+
func,
|
227
|
+
(inputs,),
|
228
|
+
retry_on_errors=retry_on_errors,
|
229
|
+
max_attempts=max_attempts,
|
230
|
+
retry_interval=retry_interval,
|
231
|
+
exponential_backoff=exponential_backoff,
|
232
|
+
max_retry_interval=max_retry_interval,
|
233
|
+
)
|
216
234
|
)
|
217
235
|
|
218
236
|
# NOTE(daiyip): when executor is not specified and max_worker is 1,
|
219
237
|
# we don't need to create a executor pool. Instead, the inputs will be
|
220
238
|
# processed by the user function in sequence within the current thread.
|
221
239
|
if executor is None and max_workers == 1:
|
222
|
-
|
240
|
+
for job in jobs:
|
241
|
+
job()
|
242
|
+
if job.error:
|
243
|
+
raise job.error
|
244
|
+
return jobs if return_jobs else [job.result for job in jobs]
|
223
245
|
|
224
246
|
shutdown_after_finish = executor is None
|
225
247
|
executor = _executor_pool.executor_from(executor, max_workers=max_workers)
|
226
248
|
|
227
249
|
try:
|
228
|
-
|
250
|
+
executed_jobs = list(
|
251
|
+
executor.map(
|
252
|
+
lambda job: job(), [with_context_access(job) for job in jobs]
|
253
|
+
)
|
254
|
+
)
|
255
|
+
for job in executed_jobs:
|
256
|
+
if job.error:
|
257
|
+
raise job.error
|
258
|
+
return (
|
259
|
+
executed_jobs if return_jobs else [job.result for job in executed_jobs]
|
260
|
+
)
|
229
261
|
finally:
|
230
262
|
if shutdown_after_finish:
|
231
263
|
# Do not wait threads to finish if they are timed out.
|
@@ -237,36 +269,139 @@ class Job:
|
|
237
269
|
"""Thread pool job."""
|
238
270
|
|
239
271
|
func: Callable[[Any], Any]
|
240
|
-
|
272
|
+
args: Sequence[Any] = ()
|
273
|
+
kwargs: Mapping[str, Any] = dataclasses.field(default_factory=dict)
|
274
|
+
_: dataclasses.KW_ONLY
|
275
|
+
|
241
276
|
result: Any = pg.MISSING_VALUE
|
242
|
-
error:
|
243
|
-
|
244
|
-
|
277
|
+
error: Annotated[
|
278
|
+
BaseException | None,
|
279
|
+
'The non-retryable error encountered during the job execution.',
|
280
|
+
] = None
|
281
|
+
retry_entries: Annotated[
|
282
|
+
Sequence[RetryEntry], 'Records of retry attempts.'
|
283
|
+
] = dataclasses.field(default_factory=list)
|
284
|
+
|
285
|
+
retry_on_errors: Annotated[
|
286
|
+
Sequence[Type[BaseException] | str],
|
287
|
+
(
|
288
|
+
'A sequence of exception types or tuples of exception type and error '
|
289
|
+
'messages (described in regular expression) as the desired exception '
|
290
|
+
'types to retry.'
|
291
|
+
),
|
292
|
+
] = ()
|
293
|
+
max_attempts: Annotated[
|
294
|
+
int, 'Max number of attempts if an error to retry is encountered.'
|
295
|
+
] = 5
|
296
|
+
retry_interval: Annotated[
|
297
|
+
int | tuple[int, int],
|
298
|
+
(
|
299
|
+
'The (base) retry interval in seconds. If a tuple, the retry '
|
300
|
+
'interval will be randomly chosen between the first and the second '
|
301
|
+
'element of the tuple.'
|
302
|
+
),
|
303
|
+
] = (5, 60)
|
304
|
+
exponential_backoff: Annotated[
|
305
|
+
bool,
|
306
|
+
(
|
307
|
+
'If True, exponential wait time will be applied on top of the base '
|
308
|
+
'retry interval.'
|
309
|
+
),
|
310
|
+
] = True
|
311
|
+
max_retry_interval: Annotated[
|
312
|
+
int,
|
313
|
+
(
|
314
|
+
'The max retry interval in seconds. This is useful when the retry '
|
315
|
+
'interval is exponential, to avoid the wait time to grow '
|
316
|
+
'exponentially.'
|
317
|
+
),
|
318
|
+
] = 300
|
319
|
+
seed: Annotated[
|
320
|
+
int | None,
|
321
|
+
(
|
322
|
+
'Random seed to generate retry interval. If None, the seed will be'
|
323
|
+
' determined based on current time.'
|
324
|
+
),
|
325
|
+
] = None
|
326
|
+
|
327
|
+
timeit: pg.object_utils.TimeIt = dataclasses.field(
|
328
|
+
default_factory=lambda: pg.object_utils.TimeIt('job')
|
329
|
+
)
|
330
|
+
|
331
|
+
@property
|
332
|
+
def elapse(self) -> float:
|
333
|
+
"""Returns the running time in seconds since the job get started."""
|
334
|
+
return self.timeit.elapse
|
335
|
+
|
336
|
+
def _retry_call(self) -> 'Job':
|
337
|
+
"""Retries func call on args."""
|
338
|
+
rand = random if self.seed is None else random.Random(self.seed)
|
339
|
+
|
340
|
+
def base_interval() -> int:
|
341
|
+
if isinstance(self.retry_interval, tuple):
|
342
|
+
return rand.randint(*self.retry_interval)
|
343
|
+
else:
|
344
|
+
assert isinstance(self.retry_interval, int)
|
345
|
+
return self.retry_interval
|
346
|
+
|
347
|
+
def next_wait_interval(attempt: int) -> float:
|
348
|
+
if not self.exponential_backoff:
|
349
|
+
attempt = 1
|
350
|
+
return min(
|
351
|
+
self.max_retry_interval, base_interval() * (2 ** (attempt - 1))
|
352
|
+
)
|
353
|
+
|
354
|
+
retry_entries = []
|
355
|
+
wait_interval = 0
|
356
|
+
while True:
|
357
|
+
with pg.catch_errors(self.retry_on_errors) as error_context:
|
358
|
+
begin_time = time.time()
|
359
|
+
self.result = self.func(*self.args, **self.kwargs)
|
360
|
+
|
361
|
+
end_time = time.time()
|
362
|
+
retry_entries.append(RetryEntry(
|
363
|
+
call_interval=end_time - begin_time,
|
364
|
+
wait_interval=wait_interval,
|
365
|
+
error=error_context.error,
|
366
|
+
))
|
367
|
+
if error_context.error is None:
|
368
|
+
self.retry_entries = retry_entries
|
369
|
+
return self
|
370
|
+
|
371
|
+
# Branch when errors are met for retry.
|
372
|
+
if len(retry_entries) < self.max_attempts:
|
373
|
+
wait_interval = next_wait_interval(len(retry_entries))
|
374
|
+
|
375
|
+
pg.logging.warning(
|
376
|
+
f'Calling {self.func!r} encountered {error_context.error!r} '
|
377
|
+
f'(attempts={len(retry_entries)}), retrying in '
|
378
|
+
f'{wait_interval} seconds...'
|
379
|
+
)
|
245
380
|
|
246
|
-
|
247
|
-
|
381
|
+
time.sleep(wait_interval)
|
382
|
+
else:
|
383
|
+
errors = [e.error for e in retry_entries]
|
384
|
+
# First wait interval is 0.
|
385
|
+
wait_intervals = [e.wait_interval for e in retry_entries[1:]]
|
386
|
+
raise RetryError(self.func, errors, wait_intervals)
|
387
|
+
|
388
|
+
def __call__(self) -> 'Job':
|
389
|
+
if getattr(self, '_has_call', False):
|
390
|
+
raise ValueError('Job can only be called once.')
|
391
|
+
self._has_call = True
|
248
392
|
try:
|
249
|
-
|
250
|
-
|
251
|
-
|
393
|
+
with self.timeit:
|
394
|
+
if self.retry_on_errors:
|
395
|
+
return self._retry_call()
|
396
|
+
self.result = self.func(*self.args, **self.kwargs)
|
397
|
+
except BaseException as e: # pylint: disable=broad-exception-caught
|
252
398
|
self.error = e
|
253
|
-
|
254
|
-
finally:
|
255
|
-
self.end_time = time.time()
|
399
|
+
return self
|
256
400
|
|
257
|
-
def mark_canceled(self, error:
|
401
|
+
def mark_canceled(self, error: BaseException) -> None:
|
258
402
|
"""Marks the job as canceled."""
|
403
|
+
self.timeit.end(error)
|
259
404
|
self.error = error
|
260
|
-
self.end_time = time.time()
|
261
|
-
|
262
|
-
@property
|
263
|
-
def elapse(self) -> float:
|
264
|
-
"""Returns the running time in seconds since the job get started."""
|
265
|
-
if self.start_time is None:
|
266
|
-
return 0.0
|
267
|
-
if self.end_time is None:
|
268
|
-
return time.time() - self.start_time
|
269
|
-
return self.end_time - self.start_time
|
270
405
|
|
271
406
|
|
272
407
|
@dataclasses.dataclass
|
@@ -276,9 +411,12 @@ class Progress:
|
|
276
411
|
|
277
412
|
_succeeded: int = 0
|
278
413
|
_failed: int = 0
|
279
|
-
_last_error:
|
414
|
+
_last_error: BaseException | None = None
|
280
415
|
_total_duration: float = 0.0
|
281
416
|
_job: Job | None = None
|
417
|
+
_timeit_summary: pg.object_utils.TimeIt.StatusSummary = dataclasses.field(
|
418
|
+
default_factory=pg.object_utils.TimeIt.StatusSummary
|
419
|
+
)
|
282
420
|
|
283
421
|
@property
|
284
422
|
def succeeded(self) -> int:
|
@@ -296,7 +434,7 @@ class Progress:
|
|
296
434
|
return self.succeeded + self.failed
|
297
435
|
|
298
436
|
@property
|
299
|
-
def last_error(self) ->
|
437
|
+
def last_error(self) -> BaseException | None:
|
300
438
|
"""Returns last error."""
|
301
439
|
return self._last_error
|
302
440
|
|
@@ -326,6 +464,28 @@ class Progress:
|
|
326
464
|
return 0.0
|
327
465
|
return self._total_duration / self.completed
|
328
466
|
|
467
|
+
@property
|
468
|
+
def timeit_summary(self) -> pg.object_utils.TimeIt.StatusSummary:
|
469
|
+
"""Returns the aggregated summary for each `pg.timeit`."""
|
470
|
+
return self._timeit_summary
|
471
|
+
|
472
|
+
def timeit_summary_str(self) -> str | None:
|
473
|
+
if not self.timeit_summary:
|
474
|
+
return None
|
475
|
+
return ', '.join([
|
476
|
+
'%s (%.2fs, %d/%d)' % (
|
477
|
+
k.lstrip('job.'), v.avg_duration, v.num_ended, v.num_started
|
478
|
+
) for k, v in self.timeit_summary.breakdown.items() if k != 'job'
|
479
|
+
])
|
480
|
+
|
481
|
+
def last_error_str(self) -> str | None:
|
482
|
+
if self.last_error is None:
|
483
|
+
return None
|
484
|
+
error_text = repr(self.last_error)
|
485
|
+
if len(error_text) >= 64:
|
486
|
+
error_text = error_text[:64] + '...'
|
487
|
+
return error_text
|
488
|
+
|
329
489
|
def update(self, job: Job) -> None:
|
330
490
|
"""Mark a job as completed."""
|
331
491
|
self._job = job
|
@@ -335,6 +495,7 @@ class Progress:
|
|
335
495
|
self._failed += 1
|
336
496
|
self._last_error = job.error
|
337
497
|
self._total_duration += job.elapse
|
498
|
+
self._timeit_summary.aggregate(job.timeit.status())
|
338
499
|
|
339
500
|
|
340
501
|
class ProgressBar:
|
@@ -356,17 +517,17 @@ class ProgressBar:
|
|
356
517
|
label: str | None
|
357
518
|
total: int
|
358
519
|
color: str | None = None
|
359
|
-
|
520
|
+
status: dict[str, Any] | None = None
|
360
521
|
|
361
522
|
@dataclasses.dataclass
|
362
523
|
class Update:
|
363
524
|
"""Progress bar update."""
|
364
525
|
bar_id: int
|
365
526
|
delta: int
|
366
|
-
|
527
|
+
status: Union[dict[str, Any], str, None] = None
|
367
528
|
color: str | None = None
|
368
529
|
|
369
|
-
_progress_bars: dict[int,
|
530
|
+
_progress_bars: dict[int, '_ProgressControl'] = {}
|
370
531
|
_install_requests: list[tuple[int, Settings]] = []
|
371
532
|
_updates: collections.deque[Update] = collections.deque()
|
372
533
|
_uninstall_requests: list[int] = []
|
@@ -378,11 +539,11 @@ class ProgressBar:
|
|
378
539
|
label: str | None,
|
379
540
|
total: int,
|
380
541
|
color: str | None = None,
|
381
|
-
|
542
|
+
status: dict[str, Any] | None = None,
|
382
543
|
) -> int:
|
383
544
|
"""Installs a progress bar and returns a reference id."""
|
384
545
|
with cls._lock:
|
385
|
-
settings = ProgressBar.Settings(label, total, color,
|
546
|
+
settings = ProgressBar.Settings(label, total, color, status)
|
386
547
|
bar_id = id(settings)
|
387
548
|
cls._install_requests.append((bar_id, settings))
|
388
549
|
return bar_id
|
@@ -392,15 +553,17 @@ class ProgressBar:
|
|
392
553
|
cls,
|
393
554
|
bar_id: int,
|
394
555
|
delta: int = 0,
|
395
|
-
|
556
|
+
status: Union[dict[str, Any], str, None] = None,
|
396
557
|
color: str | None = None,
|
397
558
|
refresh: bool = True,
|
398
559
|
) -> None:
|
399
560
|
"""Report the progress for a label."""
|
561
|
+
if status is not None and not isinstance(status, (str, dict)):
|
562
|
+
raise ValueError(f'Unsupported status: {status}')
|
400
563
|
with cls._lock:
|
401
564
|
cls._updates.append(
|
402
565
|
ProgressBar.Update(
|
403
|
-
bar_id=bar_id, delta=delta,
|
566
|
+
bar_id=bar_id, delta=delta, status=status, color=color,
|
404
567
|
)
|
405
568
|
)
|
406
569
|
if refresh:
|
@@ -422,11 +585,11 @@ class ProgressBar:
|
|
422
585
|
# Process install requests.
|
423
586
|
if cls._install_requests:
|
424
587
|
for bar_id, settings in cls._install_requests:
|
425
|
-
cls._progress_bars[bar_id] =
|
588
|
+
cls._progress_bars[bar_id] = _progress_control(
|
426
589
|
total=settings.total,
|
427
|
-
|
428
|
-
|
429
|
-
|
590
|
+
label=settings.label,
|
591
|
+
color=settings.color,
|
592
|
+
status=settings.status)
|
430
593
|
cls._install_requests.clear()
|
431
594
|
|
432
595
|
# Process updates.
|
@@ -441,15 +604,11 @@ class ProgressBar:
|
|
441
604
|
if update.delta > 0:
|
442
605
|
bar.update(update.delta)
|
443
606
|
|
444
|
-
if
|
445
|
-
bar.
|
446
|
-
elif isinstance(update.postfix, dict):
|
447
|
-
bar.set_postfix(update.postfix, refresh=False)
|
448
|
-
elif update.postfix is not None:
|
449
|
-
raise ValueError(f'Unsupported postfix: {update.postfix}')
|
607
|
+
if update.status is not None:
|
608
|
+
bar.set_status(update.status)
|
450
609
|
|
451
610
|
if update.color is not None:
|
452
|
-
bar.
|
611
|
+
bar.set_color(update.color)
|
453
612
|
updated_bars.add(bar)
|
454
613
|
|
455
614
|
# Refresh each updated bar just once.
|
@@ -459,7 +618,9 @@ class ProgressBar:
|
|
459
618
|
# Process uninstall requests.
|
460
619
|
if cls._uninstall_requests:
|
461
620
|
for bar_id in cls._uninstall_requests:
|
462
|
-
cls._progress_bars.pop(bar_id, None)
|
621
|
+
bar = cls._progress_bars.pop(bar_id, None)
|
622
|
+
if bar is not None:
|
623
|
+
bar.close()
|
463
624
|
cls._uninstall_requests.clear()
|
464
625
|
|
465
626
|
|
@@ -486,17 +647,18 @@ def concurrent_map(
|
|
486
647
|
status_fn: Callable[[Progress], dict[str, Any]] | None = None,
|
487
648
|
timeout: int | None = None,
|
488
649
|
silence_on_errors: Union[
|
489
|
-
Type[
|
650
|
+
Type[BaseException], Tuple[Type[BaseException], ...], None
|
490
651
|
] = Exception,
|
491
652
|
retry_on_errors: Union[
|
492
|
-
Type[
|
493
|
-
Tuple[Type[
|
653
|
+
Type[BaseException],
|
654
|
+
Tuple[Type[BaseException], ...],
|
494
655
|
None,
|
495
656
|
] = None,
|
496
657
|
max_attempts: int = 5,
|
497
658
|
retry_interval: int | tuple[int, int] = (5, 60),
|
498
659
|
exponential_backoff: bool = True,
|
499
|
-
|
660
|
+
return_jobs: bool = False,
|
661
|
+
) -> Iterator[Any]:
|
500
662
|
"""Maps inputs to outptus via func concurrently under current context.
|
501
663
|
|
502
664
|
Args:
|
@@ -539,9 +701,10 @@ def concurrent_map(
|
|
539
701
|
of the tuple.
|
540
702
|
exponential_backoff: If True, exponential wait time will be applied on top
|
541
703
|
of the base retry interval.
|
704
|
+
return_jobs: If True, the returned iterator will emit `Job` objects.
|
542
705
|
|
543
706
|
Yields:
|
544
|
-
An iterator of (input, output, error).
|
707
|
+
An iterator of (input, output, error) or Job object.
|
545
708
|
|
546
709
|
Raises:
|
547
710
|
Exception: Errors that are not in `silence_on_errors` or `retry_on_errors`,
|
@@ -551,15 +714,6 @@ def concurrent_map(
|
|
551
714
|
"""
|
552
715
|
# Internal usage logging.
|
553
716
|
|
554
|
-
if retry_on_errors:
|
555
|
-
func = with_retry(
|
556
|
-
func,
|
557
|
-
retry_on_errors,
|
558
|
-
max_attempts=max_attempts,
|
559
|
-
retry_interval=retry_interval,
|
560
|
-
exponential_backoff=exponential_backoff,
|
561
|
-
)
|
562
|
-
|
563
717
|
status_fn = status_fn or (lambda p: { # pylint: disable=g-long-lambda
|
564
718
|
'Succeeded': '%.2f%% (%d/%d)' % (
|
565
719
|
p.success_rate * 100, p.succeeded, p.completed),
|
@@ -574,7 +728,14 @@ def concurrent_map(
|
|
574
728
|
pending_futures = []
|
575
729
|
total = 0
|
576
730
|
for inputs in parallel_inputs:
|
577
|
-
job = Job(
|
731
|
+
job = Job(
|
732
|
+
func,
|
733
|
+
(inputs,),
|
734
|
+
retry_on_errors=retry_on_errors,
|
735
|
+
max_attempts=max_attempts,
|
736
|
+
retry_interval=retry_interval,
|
737
|
+
exponential_backoff=exponential_backoff,
|
738
|
+
)
|
578
739
|
future = executor.submit(
|
579
740
|
with_context_access(job),
|
580
741
|
)
|
@@ -596,14 +757,14 @@ def concurrent_map(
|
|
596
757
|
if show_progress:
|
597
758
|
status = status_fn(progress)
|
598
759
|
status.update({
|
599
|
-
'AvgDuration': '%.
|
760
|
+
'AvgDuration': '%.2fs' % progress.avg_duration
|
600
761
|
})
|
601
762
|
if progress.last_error is not None:
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
status['
|
606
|
-
ProgressBar.update(bar_id, delta=1,
|
763
|
+
status['LastError'] = progress.last_error_str()
|
764
|
+
|
765
|
+
if progress.timeit_summary:
|
766
|
+
status['TimeIt'] = progress.timeit_summary_str()
|
767
|
+
ProgressBar.update(bar_id, delta=1, status=status)
|
607
768
|
|
608
769
|
try:
|
609
770
|
if ordered:
|
@@ -627,7 +788,7 @@ def concurrent_map(
|
|
627
788
|
silence_on_errors and isinstance(job.error, silence_on_errors)):
|
628
789
|
raise job.error # pylint: disable=g-doc-exception
|
629
790
|
|
630
|
-
yield job.
|
791
|
+
yield job if return_jobs else job.args[0], job.result, job.error
|
631
792
|
progress.update(job)
|
632
793
|
update_progress_bar(progress)
|
633
794
|
ProgressBar.refresh()
|
@@ -648,7 +809,7 @@ def concurrent_map(
|
|
648
809
|
if job.error is not None and not (
|
649
810
|
silence_on_errors and isinstance(job.error, silence_on_errors)):
|
650
811
|
raise job.error # pylint: disable=g-doc-exception
|
651
|
-
yield job.
|
812
|
+
yield job if return_jobs else job.args[0], job.result, job.error
|
652
813
|
progress.update(job)
|
653
814
|
update_progress_bar(progress)
|
654
815
|
completed_batch.add(future)
|
@@ -671,7 +832,7 @@ def concurrent_map(
|
|
671
832
|
and isinstance(job.error, silence_on_errors)):
|
672
833
|
raise job.error # pylint: disable=g-doc-exception
|
673
834
|
|
674
|
-
yield job.
|
835
|
+
yield job.args[0], job.result, job.error
|
675
836
|
progress.update(job)
|
676
837
|
update_progress_bar(progress)
|
677
838
|
else:
|
@@ -729,5 +890,141 @@ class ExecutorPool:
|
|
729
890
|
raise ValueError(f'Unsupported value: {maybe_executor}.')
|
730
891
|
|
731
892
|
|
893
|
+
class _ProgressControl(pg.Object):
|
894
|
+
"""Abstract progress control."""
|
895
|
+
# Disable symbolic comparison so the hash is based on object address.
|
896
|
+
use_symbolic_comparison = False
|
897
|
+
|
898
|
+
total: int
|
899
|
+
label: str | None
|
900
|
+
color: str | None
|
901
|
+
status: str | dict[str, Any] | None
|
902
|
+
|
903
|
+
def set_color(self, color: str | None):
|
904
|
+
with pg.notify_on_change(False):
|
905
|
+
self.rebind(color=color)
|
906
|
+
|
907
|
+
def set_status(self, status: str | dict[str, Any] | None):
|
908
|
+
with pg.notify_on_change(False):
|
909
|
+
self.rebind(status=status)
|
910
|
+
|
911
|
+
@abc.abstractmethod
|
912
|
+
def update(self, delta):
|
913
|
+
"""Update progress."""
|
914
|
+
|
915
|
+
@abc.abstractmethod
|
916
|
+
def refresh(self) -> None:
|
917
|
+
"""Refresh progress bar."""
|
918
|
+
|
919
|
+
@abc.abstractmethod
|
920
|
+
def close(self) -> None:
|
921
|
+
"""Close progress bar."""
|
922
|
+
|
923
|
+
|
924
|
+
class _TqdmProgressControl(_ProgressControl):
|
925
|
+
"""Tqdm-based progress control."""
|
926
|
+
|
927
|
+
def _on_bound(self):
|
928
|
+
super()._on_bound()
|
929
|
+
assert tqdm is not None
|
930
|
+
self._tqdm = tqdm.tqdm(
|
931
|
+
total=self.total,
|
932
|
+
desc=self.label,
|
933
|
+
colour=self.color,
|
934
|
+
postfix=self.status,
|
935
|
+
)
|
936
|
+
|
937
|
+
def update(self, delta: int) -> None:
|
938
|
+
self._tqdm.update(delta)
|
939
|
+
|
940
|
+
def refresh(self):
|
941
|
+
self._tqdm.set_description(self.label, refresh=False)
|
942
|
+
if isinstance(self.status, str):
|
943
|
+
self._tqdm.set_postfix_str(self.status, refresh=False)
|
944
|
+
else:
|
945
|
+
self._tqdm.set_postfix(self.status, refresh=False)
|
946
|
+
self._tqdm.colour = self.color
|
947
|
+
self._tqdm.refresh()
|
948
|
+
|
949
|
+
def close(self):
|
950
|
+
self._tqdm.close()
|
951
|
+
|
952
|
+
|
953
|
+
class _ConsoleProgressControl(_ProgressControl):
|
954
|
+
"""Simple progress control by printing the status to the console."""
|
955
|
+
|
956
|
+
def _on_bound(self):
|
957
|
+
super()._on_bound()
|
958
|
+
self._progress = 0
|
959
|
+
|
960
|
+
def update(self, delta: int) -> None:
|
961
|
+
self._progress += delta
|
962
|
+
|
963
|
+
def refresh(self):
|
964
|
+
s = io.StringIO()
|
965
|
+
if self.label is not None:
|
966
|
+
s.write(pg.colored(self.label, 'red', styles=['bold']))
|
967
|
+
s.write(': ')
|
968
|
+
s.write(
|
969
|
+
pg.colored(
|
970
|
+
'%d%% (%d/%d)' %
|
971
|
+
(
|
972
|
+
self._progress * 100 // self.total,
|
973
|
+
self._progress,
|
974
|
+
self.total,
|
975
|
+
),
|
976
|
+
color=self.color or 'green'
|
977
|
+
)
|
978
|
+
)
|
979
|
+
if self.status is not None:
|
980
|
+
status = repr(self.status) if isinstance(
|
981
|
+
self.status, dict) else self.status
|
982
|
+
s.write(f' : {status}')
|
983
|
+
sys.stderr.write(s.getvalue() + '\n')
|
984
|
+
|
985
|
+
def close(self):
|
986
|
+
sys.stderr.flush()
|
987
|
+
|
988
|
+
|
989
|
+
class _NoopProgressControl(_ProgressControl):
|
990
|
+
"""No-op progress control."""
|
991
|
+
|
992
|
+
def update(self, delta: int) -> None:
|
993
|
+
pass
|
994
|
+
|
995
|
+
def refresh(self) -> None:
|
996
|
+
pass
|
997
|
+
|
998
|
+
def close(self) -> None:
|
999
|
+
pass
|
1000
|
+
|
1001
|
+
|
1002
|
+
def _progress_control(
|
1003
|
+
total: int,
|
1004
|
+
label: str | None,
|
1005
|
+
color: str | None,
|
1006
|
+
status: str | dict[str, Any] | None,
|
1007
|
+
) -> _ProgressControl:
|
1008
|
+
"""Creates a process control."""
|
1009
|
+
if progress_bar == 'tqdm':
|
1010
|
+
if not tqdm:
|
1011
|
+
raise RuntimeError(
|
1012
|
+
'Please install package "tqdm" to use `tqdm` progress bar.'
|
1013
|
+
)
|
1014
|
+
return _TqdmProgressControl(total, label, color, status)
|
1015
|
+
elif progress_bar == 'console':
|
1016
|
+
return _ConsoleProgressControl(total, label, color, status)
|
1017
|
+
elif progress_bar is None:
|
1018
|
+
return _NoopProgressControl(total, label, color, status)
|
1019
|
+
else:
|
1020
|
+
raise ValueError(f'Unsupported progress bar type: {progress_bar}')
|
1021
|
+
|
1022
|
+
|
1023
|
+
def get_executor(
|
1024
|
+
resource_id: str,
|
1025
|
+
max_workers: int | None = None) -> concurrent.futures.ThreadPoolExecutor:
|
1026
|
+
"""Gets a thread pool executor associated with a resource id."""
|
1027
|
+
return _executor_pool.get(resource_id, max_workers)
|
1028
|
+
|
732
1029
|
# The global executor pool based on resource IDs.
|
733
1030
|
_executor_pool = ExecutorPool()
|