langfun 0.0.2.dev20240330__py3-none-any.whl → 0.0.2.dev20240511__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/__init__.py +7 -0
- langfun/core/__init__.py +1 -0
- langfun/core/coding/python/correction.py +0 -7
- langfun/core/component.py +6 -0
- langfun/core/component_test.py +1 -0
- langfun/core/eval/__init__.py +15 -0
- langfun/core/eval/base.py +665 -95
- langfun/core/eval/base_test.py +224 -53
- langfun/core/eval/matching.py +48 -30
- langfun/core/eval/matching_test.py +25 -3
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +19 -10
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/langfunc.py +1 -22
- langfun/core/langfunc_test.py +10 -4
- langfun/core/language_model.py +130 -24
- langfun/core/language_model_test.py +249 -26
- langfun/core/llms/__init__.py +27 -2
- langfun/core/llms/anthropic.py +263 -0
- langfun/core/llms/anthropic_test.py +167 -0
- langfun/core/llms/cache/in_memory_test.py +37 -28
- langfun/core/llms/fake.py +34 -25
- langfun/core/llms/fake_test.py +122 -11
- langfun/core/llms/google_genai.py +8 -0
- langfun/core/llms/google_genai_test.py +8 -3
- langfun/core/llms/groq.py +260 -0
- langfun/core/llms/groq_test.py +170 -0
- langfun/core/llms/llama_cpp.py +3 -1
- langfun/core/llms/openai.py +100 -81
- langfun/core/llms/openai_test.py +287 -60
- langfun/core/llms/vertexai.py +291 -0
- langfun/core/llms/vertexai_test.py +233 -0
- langfun/core/modalities/image.py +1 -3
- langfun/core/modalities/mime.py +6 -0
- langfun/core/modalities/video.py +6 -5
- langfun/core/structured/__init__.py +5 -0
- langfun/core/structured/completion_test.py +2 -2
- langfun/core/structured/function_generation.py +245 -0
- langfun/core/structured/function_generation_test.py +329 -0
- langfun/core/structured/mapping.py +61 -3
- langfun/core/structured/mapping_test.py +17 -0
- langfun/core/structured/parsing_test.py +18 -13
- langfun/core/structured/prompting.py +61 -12
- langfun/core/structured/prompting_test.py +122 -12
- langfun/core/structured/schema.py +38 -6
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +36 -7
- langfun/core/structured/scoring.py +4 -1
- langfun/core/structured/scoring_test.py +6 -0
- langfun/core/template.py +147 -11
- langfun/core/template_test.py +75 -0
- langfun/core/templates/selfplay_test.py +6 -2
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/METADATA +3 -2
- langfun-0.0.2.dev20240511.dist-info/RECORD +112 -0
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,263 @@
|
|
1
|
+
# Copyright 2023 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Language models from Anthropic."""
|
15
|
+
|
16
|
+
import base64
|
17
|
+
import functools
|
18
|
+
import os
|
19
|
+
from typing import Annotated, Any
|
20
|
+
|
21
|
+
import langfun.core as lf
|
22
|
+
from langfun.core import modalities as lf_modalities
|
23
|
+
import pyglove as pg
|
24
|
+
import requests
|
25
|
+
|
26
|
+
|
27
|
+
SUPPORTED_MODELS_AND_SETTINGS = {
|
28
|
+
# See https://docs.anthropic.com/claude/docs/models-overview
|
29
|
+
# Rate limits from https://docs.anthropic.com/claude/reference/rate-limits
|
30
|
+
# RPM/TPM for Claude-2.1, Claude-2.0, and Claude-Instant-1.2 estimated
|
31
|
+
# as RPM/TPM of the largest-available model (Claude-3-Opus).
|
32
|
+
'claude-3-opus-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
33
|
+
'claude-3-sonnet-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
34
|
+
'claude-3-haiku-20240307': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
35
|
+
'claude-2.1': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
36
|
+
'claude-2.0': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
37
|
+
'claude-instant-1.2': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
38
|
+
}
|
39
|
+
|
40
|
+
|
41
|
+
class AnthropicError(Exception): # pylint: disable=g-bad-exception-name
|
42
|
+
"""Base class for Anthropic errors."""
|
43
|
+
|
44
|
+
|
45
|
+
class RateLimitError(AnthropicError):
|
46
|
+
"""Error for rate limit reached."""
|
47
|
+
|
48
|
+
|
49
|
+
class OverloadedError(AnthropicError):
|
50
|
+
"""Anthropic's server is temporarily overloaded."""
|
51
|
+
|
52
|
+
|
53
|
+
_ANTHROPIC_MESSAGE_API_ENDPOINT = 'https://api.anthropic.com/v1/messages'
|
54
|
+
_ANTHROPIC_API_VERSION = '2023-06-01'
|
55
|
+
|
56
|
+
|
57
|
+
@lf.use_init_args(['model'])
|
58
|
+
class Anthropic(lf.LanguageModel):
|
59
|
+
"""Anthropic LLMs (Claude) through REST APIs.
|
60
|
+
|
61
|
+
See https://docs.anthropic.com/claude/reference/messages_post
|
62
|
+
"""
|
63
|
+
|
64
|
+
model: pg.typing.Annotated[
|
65
|
+
pg.typing.Enum(
|
66
|
+
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
|
67
|
+
),
|
68
|
+
'The name of the model to use.',
|
69
|
+
]
|
70
|
+
|
71
|
+
multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
|
72
|
+
True
|
73
|
+
)
|
74
|
+
|
75
|
+
api_key: Annotated[
|
76
|
+
str | None,
|
77
|
+
(
|
78
|
+
'API key. If None, the key will be read from environment variable '
|
79
|
+
"'ANTHROPIC_API_KEY'."
|
80
|
+
),
|
81
|
+
] = None
|
82
|
+
|
83
|
+
def _on_bound(self):
|
84
|
+
super()._on_bound()
|
85
|
+
self._api_key = None
|
86
|
+
self.__dict__.pop('_api_initialized', None)
|
87
|
+
self.__dict__.pop('_session', None)
|
88
|
+
|
89
|
+
@functools.cached_property
|
90
|
+
def _api_initialized(self):
|
91
|
+
api_key = self.api_key or os.environ.get('ANTHROPIC_API_KEY', None)
|
92
|
+
if not api_key:
|
93
|
+
raise ValueError(
|
94
|
+
'Please specify `api_key` during `__init__` or set environment '
|
95
|
+
'variable `ANTHROPIC_API_KEY` with your Anthropic API key.'
|
96
|
+
)
|
97
|
+
self._api_key = api_key
|
98
|
+
return True
|
99
|
+
|
100
|
+
@functools.cached_property
|
101
|
+
def _session(self) -> requests.Session:
|
102
|
+
assert self._api_initialized
|
103
|
+
s = requests.Session()
|
104
|
+
s.headers.update({
|
105
|
+
'x-api-key': self._api_key,
|
106
|
+
'anthropic-version': _ANTHROPIC_API_VERSION,
|
107
|
+
'content-type': 'application/json',
|
108
|
+
})
|
109
|
+
return s
|
110
|
+
|
111
|
+
@property
|
112
|
+
def model_id(self) -> str:
|
113
|
+
"""Returns a string to identify the model."""
|
114
|
+
return self.model
|
115
|
+
|
116
|
+
@property
|
117
|
+
def max_concurrency(self) -> int:
|
118
|
+
rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
|
119
|
+
tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
|
120
|
+
return self.rate_to_max_concurrency(
|
121
|
+
requests_per_min=rpm, tokens_per_min=tpm
|
122
|
+
)
|
123
|
+
|
124
|
+
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
125
|
+
assert self._api_initialized
|
126
|
+
return self._parallel_execute_with_currency_control(
|
127
|
+
self._sample_single, prompts, retry_on_errors=(RateLimitError)
|
128
|
+
)
|
129
|
+
|
130
|
+
def _get_request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
131
|
+
"""Returns a dict as request arguments."""
|
132
|
+
# Authropic requires `max_tokens` to be specified.
|
133
|
+
max_tokens = (
|
134
|
+
options.max_tokens
|
135
|
+
or SUPPORTED_MODELS_AND_SETTINGS[self.model].max_tokens
|
136
|
+
)
|
137
|
+
args = dict(
|
138
|
+
model=self.model,
|
139
|
+
max_tokens=max_tokens,
|
140
|
+
stream=False,
|
141
|
+
)
|
142
|
+
if options.stop:
|
143
|
+
args['stop_sequences'] = options.stop
|
144
|
+
if options.temperature is not None:
|
145
|
+
args['temperature'] = options.temperature
|
146
|
+
if options.top_k is not None:
|
147
|
+
args['top_k'] = options.top_k
|
148
|
+
if options.top_p is not None:
|
149
|
+
args['top_p'] = options.top_p
|
150
|
+
return args
|
151
|
+
|
152
|
+
def _content_from_message(self, prompt: lf.Message) -> list[dict[str, Any]]:
|
153
|
+
"""Converts an message to Anthropic's content protocol (list of dicts)."""
|
154
|
+
# Refer: https://docs.anthropic.com/claude/reference/messages-examples
|
155
|
+
if self.multimodal:
|
156
|
+
content = []
|
157
|
+
for chunk in prompt.chunk():
|
158
|
+
if isinstance(chunk, str):
|
159
|
+
item = dict(type='text', text=chunk)
|
160
|
+
elif isinstance(chunk, lf_modalities.Image):
|
161
|
+
# NOTE(daiyip): Anthropic only support image content instead of URL.
|
162
|
+
item = dict(
|
163
|
+
type='image',
|
164
|
+
source=dict(
|
165
|
+
type='base64',
|
166
|
+
media_type=chunk.mime_type,
|
167
|
+
data=base64.b64encode(chunk.to_bytes()).decode(),
|
168
|
+
),
|
169
|
+
)
|
170
|
+
else:
|
171
|
+
raise ValueError(f'Unsupported modality object: {chunk!r}.')
|
172
|
+
content.append(item)
|
173
|
+
return content
|
174
|
+
else:
|
175
|
+
return [dict(type='text', text=prompt.text)]
|
176
|
+
|
177
|
+
def _message_from_content(self, content: list[dict[str, Any]]) -> lf.Message:
|
178
|
+
"""Converts Anthropic's content protocol to message."""
|
179
|
+
# Refer: https://docs.anthropic.com/claude/reference/messages-examples
|
180
|
+
return lf.AIMessage.from_chunks(
|
181
|
+
[x['text'] for x in content if x['type'] == 'text']
|
182
|
+
)
|
183
|
+
|
184
|
+
def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
|
185
|
+
"""Parses Anthropic's response."""
|
186
|
+
# NOTE(daiyip): Refer https://docs.anthropic.com/claude/reference/errors
|
187
|
+
if response.status_code == 200:
|
188
|
+
output = response.json()
|
189
|
+
message = self._message_from_content(output['content'])
|
190
|
+
input_tokens = output['usage']['input_tokens']
|
191
|
+
output_tokens = output['usage']['output_tokens']
|
192
|
+
return lf.LMSamplingResult(
|
193
|
+
[lf.LMSample(message)],
|
194
|
+
usage=lf.LMSamplingUsage(
|
195
|
+
prompt_tokens=input_tokens,
|
196
|
+
completion_tokens=output_tokens,
|
197
|
+
total_tokens=input_tokens + output_tokens,
|
198
|
+
),
|
199
|
+
)
|
200
|
+
else:
|
201
|
+
if response.status_code == 429:
|
202
|
+
error_cls = RateLimitError
|
203
|
+
elif response.status_code in (502, 529):
|
204
|
+
error_cls = OverloadedError
|
205
|
+
else:
|
206
|
+
error_cls = AnthropicError
|
207
|
+
raise error_cls(f'{response.status_code}: {response.content}')
|
208
|
+
|
209
|
+
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
210
|
+
request = dict()
|
211
|
+
request.update(self._get_request_args(self.sampling_options))
|
212
|
+
request.update(
|
213
|
+
dict(
|
214
|
+
messages=[
|
215
|
+
dict(role='user', content=self._content_from_message(prompt))
|
216
|
+
]
|
217
|
+
)
|
218
|
+
)
|
219
|
+
try:
|
220
|
+
response = self._session.post(
|
221
|
+
_ANTHROPIC_MESSAGE_API_ENDPOINT, json=request, timeout=self.timeout,
|
222
|
+
)
|
223
|
+
return self._parse_response(response)
|
224
|
+
except ConnectionError as e:
|
225
|
+
raise OverloadedError(str(e)) from e
|
226
|
+
|
227
|
+
|
228
|
+
class Claude3(Anthropic):
|
229
|
+
"""Base class for Claude 3 models. 200K input tokens and 4K output tokens."""
|
230
|
+
multimodal = True
|
231
|
+
|
232
|
+
|
233
|
+
class Claude3Opus(Claude3):
|
234
|
+
"""Anthropic's most powerful model."""
|
235
|
+
|
236
|
+
model = 'claude-3-opus-20240229'
|
237
|
+
|
238
|
+
|
239
|
+
class Claude3Sonnet(Claude3):
|
240
|
+
"""A balance between between Opus and Haiku."""
|
241
|
+
|
242
|
+
model = 'claude-3-sonnet-20240229'
|
243
|
+
|
244
|
+
|
245
|
+
class Claude3Haiku(Claude3):
|
246
|
+
"""Anthropic's most compact model."""
|
247
|
+
|
248
|
+
model = 'claude-3-haiku-20240307'
|
249
|
+
|
250
|
+
|
251
|
+
class Claude2(Anthropic):
|
252
|
+
"""Predecessor to Claude 3 with 100K context window.."""
|
253
|
+
model = 'claude-2.0'
|
254
|
+
|
255
|
+
|
256
|
+
class Claude21(Anthropic):
|
257
|
+
"""Updated Claude 2 model with improved accuracy and 200K context window."""
|
258
|
+
model = 'claude-2.1'
|
259
|
+
|
260
|
+
|
261
|
+
class ClaudeInstant(Anthropic):
|
262
|
+
"""Cheapest small and fast model, 100K context window."""
|
263
|
+
model = 'claude-instant-1.2'
|
@@ -0,0 +1,167 @@
|
|
1
|
+
# Copyright 2023 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tests for Anthropic models."""
|
15
|
+
|
16
|
+
import base64
|
17
|
+
import os
|
18
|
+
from typing import Any
|
19
|
+
import unittest
|
20
|
+
from unittest import mock
|
21
|
+
from langfun.core import modalities as lf_modalities
|
22
|
+
from langfun.core.llms import anthropic
|
23
|
+
import pyglove as pg
|
24
|
+
import requests
|
25
|
+
|
26
|
+
|
27
|
+
def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
|
28
|
+
del url, kwargs
|
29
|
+
|
30
|
+
response = requests.Response()
|
31
|
+
response.status_code = 200
|
32
|
+
response._content = pg.to_json_str({
|
33
|
+
'content': [{
|
34
|
+
'type': 'text',
|
35
|
+
'text': (
|
36
|
+
f'hello with temperature={json.get("temperature")}, '
|
37
|
+
f'top_k={json.get("top_k")}, '
|
38
|
+
f'top_p={json.get("top_p")}, '
|
39
|
+
f'max_tokens={json.get("max_tokens")}, '
|
40
|
+
f'stop={json.get("stop_sequences")}.'
|
41
|
+
),
|
42
|
+
}],
|
43
|
+
'usage': {
|
44
|
+
'input_tokens': 2,
|
45
|
+
'output_tokens': 1,
|
46
|
+
},
|
47
|
+
}).encode()
|
48
|
+
return response
|
49
|
+
|
50
|
+
|
51
|
+
image_content = (
|
52
|
+
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
|
53
|
+
b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
|
54
|
+
b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
|
55
|
+
b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
|
56
|
+
b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
|
57
|
+
b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
|
58
|
+
b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
|
59
|
+
b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
|
60
|
+
)
|
61
|
+
|
62
|
+
|
63
|
+
def mock_mm_requests_post(url: str, json: dict[str, Any], **kwargs):
|
64
|
+
del url, kwargs
|
65
|
+
v = json['messages'][0]['content'][0]
|
66
|
+
image = lf_modalities.Image.from_bytes(base64.b64decode(v['source']['data']))
|
67
|
+
|
68
|
+
response = requests.Response()
|
69
|
+
response.status_code = 200
|
70
|
+
response._content = pg.to_json_str({
|
71
|
+
'content': [{
|
72
|
+
'type': 'text',
|
73
|
+
'text': f'{v["type"]}: {image.mime_type}',
|
74
|
+
}],
|
75
|
+
'usage': {
|
76
|
+
'input_tokens': 2,
|
77
|
+
'output_tokens': 1,
|
78
|
+
},
|
79
|
+
}).encode()
|
80
|
+
return response
|
81
|
+
|
82
|
+
|
83
|
+
def mock_requests_post_error(status_code, error_type, error_message):
|
84
|
+
def _mock_requests(url: str, json: dict[str, Any], **kwargs):
|
85
|
+
del url, json, kwargs
|
86
|
+
response = requests.Response()
|
87
|
+
response.status_code = status_code
|
88
|
+
response._content = pg.to_json_str(
|
89
|
+
{
|
90
|
+
'error': {
|
91
|
+
'type': error_type,
|
92
|
+
'message': error_message,
|
93
|
+
}
|
94
|
+
}
|
95
|
+
).encode()
|
96
|
+
return response
|
97
|
+
|
98
|
+
return _mock_requests
|
99
|
+
|
100
|
+
|
101
|
+
class AnthropicTest(unittest.TestCase):
|
102
|
+
|
103
|
+
def test_basics(self):
|
104
|
+
self.assertEqual(
|
105
|
+
anthropic.Claude3Haiku().model_id, 'claude-3-haiku-20240307'
|
106
|
+
)
|
107
|
+
self.assertGreater(anthropic.Claude3Haiku().max_concurrency, 0)
|
108
|
+
|
109
|
+
def test_api_key(self):
|
110
|
+
lm = anthropic.Claude3Haiku()
|
111
|
+
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
112
|
+
lm('hi')
|
113
|
+
|
114
|
+
with mock.patch('requests.Session.post') as mock_request:
|
115
|
+
mock_request.side_effect = mock_requests_post
|
116
|
+
|
117
|
+
lm = anthropic.Claude3Haiku(api_key='fake key')
|
118
|
+
self.assertRegex(lm('hi').text, 'hello.*')
|
119
|
+
|
120
|
+
os.environ['ANTHROPIC_API_KEY'] = 'abc'
|
121
|
+
lm = anthropic.Claude3Haiku()
|
122
|
+
self.assertRegex(lm('hi').text, 'hello.*')
|
123
|
+
del os.environ['ANTHROPIC_API_KEY']
|
124
|
+
|
125
|
+
def test_call(self):
|
126
|
+
with mock.patch('requests.Session.post') as mock_request:
|
127
|
+
mock_request.side_effect = mock_requests_post
|
128
|
+
lm = anthropic.Claude3Haiku(api_key='fake_key')
|
129
|
+
response = lm('hello', temperature=0.0, top_k=0.1, top_p=0.2, stop=['\n'])
|
130
|
+
self.assertEqual(
|
131
|
+
response.text,
|
132
|
+
(
|
133
|
+
'hello with temperature=0.0, top_k=0.1, top_p=0.2, '
|
134
|
+
"max_tokens=4096, stop=['\\n']."
|
135
|
+
),
|
136
|
+
)
|
137
|
+
self.assertIsNotNone(response.usage)
|
138
|
+
self.assertIsNotNone(response.usage.prompt_tokens, 2)
|
139
|
+
self.assertIsNotNone(response.usage.completion_tokens, 1)
|
140
|
+
self.assertIsNotNone(response.usage.total_tokens, 3)
|
141
|
+
|
142
|
+
def test_mm_call(self):
|
143
|
+
with mock.patch('requests.Session.post') as mock_mm_request:
|
144
|
+
mock_mm_request.side_effect = mock_mm_requests_post
|
145
|
+
lm = anthropic.Claude3Haiku(api_key='fake_key')
|
146
|
+
response = lm(lf_modalities.Image.from_bytes(image_content), lm=lm)
|
147
|
+
self.assertEqual(response.text, 'image: image/png')
|
148
|
+
|
149
|
+
def test_call_errors(self):
|
150
|
+
for status_code, error_type, error_message in [
|
151
|
+
(429, 'rate_limit', 'Rate limit exceeded.'),
|
152
|
+
(529, 'service_unavailable', 'Service unavailable.'),
|
153
|
+
(500, 'bad_request', 'Bad request.'),
|
154
|
+
]:
|
155
|
+
with mock.patch('requests.Session.post') as mock_mm_request:
|
156
|
+
mock_mm_request.side_effect = mock_requests_post_error(
|
157
|
+
status_code, error_type, error_message
|
158
|
+
)
|
159
|
+
lm = anthropic.Claude3Haiku(api_key='fake_key')
|
160
|
+
with self.assertRaisesRegex(
|
161
|
+
Exception, f'.*{status_code}: .*{error_message}'
|
162
|
+
):
|
163
|
+
lm('hello', lm=lm, max_attempts=1)
|
164
|
+
|
165
|
+
|
166
|
+
if __name__ == '__main__':
|
167
|
+
unittest.main()
|
@@ -44,28 +44,37 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
44
44
|
self.assertEqual(
|
45
45
|
list(cache.keys()),
|
46
46
|
[
|
47
|
-
('a', (
|
48
|
-
('a', (
|
49
|
-
('b', (
|
50
|
-
('c', (
|
47
|
+
('a', (None, None, 1, 40, None, None), 0),
|
48
|
+
('a', (None, None, 1, 40, None, None), 1),
|
49
|
+
('b', (None, None, 1, 40, None, None), 0),
|
50
|
+
('c', (None, None, 1, 40, None, None), 0),
|
51
51
|
],
|
52
52
|
)
|
53
53
|
self.assertEqual(
|
54
54
|
list(cache.keys('StaticSequence')),
|
55
55
|
[
|
56
|
-
('a', (
|
57
|
-
('a', (
|
58
|
-
('b', (
|
59
|
-
('c', (
|
56
|
+
('a', (None, None, 1, 40, None, None), 0),
|
57
|
+
('a', (None, None, 1, 40, None, None), 1),
|
58
|
+
('b', (None, None, 1, 40, None, None), 0),
|
59
|
+
('c', (None, None, 1, 40, None, None), 0),
|
60
60
|
],
|
61
61
|
)
|
62
62
|
|
63
63
|
def cache_entry(response_text, cache_seed=0):
|
64
64
|
return base.LMCacheEntry(
|
65
|
-
lf.LMSamplingResult(
|
66
|
-
|
67
|
-
lf.
|
68
|
-
|
65
|
+
lf.LMSamplingResult(
|
66
|
+
[
|
67
|
+
lf.LMSample(
|
68
|
+
lf.AIMessage(response_text, cache_seed=cache_seed),
|
69
|
+
score=1.0
|
70
|
+
)
|
71
|
+
],
|
72
|
+
usage=lf.LMSamplingUsage(
|
73
|
+
1,
|
74
|
+
len(response_text),
|
75
|
+
len(response_text) + 1,
|
76
|
+
)
|
77
|
+
)
|
69
78
|
)
|
70
79
|
|
71
80
|
self.assertEqual(
|
@@ -90,19 +99,19 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
90
99
|
list(cache.items()),
|
91
100
|
[
|
92
101
|
(
|
93
|
-
('a', (
|
102
|
+
('a', (None, None, 1, 40, None, None), 0),
|
94
103
|
cache_entry('1'),
|
95
104
|
),
|
96
105
|
(
|
97
|
-
('a', (
|
106
|
+
('a', (None, None, 1, 40, None, None), 1),
|
98
107
|
cache_entry('2', 1),
|
99
108
|
),
|
100
109
|
(
|
101
|
-
('b', (
|
110
|
+
('b', (None, None, 1, 40, None, None), 0),
|
102
111
|
cache_entry('3'),
|
103
112
|
),
|
104
113
|
(
|
105
|
-
('c', (
|
114
|
+
('c', (None, None, 1, 40, None, None), 0),
|
106
115
|
cache_entry('4'),
|
107
116
|
),
|
108
117
|
],
|
@@ -111,19 +120,19 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
111
120
|
list(cache.items('StaticSequence')),
|
112
121
|
[
|
113
122
|
(
|
114
|
-
('a', (
|
123
|
+
('a', (None, None, 1, 40, None, None), 0),
|
115
124
|
cache_entry('1'),
|
116
125
|
),
|
117
126
|
(
|
118
|
-
('a', (
|
127
|
+
('a', (None, None, 1, 40, None, None), 1),
|
119
128
|
cache_entry('2', 1),
|
120
129
|
),
|
121
130
|
(
|
122
|
-
('b', (
|
131
|
+
('b', (None, None, 1, 40, None, None), 0),
|
123
132
|
cache_entry('3'),
|
124
133
|
),
|
125
134
|
(
|
126
|
-
('c', (
|
135
|
+
('c', (None, None, 1, 40, None, None), 0),
|
127
136
|
cache_entry('4'),
|
128
137
|
),
|
129
138
|
],
|
@@ -161,15 +170,15 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
161
170
|
self.assertEqual(
|
162
171
|
list(cache.keys()),
|
163
172
|
[
|
164
|
-
('a', (
|
165
|
-
('a', (1.0,
|
173
|
+
('a', (None, None, 1, 40, None, None), 0),
|
174
|
+
('a', (1.0, None, 1, 40, None, None), 0),
|
166
175
|
],
|
167
176
|
)
|
168
177
|
|
169
178
|
def test_different_model(self):
|
170
179
|
cache = in_memory.InMemory()
|
171
|
-
lm1 = fake.StaticSequence(['1', '2', '3'], cache=cache)
|
172
|
-
lm2 = fake.Echo(cache=cache)
|
180
|
+
lm1 = fake.StaticSequence(['1', '2', '3'], cache=cache, temperature=0.0)
|
181
|
+
lm2 = fake.Echo(cache=cache, temperature=0.0)
|
173
182
|
|
174
183
|
self.assertEqual(lm1('a'), '1')
|
175
184
|
self.assertEqual(lm2('a'), 'a')
|
@@ -180,15 +189,15 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
180
189
|
self.assertEqual(
|
181
190
|
list(cache.keys('StaticSequence')),
|
182
191
|
[
|
183
|
-
('a', (0.0,
|
184
|
-
('b', (0.0,
|
192
|
+
('a', (0.0, None, 1, 40, None, None), 0),
|
193
|
+
('b', (0.0, None, 1, 40, None, None), 0),
|
185
194
|
],
|
186
195
|
)
|
187
196
|
self.assertEqual(
|
188
197
|
list(cache.keys('Echo')),
|
189
198
|
[
|
190
|
-
('a', (0.0,
|
191
|
-
('b', (0.0,
|
199
|
+
('a', (0.0, None, 1, 40, None, None), 0),
|
200
|
+
('b', (0.0, None, 1, 40, None, None), 0),
|
192
201
|
],
|
193
202
|
)
|
194
203
|
self.assertEqual(len(cache), 4)
|
langfun/core/llms/fake.py
CHANGED
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Fake LMs for testing."""
|
15
15
|
|
16
|
+
import abc
|
16
17
|
from typing import Annotated
|
17
18
|
import langfun.core as lf
|
18
19
|
|
@@ -23,15 +24,32 @@ class Fake(lf.LanguageModel):
|
|
23
24
|
def _score(self, prompt: lf.Message, completions: list[lf.Message]):
|
24
25
|
return [lf.LMScoringResult(score=-i * 1.0) for i in range(len(completions))]
|
25
26
|
|
27
|
+
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
28
|
+
results = []
|
29
|
+
for prompt in prompts:
|
30
|
+
response = self._response_from(prompt)
|
31
|
+
results.append(
|
32
|
+
lf.LMSamplingResult(
|
33
|
+
[lf.LMSample(response, 1.0)],
|
34
|
+
usage=lf.LMSamplingUsage(
|
35
|
+
prompt_tokens=len(prompt.text),
|
36
|
+
completion_tokens=len(response.text),
|
37
|
+
total_tokens=len(prompt.text) + len(response.text),
|
38
|
+
)
|
39
|
+
)
|
40
|
+
)
|
41
|
+
return results
|
42
|
+
|
43
|
+
@abc.abstractmethod
|
44
|
+
def _response_from(self, prompt: lf.Message) -> lf.Message:
|
45
|
+
"""Returns the response for the given prompt."""
|
46
|
+
|
26
47
|
|
27
48
|
class Echo(Fake):
|
28
49
|
"""A simple echo language model for testing."""
|
29
50
|
|
30
|
-
def
|
31
|
-
return
|
32
|
-
lf.LMSamplingResult([lf.LMSample(prompt.text, 1.0)])
|
33
|
-
for prompt in prompts
|
34
|
-
]
|
51
|
+
def _response_from(self, prompt: lf.Message) -> lf.Message:
|
52
|
+
return lf.AIMessage(prompt.text)
|
35
53
|
|
36
54
|
|
37
55
|
@lf.use_init_args(['response'])
|
@@ -39,15 +57,12 @@ class StaticResponse(Fake):
|
|
39
57
|
"""Language model that always gives the same canned response."""
|
40
58
|
|
41
59
|
response: Annotated[
|
42
|
-
str,
|
60
|
+
str | lf.Message,
|
43
61
|
'A canned response that will be returned regardless of the prompt.'
|
44
62
|
]
|
45
63
|
|
46
|
-
def
|
47
|
-
return
|
48
|
-
lf.LMSamplingResult([lf.LMSample(self.response, 1.0)])
|
49
|
-
for _ in prompts
|
50
|
-
]
|
64
|
+
def _response_from(self, prompt: lf.Message) -> lf.Message:
|
65
|
+
return lf.AIMessage.from_value(self.response)
|
51
66
|
|
52
67
|
|
53
68
|
@lf.use_init_args(['mapping'])
|
@@ -55,15 +70,12 @@ class StaticMapping(Fake):
|
|
55
70
|
"""A static mapping from prompt to response."""
|
56
71
|
|
57
72
|
mapping: Annotated[
|
58
|
-
dict[str, str],
|
73
|
+
dict[str, str | lf.Message],
|
59
74
|
'A mapping from prompt to response.'
|
60
75
|
]
|
61
76
|
|
62
|
-
def
|
63
|
-
return [
|
64
|
-
lf.LMSamplingResult([lf.LMSample(self.mapping[prompt], 1.0)])
|
65
|
-
for prompt in prompts
|
66
|
-
]
|
77
|
+
def _response_from(self, prompt: lf.Message) -> lf.Message:
|
78
|
+
return lf.AIMessage.from_value(self.mapping[prompt])
|
67
79
|
|
68
80
|
|
69
81
|
@lf.use_init_args(['sequence'])
|
@@ -71,7 +83,7 @@ class StaticSequence(Fake):
|
|
71
83
|
"""A static sequence of responses to use."""
|
72
84
|
|
73
85
|
sequence: Annotated[
|
74
|
-
list[str],
|
86
|
+
list[str | lf.Message],
|
75
87
|
'A sequence of strings as the response.'
|
76
88
|
]
|
77
89
|
|
@@ -79,10 +91,7 @@ class StaticSequence(Fake):
|
|
79
91
|
super()._on_bound()
|
80
92
|
self._pos = 0
|
81
93
|
|
82
|
-
def
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
[lf.LMSample(self.sequence[self._pos], 1.0)]))
|
87
|
-
self._pos += 1
|
88
|
-
return results
|
94
|
+
def _response_from(self, prompt: lf.Message) -> lf.Message:
|
95
|
+
r = lf.AIMessage.from_value(self.sequence[self._pos])
|
96
|
+
self._pos += 1
|
97
|
+
return r
|