langfun 0.1.1.dev20240826__py3-none-any.whl → 0.1.1.dev202408282153__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,17 +13,30 @@
13
13
  # limitations under the License.
14
14
  """Utility library for handling concurrency in langfun."""
15
15
 
16
+ import abc
16
17
  import collections
17
18
  import concurrent.futures
18
19
  import dataclasses
20
+ import io
19
21
  import random
22
+ import sys
20
23
  import threading
21
24
  import time
22
25
  from typing import Any, Callable, Iterable, Iterator, Literal, Sequence, Tuple, Type, Union
23
26
 
24
27
  from langfun.core import component
28
+ from langfun.core import text_formatting
25
29
  import pyglove as pg
26
- from tqdm import auto as tqdm
30
+
31
+
32
+ progress_bar: Literal['tqdm', 'console', None] = None
33
+
34
+ try:
35
+ from tqdm import auto as tqdm # pylint: disable=g-import-not-at-top
36
+ progress_bar = 'tqdm'
37
+ except ImportError as e:
38
+ progress_bar = 'console'
39
+ tqdm = None
27
40
 
28
41
 
29
42
  def with_context_access(func: Callable[..., Any]) -> Callable[..., Any]:
@@ -142,7 +155,6 @@ def with_retry(
142
155
  attempt = 1
143
156
  return base_interval() * (2 ** (attempt - 1))
144
157
 
145
- wait_interval = None
146
158
  wait_intervals = []
147
159
  errors = []
148
160
  while True:
@@ -356,17 +368,17 @@ class ProgressBar:
356
368
  label: str | None
357
369
  total: int
358
370
  color: str | None = None
359
- postfix: dict[str, str] | None = None
371
+ status: dict[str, Any] | None = None
360
372
 
361
373
  @dataclasses.dataclass
362
374
  class Update:
363
375
  """Progress bar update."""
364
376
  bar_id: int
365
377
  delta: int
366
- postfix: Union[dict[str, str], str, None] = None
378
+ status: Union[dict[str, Any], str, None] = None
367
379
  color: str | None = None
368
380
 
369
- _progress_bars: dict[int, tqdm.tqdm] = {}
381
+ _progress_bars: dict[int, '_ProgressControl'] = {}
370
382
  _install_requests: list[tuple[int, Settings]] = []
371
383
  _updates: collections.deque[Update] = collections.deque()
372
384
  _uninstall_requests: list[int] = []
@@ -378,11 +390,11 @@ class ProgressBar:
378
390
  label: str | None,
379
391
  total: int,
380
392
  color: str | None = None,
381
- postfix: dict[str, str] | None = None,
393
+ status: dict[str, Any] | None = None,
382
394
  ) -> int:
383
395
  """Installs a progress bar and returns a reference id."""
384
396
  with cls._lock:
385
- settings = ProgressBar.Settings(label, total, color, postfix)
397
+ settings = ProgressBar.Settings(label, total, color, status)
386
398
  bar_id = id(settings)
387
399
  cls._install_requests.append((bar_id, settings))
388
400
  return bar_id
@@ -392,15 +404,17 @@ class ProgressBar:
392
404
  cls,
393
405
  bar_id: int,
394
406
  delta: int = 0,
395
- postfix: Union[dict[str, str], str, None] = None,
407
+ status: Union[dict[str, Any], str, None] = None,
396
408
  color: str | None = None,
397
409
  refresh: bool = True,
398
410
  ) -> None:
399
411
  """Report the progress for a label."""
412
+ if status is not None and not isinstance(status, (str, dict)):
413
+ raise ValueError(f'Unsupported status: {status}')
400
414
  with cls._lock:
401
415
  cls._updates.append(
402
416
  ProgressBar.Update(
403
- bar_id=bar_id, delta=delta, postfix=postfix, color=color,
417
+ bar_id=bar_id, delta=delta, status=status, color=color,
404
418
  )
405
419
  )
406
420
  if refresh:
@@ -422,11 +436,11 @@ class ProgressBar:
422
436
  # Process install requests.
423
437
  if cls._install_requests:
424
438
  for bar_id, settings in cls._install_requests:
425
- cls._progress_bars[bar_id] = tqdm.tqdm(
439
+ cls._progress_bars[bar_id] = _progress_control(
426
440
  total=settings.total,
427
- desc=settings.label,
428
- colour=settings.color,
429
- postfix=settings.postfix)
441
+ label=settings.label,
442
+ color=settings.color,
443
+ status=settings.status)
430
444
  cls._install_requests.clear()
431
445
 
432
446
  # Process updates.
@@ -441,15 +455,11 @@ class ProgressBar:
441
455
  if update.delta > 0:
442
456
  bar.update(update.delta)
443
457
 
444
- if isinstance(update.postfix, str):
445
- bar.set_postfix_str(update.postfix, refresh=False)
446
- elif isinstance(update.postfix, dict):
447
- bar.set_postfix(update.postfix, refresh=False)
448
- elif update.postfix is not None:
449
- raise ValueError(f'Unsupported postfix: {update.postfix}')
458
+ if update.status is not None:
459
+ bar.set_status(update.status)
450
460
 
451
461
  if update.color is not None:
452
- bar.colour = update.color
462
+ bar.set_color(update.color)
453
463
  updated_bars.add(bar)
454
464
 
455
465
  # Refresh each updated bar just once.
@@ -603,7 +613,7 @@ def concurrent_map(
603
613
  if len(error_text) >= 64:
604
614
  error_text = error_text[:64] + '...'
605
615
  status['LastError'] = error_text
606
- ProgressBar.update(bar_id, delta=1, postfix=status)
616
+ ProgressBar.update(bar_id, delta=1, status=status)
607
617
 
608
618
  try:
609
619
  if ordered:
@@ -729,5 +739,122 @@ class ExecutorPool:
729
739
  raise ValueError(f'Unsupported value: {maybe_executor}.')
730
740
 
731
741
 
742
+ class _ProgressControl(pg.Object):
743
+ """Abstract progress control."""
744
+ # Disable symbolic comparison so the hash is based on object address.
745
+ use_symbolic_comparison = False
746
+
747
+ total: int
748
+ label: str | None
749
+ color: str | None
750
+ status: str | dict[str, Any] | None
751
+
752
+ def set_color(self, color: str | None):
753
+ with pg.notify_on_change(False):
754
+ self.rebind(color=color)
755
+
756
+ def set_status(self, status: str | dict[str, Any] | None):
757
+ with pg.notify_on_change(False):
758
+ self.rebind(status=status)
759
+
760
+ @abc.abstractmethod
761
+ def update(self, delta):
762
+ """Update progress."""
763
+
764
+ @abc.abstractmethod
765
+ def refresh(self) -> None:
766
+ """Refresh progress bar."""
767
+
768
+
769
+ class _TqdmProgressControl(_ProgressControl):
770
+ """Tqdm-based progress control."""
771
+
772
+ def _on_bound(self):
773
+ super()._on_bound()
774
+ assert tqdm is not None
775
+ self._tqdm = tqdm.tqdm(
776
+ total=self.total,
777
+ desc=self.label,
778
+ colour=self.color,
779
+ postfix=self.status,
780
+ )
781
+
782
+ def update(self, delta: int) -> None:
783
+ self._tqdm.update(delta)
784
+
785
+ def refresh(self):
786
+ self._tqdm.set_description(self.label, refresh=False)
787
+ if isinstance(self.status, str):
788
+ self._tqdm.set_postfix_str(self.status, refresh=False)
789
+ else:
790
+ self._tqdm.set_postfix(self.status, refresh=False)
791
+ self._tqdm.colour = self.color
792
+ self._tqdm.refresh()
793
+
794
+
795
+ class _ConsoleProgressControl(_ProgressControl):
796
+ """Simple progress control by printing the status to the console."""
797
+
798
+ def _on_bound(self):
799
+ super()._on_bound()
800
+ self._progress = 0
801
+
802
+ def update(self, delta: int) -> None:
803
+ self._progress += delta
804
+
805
+ def refresh(self):
806
+ s = io.StringIO()
807
+ if self.label is not None:
808
+ s.write(text_formatting.colored(self.label, 'red', styles=['bold']))
809
+ s.write(': ')
810
+ s.write(
811
+ text_formatting.colored(
812
+ '%d%% (%d/%d)' %
813
+ (
814
+ self._progress * 100 // self.total,
815
+ self._progress,
816
+ self.total,
817
+ ),
818
+ color=self.color or 'green'
819
+ )
820
+ )
821
+ if self.status is not None:
822
+ status = repr(self.status) if isinstance(
823
+ self.status, dict) else self.status
824
+ s.write(f' : {status}')
825
+ sys.stderr.write(s.getvalue() + '\n')
826
+
827
+
828
+ class _NoopProgressControl(_ProgressControl):
829
+ """No-op progress control."""
830
+
831
+ def update(self, delta: int) -> None:
832
+ pass
833
+
834
+ def refresh(self) -> None:
835
+ pass
836
+
837
+
838
+ def _progress_control(
839
+ total: int,
840
+ label: str | None,
841
+ color: str | None,
842
+ status: str | dict[str, Any] | None,
843
+ ) -> _ProgressControl:
844
+ """Creates a process control."""
845
+ if progress_bar == 'tqdm':
846
+ if not tqdm:
847
+ raise RuntimeError(
848
+ 'Please install package "tqdm" to use `tqdm` progress bar.'
849
+ )
850
+ return _TqdmProgressControl(total, label, color, status)
851
+ elif progress_bar == 'console':
852
+ return _ConsoleProgressControl(total, label, color, status)
853
+ elif progress_bar is None:
854
+ return _NoopProgressControl(total, label, color, status)
855
+ else:
856
+ raise ValueError(f'Unsupported progress bar type: {progress_bar}')
857
+
858
+
732
859
  # The global executor pool based on resource IDs.
733
860
  _executor_pool = ExecutorPool()
@@ -233,6 +233,57 @@ class ProgressTest(unittest.TestCase):
233
233
  self.assertIs(p.last_error, job2.error)
234
234
 
235
235
 
236
+ class ProgressControlTest(unittest.TestCase):
237
+
238
+ def test_noop(self):
239
+ concurrent.progress_bar = None
240
+ ctrl = concurrent._progress_control(100, 'noop', 'blue', None)
241
+ self.assertIsInstance(ctrl, concurrent._NoopProgressControl)
242
+ string_io = io.StringIO()
243
+ with contextlib.redirect_stderr(string_io):
244
+ ctrl.update(1)
245
+ ctrl.refresh()
246
+ self.assertEqual(string_io.getvalue(), '')
247
+ concurrent.progress_bar = 'tqdm'
248
+
249
+ def test_console(self):
250
+ concurrent.progress_bar = 'console'
251
+ ctrl = concurrent._progress_control(100, 'foo', 'blue', None)
252
+ self.assertIsInstance(ctrl, concurrent._ConsoleProgressControl)
253
+ string_io = io.StringIO()
254
+ with contextlib.redirect_stderr(string_io):
255
+ ctrl.set_status('bar')
256
+ ctrl.update(10)
257
+ ctrl.refresh()
258
+ self.assertEqual(
259
+ string_io.getvalue(),
260
+ '\x1b[1m\x1b[31mfoo\x1b[0m: \x1b[34m10% (10/100)\x1b[0m : bar\n'
261
+ )
262
+ concurrent.progress_bar = 'tqdm'
263
+
264
+ def test_tqdm(self):
265
+ concurrent.progress_bar = 'tqdm'
266
+ string_io = io.StringIO()
267
+ with contextlib.redirect_stderr(string_io):
268
+ ctrl = concurrent._progress_control(100, 'foo', 'blue', None)
269
+ self.assertIsInstance(ctrl, concurrent._TqdmProgressControl)
270
+ ctrl.update(10)
271
+ ctrl.refresh()
272
+ self.assertIn('10/100', string_io.getvalue())
273
+
274
+ tqdm = concurrent.tqdm
275
+ concurrent.tqdm = None
276
+ with self.assertRaisesRegex(RuntimeError, 'install package "tqdm"'):
277
+ _ = concurrent._progress_control(100, 'foo', 'blue', None)
278
+ concurrent.tqdm = tqdm
279
+
280
+ def test_unsupported(self):
281
+ concurrent.progress_bar = 'unknown'
282
+ with self.assertRaisesRegex(ValueError, 'Unsupported progress bar type'):
283
+ _ = concurrent._progress_control(100, 'foo', 'blue', None)
284
+ concurrent.progress_bar = 'tqdm'
285
+
286
+
236
287
  class ProgressBarTest(unittest.TestCase):
237
288
 
238
289
  def test_multithread_support(self):
@@ -241,13 +292,12 @@ class ProgressBarTest(unittest.TestCase):
241
292
  bar_id = concurrent.ProgressBar.install(None, 5)
242
293
  def fun(x):
243
294
  del x
244
- concurrent.ProgressBar.update(bar_id, 1, postfix=None)
295
+ concurrent.ProgressBar.update(bar_id, 1, status=None)
245
296
 
246
297
  for _ in concurrent.concurrent_execute(fun, range(5)):
247
298
  concurrent.ProgressBar.refresh()
248
299
  concurrent.ProgressBar.uninstall(bar_id)
249
300
  output_str = string_io.getvalue()
250
- print(output_str)
251
301
  self.assertIn('100%', output_str)
252
302
  self.assertIn('5/5', output_str)
253
303
 
@@ -255,12 +305,12 @@ class ProgressBarTest(unittest.TestCase):
255
305
  string_io = io.StringIO()
256
306
  with contextlib.redirect_stderr(string_io):
257
307
  bar_id = concurrent.ProgressBar.install(None, 4)
258
- concurrent.ProgressBar.update(bar_id, 1, postfix=None)
259
- concurrent.ProgressBar.update(bar_id, 1, postfix='hello')
260
- concurrent.ProgressBar.update(bar_id, color='lightgreen')
261
- concurrent.ProgressBar.update(bar_id, 2, postfix=dict(x=1))
262
- with self.assertRaisesRegex(ValueError, 'Unsupported postfix'):
263
- concurrent.ProgressBar.update(bar_id, 0, postfix=1)
308
+ concurrent.ProgressBar.update(bar_id, 1, status=None)
309
+ concurrent.ProgressBar.update(bar_id, 1, status='hello')
310
+ concurrent.ProgressBar.update(bar_id, color='green')
311
+ concurrent.ProgressBar.update(bar_id, 2, status=dict(x=1))
312
+ with self.assertRaisesRegex(ValueError, 'Unsupported status'):
313
+ concurrent.ProgressBar.update(bar_id, 0, status=1)
264
314
  concurrent.ProgressBar.uninstall(bar_id)
265
315
  self.assertIn('1/4', string_io.getvalue())
266
316
  self.assertIn('2/4', string_io.getvalue())
langfun/core/eval/base.py CHANGED
@@ -242,7 +242,7 @@ class Evaluable(lf.Component):
242
242
  ):
243
243
  if show_progress:
244
244
  lf.concurrent.ProgressBar.update(
245
- progress_bar, postfix='LOADING SAVED RESULTS...', color='yellow'
245
+ progress_bar, status='LOADING SAVED RESULTS...', color='yellow'
246
246
  )
247
247
  if self.try_load_result():
248
248
  run_status = 'CACHED'
@@ -271,7 +271,7 @@ class Evaluable(lf.Component):
271
271
  if should_save:
272
272
  if show_progress:
273
273
  lf.concurrent.ProgressBar.update(
274
- progress_bar, postfix='SAVING RESULTS...', color='yellow'
274
+ progress_bar, status='SAVING RESULTS...', color='yellow'
275
275
  )
276
276
 
277
277
  # Save evaluation results.
@@ -284,7 +284,7 @@ class Evaluable(lf.Component):
284
284
  if show_progress:
285
285
  lf.concurrent.ProgressBar.update(
286
286
  progress_bar,
287
- postfix=self._completion_status(run_status),
287
+ status=self._completion_status(run_status),
288
288
  color='green',
289
289
  )
290
290
  else:
@@ -340,7 +340,7 @@ class Evaluable(lf.Component):
340
340
  f'[#{leaf.index} - {leaf.node.id}]',
341
341
  total=leaf.node.num_examples if leaf.enabled else 0,
342
342
  color='cyan' if leaf.enabled else 'yellow',
343
- postfix=None if leaf.enabled else 'SKIPPED.')
343
+ status=None if leaf.enabled else 'SKIPPED.')
344
344
 
345
345
  # Run leaf groups in parallel.
346
346
  try:
@@ -354,7 +354,7 @@ class Evaluable(lf.Component):
354
354
  # Save results for non-leaf nodes.
355
355
  lf.concurrent.ProgressBar.update(
356
356
  overview_bar,
357
- postfix='SAVING RESULTS...',
357
+ status='SAVING RESULTS...',
358
358
  color='yellow')
359
359
 
360
360
  for node in self.nonleaf_nodes:
@@ -364,7 +364,7 @@ class Evaluable(lf.Component):
364
364
 
365
365
  if should_save and summary:
366
366
  lf.concurrent.ProgressBar.update(
367
- overview_bar, postfix='FINALIZING SUMMARY...'
367
+ overview_bar, status='FINALIZING SUMMARY...'
368
368
  )
369
369
 
370
370
  summary.save(os.path.join(self.root_dir, Evaluable.SUMMARY_HTML))
@@ -378,7 +378,7 @@ class Evaluable(lf.Component):
378
378
  # Signal all task completed by making the bar green.
379
379
  lf.concurrent.ProgressBar.update(
380
380
  overview_bar,
381
- postfix='COMPLETED',
381
+ status='COMPLETED',
382
382
  color='green')
383
383
 
384
384
  finally:
@@ -1261,7 +1261,7 @@ class Evaluation(Evaluable):
1261
1261
 
1262
1262
  def finalize(self) -> pg.Dict:
1263
1263
  """Finalizes the evaluation result."""
1264
- if self.cache:
1264
+ if self.cache is not None:
1265
1265
  cache_stats = dict(
1266
1266
  use_cache=True,
1267
1267
  num_queries=self.cache.stats.num_queries,
@@ -564,7 +564,7 @@ class SuiteTest(unittest.TestCase):
564
564
  schema_fn='answer_schema()',
565
565
  ),
566
566
  cache_stats=dict(
567
- use_cache=True, num_queries=4, num_hits=1, num_updates=3
567
+ use_cache=True, num_queries=4, num_hits=0, num_updates=4
568
568
  ),
569
569
  metrics=dict(
570
570
  total=2,
langfun/core/langfunc.py CHANGED
@@ -269,6 +269,9 @@ class LangFunc(
269
269
  # Send rendered text to LM.
270
270
  lm_output = self.lm(lm_input, cache_seed=cache_seed)
271
271
 
272
+ # Attach cache seed.
273
+ lm_input.metadata.cache_seed = cache_seed
274
+
272
275
  # Transform the output message.
273
276
  lm_output = self.transform_output(lm_output)
274
277
  lm_output.tag(message_lib.Message.TAG_LM_OUTPUT)
@@ -94,11 +94,13 @@ class LangFuncCallTest(unittest.TestCase):
94
94
  )
95
95
  )
96
96
  self.assertEqual(r.tags, ['lm-response', 'lm-output'])
97
- self.assertEqual(r.source, message.UserMessage('Hello'))
97
+ self.assertEqual(
98
+ r.source,
99
+ message.UserMessage('Hello', metadata=dict(cache_seed=0))
100
+ )
98
101
  self.assertEqual(r.source.tags, ['rendered', 'lm-input'])
99
102
 
100
103
  self.assertEqual(str(l), 'Hello')
101
- print(repr(l))
102
104
  self.assertEqual(
103
105
  repr(l),
104
106
  "LangFunc(template_str='Hello', clean=True,"
@@ -114,7 +116,7 @@ class LangFuncCallTest(unittest.TestCase):
114
116
  self.assertEqual(l, 'Hello')
115
117
  self.assertEqual(l.natural_language_format(), 'Hello')
116
118
  self.assertEqual(l.render(), 'Hello')
117
- r = l()
119
+ r = l(cache_seed=1)
118
120
  self.assertEqual(
119
121
  r,
120
122
  message.AIMessage(
@@ -123,6 +125,7 @@ class LangFuncCallTest(unittest.TestCase):
123
125
  )
124
126
  )
125
127
  self.assertEqual(r.tags, ['lm-response', 'lm-output'])
128
+ self.assertEqual(r.source.metadata.cache_seed, 1)
126
129
 
127
130
  self.assertEqual(str(l), 'Hello')
128
131
 
@@ -234,6 +234,7 @@ class LMCache(pg.Object):
234
234
  num_hit_expires: int = 0
235
235
  num_misses: int = 0
236
236
  num_updates: int = 0
237
+ num_deletes: int = 0
237
238
 
238
239
  @abc.abstractmethod
239
240
  def get(
@@ -251,6 +252,15 @@ class LMCache(pg.Object):
251
252
  ) -> None:
252
253
  """Puts the result of a prompt generated by a language model in cache."""
253
254
 
255
+ @abc.abstractmethod
256
+ def delete(
257
+ self,
258
+ lm: 'LanguageModel',
259
+ prompt: message_lib.Message,
260
+ seed: int,
261
+ ) -> bool:
262
+ """Deletes the result of a prompt generated by a language model in cache."""
263
+
254
264
  @property
255
265
  @abc.abstractmethod
256
266
  def stats(self) -> Stats:
@@ -60,13 +60,16 @@ class LMCacheBase(lf.LMCache):
60
60
  self, lm: lf.LanguageModel, prompt: lf.Message, seed: int
61
61
  ) -> lf.LMSamplingResult | None:
62
62
  """Gets the cached result of a prompt generated by a language model."""
63
- entry = self._get(lm.model_id, self._key(lm, prompt, seed))
63
+ key = self._key(lm, prompt, seed)
64
+ entry = self._get(lm.model_id, key)
64
65
  self._stats.num_queries += 1
65
66
  if entry is None:
66
67
  self._stats.num_misses += 1
67
68
  return None
68
69
  if entry.expire is not None and entry.expire < datetime.datetime.now():
69
70
  self._stats.num_hit_expires += 1
71
+ self._stats.num_deletes += 1
72
+ assert self._delete(lm.model_id, key)
70
73
  return None
71
74
  self._stats.num_hits += 1
72
75
  return entry.result
@@ -86,6 +89,18 @@ class LMCacheBase(lf.LMCache):
86
89
  self._put(lm.model_id, self._key(lm, prompt, seed), entry)
87
90
  self._stats.num_updates += 1
88
91
 
92
+ def delete(
93
+ self,
94
+ lm: lf.LanguageModel,
95
+ prompt: lf.Message,
96
+ seed: int,
97
+ ) -> bool:
98
+ """Deletes the result of a prompt generated by a language model in cache."""
99
+ deleted = self._delete(lm.model_id, self._key(lm, prompt, seed))
100
+ if deleted:
101
+ self._stats.num_deletes += 1
102
+ return deleted
103
+
89
104
  @abc.abstractmethod
90
105
  def _get(self, model_id: str, key: str) -> LMCacheEntry | None:
91
106
  """Returns a LM cache entry associated with the key."""
@@ -94,6 +109,10 @@ class LMCacheBase(lf.LMCache):
94
109
  def _put(self, model_id: str, key: str, entry: LMCacheEntry) -> None:
95
110
  """Puts a LM cache entry associated with the key."""
96
111
 
112
+ @abc.abstractmethod
113
+ def _delete(self, model_id: str, key: str) -> bool:
114
+ """Deletes a LM cache entry associated with the key."""
115
+
97
116
  def _sym_clone(self, deep: bool, memo: Any = None) -> 'LMCacheBase':
98
117
  v = super()._sym_clone(deep, memo)
99
118
  v._stats = self._stats # pylint: disable=protected-access
@@ -102,4 +121,4 @@ class LMCacheBase(lf.LMCache):
102
121
 
103
122
  def default_key(lm: lf.LanguageModel, prompt: lf.Message, seed: int) -> Any:
104
123
  """Default key for LM cache."""
105
- return (prompt.text, lm.sampling_options.cache_key(), seed)
124
+ return (prompt.text_with_modality_hash, lm.sampling_options.cache_key(), seed)
@@ -99,6 +99,13 @@ class InMemory(base.LMCacheBase):
99
99
  """Puts a LM cache entry associated with the key."""
100
100
  self._cache[model_id][key] = entry
101
101
 
102
+ def _delete(self, model_id: str, key: str) -> bool:
103
+ """Deletes a LM cache entry associated with the key."""
104
+ model_cache = self._cache.get(model_id, None)
105
+ if model_cache is None:
106
+ return False
107
+ return model_cache.pop(key, None) is not None
108
+
102
109
  def reset(self, model_id: str | None = None) -> None:
103
110
  """Resets the cache."""
104
111
  if model_id is not None:
@@ -148,6 +148,50 @@ class InMemoryLMCacheTest(unittest.TestCase):
148
148
  self.assertIs(copy.deepcopy(cache)._cache, cache._cache)
149
149
  self.assertIs(copy.deepcopy(cache)._stats, cache._stats)
150
150
 
151
+ self.assertFalse(
152
+ cache.delete(fake.StaticResponse('hi'), lf.UserMessage('c'), seed=0)
153
+ )
154
+ self.assertFalse(cache.delete(lm, lf.UserMessage('c'), seed=1))
155
+ self.assertFalse(cache.delete(lm, lf.UserMessage('d'), seed=0))
156
+ self.assertTrue(cache.delete(lm, lf.UserMessage('c'), seed=0))
157
+ self.assertEqual(
158
+ list(cache.keys('StaticSequence')),
159
+ [
160
+ ('a', (None, None, 1, 40, None, None), 0),
161
+ ('a', (None, None, 1, 40, None, None), 1),
162
+ ('b', (None, None, 1, 40, None, None), 0),
163
+ ],
164
+ )
165
+ self.assertEqual(cache.stats.num_deletes, 1)
166
+
167
+ def test_cache_with_modalities(self):
168
+
169
+ class CustomModality(lf.Modality):
170
+ content: str
171
+
172
+ def to_bytes(self):
173
+ return self.content.encode()
174
+
175
+ cache = in_memory.InMemory()
176
+ lm = fake.StaticSequence(['1', '2', '3', '4', '5', '6'], cache=cache)
177
+ lm(lf.UserMessage('hi <<[[image]]>>', image=CustomModality('foo')))
178
+ lm(lf.UserMessage('hi <<[[image]]>>', image=CustomModality('bar')))
179
+ self.assertEqual(
180
+ list(cache.keys()),
181
+ [
182
+ (
183
+ 'hi <<[[image]]>><image>acbd18db</image>',
184
+ (None, None, 1, 40, None, None),
185
+ 0,
186
+ ),
187
+ (
188
+ 'hi <<[[image]]>><image>37b51d19</image>',
189
+ (None, None, 1, 40, None, None),
190
+ 0,
191
+ ),
192
+ ],
193
+ )
194
+
151
195
  def test_ttl(self):
152
196
  cache = in_memory.InMemory(ttl=1)
153
197
  lm = fake.StaticSequence(['1', '2', '3'], cache=cache)
@@ -160,6 +204,7 @@ class InMemoryLMCacheTest(unittest.TestCase):
160
204
  self.assertEqual(cache.stats.num_hits, 1)
161
205
  self.assertEqual(cache.stats.num_hit_expires, 1)
162
206
  self.assertEqual(cache.stats.num_misses, 1)
207
+ self.assertEqual(cache.stats.num_deletes, 1)
163
208
 
164
209
  def test_different_sampling_options(self):
165
210
  cache = in_memory.InMemory()