langfun 0.1.1.dev20240826__py3-none-any.whl → 0.1.1.dev202408282153__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/concurrent.py +148 -21
- langfun/core/concurrent_test.py +58 -8
- langfun/core/eval/base.py +8 -8
- langfun/core/eval/base_test.py +1 -1
- langfun/core/langfunc.py +3 -0
- langfun/core/langfunc_test.py +6 -3
- langfun/core/language_model.py +10 -0
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +7 -0
- langfun/core/llms/cache/in_memory_test.py +45 -0
- langfun/core/llms/google_genai.py +40 -13
- langfun/core/llms/openai.py +34 -17
- langfun/core/llms/vertexai.py +30 -9
- langfun/core/message.py +10 -0
- langfun/core/message_test.py +14 -0
- langfun/core/modalities/image.py +17 -5
- langfun/core/modalities/mime.py +13 -4
- langfun/core/modalities/ms_office.py +17 -8
- langfun/core/modality.py +12 -0
- langfun/core/modality_test.py +1 -0
- langfun/core/structured/mapping.py +6 -0
- langfun/core/structured/prompting_test.py +12 -8
- langfun/core/text_formatting.py +7 -1
- langfun/core/text_formatting_test.py +18 -0
- {langfun-0.1.1.dev20240826.dist-info → langfun-0.1.1.dev202408282153.dist-info}/METADATA +78 -14
- {langfun-0.1.1.dev20240826.dist-info → langfun-0.1.1.dev202408282153.dist-info}/RECORD +29 -29
- {langfun-0.1.1.dev20240826.dist-info → langfun-0.1.1.dev202408282153.dist-info}/WHEEL +1 -1
- {langfun-0.1.1.dev20240826.dist-info → langfun-0.1.1.dev202408282153.dist-info}/LICENSE +0 -0
- {langfun-0.1.1.dev20240826.dist-info → langfun-0.1.1.dev202408282153.dist-info}/top_level.txt +0 -0
langfun/core/concurrent.py
CHANGED
@@ -13,17 +13,30 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Utility library for handling concurrency in langfun."""
|
15
15
|
|
16
|
+
import abc
|
16
17
|
import collections
|
17
18
|
import concurrent.futures
|
18
19
|
import dataclasses
|
20
|
+
import io
|
19
21
|
import random
|
22
|
+
import sys
|
20
23
|
import threading
|
21
24
|
import time
|
22
25
|
from typing import Any, Callable, Iterable, Iterator, Literal, Sequence, Tuple, Type, Union
|
23
26
|
|
24
27
|
from langfun.core import component
|
28
|
+
from langfun.core import text_formatting
|
25
29
|
import pyglove as pg
|
26
|
-
|
30
|
+
|
31
|
+
|
32
|
+
progress_bar: Literal['tqdm', 'console', None] = None
|
33
|
+
|
34
|
+
try:
|
35
|
+
from tqdm import auto as tqdm # pylint: disable=g-import-not-at-top
|
36
|
+
progress_bar = 'tqdm'
|
37
|
+
except ImportError as e:
|
38
|
+
progress_bar = 'console'
|
39
|
+
tqdm = None
|
27
40
|
|
28
41
|
|
29
42
|
def with_context_access(func: Callable[..., Any]) -> Callable[..., Any]:
|
@@ -142,7 +155,6 @@ def with_retry(
|
|
142
155
|
attempt = 1
|
143
156
|
return base_interval() * (2 ** (attempt - 1))
|
144
157
|
|
145
|
-
wait_interval = None
|
146
158
|
wait_intervals = []
|
147
159
|
errors = []
|
148
160
|
while True:
|
@@ -356,17 +368,17 @@ class ProgressBar:
|
|
356
368
|
label: str | None
|
357
369
|
total: int
|
358
370
|
color: str | None = None
|
359
|
-
|
371
|
+
status: dict[str, Any] | None = None
|
360
372
|
|
361
373
|
@dataclasses.dataclass
|
362
374
|
class Update:
|
363
375
|
"""Progress bar update."""
|
364
376
|
bar_id: int
|
365
377
|
delta: int
|
366
|
-
|
378
|
+
status: Union[dict[str, Any], str, None] = None
|
367
379
|
color: str | None = None
|
368
380
|
|
369
|
-
_progress_bars: dict[int,
|
381
|
+
_progress_bars: dict[int, '_ProgressControl'] = {}
|
370
382
|
_install_requests: list[tuple[int, Settings]] = []
|
371
383
|
_updates: collections.deque[Update] = collections.deque()
|
372
384
|
_uninstall_requests: list[int] = []
|
@@ -378,11 +390,11 @@ class ProgressBar:
|
|
378
390
|
label: str | None,
|
379
391
|
total: int,
|
380
392
|
color: str | None = None,
|
381
|
-
|
393
|
+
status: dict[str, Any] | None = None,
|
382
394
|
) -> int:
|
383
395
|
"""Installs a progress bar and returns a reference id."""
|
384
396
|
with cls._lock:
|
385
|
-
settings = ProgressBar.Settings(label, total, color,
|
397
|
+
settings = ProgressBar.Settings(label, total, color, status)
|
386
398
|
bar_id = id(settings)
|
387
399
|
cls._install_requests.append((bar_id, settings))
|
388
400
|
return bar_id
|
@@ -392,15 +404,17 @@ class ProgressBar:
|
|
392
404
|
cls,
|
393
405
|
bar_id: int,
|
394
406
|
delta: int = 0,
|
395
|
-
|
407
|
+
status: Union[dict[str, Any], str, None] = None,
|
396
408
|
color: str | None = None,
|
397
409
|
refresh: bool = True,
|
398
410
|
) -> None:
|
399
411
|
"""Report the progress for a label."""
|
412
|
+
if status is not None and not isinstance(status, (str, dict)):
|
413
|
+
raise ValueError(f'Unsupported status: {status}')
|
400
414
|
with cls._lock:
|
401
415
|
cls._updates.append(
|
402
416
|
ProgressBar.Update(
|
403
|
-
bar_id=bar_id, delta=delta,
|
417
|
+
bar_id=bar_id, delta=delta, status=status, color=color,
|
404
418
|
)
|
405
419
|
)
|
406
420
|
if refresh:
|
@@ -422,11 +436,11 @@ class ProgressBar:
|
|
422
436
|
# Process install requests.
|
423
437
|
if cls._install_requests:
|
424
438
|
for bar_id, settings in cls._install_requests:
|
425
|
-
cls._progress_bars[bar_id] =
|
439
|
+
cls._progress_bars[bar_id] = _progress_control(
|
426
440
|
total=settings.total,
|
427
|
-
|
428
|
-
|
429
|
-
|
441
|
+
label=settings.label,
|
442
|
+
color=settings.color,
|
443
|
+
status=settings.status)
|
430
444
|
cls._install_requests.clear()
|
431
445
|
|
432
446
|
# Process updates.
|
@@ -441,15 +455,11 @@ class ProgressBar:
|
|
441
455
|
if update.delta > 0:
|
442
456
|
bar.update(update.delta)
|
443
457
|
|
444
|
-
if
|
445
|
-
bar.
|
446
|
-
elif isinstance(update.postfix, dict):
|
447
|
-
bar.set_postfix(update.postfix, refresh=False)
|
448
|
-
elif update.postfix is not None:
|
449
|
-
raise ValueError(f'Unsupported postfix: {update.postfix}')
|
458
|
+
if update.status is not None:
|
459
|
+
bar.set_status(update.status)
|
450
460
|
|
451
461
|
if update.color is not None:
|
452
|
-
bar.
|
462
|
+
bar.set_color(update.color)
|
453
463
|
updated_bars.add(bar)
|
454
464
|
|
455
465
|
# Refresh each updated bar just once.
|
@@ -603,7 +613,7 @@ def concurrent_map(
|
|
603
613
|
if len(error_text) >= 64:
|
604
614
|
error_text = error_text[:64] + '...'
|
605
615
|
status['LastError'] = error_text
|
606
|
-
ProgressBar.update(bar_id, delta=1,
|
616
|
+
ProgressBar.update(bar_id, delta=1, status=status)
|
607
617
|
|
608
618
|
try:
|
609
619
|
if ordered:
|
@@ -729,5 +739,122 @@ class ExecutorPool:
|
|
729
739
|
raise ValueError(f'Unsupported value: {maybe_executor}.')
|
730
740
|
|
731
741
|
|
742
|
+
class _ProgressControl(pg.Object):
|
743
|
+
"""Abstract progress control."""
|
744
|
+
# Disable symbolic comparison so the hash is based on object address.
|
745
|
+
use_symbolic_comparison = False
|
746
|
+
|
747
|
+
total: int
|
748
|
+
label: str | None
|
749
|
+
color: str | None
|
750
|
+
status: str | dict[str, Any] | None
|
751
|
+
|
752
|
+
def set_color(self, color: str | None):
|
753
|
+
with pg.notify_on_change(False):
|
754
|
+
self.rebind(color=color)
|
755
|
+
|
756
|
+
def set_status(self, status: str | dict[str, Any] | None):
|
757
|
+
with pg.notify_on_change(False):
|
758
|
+
self.rebind(status=status)
|
759
|
+
|
760
|
+
@abc.abstractmethod
|
761
|
+
def update(self, delta):
|
762
|
+
"""Update progress."""
|
763
|
+
|
764
|
+
@abc.abstractmethod
|
765
|
+
def refresh(self) -> None:
|
766
|
+
"""Refresh progress bar."""
|
767
|
+
|
768
|
+
|
769
|
+
class _TqdmProgressControl(_ProgressControl):
|
770
|
+
"""Tqdm-based progress control."""
|
771
|
+
|
772
|
+
def _on_bound(self):
|
773
|
+
super()._on_bound()
|
774
|
+
assert tqdm is not None
|
775
|
+
self._tqdm = tqdm.tqdm(
|
776
|
+
total=self.total,
|
777
|
+
desc=self.label,
|
778
|
+
colour=self.color,
|
779
|
+
postfix=self.status,
|
780
|
+
)
|
781
|
+
|
782
|
+
def update(self, delta: int) -> None:
|
783
|
+
self._tqdm.update(delta)
|
784
|
+
|
785
|
+
def refresh(self):
|
786
|
+
self._tqdm.set_description(self.label, refresh=False)
|
787
|
+
if isinstance(self.status, str):
|
788
|
+
self._tqdm.set_postfix_str(self.status, refresh=False)
|
789
|
+
else:
|
790
|
+
self._tqdm.set_postfix(self.status, refresh=False)
|
791
|
+
self._tqdm.colour = self.color
|
792
|
+
self._tqdm.refresh()
|
793
|
+
|
794
|
+
|
795
|
+
class _ConsoleProgressControl(_ProgressControl):
|
796
|
+
"""Simple progress control by printing the status to the console."""
|
797
|
+
|
798
|
+
def _on_bound(self):
|
799
|
+
super()._on_bound()
|
800
|
+
self._progress = 0
|
801
|
+
|
802
|
+
def update(self, delta: int) -> None:
|
803
|
+
self._progress += delta
|
804
|
+
|
805
|
+
def refresh(self):
|
806
|
+
s = io.StringIO()
|
807
|
+
if self.label is not None:
|
808
|
+
s.write(text_formatting.colored(self.label, 'red', styles=['bold']))
|
809
|
+
s.write(': ')
|
810
|
+
s.write(
|
811
|
+
text_formatting.colored(
|
812
|
+
'%d%% (%d/%d)' %
|
813
|
+
(
|
814
|
+
self._progress * 100 // self.total,
|
815
|
+
self._progress,
|
816
|
+
self.total,
|
817
|
+
),
|
818
|
+
color=self.color or 'green'
|
819
|
+
)
|
820
|
+
)
|
821
|
+
if self.status is not None:
|
822
|
+
status = repr(self.status) if isinstance(
|
823
|
+
self.status, dict) else self.status
|
824
|
+
s.write(f' : {status}')
|
825
|
+
sys.stderr.write(s.getvalue() + '\n')
|
826
|
+
|
827
|
+
|
828
|
+
class _NoopProgressControl(_ProgressControl):
|
829
|
+
"""No-op progress control."""
|
830
|
+
|
831
|
+
def update(self, delta: int) -> None:
|
832
|
+
pass
|
833
|
+
|
834
|
+
def refresh(self) -> None:
|
835
|
+
pass
|
836
|
+
|
837
|
+
|
838
|
+
def _progress_control(
|
839
|
+
total: int,
|
840
|
+
label: str | None,
|
841
|
+
color: str | None,
|
842
|
+
status: str | dict[str, Any] | None,
|
843
|
+
) -> _ProgressControl:
|
844
|
+
"""Creates a process control."""
|
845
|
+
if progress_bar == 'tqdm':
|
846
|
+
if not tqdm:
|
847
|
+
raise RuntimeError(
|
848
|
+
'Please install package "tqdm" to use `tqdm` progress bar.'
|
849
|
+
)
|
850
|
+
return _TqdmProgressControl(total, label, color, status)
|
851
|
+
elif progress_bar == 'console':
|
852
|
+
return _ConsoleProgressControl(total, label, color, status)
|
853
|
+
elif progress_bar is None:
|
854
|
+
return _NoopProgressControl(total, label, color, status)
|
855
|
+
else:
|
856
|
+
raise ValueError(f'Unsupported progress bar type: {progress_bar}')
|
857
|
+
|
858
|
+
|
732
859
|
# The global executor pool based on resource IDs.
|
733
860
|
_executor_pool = ExecutorPool()
|
langfun/core/concurrent_test.py
CHANGED
@@ -233,6 +233,57 @@ class ProgressTest(unittest.TestCase):
|
|
233
233
|
self.assertIs(p.last_error, job2.error)
|
234
234
|
|
235
235
|
|
236
|
+
class ProgressControlTest(unittest.TestCase):
|
237
|
+
|
238
|
+
def test_noop(self):
|
239
|
+
concurrent.progress_bar = None
|
240
|
+
ctrl = concurrent._progress_control(100, 'noop', 'blue', None)
|
241
|
+
self.assertIsInstance(ctrl, concurrent._NoopProgressControl)
|
242
|
+
string_io = io.StringIO()
|
243
|
+
with contextlib.redirect_stderr(string_io):
|
244
|
+
ctrl.update(1)
|
245
|
+
ctrl.refresh()
|
246
|
+
self.assertEqual(string_io.getvalue(), '')
|
247
|
+
concurrent.progress_bar = 'tqdm'
|
248
|
+
|
249
|
+
def test_console(self):
|
250
|
+
concurrent.progress_bar = 'console'
|
251
|
+
ctrl = concurrent._progress_control(100, 'foo', 'blue', None)
|
252
|
+
self.assertIsInstance(ctrl, concurrent._ConsoleProgressControl)
|
253
|
+
string_io = io.StringIO()
|
254
|
+
with contextlib.redirect_stderr(string_io):
|
255
|
+
ctrl.set_status('bar')
|
256
|
+
ctrl.update(10)
|
257
|
+
ctrl.refresh()
|
258
|
+
self.assertEqual(
|
259
|
+
string_io.getvalue(),
|
260
|
+
'\x1b[1m\x1b[31mfoo\x1b[0m: \x1b[34m10% (10/100)\x1b[0m : bar\n'
|
261
|
+
)
|
262
|
+
concurrent.progress_bar = 'tqdm'
|
263
|
+
|
264
|
+
def test_tqdm(self):
|
265
|
+
concurrent.progress_bar = 'tqdm'
|
266
|
+
string_io = io.StringIO()
|
267
|
+
with contextlib.redirect_stderr(string_io):
|
268
|
+
ctrl = concurrent._progress_control(100, 'foo', 'blue', None)
|
269
|
+
self.assertIsInstance(ctrl, concurrent._TqdmProgressControl)
|
270
|
+
ctrl.update(10)
|
271
|
+
ctrl.refresh()
|
272
|
+
self.assertIn('10/100', string_io.getvalue())
|
273
|
+
|
274
|
+
tqdm = concurrent.tqdm
|
275
|
+
concurrent.tqdm = None
|
276
|
+
with self.assertRaisesRegex(RuntimeError, 'install package "tqdm"'):
|
277
|
+
_ = concurrent._progress_control(100, 'foo', 'blue', None)
|
278
|
+
concurrent.tqdm = tqdm
|
279
|
+
|
280
|
+
def test_unsupported(self):
|
281
|
+
concurrent.progress_bar = 'unknown'
|
282
|
+
with self.assertRaisesRegex(ValueError, 'Unsupported progress bar type'):
|
283
|
+
_ = concurrent._progress_control(100, 'foo', 'blue', None)
|
284
|
+
concurrent.progress_bar = 'tqdm'
|
285
|
+
|
286
|
+
|
236
287
|
class ProgressBarTest(unittest.TestCase):
|
237
288
|
|
238
289
|
def test_multithread_support(self):
|
@@ -241,13 +292,12 @@ class ProgressBarTest(unittest.TestCase):
|
|
241
292
|
bar_id = concurrent.ProgressBar.install(None, 5)
|
242
293
|
def fun(x):
|
243
294
|
del x
|
244
|
-
concurrent.ProgressBar.update(bar_id, 1,
|
295
|
+
concurrent.ProgressBar.update(bar_id, 1, status=None)
|
245
296
|
|
246
297
|
for _ in concurrent.concurrent_execute(fun, range(5)):
|
247
298
|
concurrent.ProgressBar.refresh()
|
248
299
|
concurrent.ProgressBar.uninstall(bar_id)
|
249
300
|
output_str = string_io.getvalue()
|
250
|
-
print(output_str)
|
251
301
|
self.assertIn('100%', output_str)
|
252
302
|
self.assertIn('5/5', output_str)
|
253
303
|
|
@@ -255,12 +305,12 @@ class ProgressBarTest(unittest.TestCase):
|
|
255
305
|
string_io = io.StringIO()
|
256
306
|
with contextlib.redirect_stderr(string_io):
|
257
307
|
bar_id = concurrent.ProgressBar.install(None, 4)
|
258
|
-
concurrent.ProgressBar.update(bar_id, 1,
|
259
|
-
concurrent.ProgressBar.update(bar_id, 1,
|
260
|
-
concurrent.ProgressBar.update(bar_id, color='
|
261
|
-
concurrent.ProgressBar.update(bar_id, 2,
|
262
|
-
with self.assertRaisesRegex(ValueError, 'Unsupported
|
263
|
-
concurrent.ProgressBar.update(bar_id, 0,
|
308
|
+
concurrent.ProgressBar.update(bar_id, 1, status=None)
|
309
|
+
concurrent.ProgressBar.update(bar_id, 1, status='hello')
|
310
|
+
concurrent.ProgressBar.update(bar_id, color='green')
|
311
|
+
concurrent.ProgressBar.update(bar_id, 2, status=dict(x=1))
|
312
|
+
with self.assertRaisesRegex(ValueError, 'Unsupported status'):
|
313
|
+
concurrent.ProgressBar.update(bar_id, 0, status=1)
|
264
314
|
concurrent.ProgressBar.uninstall(bar_id)
|
265
315
|
self.assertIn('1/4', string_io.getvalue())
|
266
316
|
self.assertIn('2/4', string_io.getvalue())
|
langfun/core/eval/base.py
CHANGED
@@ -242,7 +242,7 @@ class Evaluable(lf.Component):
|
|
242
242
|
):
|
243
243
|
if show_progress:
|
244
244
|
lf.concurrent.ProgressBar.update(
|
245
|
-
progress_bar,
|
245
|
+
progress_bar, status='LOADING SAVED RESULTS...', color='yellow'
|
246
246
|
)
|
247
247
|
if self.try_load_result():
|
248
248
|
run_status = 'CACHED'
|
@@ -271,7 +271,7 @@ class Evaluable(lf.Component):
|
|
271
271
|
if should_save:
|
272
272
|
if show_progress:
|
273
273
|
lf.concurrent.ProgressBar.update(
|
274
|
-
progress_bar,
|
274
|
+
progress_bar, status='SAVING RESULTS...', color='yellow'
|
275
275
|
)
|
276
276
|
|
277
277
|
# Save evaluation results.
|
@@ -284,7 +284,7 @@ class Evaluable(lf.Component):
|
|
284
284
|
if show_progress:
|
285
285
|
lf.concurrent.ProgressBar.update(
|
286
286
|
progress_bar,
|
287
|
-
|
287
|
+
status=self._completion_status(run_status),
|
288
288
|
color='green',
|
289
289
|
)
|
290
290
|
else:
|
@@ -340,7 +340,7 @@ class Evaluable(lf.Component):
|
|
340
340
|
f'[#{leaf.index} - {leaf.node.id}]',
|
341
341
|
total=leaf.node.num_examples if leaf.enabled else 0,
|
342
342
|
color='cyan' if leaf.enabled else 'yellow',
|
343
|
-
|
343
|
+
status=None if leaf.enabled else 'SKIPPED.')
|
344
344
|
|
345
345
|
# Run leaf groups in parallel.
|
346
346
|
try:
|
@@ -354,7 +354,7 @@ class Evaluable(lf.Component):
|
|
354
354
|
# Save results for non-leaf nodes.
|
355
355
|
lf.concurrent.ProgressBar.update(
|
356
356
|
overview_bar,
|
357
|
-
|
357
|
+
status='SAVING RESULTS...',
|
358
358
|
color='yellow')
|
359
359
|
|
360
360
|
for node in self.nonleaf_nodes:
|
@@ -364,7 +364,7 @@ class Evaluable(lf.Component):
|
|
364
364
|
|
365
365
|
if should_save and summary:
|
366
366
|
lf.concurrent.ProgressBar.update(
|
367
|
-
overview_bar,
|
367
|
+
overview_bar, status='FINALIZING SUMMARY...'
|
368
368
|
)
|
369
369
|
|
370
370
|
summary.save(os.path.join(self.root_dir, Evaluable.SUMMARY_HTML))
|
@@ -378,7 +378,7 @@ class Evaluable(lf.Component):
|
|
378
378
|
# Signal all task completed by making the bar green.
|
379
379
|
lf.concurrent.ProgressBar.update(
|
380
380
|
overview_bar,
|
381
|
-
|
381
|
+
status='COMPLETED',
|
382
382
|
color='green')
|
383
383
|
|
384
384
|
finally:
|
@@ -1261,7 +1261,7 @@ class Evaluation(Evaluable):
|
|
1261
1261
|
|
1262
1262
|
def finalize(self) -> pg.Dict:
|
1263
1263
|
"""Finalizes the evaluation result."""
|
1264
|
-
if self.cache:
|
1264
|
+
if self.cache is not None:
|
1265
1265
|
cache_stats = dict(
|
1266
1266
|
use_cache=True,
|
1267
1267
|
num_queries=self.cache.stats.num_queries,
|
langfun/core/eval/base_test.py
CHANGED
@@ -564,7 +564,7 @@ class SuiteTest(unittest.TestCase):
|
|
564
564
|
schema_fn='answer_schema()',
|
565
565
|
),
|
566
566
|
cache_stats=dict(
|
567
|
-
use_cache=True, num_queries=4, num_hits=
|
567
|
+
use_cache=True, num_queries=4, num_hits=0, num_updates=4
|
568
568
|
),
|
569
569
|
metrics=dict(
|
570
570
|
total=2,
|
langfun/core/langfunc.py
CHANGED
@@ -269,6 +269,9 @@ class LangFunc(
|
|
269
269
|
# Send rendered text to LM.
|
270
270
|
lm_output = self.lm(lm_input, cache_seed=cache_seed)
|
271
271
|
|
272
|
+
# Attach cache seed.
|
273
|
+
lm_input.metadata.cache_seed = cache_seed
|
274
|
+
|
272
275
|
# Transform the output message.
|
273
276
|
lm_output = self.transform_output(lm_output)
|
274
277
|
lm_output.tag(message_lib.Message.TAG_LM_OUTPUT)
|
langfun/core/langfunc_test.py
CHANGED
@@ -94,11 +94,13 @@ class LangFuncCallTest(unittest.TestCase):
|
|
94
94
|
)
|
95
95
|
)
|
96
96
|
self.assertEqual(r.tags, ['lm-response', 'lm-output'])
|
97
|
-
self.assertEqual(
|
97
|
+
self.assertEqual(
|
98
|
+
r.source,
|
99
|
+
message.UserMessage('Hello', metadata=dict(cache_seed=0))
|
100
|
+
)
|
98
101
|
self.assertEqual(r.source.tags, ['rendered', 'lm-input'])
|
99
102
|
|
100
103
|
self.assertEqual(str(l), 'Hello')
|
101
|
-
print(repr(l))
|
102
104
|
self.assertEqual(
|
103
105
|
repr(l),
|
104
106
|
"LangFunc(template_str='Hello', clean=True,"
|
@@ -114,7 +116,7 @@ class LangFuncCallTest(unittest.TestCase):
|
|
114
116
|
self.assertEqual(l, 'Hello')
|
115
117
|
self.assertEqual(l.natural_language_format(), 'Hello')
|
116
118
|
self.assertEqual(l.render(), 'Hello')
|
117
|
-
r = l()
|
119
|
+
r = l(cache_seed=1)
|
118
120
|
self.assertEqual(
|
119
121
|
r,
|
120
122
|
message.AIMessage(
|
@@ -123,6 +125,7 @@ class LangFuncCallTest(unittest.TestCase):
|
|
123
125
|
)
|
124
126
|
)
|
125
127
|
self.assertEqual(r.tags, ['lm-response', 'lm-output'])
|
128
|
+
self.assertEqual(r.source.metadata.cache_seed, 1)
|
126
129
|
|
127
130
|
self.assertEqual(str(l), 'Hello')
|
128
131
|
|
langfun/core/language_model.py
CHANGED
@@ -234,6 +234,7 @@ class LMCache(pg.Object):
|
|
234
234
|
num_hit_expires: int = 0
|
235
235
|
num_misses: int = 0
|
236
236
|
num_updates: int = 0
|
237
|
+
num_deletes: int = 0
|
237
238
|
|
238
239
|
@abc.abstractmethod
|
239
240
|
def get(
|
@@ -251,6 +252,15 @@ class LMCache(pg.Object):
|
|
251
252
|
) -> None:
|
252
253
|
"""Puts the result of a prompt generated by a language model in cache."""
|
253
254
|
|
255
|
+
@abc.abstractmethod
|
256
|
+
def delete(
|
257
|
+
self,
|
258
|
+
lm: 'LanguageModel',
|
259
|
+
prompt: message_lib.Message,
|
260
|
+
seed: int,
|
261
|
+
) -> bool:
|
262
|
+
"""Deletes the result of a prompt generated by a language model in cache."""
|
263
|
+
|
254
264
|
@property
|
255
265
|
@abc.abstractmethod
|
256
266
|
def stats(self) -> Stats:
|
langfun/core/llms/cache/base.py
CHANGED
@@ -60,13 +60,16 @@ class LMCacheBase(lf.LMCache):
|
|
60
60
|
self, lm: lf.LanguageModel, prompt: lf.Message, seed: int
|
61
61
|
) -> lf.LMSamplingResult | None:
|
62
62
|
"""Gets the cached result of a prompt generated by a language model."""
|
63
|
-
|
63
|
+
key = self._key(lm, prompt, seed)
|
64
|
+
entry = self._get(lm.model_id, key)
|
64
65
|
self._stats.num_queries += 1
|
65
66
|
if entry is None:
|
66
67
|
self._stats.num_misses += 1
|
67
68
|
return None
|
68
69
|
if entry.expire is not None and entry.expire < datetime.datetime.now():
|
69
70
|
self._stats.num_hit_expires += 1
|
71
|
+
self._stats.num_deletes += 1
|
72
|
+
assert self._delete(lm.model_id, key)
|
70
73
|
return None
|
71
74
|
self._stats.num_hits += 1
|
72
75
|
return entry.result
|
@@ -86,6 +89,18 @@ class LMCacheBase(lf.LMCache):
|
|
86
89
|
self._put(lm.model_id, self._key(lm, prompt, seed), entry)
|
87
90
|
self._stats.num_updates += 1
|
88
91
|
|
92
|
+
def delete(
|
93
|
+
self,
|
94
|
+
lm: lf.LanguageModel,
|
95
|
+
prompt: lf.Message,
|
96
|
+
seed: int,
|
97
|
+
) -> bool:
|
98
|
+
"""Deletes the result of a prompt generated by a language model in cache."""
|
99
|
+
deleted = self._delete(lm.model_id, self._key(lm, prompt, seed))
|
100
|
+
if deleted:
|
101
|
+
self._stats.num_deletes += 1
|
102
|
+
return deleted
|
103
|
+
|
89
104
|
@abc.abstractmethod
|
90
105
|
def _get(self, model_id: str, key: str) -> LMCacheEntry | None:
|
91
106
|
"""Returns a LM cache entry associated with the key."""
|
@@ -94,6 +109,10 @@ class LMCacheBase(lf.LMCache):
|
|
94
109
|
def _put(self, model_id: str, key: str, entry: LMCacheEntry) -> None:
|
95
110
|
"""Puts a LM cache entry associated with the key."""
|
96
111
|
|
112
|
+
@abc.abstractmethod
|
113
|
+
def _delete(self, model_id: str, key: str) -> bool:
|
114
|
+
"""Deletes a LM cache entry associated with the key."""
|
115
|
+
|
97
116
|
def _sym_clone(self, deep: bool, memo: Any = None) -> 'LMCacheBase':
|
98
117
|
v = super()._sym_clone(deep, memo)
|
99
118
|
v._stats = self._stats # pylint: disable=protected-access
|
@@ -102,4 +121,4 @@ class LMCacheBase(lf.LMCache):
|
|
102
121
|
|
103
122
|
def default_key(lm: lf.LanguageModel, prompt: lf.Message, seed: int) -> Any:
|
104
123
|
"""Default key for LM cache."""
|
105
|
-
return (prompt.
|
124
|
+
return (prompt.text_with_modality_hash, lm.sampling_options.cache_key(), seed)
|
@@ -99,6 +99,13 @@ class InMemory(base.LMCacheBase):
|
|
99
99
|
"""Puts a LM cache entry associated with the key."""
|
100
100
|
self._cache[model_id][key] = entry
|
101
101
|
|
102
|
+
def _delete(self, model_id: str, key: str) -> bool:
|
103
|
+
"""Deletes a LM cache entry associated with the key."""
|
104
|
+
model_cache = self._cache.get(model_id, None)
|
105
|
+
if model_cache is None:
|
106
|
+
return False
|
107
|
+
return model_cache.pop(key, None) is not None
|
108
|
+
|
102
109
|
def reset(self, model_id: str | None = None) -> None:
|
103
110
|
"""Resets the cache."""
|
104
111
|
if model_id is not None:
|
@@ -148,6 +148,50 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
148
148
|
self.assertIs(copy.deepcopy(cache)._cache, cache._cache)
|
149
149
|
self.assertIs(copy.deepcopy(cache)._stats, cache._stats)
|
150
150
|
|
151
|
+
self.assertFalse(
|
152
|
+
cache.delete(fake.StaticResponse('hi'), lf.UserMessage('c'), seed=0)
|
153
|
+
)
|
154
|
+
self.assertFalse(cache.delete(lm, lf.UserMessage('c'), seed=1))
|
155
|
+
self.assertFalse(cache.delete(lm, lf.UserMessage('d'), seed=0))
|
156
|
+
self.assertTrue(cache.delete(lm, lf.UserMessage('c'), seed=0))
|
157
|
+
self.assertEqual(
|
158
|
+
list(cache.keys('StaticSequence')),
|
159
|
+
[
|
160
|
+
('a', (None, None, 1, 40, None, None), 0),
|
161
|
+
('a', (None, None, 1, 40, None, None), 1),
|
162
|
+
('b', (None, None, 1, 40, None, None), 0),
|
163
|
+
],
|
164
|
+
)
|
165
|
+
self.assertEqual(cache.stats.num_deletes, 1)
|
166
|
+
|
167
|
+
def test_cache_with_modalities(self):
|
168
|
+
|
169
|
+
class CustomModality(lf.Modality):
|
170
|
+
content: str
|
171
|
+
|
172
|
+
def to_bytes(self):
|
173
|
+
return self.content.encode()
|
174
|
+
|
175
|
+
cache = in_memory.InMemory()
|
176
|
+
lm = fake.StaticSequence(['1', '2', '3', '4', '5', '6'], cache=cache)
|
177
|
+
lm(lf.UserMessage('hi <<[[image]]>>', image=CustomModality('foo')))
|
178
|
+
lm(lf.UserMessage('hi <<[[image]]>>', image=CustomModality('bar')))
|
179
|
+
self.assertEqual(
|
180
|
+
list(cache.keys()),
|
181
|
+
[
|
182
|
+
(
|
183
|
+
'hi <<[[image]]>><image>acbd18db</image>',
|
184
|
+
(None, None, 1, 40, None, None),
|
185
|
+
0,
|
186
|
+
),
|
187
|
+
(
|
188
|
+
'hi <<[[image]]>><image>37b51d19</image>',
|
189
|
+
(None, None, 1, 40, None, None),
|
190
|
+
0,
|
191
|
+
),
|
192
|
+
],
|
193
|
+
)
|
194
|
+
|
151
195
|
def test_ttl(self):
|
152
196
|
cache = in_memory.InMemory(ttl=1)
|
153
197
|
lm = fake.StaticSequence(['1', '2', '3'], cache=cache)
|
@@ -160,6 +204,7 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
160
204
|
self.assertEqual(cache.stats.num_hits, 1)
|
161
205
|
self.assertEqual(cache.stats.num_hit_expires, 1)
|
162
206
|
self.assertEqual(cache.stats.num_misses, 1)
|
207
|
+
self.assertEqual(cache.stats.num_deletes, 1)
|
163
208
|
|
164
209
|
def test_different_sampling_options(self):
|
165
210
|
cache = in_memory.InMemory()
|