langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +22 -2
- langfun/core/__init__.py +17 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -28
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +69 -2
- langfun/core/component_test.py +54 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +17 -0
- langfun/core/eval/base.py +767 -140
- langfun/core/eval/base_test.py +238 -53
- langfun/core/eval/matching.py +80 -76
- langfun/core/eval/matching_test.py +19 -9
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +37 -28
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +3 -21
- langfun/core/langfunc_test.py +26 -8
- langfun/core/language_model.py +686 -48
- langfun/core/language_model_test.py +681 -44
- langfun/core/llms/__init__.py +100 -12
- langfun/core/llms/anthropic.py +488 -0
- langfun/core/llms/anthropic_test.py +235 -0
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +88 -28
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +39 -26
- langfun/core/llms/fake_test.py +136 -11
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -197
- langfun/core/llms/groq.py +276 -0
- langfun/core/llms/groq_test.py +64 -0
- langfun/core/llms/llama_cpp.py +15 -40
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +436 -226
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +35 -174
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -23
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +15 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +9 -8
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +278 -0
- langfun/core/structured/function_generation_test.py +399 -0
- langfun/core/structured/mapping.py +150 -46
- langfun/core/structured/mapping_test.py +105 -0
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +71 -22
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
- langfun/core/structured/schema.py +208 -99
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +133 -34
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +240 -11
- langfun/core/template_test.py +146 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +14 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -217
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,235 @@
|
|
1
|
+
# Copyright 2023 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tests for Anthropic models."""
|
15
|
+
|
16
|
+
import base64
|
17
|
+
import os
|
18
|
+
from typing import Any
|
19
|
+
import unittest
|
20
|
+
from unittest import mock
|
21
|
+
|
22
|
+
from google.auth import exceptions
|
23
|
+
from langfun.core import language_model
|
24
|
+
from langfun.core import message as lf_message
|
25
|
+
from langfun.core import modalities as lf_modalities
|
26
|
+
from langfun.core.llms import anthropic
|
27
|
+
import pyglove as pg
|
28
|
+
import requests
|
29
|
+
|
30
|
+
|
31
|
+
def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
|
32
|
+
del url, kwargs
|
33
|
+
|
34
|
+
response = requests.Response()
|
35
|
+
response.status_code = 200
|
36
|
+
response._content = pg.to_json_str({
|
37
|
+
'content': [{
|
38
|
+
'type': 'text',
|
39
|
+
'text': (
|
40
|
+
f'hello with temperature={json.get("temperature")}, '
|
41
|
+
f'top_k={json.get("top_k")}, '
|
42
|
+
f'top_p={json.get("top_p")}, '
|
43
|
+
f'max_tokens={json.get("max_tokens")}, '
|
44
|
+
f'stop={json.get("stop_sequences")}.'
|
45
|
+
),
|
46
|
+
}],
|
47
|
+
'usage': {
|
48
|
+
'input_tokens': 2,
|
49
|
+
'output_tokens': 1,
|
50
|
+
},
|
51
|
+
}).encode()
|
52
|
+
return response
|
53
|
+
|
54
|
+
|
55
|
+
image_content = (
|
56
|
+
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
|
57
|
+
b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
|
58
|
+
b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
|
59
|
+
b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
|
60
|
+
b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
|
61
|
+
b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
|
62
|
+
b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
|
63
|
+
b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
|
64
|
+
)
|
65
|
+
|
66
|
+
pdf_content = (
|
67
|
+
b'%PDF-1.4\n1 0 obj\n<< /Type /Catalog /Pages 2 0 R >>\nendobj\n2 0 obj\n<<'
|
68
|
+
b' /Type /Pages /Count 1 /Kids [3 0 R] >>\nendobj\n3 0 obj\n<< /Type /Page'
|
69
|
+
b' /Parent 2 0 R /MediaBox [0 0 612 792] /Contents 4 0 R >>\nendobj\n4 0'
|
70
|
+
b' obj\n<< /Length 44 >>\nstream\nBT /F1 24 Tf 100 700 Td (Hello, PDF'
|
71
|
+
b' content!) Tj ET\nendstream\nendobj\n5 0 obj\n<< /Type /Font /Subtype'
|
72
|
+
b' /Type1 /BaseFont /Helvetica >>\nendobj\nxref\n0 6\n0000000000 65535 f'
|
73
|
+
b' \n0000000010 00000 n \n0000000079 00000 n \n0000000178 00000 n'
|
74
|
+
b' \n0000000278 00000 n \n0000000407 00000 n \ntrailer\n<< /Size 6 /Root 1'
|
75
|
+
b' 0 R >>\nstartxref\n517\n%%EOF'
|
76
|
+
)
|
77
|
+
|
78
|
+
|
79
|
+
def mock_mm_requests_post(url: str, json: dict[str, Any], **kwargs):
|
80
|
+
del url, kwargs
|
81
|
+
v = json['messages'][0]['content'][0]
|
82
|
+
content = lf_modalities.Mime.from_bytes(base64.b64decode(v['source']['data']))
|
83
|
+
|
84
|
+
response = requests.Response()
|
85
|
+
response.status_code = 200
|
86
|
+
response._content = pg.to_json_str({
|
87
|
+
'content': [{
|
88
|
+
'type': 'text',
|
89
|
+
'text': f'{v["type"]}: {content.mime_type}',
|
90
|
+
}],
|
91
|
+
'usage': {
|
92
|
+
'input_tokens': 2,
|
93
|
+
'output_tokens': 1,
|
94
|
+
},
|
95
|
+
}).encode()
|
96
|
+
return response
|
97
|
+
|
98
|
+
|
99
|
+
def mock_requests_post_error(status_code, error_type, error_message):
|
100
|
+
def _mock_requests(url: str, json: dict[str, Any], **kwargs):
|
101
|
+
del url, json, kwargs
|
102
|
+
response = requests.Response()
|
103
|
+
response.status_code = status_code
|
104
|
+
response._content = pg.to_json_str(
|
105
|
+
{
|
106
|
+
'error': {
|
107
|
+
'type': error_type,
|
108
|
+
'message': error_message,
|
109
|
+
}
|
110
|
+
}
|
111
|
+
).encode()
|
112
|
+
return response
|
113
|
+
|
114
|
+
return _mock_requests
|
115
|
+
|
116
|
+
|
117
|
+
class AnthropicTest(unittest.TestCase):
|
118
|
+
|
119
|
+
def test_basics(self):
|
120
|
+
self.assertEqual(
|
121
|
+
anthropic.Claude3Haiku().model_id, 'claude-3-haiku-20240307'
|
122
|
+
)
|
123
|
+
self.assertGreater(anthropic.Claude3Haiku().max_concurrency, 0)
|
124
|
+
|
125
|
+
def test_api_key(self):
|
126
|
+
lm = anthropic.Claude3Haiku()
|
127
|
+
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
128
|
+
lm('hi')
|
129
|
+
|
130
|
+
with mock.patch('requests.Session.post') as mock_request:
|
131
|
+
mock_request.side_effect = mock_requests_post
|
132
|
+
|
133
|
+
lm = anthropic.Claude3Haiku(api_key='fake key')
|
134
|
+
self.assertRegex(lm('hi').text, 'hello.*')
|
135
|
+
|
136
|
+
os.environ['ANTHROPIC_API_KEY'] = 'abc'
|
137
|
+
lm = anthropic.Claude3Haiku()
|
138
|
+
self.assertRegex(lm('hi').text, 'hello.*')
|
139
|
+
del os.environ['ANTHROPIC_API_KEY']
|
140
|
+
|
141
|
+
def test_call(self):
|
142
|
+
with mock.patch('requests.Session.post') as mock_request:
|
143
|
+
mock_request.side_effect = mock_requests_post
|
144
|
+
lm = anthropic.Claude3Haiku(api_key='fake_key')
|
145
|
+
response = lm('hello', temperature=0.0, top_k=0.1, top_p=0.2, stop=['\n'])
|
146
|
+
self.assertEqual(
|
147
|
+
response.text,
|
148
|
+
(
|
149
|
+
'hello with temperature=0.0, top_k=0.1, top_p=0.2, '
|
150
|
+
"max_tokens=4096, stop=['\\n']."
|
151
|
+
),
|
152
|
+
)
|
153
|
+
self.assertIsNotNone(response.usage)
|
154
|
+
self.assertIsNotNone(response.usage.prompt_tokens, 2)
|
155
|
+
self.assertIsNotNone(response.usage.completion_tokens, 1)
|
156
|
+
self.assertIsNotNone(response.usage.total_tokens, 3)
|
157
|
+
|
158
|
+
def test_mm_call(self):
|
159
|
+
with mock.patch('requests.Session.post') as mock_mm_request:
|
160
|
+
mock_mm_request.side_effect = mock_mm_requests_post
|
161
|
+
lm = anthropic.Claude3Haiku(api_key='fake_key')
|
162
|
+
response = lm(lf_modalities.Image.from_bytes(image_content), lm=lm)
|
163
|
+
self.assertEqual(response.text, 'image: image/png')
|
164
|
+
|
165
|
+
def test_pdf_call(self):
|
166
|
+
with mock.patch('requests.Session.post') as mock_mm_request:
|
167
|
+
mock_mm_request.side_effect = mock_mm_requests_post
|
168
|
+
lm = anthropic.Claude3Haiku(api_key='fake_key')
|
169
|
+
response = lm(lf_modalities.PDF.from_bytes(pdf_content), lm=lm)
|
170
|
+
self.assertEqual(response.text, 'document: application/pdf')
|
171
|
+
|
172
|
+
def test_call_errors(self):
|
173
|
+
for status_code, error_type, error_message in [
|
174
|
+
(429, 'rate_limit', 'Rate limit exceeded.'),
|
175
|
+
(529, 'service_unavailable', 'Service unavailable.'),
|
176
|
+
(500, 'bad_request', 'Bad request.'),
|
177
|
+
]:
|
178
|
+
with mock.patch('requests.Session.post') as mock_mm_request:
|
179
|
+
mock_mm_request.side_effect = mock_requests_post_error(
|
180
|
+
status_code, error_type, error_message
|
181
|
+
)
|
182
|
+
lm = anthropic.Claude3Haiku(api_key='fake_key')
|
183
|
+
with self.assertRaisesRegex(
|
184
|
+
Exception, f'.*{status_code}: .*{error_message}'
|
185
|
+
):
|
186
|
+
lm('hello', max_attempts=1)
|
187
|
+
|
188
|
+
|
189
|
+
class VertexAIAnthropicTest(unittest.TestCase):
|
190
|
+
"""Tests for VertexAI Anthropic models."""
|
191
|
+
|
192
|
+
def test_basics(self):
|
193
|
+
with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
|
194
|
+
lm = anthropic.VertexAIClaude3_5_Sonnet_20241022()
|
195
|
+
lm('hi')
|
196
|
+
|
197
|
+
model = anthropic.VertexAIClaude3_5_Sonnet_20241022(project='langfun')
|
198
|
+
|
199
|
+
# NOTE(daiyip): For OSS users, default credentials are not available unless
|
200
|
+
# users have already set up their GCP project. Therefore we ignore the
|
201
|
+
# exception here.
|
202
|
+
try:
|
203
|
+
model._initialize()
|
204
|
+
except exceptions.DefaultCredentialsError:
|
205
|
+
pass
|
206
|
+
|
207
|
+
self.assertEqual(
|
208
|
+
model.api_endpoint,
|
209
|
+
(
|
210
|
+
'https://us-east5-aiplatform.googleapis.com/v1/projects/'
|
211
|
+
'langfun/locations/us-east5/publishers/anthropic/'
|
212
|
+
'models/claude-3-5-sonnet-v2@20241022:streamRawPredict'
|
213
|
+
)
|
214
|
+
)
|
215
|
+
request = model.request(
|
216
|
+
lf_message.UserMessage('hi'),
|
217
|
+
language_model.LMSamplingOptions(temperature=0.0),
|
218
|
+
)
|
219
|
+
self.assertEqual(
|
220
|
+
request,
|
221
|
+
{
|
222
|
+
'anthropic_version': 'vertex-2023-10-16',
|
223
|
+
'max_tokens': 8192,
|
224
|
+
'messages': [
|
225
|
+
{'content': [{'text': 'hi', 'type': 'text'}], 'role': 'user'}
|
226
|
+
],
|
227
|
+
'stream': False,
|
228
|
+
'temperature': 0.0,
|
229
|
+
'top_k': 40,
|
230
|
+
},
|
231
|
+
)
|
232
|
+
|
233
|
+
|
234
|
+
if __name__ == '__main__':
|
235
|
+
unittest.main()
|
langfun/core/llms/cache/base.py
CHANGED
@@ -60,13 +60,16 @@ class LMCacheBase(lf.LMCache):
|
|
60
60
|
self, lm: lf.LanguageModel, prompt: lf.Message, seed: int
|
61
61
|
) -> lf.LMSamplingResult | None:
|
62
62
|
"""Gets the cached result of a prompt generated by a language model."""
|
63
|
-
|
63
|
+
key = self._key(lm, prompt, seed)
|
64
|
+
entry = self._get(lm.model_id, key)
|
64
65
|
self._stats.num_queries += 1
|
65
66
|
if entry is None:
|
66
67
|
self._stats.num_misses += 1
|
67
68
|
return None
|
68
69
|
if entry.expire is not None and entry.expire < datetime.datetime.now():
|
69
70
|
self._stats.num_hit_expires += 1
|
71
|
+
self._stats.num_deletes += 1
|
72
|
+
assert self._delete(lm.model_id, key)
|
70
73
|
return None
|
71
74
|
self._stats.num_hits += 1
|
72
75
|
return entry.result
|
@@ -86,6 +89,18 @@ class LMCacheBase(lf.LMCache):
|
|
86
89
|
self._put(lm.model_id, self._key(lm, prompt, seed), entry)
|
87
90
|
self._stats.num_updates += 1
|
88
91
|
|
92
|
+
def delete(
|
93
|
+
self,
|
94
|
+
lm: lf.LanguageModel,
|
95
|
+
prompt: lf.Message,
|
96
|
+
seed: int,
|
97
|
+
) -> bool:
|
98
|
+
"""Deletes the result of a prompt generated by a language model in cache."""
|
99
|
+
deleted = self._delete(lm.model_id, self._key(lm, prompt, seed))
|
100
|
+
if deleted:
|
101
|
+
self._stats.num_deletes += 1
|
102
|
+
return deleted
|
103
|
+
|
89
104
|
@abc.abstractmethod
|
90
105
|
def _get(self, model_id: str, key: str) -> LMCacheEntry | None:
|
91
106
|
"""Returns a LM cache entry associated with the key."""
|
@@ -94,6 +109,10 @@ class LMCacheBase(lf.LMCache):
|
|
94
109
|
def _put(self, model_id: str, key: str, entry: LMCacheEntry) -> None:
|
95
110
|
"""Puts a LM cache entry associated with the key."""
|
96
111
|
|
112
|
+
@abc.abstractmethod
|
113
|
+
def _delete(self, model_id: str, key: str) -> bool:
|
114
|
+
"""Deletes a LM cache entry associated with the key."""
|
115
|
+
|
97
116
|
def _sym_clone(self, deep: bool, memo: Any = None) -> 'LMCacheBase':
|
98
117
|
v = super()._sym_clone(deep, memo)
|
99
118
|
v._stats = self._stats # pylint: disable=protected-access
|
@@ -102,4 +121,4 @@ class LMCacheBase(lf.LMCache):
|
|
102
121
|
|
103
122
|
def default_key(lm: lf.LanguageModel, prompt: lf.Message, seed: int) -> Any:
|
104
123
|
"""Default key for LM cache."""
|
105
|
-
return (prompt.
|
124
|
+
return (prompt.text_with_modality_hash, lm.sampling_options.cache_key(), seed)
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
import collections
|
17
17
|
import contextlib
|
18
|
+
import json
|
18
19
|
from typing import Annotated, Any, Iterator
|
19
20
|
import langfun.core as lf
|
20
21
|
from langfun.core.llms.cache import base
|
@@ -49,6 +50,11 @@ class InMemory(base.LMCacheBase):
|
|
49
50
|
"Creating a new cache as cache file '%s' does not exist.",
|
50
51
|
self.filename,
|
51
52
|
)
|
53
|
+
except json.JSONDecodeError:
|
54
|
+
pg.logging.warning(
|
55
|
+
"Creating a new cache as cache file '%s' is corrupted.",
|
56
|
+
self.filename,
|
57
|
+
)
|
52
58
|
|
53
59
|
def model_ids(self) -> list[str]:
|
54
60
|
"""Returns the model ids of cached queires."""
|
@@ -99,6 +105,13 @@ class InMemory(base.LMCacheBase):
|
|
99
105
|
"""Puts a LM cache entry associated with the key."""
|
100
106
|
self._cache[model_id][key] = entry
|
101
107
|
|
108
|
+
def _delete(self, model_id: str, key: str) -> bool:
|
109
|
+
"""Deletes a LM cache entry associated with the key."""
|
110
|
+
model_cache = self._cache.get(model_id, None)
|
111
|
+
if model_cache is None:
|
112
|
+
return False
|
113
|
+
return model_cache.pop(key, None) is not None
|
114
|
+
|
102
115
|
def reset(self, model_id: str | None = None) -> None:
|
103
116
|
"""Resets the cache."""
|
104
117
|
if model_id is not None:
|
@@ -44,28 +44,38 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
44
44
|
self.assertEqual(
|
45
45
|
list(cache.keys()),
|
46
46
|
[
|
47
|
-
('a', (
|
48
|
-
('a', (
|
49
|
-
('b', (
|
50
|
-
('c', (
|
47
|
+
('a', (None, None, 1, 40, None, None), 0),
|
48
|
+
('a', (None, None, 1, 40, None, None), 1),
|
49
|
+
('b', (None, None, 1, 40, None, None), 0),
|
50
|
+
('c', (None, None, 1, 40, None, None), 0),
|
51
51
|
],
|
52
52
|
)
|
53
53
|
self.assertEqual(
|
54
54
|
list(cache.keys('StaticSequence')),
|
55
55
|
[
|
56
|
-
('a', (
|
57
|
-
('a', (
|
58
|
-
('b', (
|
59
|
-
('c', (
|
56
|
+
('a', (None, None, 1, 40, None, None), 0),
|
57
|
+
('a', (None, None, 1, 40, None, None), 1),
|
58
|
+
('b', (None, None, 1, 40, None, None), 0),
|
59
|
+
('c', (None, None, 1, 40, None, None), 0),
|
60
60
|
],
|
61
61
|
)
|
62
62
|
|
63
63
|
def cache_entry(response_text, cache_seed=0):
|
64
64
|
return base.LMCacheEntry(
|
65
|
-
lf.LMSamplingResult(
|
66
|
-
|
67
|
-
lf.
|
68
|
-
|
65
|
+
lf.LMSamplingResult(
|
66
|
+
[
|
67
|
+
lf.LMSample(
|
68
|
+
lf.AIMessage(response_text, cache_seed=cache_seed),
|
69
|
+
score=1.0,
|
70
|
+
)
|
71
|
+
],
|
72
|
+
usage=lf.LMSamplingUsage(
|
73
|
+
1,
|
74
|
+
len(response_text),
|
75
|
+
len(response_text) + 1,
|
76
|
+
),
|
77
|
+
is_cached=True,
|
78
|
+
)
|
69
79
|
)
|
70
80
|
|
71
81
|
self.assertEqual(
|
@@ -90,19 +100,19 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
90
100
|
list(cache.items()),
|
91
101
|
[
|
92
102
|
(
|
93
|
-
('a', (
|
103
|
+
('a', (None, None, 1, 40, None, None), 0),
|
94
104
|
cache_entry('1'),
|
95
105
|
),
|
96
106
|
(
|
97
|
-
('a', (
|
107
|
+
('a', (None, None, 1, 40, None, None), 1),
|
98
108
|
cache_entry('2', 1),
|
99
109
|
),
|
100
110
|
(
|
101
|
-
('b', (
|
111
|
+
('b', (None, None, 1, 40, None, None), 0),
|
102
112
|
cache_entry('3'),
|
103
113
|
),
|
104
114
|
(
|
105
|
-
('c', (
|
115
|
+
('c', (None, None, 1, 40, None, None), 0),
|
106
116
|
cache_entry('4'),
|
107
117
|
),
|
108
118
|
],
|
@@ -111,19 +121,19 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
111
121
|
list(cache.items('StaticSequence')),
|
112
122
|
[
|
113
123
|
(
|
114
|
-
('a', (
|
124
|
+
('a', (None, None, 1, 40, None, None), 0),
|
115
125
|
cache_entry('1'),
|
116
126
|
),
|
117
127
|
(
|
118
|
-
('a', (
|
128
|
+
('a', (None, None, 1, 40, None, None), 1),
|
119
129
|
cache_entry('2', 1),
|
120
130
|
),
|
121
131
|
(
|
122
|
-
('b', (
|
132
|
+
('b', (None, None, 1, 40, None, None), 0),
|
123
133
|
cache_entry('3'),
|
124
134
|
),
|
125
135
|
(
|
126
|
-
('c', (
|
136
|
+
('c', (None, None, 1, 40, None, None), 0),
|
127
137
|
cache_entry('4'),
|
128
138
|
),
|
129
139
|
],
|
@@ -139,6 +149,50 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
139
149
|
self.assertIs(copy.deepcopy(cache)._cache, cache._cache)
|
140
150
|
self.assertIs(copy.deepcopy(cache)._stats, cache._stats)
|
141
151
|
|
152
|
+
self.assertFalse(
|
153
|
+
cache.delete(fake.StaticResponse('hi'), lf.UserMessage('c'), seed=0)
|
154
|
+
)
|
155
|
+
self.assertFalse(cache.delete(lm, lf.UserMessage('c'), seed=1))
|
156
|
+
self.assertFalse(cache.delete(lm, lf.UserMessage('d'), seed=0))
|
157
|
+
self.assertTrue(cache.delete(lm, lf.UserMessage('c'), seed=0))
|
158
|
+
self.assertEqual(
|
159
|
+
list(cache.keys('StaticSequence')),
|
160
|
+
[
|
161
|
+
('a', (None, None, 1, 40, None, None), 0),
|
162
|
+
('a', (None, None, 1, 40, None, None), 1),
|
163
|
+
('b', (None, None, 1, 40, None, None), 0),
|
164
|
+
],
|
165
|
+
)
|
166
|
+
self.assertEqual(cache.stats.num_deletes, 1)
|
167
|
+
|
168
|
+
def test_cache_with_modalities(self):
|
169
|
+
|
170
|
+
class CustomModality(lf.Modality):
|
171
|
+
content: str
|
172
|
+
|
173
|
+
def to_bytes(self):
|
174
|
+
return self.content.encode()
|
175
|
+
|
176
|
+
cache = in_memory.InMemory()
|
177
|
+
lm = fake.StaticSequence(['1', '2', '3', '4', '5', '6'], cache=cache)
|
178
|
+
lm(lf.UserMessage('hi <<[[image]]>>', image=CustomModality('foo')))
|
179
|
+
lm(lf.UserMessage('hi <<[[image]]>>', image=CustomModality('bar')))
|
180
|
+
self.assertEqual(
|
181
|
+
list(cache.keys()),
|
182
|
+
[
|
183
|
+
(
|
184
|
+
'hi <<[[image]]>><image>acbd18db</image>',
|
185
|
+
(None, None, 1, 40, None, None),
|
186
|
+
0,
|
187
|
+
),
|
188
|
+
(
|
189
|
+
'hi <<[[image]]>><image>37b51d19</image>',
|
190
|
+
(None, None, 1, 40, None, None),
|
191
|
+
0,
|
192
|
+
),
|
193
|
+
],
|
194
|
+
)
|
195
|
+
|
142
196
|
def test_ttl(self):
|
143
197
|
cache = in_memory.InMemory(ttl=1)
|
144
198
|
lm = fake.StaticSequence(['1', '2', '3'], cache=cache)
|
@@ -151,6 +205,7 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
151
205
|
self.assertEqual(cache.stats.num_hits, 1)
|
152
206
|
self.assertEqual(cache.stats.num_hit_expires, 1)
|
153
207
|
self.assertEqual(cache.stats.num_misses, 1)
|
208
|
+
self.assertEqual(cache.stats.num_deletes, 1)
|
154
209
|
|
155
210
|
def test_different_sampling_options(self):
|
156
211
|
cache = in_memory.InMemory()
|
@@ -161,15 +216,15 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
161
216
|
self.assertEqual(
|
162
217
|
list(cache.keys()),
|
163
218
|
[
|
164
|
-
('a', (
|
165
|
-
('a', (1.0,
|
219
|
+
('a', (None, None, 1, 40, None, None), 0),
|
220
|
+
('a', (1.0, None, 1, 40, None, None), 0),
|
166
221
|
],
|
167
222
|
)
|
168
223
|
|
169
224
|
def test_different_model(self):
|
170
225
|
cache = in_memory.InMemory()
|
171
|
-
lm1 = fake.StaticSequence(['1', '2', '3'], cache=cache)
|
172
|
-
lm2 = fake.Echo(cache=cache)
|
226
|
+
lm1 = fake.StaticSequence(['1', '2', '3'], cache=cache, temperature=0.0)
|
227
|
+
lm2 = fake.Echo(cache=cache, temperature=0.0)
|
173
228
|
|
174
229
|
self.assertEqual(lm1('a'), '1')
|
175
230
|
self.assertEqual(lm2('a'), 'a')
|
@@ -180,15 +235,15 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
180
235
|
self.assertEqual(
|
181
236
|
list(cache.keys('StaticSequence')),
|
182
237
|
[
|
183
|
-
('a', (0.0,
|
184
|
-
('b', (0.0,
|
238
|
+
('a', (0.0, None, 1, 40, None, None), 0),
|
239
|
+
('b', (0.0, None, 1, 40, None, None), 0),
|
185
240
|
],
|
186
241
|
)
|
187
242
|
self.assertEqual(
|
188
243
|
list(cache.keys('Echo')),
|
189
244
|
[
|
190
|
-
('a', (0.0,
|
191
|
-
('b', (0.0,
|
245
|
+
('a', (0.0, None, 1, 40, None, None), 0),
|
246
|
+
('b', (0.0, None, 1, 40, None, None), 0),
|
192
247
|
],
|
193
248
|
)
|
194
249
|
self.assertEqual(len(cache), 4)
|
@@ -240,6 +295,11 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
240
295
|
self.assertEqual(cache2.stats.num_updates, 2)
|
241
296
|
cache2.save()
|
242
297
|
|
298
|
+
# Corrupted file.
|
299
|
+
pg.io.writefile(path, 'bad_content')
|
300
|
+
cache3 = in_memory.InMemory(path)
|
301
|
+
self.assertEqual(len(cache3), 0)
|
302
|
+
|
243
303
|
|
244
304
|
class LmCacheTest(unittest.TestCase):
|
245
305
|
|
@@ -0,0 +1,101 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Compositions of different LLM models."""
|
15
|
+
import random
|
16
|
+
from typing import Annotated
|
17
|
+
|
18
|
+
import langfun.core as lf
|
19
|
+
import pyglove as pg
|
20
|
+
|
21
|
+
|
22
|
+
@pg.use_init_args(['candidates', 'seed'])
|
23
|
+
class RandomChoice(lf.LanguageModel):
|
24
|
+
"""Random choice of a list of LLM models."""
|
25
|
+
|
26
|
+
candidates: Annotated[
|
27
|
+
list[lf.LanguageModel],
|
28
|
+
(
|
29
|
+
'A list of LLMs as candidates to choose from.'
|
30
|
+
)
|
31
|
+
]
|
32
|
+
|
33
|
+
seed: Annotated[
|
34
|
+
int,
|
35
|
+
(
|
36
|
+
'The random seed to use for the random choice.'
|
37
|
+
)
|
38
|
+
] = 0
|
39
|
+
|
40
|
+
def _on_bound(self):
|
41
|
+
super()._on_bound()
|
42
|
+
self._rand = random.Random(self.seed)
|
43
|
+
# Applying sampling options to all candidates.
|
44
|
+
parent_non_default = self.sampling_options.sym_nondefault()
|
45
|
+
if parent_non_default:
|
46
|
+
for c in self.candidates:
|
47
|
+
c.sampling_options.rebind(
|
48
|
+
parent_non_default, notify_parents=False, raise_on_no_change=False
|
49
|
+
)
|
50
|
+
|
51
|
+
@property
|
52
|
+
def model_id(self) -> str:
|
53
|
+
model_ids = ', '.join(
|
54
|
+
sorted(c.model_id for c in self.candidates)
|
55
|
+
)
|
56
|
+
return f'RandomChoice({model_ids})'
|
57
|
+
|
58
|
+
@property
|
59
|
+
def resource_id(self) -> str:
|
60
|
+
resource_ids = ', '.join(
|
61
|
+
sorted(c.resource_id for c in self.candidates)
|
62
|
+
)
|
63
|
+
return f'RandomChoice({resource_ids})'
|
64
|
+
|
65
|
+
def _select_lm(self) -> lf.LanguageModel:
|
66
|
+
"""Selects a random LLM from the candidates."""
|
67
|
+
return self._rand.choice(self.candidates)
|
68
|
+
|
69
|
+
def sample(
|
70
|
+
self,
|
71
|
+
prompts: list[str | lf.Message],
|
72
|
+
*,
|
73
|
+
cache_seed: int = 0,
|
74
|
+
**kwargs,
|
75
|
+
) -> list[lf.LMSamplingResult]:
|
76
|
+
return self._select_lm().sample(
|
77
|
+
prompts, cache_seed=cache_seed, **kwargs
|
78
|
+
)
|
79
|
+
|
80
|
+
def __call__(
|
81
|
+
self, prompt: lf.Message, *, cache_seed: int = 0, **kwargs
|
82
|
+
) -> lf.Message:
|
83
|
+
return self._select_lm()(prompt, cache_seed=cache_seed, **kwargs)
|
84
|
+
|
85
|
+
def score(
|
86
|
+
self,
|
87
|
+
prompt: str | lf.Message | list[lf.Message],
|
88
|
+
completions: list[str | lf.Message],
|
89
|
+
**kwargs,
|
90
|
+
) -> list[lf.LMScoringResult]:
|
91
|
+
return self._select_lm().score(prompt, completions, **kwargs)
|
92
|
+
|
93
|
+
def tokenize(
|
94
|
+
self,
|
95
|
+
prompt: str | lf.Message,
|
96
|
+
**kwargs,
|
97
|
+
) -> list[tuple[str | bytes, int]]:
|
98
|
+
return self._select_lm().tokenize(prompt, **kwargs)
|
99
|
+
|
100
|
+
def _sample(self, *arg, **kwargs):
|
101
|
+
assert False, 'Should never trigger.'
|