langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. langfun/__init__.py +22 -2
  2. langfun/core/__init__.py +17 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -28
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +69 -2
  18. langfun/core/component_test.py +54 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +17 -0
  24. langfun/core/eval/base.py +767 -140
  25. langfun/core/eval/base_test.py +238 -53
  26. langfun/core/eval/matching.py +80 -76
  27. langfun/core/eval/matching_test.py +19 -9
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +37 -28
  31. langfun/core/eval/scoring_test.py +21 -3
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +3 -21
  55. langfun/core/langfunc_test.py +26 -8
  56. langfun/core/language_model.py +686 -48
  57. langfun/core/language_model_test.py +681 -44
  58. langfun/core/llms/__init__.py +100 -12
  59. langfun/core/llms/anthropic.py +488 -0
  60. langfun/core/llms/anthropic_test.py +235 -0
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +88 -28
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +39 -26
  69. langfun/core/llms/fake_test.py +136 -11
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -197
  74. langfun/core/llms/groq.py +276 -0
  75. langfun/core/llms/groq_test.py +64 -0
  76. langfun/core/llms/llama_cpp.py +15 -40
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +436 -226
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +35 -174
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -23
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +15 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +9 -8
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +278 -0
  112. langfun/core/structured/function_generation_test.py +399 -0
  113. langfun/core/structured/mapping.py +150 -46
  114. langfun/core/structured/mapping_test.py +105 -0
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +71 -22
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
  119. langfun/core/structured/schema.py +208 -99
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_generation_test.py +2 -2
  122. langfun/core/structured/schema_test.py +133 -34
  123. langfun/core/structured/scoring.py +125 -19
  124. langfun/core/structured/scoring_test.py +30 -0
  125. langfun/core/structured/tokenization.py +64 -0
  126. langfun/core/structured/tokenization_test.py +48 -0
  127. langfun/core/template.py +240 -11
  128. langfun/core/template_test.py +146 -1
  129. langfun/core/templates/conversation.py +9 -0
  130. langfun/core/templates/conversation_test.py +4 -3
  131. langfun/core/templates/selfplay_test.py +14 -2
  132. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  133. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  134. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  135. langfun/core/coding/python/errors.py +0 -108
  136. langfun/core/coding/python/errors_test.py +0 -99
  137. langfun/core/coding/python/permissions.py +0 -90
  138. langfun/core/coding/python/permissions_test.py +0 -86
  139. langfun/core/structured/prompting.py +0 -217
  140. langfun/core/text_formatting.py +0 -162
  141. langfun/core/text_formatting_test.py +0 -47
  142. langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
  143. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  144. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  145. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,235 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for Anthropic models."""
15
+
16
+ import base64
17
+ import os
18
+ from typing import Any
19
+ import unittest
20
+ from unittest import mock
21
+
22
+ from google.auth import exceptions
23
+ from langfun.core import language_model
24
+ from langfun.core import message as lf_message
25
+ from langfun.core import modalities as lf_modalities
26
+ from langfun.core.llms import anthropic
27
+ import pyglove as pg
28
+ import requests
29
+
30
+
31
+ def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
32
+ del url, kwargs
33
+
34
+ response = requests.Response()
35
+ response.status_code = 200
36
+ response._content = pg.to_json_str({
37
+ 'content': [{
38
+ 'type': 'text',
39
+ 'text': (
40
+ f'hello with temperature={json.get("temperature")}, '
41
+ f'top_k={json.get("top_k")}, '
42
+ f'top_p={json.get("top_p")}, '
43
+ f'max_tokens={json.get("max_tokens")}, '
44
+ f'stop={json.get("stop_sequences")}.'
45
+ ),
46
+ }],
47
+ 'usage': {
48
+ 'input_tokens': 2,
49
+ 'output_tokens': 1,
50
+ },
51
+ }).encode()
52
+ return response
53
+
54
+
55
+ image_content = (
56
+ b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
57
+ b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
58
+ b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
59
+ b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
60
+ b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
61
+ b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
62
+ b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
63
+ b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
64
+ )
65
+
66
+ pdf_content = (
67
+ b'%PDF-1.4\n1 0 obj\n<< /Type /Catalog /Pages 2 0 R >>\nendobj\n2 0 obj\n<<'
68
+ b' /Type /Pages /Count 1 /Kids [3 0 R] >>\nendobj\n3 0 obj\n<< /Type /Page'
69
+ b' /Parent 2 0 R /MediaBox [0 0 612 792] /Contents 4 0 R >>\nendobj\n4 0'
70
+ b' obj\n<< /Length 44 >>\nstream\nBT /F1 24 Tf 100 700 Td (Hello, PDF'
71
+ b' content!) Tj ET\nendstream\nendobj\n5 0 obj\n<< /Type /Font /Subtype'
72
+ b' /Type1 /BaseFont /Helvetica >>\nendobj\nxref\n0 6\n0000000000 65535 f'
73
+ b' \n0000000010 00000 n \n0000000079 00000 n \n0000000178 00000 n'
74
+ b' \n0000000278 00000 n \n0000000407 00000 n \ntrailer\n<< /Size 6 /Root 1'
75
+ b' 0 R >>\nstartxref\n517\n%%EOF'
76
+ )
77
+
78
+
79
+ def mock_mm_requests_post(url: str, json: dict[str, Any], **kwargs):
80
+ del url, kwargs
81
+ v = json['messages'][0]['content'][0]
82
+ content = lf_modalities.Mime.from_bytes(base64.b64decode(v['source']['data']))
83
+
84
+ response = requests.Response()
85
+ response.status_code = 200
86
+ response._content = pg.to_json_str({
87
+ 'content': [{
88
+ 'type': 'text',
89
+ 'text': f'{v["type"]}: {content.mime_type}',
90
+ }],
91
+ 'usage': {
92
+ 'input_tokens': 2,
93
+ 'output_tokens': 1,
94
+ },
95
+ }).encode()
96
+ return response
97
+
98
+
99
+ def mock_requests_post_error(status_code, error_type, error_message):
100
+ def _mock_requests(url: str, json: dict[str, Any], **kwargs):
101
+ del url, json, kwargs
102
+ response = requests.Response()
103
+ response.status_code = status_code
104
+ response._content = pg.to_json_str(
105
+ {
106
+ 'error': {
107
+ 'type': error_type,
108
+ 'message': error_message,
109
+ }
110
+ }
111
+ ).encode()
112
+ return response
113
+
114
+ return _mock_requests
115
+
116
+
117
+ class AnthropicTest(unittest.TestCase):
118
+
119
+ def test_basics(self):
120
+ self.assertEqual(
121
+ anthropic.Claude3Haiku().model_id, 'claude-3-haiku-20240307'
122
+ )
123
+ self.assertGreater(anthropic.Claude3Haiku().max_concurrency, 0)
124
+
125
+ def test_api_key(self):
126
+ lm = anthropic.Claude3Haiku()
127
+ with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
128
+ lm('hi')
129
+
130
+ with mock.patch('requests.Session.post') as mock_request:
131
+ mock_request.side_effect = mock_requests_post
132
+
133
+ lm = anthropic.Claude3Haiku(api_key='fake key')
134
+ self.assertRegex(lm('hi').text, 'hello.*')
135
+
136
+ os.environ['ANTHROPIC_API_KEY'] = 'abc'
137
+ lm = anthropic.Claude3Haiku()
138
+ self.assertRegex(lm('hi').text, 'hello.*')
139
+ del os.environ['ANTHROPIC_API_KEY']
140
+
141
+ def test_call(self):
142
+ with mock.patch('requests.Session.post') as mock_request:
143
+ mock_request.side_effect = mock_requests_post
144
+ lm = anthropic.Claude3Haiku(api_key='fake_key')
145
+ response = lm('hello', temperature=0.0, top_k=0.1, top_p=0.2, stop=['\n'])
146
+ self.assertEqual(
147
+ response.text,
148
+ (
149
+ 'hello with temperature=0.0, top_k=0.1, top_p=0.2, '
150
+ "max_tokens=4096, stop=['\\n']."
151
+ ),
152
+ )
153
+ self.assertIsNotNone(response.usage)
154
+ self.assertIsNotNone(response.usage.prompt_tokens, 2)
155
+ self.assertIsNotNone(response.usage.completion_tokens, 1)
156
+ self.assertIsNotNone(response.usage.total_tokens, 3)
157
+
158
+ def test_mm_call(self):
159
+ with mock.patch('requests.Session.post') as mock_mm_request:
160
+ mock_mm_request.side_effect = mock_mm_requests_post
161
+ lm = anthropic.Claude3Haiku(api_key='fake_key')
162
+ response = lm(lf_modalities.Image.from_bytes(image_content), lm=lm)
163
+ self.assertEqual(response.text, 'image: image/png')
164
+
165
+ def test_pdf_call(self):
166
+ with mock.patch('requests.Session.post') as mock_mm_request:
167
+ mock_mm_request.side_effect = mock_mm_requests_post
168
+ lm = anthropic.Claude3Haiku(api_key='fake_key')
169
+ response = lm(lf_modalities.PDF.from_bytes(pdf_content), lm=lm)
170
+ self.assertEqual(response.text, 'document: application/pdf')
171
+
172
+ def test_call_errors(self):
173
+ for status_code, error_type, error_message in [
174
+ (429, 'rate_limit', 'Rate limit exceeded.'),
175
+ (529, 'service_unavailable', 'Service unavailable.'),
176
+ (500, 'bad_request', 'Bad request.'),
177
+ ]:
178
+ with mock.patch('requests.Session.post') as mock_mm_request:
179
+ mock_mm_request.side_effect = mock_requests_post_error(
180
+ status_code, error_type, error_message
181
+ )
182
+ lm = anthropic.Claude3Haiku(api_key='fake_key')
183
+ with self.assertRaisesRegex(
184
+ Exception, f'.*{status_code}: .*{error_message}'
185
+ ):
186
+ lm('hello', max_attempts=1)
187
+
188
+
189
+ class VertexAIAnthropicTest(unittest.TestCase):
190
+ """Tests for VertexAI Anthropic models."""
191
+
192
+ def test_basics(self):
193
+ with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
194
+ lm = anthropic.VertexAIClaude3_5_Sonnet_20241022()
195
+ lm('hi')
196
+
197
+ model = anthropic.VertexAIClaude3_5_Sonnet_20241022(project='langfun')
198
+
199
+ # NOTE(daiyip): For OSS users, default credentials are not available unless
200
+ # users have already set up their GCP project. Therefore we ignore the
201
+ # exception here.
202
+ try:
203
+ model._initialize()
204
+ except exceptions.DefaultCredentialsError:
205
+ pass
206
+
207
+ self.assertEqual(
208
+ model.api_endpoint,
209
+ (
210
+ 'https://us-east5-aiplatform.googleapis.com/v1/projects/'
211
+ 'langfun/locations/us-east5/publishers/anthropic/'
212
+ 'models/claude-3-5-sonnet-v2@20241022:streamRawPredict'
213
+ )
214
+ )
215
+ request = model.request(
216
+ lf_message.UserMessage('hi'),
217
+ language_model.LMSamplingOptions(temperature=0.0),
218
+ )
219
+ self.assertEqual(
220
+ request,
221
+ {
222
+ 'anthropic_version': 'vertex-2023-10-16',
223
+ 'max_tokens': 8192,
224
+ 'messages': [
225
+ {'content': [{'text': 'hi', 'type': 'text'}], 'role': 'user'}
226
+ ],
227
+ 'stream': False,
228
+ 'temperature': 0.0,
229
+ 'top_k': 40,
230
+ },
231
+ )
232
+
233
+
234
+ if __name__ == '__main__':
235
+ unittest.main()
@@ -60,13 +60,16 @@ class LMCacheBase(lf.LMCache):
60
60
  self, lm: lf.LanguageModel, prompt: lf.Message, seed: int
61
61
  ) -> lf.LMSamplingResult | None:
62
62
  """Gets the cached result of a prompt generated by a language model."""
63
- entry = self._get(lm.model_id, self._key(lm, prompt, seed))
63
+ key = self._key(lm, prompt, seed)
64
+ entry = self._get(lm.model_id, key)
64
65
  self._stats.num_queries += 1
65
66
  if entry is None:
66
67
  self._stats.num_misses += 1
67
68
  return None
68
69
  if entry.expire is not None and entry.expire < datetime.datetime.now():
69
70
  self._stats.num_hit_expires += 1
71
+ self._stats.num_deletes += 1
72
+ assert self._delete(lm.model_id, key)
70
73
  return None
71
74
  self._stats.num_hits += 1
72
75
  return entry.result
@@ -86,6 +89,18 @@ class LMCacheBase(lf.LMCache):
86
89
  self._put(lm.model_id, self._key(lm, prompt, seed), entry)
87
90
  self._stats.num_updates += 1
88
91
 
92
+ def delete(
93
+ self,
94
+ lm: lf.LanguageModel,
95
+ prompt: lf.Message,
96
+ seed: int,
97
+ ) -> bool:
98
+ """Deletes the result of a prompt generated by a language model in cache."""
99
+ deleted = self._delete(lm.model_id, self._key(lm, prompt, seed))
100
+ if deleted:
101
+ self._stats.num_deletes += 1
102
+ return deleted
103
+
89
104
  @abc.abstractmethod
90
105
  def _get(self, model_id: str, key: str) -> LMCacheEntry | None:
91
106
  """Returns a LM cache entry associated with the key."""
@@ -94,6 +109,10 @@ class LMCacheBase(lf.LMCache):
94
109
  def _put(self, model_id: str, key: str, entry: LMCacheEntry) -> None:
95
110
  """Puts a LM cache entry associated with the key."""
96
111
 
112
+ @abc.abstractmethod
113
+ def _delete(self, model_id: str, key: str) -> bool:
114
+ """Deletes a LM cache entry associated with the key."""
115
+
97
116
  def _sym_clone(self, deep: bool, memo: Any = None) -> 'LMCacheBase':
98
117
  v = super()._sym_clone(deep, memo)
99
118
  v._stats = self._stats # pylint: disable=protected-access
@@ -102,4 +121,4 @@ class LMCacheBase(lf.LMCache):
102
121
 
103
122
  def default_key(lm: lf.LanguageModel, prompt: lf.Message, seed: int) -> Any:
104
123
  """Default key for LM cache."""
105
- return (prompt.text, lm.sampling_options.cache_key(), seed)
124
+ return (prompt.text_with_modality_hash, lm.sampling_options.cache_key(), seed)
@@ -15,6 +15,7 @@
15
15
 
16
16
  import collections
17
17
  import contextlib
18
+ import json
18
19
  from typing import Annotated, Any, Iterator
19
20
  import langfun.core as lf
20
21
  from langfun.core.llms.cache import base
@@ -49,6 +50,11 @@ class InMemory(base.LMCacheBase):
49
50
  "Creating a new cache as cache file '%s' does not exist.",
50
51
  self.filename,
51
52
  )
53
+ except json.JSONDecodeError:
54
+ pg.logging.warning(
55
+ "Creating a new cache as cache file '%s' is corrupted.",
56
+ self.filename,
57
+ )
52
58
 
53
59
  def model_ids(self) -> list[str]:
54
60
  """Returns the model ids of cached queires."""
@@ -99,6 +105,13 @@ class InMemory(base.LMCacheBase):
99
105
  """Puts a LM cache entry associated with the key."""
100
106
  self._cache[model_id][key] = entry
101
107
 
108
+ def _delete(self, model_id: str, key: str) -> bool:
109
+ """Deletes a LM cache entry associated with the key."""
110
+ model_cache = self._cache.get(model_id, None)
111
+ if model_cache is None:
112
+ return False
113
+ return model_cache.pop(key, None) is not None
114
+
102
115
  def reset(self, model_id: str | None = None) -> None:
103
116
  """Resets the cache."""
104
117
  if model_id is not None:
@@ -44,28 +44,38 @@ class InMemoryLMCacheTest(unittest.TestCase):
44
44
  self.assertEqual(
45
45
  list(cache.keys()),
46
46
  [
47
- ('a', (0.0, 1024, 1, 40, None, None), 0),
48
- ('a', (0.0, 1024, 1, 40, None, None), 1),
49
- ('b', (0.0, 1024, 1, 40, None, None), 0),
50
- ('c', (0.0, 1024, 1, 40, None, None), 0),
47
+ ('a', (None, None, 1, 40, None, None), 0),
48
+ ('a', (None, None, 1, 40, None, None), 1),
49
+ ('b', (None, None, 1, 40, None, None), 0),
50
+ ('c', (None, None, 1, 40, None, None), 0),
51
51
  ],
52
52
  )
53
53
  self.assertEqual(
54
54
  list(cache.keys('StaticSequence')),
55
55
  [
56
- ('a', (0.0, 1024, 1, 40, None, None), 0),
57
- ('a', (0.0, 1024, 1, 40, None, None), 1),
58
- ('b', (0.0, 1024, 1, 40, None, None), 0),
59
- ('c', (0.0, 1024, 1, 40, None, None), 0),
56
+ ('a', (None, None, 1, 40, None, None), 0),
57
+ ('a', (None, None, 1, 40, None, None), 1),
58
+ ('b', (None, None, 1, 40, None, None), 0),
59
+ ('c', (None, None, 1, 40, None, None), 0),
60
60
  ],
61
61
  )
62
62
 
63
63
  def cache_entry(response_text, cache_seed=0):
64
64
  return base.LMCacheEntry(
65
- lf.LMSamplingResult([
66
- lf.LMSample(
67
- lf.AIMessage(response_text, cache_seed=cache_seed), score=1.0)
68
- ])
65
+ lf.LMSamplingResult(
66
+ [
67
+ lf.LMSample(
68
+ lf.AIMessage(response_text, cache_seed=cache_seed),
69
+ score=1.0,
70
+ )
71
+ ],
72
+ usage=lf.LMSamplingUsage(
73
+ 1,
74
+ len(response_text),
75
+ len(response_text) + 1,
76
+ ),
77
+ is_cached=True,
78
+ )
69
79
  )
70
80
 
71
81
  self.assertEqual(
@@ -90,19 +100,19 @@ class InMemoryLMCacheTest(unittest.TestCase):
90
100
  list(cache.items()),
91
101
  [
92
102
  (
93
- ('a', (0.0, 1024, 1, 40, None, None), 0),
103
+ ('a', (None, None, 1, 40, None, None), 0),
94
104
  cache_entry('1'),
95
105
  ),
96
106
  (
97
- ('a', (0.0, 1024, 1, 40, None, None), 1),
107
+ ('a', (None, None, 1, 40, None, None), 1),
98
108
  cache_entry('2', 1),
99
109
  ),
100
110
  (
101
- ('b', (0.0, 1024, 1, 40, None, None), 0),
111
+ ('b', (None, None, 1, 40, None, None), 0),
102
112
  cache_entry('3'),
103
113
  ),
104
114
  (
105
- ('c', (0.0, 1024, 1, 40, None, None), 0),
115
+ ('c', (None, None, 1, 40, None, None), 0),
106
116
  cache_entry('4'),
107
117
  ),
108
118
  ],
@@ -111,19 +121,19 @@ class InMemoryLMCacheTest(unittest.TestCase):
111
121
  list(cache.items('StaticSequence')),
112
122
  [
113
123
  (
114
- ('a', (0.0, 1024, 1, 40, None, None), 0),
124
+ ('a', (None, None, 1, 40, None, None), 0),
115
125
  cache_entry('1'),
116
126
  ),
117
127
  (
118
- ('a', (0.0, 1024, 1, 40, None, None), 1),
128
+ ('a', (None, None, 1, 40, None, None), 1),
119
129
  cache_entry('2', 1),
120
130
  ),
121
131
  (
122
- ('b', (0.0, 1024, 1, 40, None, None), 0),
132
+ ('b', (None, None, 1, 40, None, None), 0),
123
133
  cache_entry('3'),
124
134
  ),
125
135
  (
126
- ('c', (0.0, 1024, 1, 40, None, None), 0),
136
+ ('c', (None, None, 1, 40, None, None), 0),
127
137
  cache_entry('4'),
128
138
  ),
129
139
  ],
@@ -139,6 +149,50 @@ class InMemoryLMCacheTest(unittest.TestCase):
139
149
  self.assertIs(copy.deepcopy(cache)._cache, cache._cache)
140
150
  self.assertIs(copy.deepcopy(cache)._stats, cache._stats)
141
151
 
152
+ self.assertFalse(
153
+ cache.delete(fake.StaticResponse('hi'), lf.UserMessage('c'), seed=0)
154
+ )
155
+ self.assertFalse(cache.delete(lm, lf.UserMessage('c'), seed=1))
156
+ self.assertFalse(cache.delete(lm, lf.UserMessage('d'), seed=0))
157
+ self.assertTrue(cache.delete(lm, lf.UserMessage('c'), seed=0))
158
+ self.assertEqual(
159
+ list(cache.keys('StaticSequence')),
160
+ [
161
+ ('a', (None, None, 1, 40, None, None), 0),
162
+ ('a', (None, None, 1, 40, None, None), 1),
163
+ ('b', (None, None, 1, 40, None, None), 0),
164
+ ],
165
+ )
166
+ self.assertEqual(cache.stats.num_deletes, 1)
167
+
168
+ def test_cache_with_modalities(self):
169
+
170
+ class CustomModality(lf.Modality):
171
+ content: str
172
+
173
+ def to_bytes(self):
174
+ return self.content.encode()
175
+
176
+ cache = in_memory.InMemory()
177
+ lm = fake.StaticSequence(['1', '2', '3', '4', '5', '6'], cache=cache)
178
+ lm(lf.UserMessage('hi <<[[image]]>>', image=CustomModality('foo')))
179
+ lm(lf.UserMessage('hi <<[[image]]>>', image=CustomModality('bar')))
180
+ self.assertEqual(
181
+ list(cache.keys()),
182
+ [
183
+ (
184
+ 'hi <<[[image]]>><image>acbd18db</image>',
185
+ (None, None, 1, 40, None, None),
186
+ 0,
187
+ ),
188
+ (
189
+ 'hi <<[[image]]>><image>37b51d19</image>',
190
+ (None, None, 1, 40, None, None),
191
+ 0,
192
+ ),
193
+ ],
194
+ )
195
+
142
196
  def test_ttl(self):
143
197
  cache = in_memory.InMemory(ttl=1)
144
198
  lm = fake.StaticSequence(['1', '2', '3'], cache=cache)
@@ -151,6 +205,7 @@ class InMemoryLMCacheTest(unittest.TestCase):
151
205
  self.assertEqual(cache.stats.num_hits, 1)
152
206
  self.assertEqual(cache.stats.num_hit_expires, 1)
153
207
  self.assertEqual(cache.stats.num_misses, 1)
208
+ self.assertEqual(cache.stats.num_deletes, 1)
154
209
 
155
210
  def test_different_sampling_options(self):
156
211
  cache = in_memory.InMemory()
@@ -161,15 +216,15 @@ class InMemoryLMCacheTest(unittest.TestCase):
161
216
  self.assertEqual(
162
217
  list(cache.keys()),
163
218
  [
164
- ('a', (0.0, 1024, 1, 40, None, None), 0),
165
- ('a', (1.0, 1024, 1, 40, None, None), 0),
219
+ ('a', (None, None, 1, 40, None, None), 0),
220
+ ('a', (1.0, None, 1, 40, None, None), 0),
166
221
  ],
167
222
  )
168
223
 
169
224
  def test_different_model(self):
170
225
  cache = in_memory.InMemory()
171
- lm1 = fake.StaticSequence(['1', '2', '3'], cache=cache)
172
- lm2 = fake.Echo(cache=cache)
226
+ lm1 = fake.StaticSequence(['1', '2', '3'], cache=cache, temperature=0.0)
227
+ lm2 = fake.Echo(cache=cache, temperature=0.0)
173
228
 
174
229
  self.assertEqual(lm1('a'), '1')
175
230
  self.assertEqual(lm2('a'), 'a')
@@ -180,15 +235,15 @@ class InMemoryLMCacheTest(unittest.TestCase):
180
235
  self.assertEqual(
181
236
  list(cache.keys('StaticSequence')),
182
237
  [
183
- ('a', (0.0, 1024, 1, 40, None, None), 0),
184
- ('b', (0.0, 1024, 1, 40, None, None), 0),
238
+ ('a', (0.0, None, 1, 40, None, None), 0),
239
+ ('b', (0.0, None, 1, 40, None, None), 0),
185
240
  ],
186
241
  )
187
242
  self.assertEqual(
188
243
  list(cache.keys('Echo')),
189
244
  [
190
- ('a', (0.0, 1024, 1, 40, None, None), 0),
191
- ('b', (0.0, 1024, 1, 40, None, None), 0),
245
+ ('a', (0.0, None, 1, 40, None, None), 0),
246
+ ('b', (0.0, None, 1, 40, None, None), 0),
192
247
  ],
193
248
  )
194
249
  self.assertEqual(len(cache), 4)
@@ -240,6 +295,11 @@ class InMemoryLMCacheTest(unittest.TestCase):
240
295
  self.assertEqual(cache2.stats.num_updates, 2)
241
296
  cache2.save()
242
297
 
298
+ # Corrupted file.
299
+ pg.io.writefile(path, 'bad_content')
300
+ cache3 = in_memory.InMemory(path)
301
+ self.assertEqual(len(cache3), 0)
302
+
243
303
 
244
304
  class LmCacheTest(unittest.TestCase):
245
305
 
@@ -0,0 +1,101 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Compositions of different LLM models."""
15
+ import random
16
+ from typing import Annotated
17
+
18
+ import langfun.core as lf
19
+ import pyglove as pg
20
+
21
+
22
+ @pg.use_init_args(['candidates', 'seed'])
23
+ class RandomChoice(lf.LanguageModel):
24
+ """Random choice of a list of LLM models."""
25
+
26
+ candidates: Annotated[
27
+ list[lf.LanguageModel],
28
+ (
29
+ 'A list of LLMs as candidates to choose from.'
30
+ )
31
+ ]
32
+
33
+ seed: Annotated[
34
+ int,
35
+ (
36
+ 'The random seed to use for the random choice.'
37
+ )
38
+ ] = 0
39
+
40
+ def _on_bound(self):
41
+ super()._on_bound()
42
+ self._rand = random.Random(self.seed)
43
+ # Applying sampling options to all candidates.
44
+ parent_non_default = self.sampling_options.sym_nondefault()
45
+ if parent_non_default:
46
+ for c in self.candidates:
47
+ c.sampling_options.rebind(
48
+ parent_non_default, notify_parents=False, raise_on_no_change=False
49
+ )
50
+
51
+ @property
52
+ def model_id(self) -> str:
53
+ model_ids = ', '.join(
54
+ sorted(c.model_id for c in self.candidates)
55
+ )
56
+ return f'RandomChoice({model_ids})'
57
+
58
+ @property
59
+ def resource_id(self) -> str:
60
+ resource_ids = ', '.join(
61
+ sorted(c.resource_id for c in self.candidates)
62
+ )
63
+ return f'RandomChoice({resource_ids})'
64
+
65
+ def _select_lm(self) -> lf.LanguageModel:
66
+ """Selects a random LLM from the candidates."""
67
+ return self._rand.choice(self.candidates)
68
+
69
+ def sample(
70
+ self,
71
+ prompts: list[str | lf.Message],
72
+ *,
73
+ cache_seed: int = 0,
74
+ **kwargs,
75
+ ) -> list[lf.LMSamplingResult]:
76
+ return self._select_lm().sample(
77
+ prompts, cache_seed=cache_seed, **kwargs
78
+ )
79
+
80
+ def __call__(
81
+ self, prompt: lf.Message, *, cache_seed: int = 0, **kwargs
82
+ ) -> lf.Message:
83
+ return self._select_lm()(prompt, cache_seed=cache_seed, **kwargs)
84
+
85
+ def score(
86
+ self,
87
+ prompt: str | lf.Message | list[lf.Message],
88
+ completions: list[str | lf.Message],
89
+ **kwargs,
90
+ ) -> list[lf.LMScoringResult]:
91
+ return self._select_lm().score(prompt, completions, **kwargs)
92
+
93
+ def tokenize(
94
+ self,
95
+ prompt: str | lf.Message,
96
+ **kwargs,
97
+ ) -> list[tuple[str | bytes, int]]:
98
+ return self._select_lm().tokenize(prompt, **kwargs)
99
+
100
+ def _sample(self, *arg, **kwargs):
101
+ assert False, 'Should never trigger.'