langfun 0.0.2.dev20240418__py3-none-any.whl → 0.0.2.dev20240420__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/language_model.py +39 -13
- langfun/core/llms/__init__.py +12 -2
- langfun/core/llms/anthropic.py +249 -0
- langfun/core/llms/anthropic_test.py +167 -0
- langfun/core/llms/fake_test.py +4 -4
- langfun/core/llms/openai.py +33 -13
- langfun/core/llms/openai_test.py +13 -11
- {langfun-0.0.2.dev20240418.dist-info → langfun-0.0.2.dev20240420.dist-info}/METADATA +1 -1
- {langfun-0.0.2.dev20240418.dist-info → langfun-0.0.2.dev20240420.dist-info}/RECORD +12 -10
- {langfun-0.0.2.dev20240418.dist-info → langfun-0.0.2.dev20240420.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240418.dist-info → langfun-0.0.2.dev20240420.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240418.dist-info → langfun-0.0.2.dev20240420.dist-info}/top_level.txt +0 -0
langfun/core/language_model.py
CHANGED
@@ -440,7 +440,7 @@ class LanguageModel(component.Component):
|
|
440
440
|
response.metadata.usage = result.usage
|
441
441
|
|
442
442
|
elapse = time.time() - request_start
|
443
|
-
self._debug(prompt, response, call_counter, elapse)
|
443
|
+
self._debug(prompt, response, call_counter, result.usage, elapse)
|
444
444
|
return response
|
445
445
|
|
446
446
|
def _debug(
|
@@ -448,35 +448,51 @@ class LanguageModel(component.Component):
|
|
448
448
|
prompt: message_lib.Message,
|
449
449
|
response: message_lib.Message,
|
450
450
|
call_counter: int,
|
451
|
+
usage: LMSamplingUsage | None,
|
451
452
|
elapse: float,
|
452
|
-
):
|
453
|
+
) -> None:
|
453
454
|
"""Outputs debugging information."""
|
454
455
|
debug = self.debug
|
455
456
|
if isinstance(debug, bool):
|
456
457
|
debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
|
457
458
|
|
458
459
|
if debug & LMDebugMode.INFO:
|
459
|
-
self._debug_model_info(call_counter)
|
460
|
+
self._debug_model_info(call_counter, usage)
|
460
461
|
|
461
462
|
if debug & LMDebugMode.PROMPT:
|
462
|
-
self._debug_prompt(prompt, call_counter)
|
463
|
+
self._debug_prompt(prompt, call_counter, usage)
|
463
464
|
|
464
465
|
if debug & LMDebugMode.RESPONSE:
|
465
|
-
self._debug_response(response, call_counter, elapse)
|
466
|
+
self._debug_response(response, call_counter, usage, elapse)
|
466
467
|
|
467
|
-
def _debug_model_info(
|
468
|
+
def _debug_model_info(
|
469
|
+
self, call_counter: int, usage: LMSamplingUsage | None) -> None:
|
468
470
|
"""Outputs debugging information about the model."""
|
471
|
+
title_suffix = ''
|
472
|
+
if usage and usage.total_tokens != 0:
|
473
|
+
title_suffix = console.colored(
|
474
|
+
f' (total {usage.total_tokens} tokens)', 'red')
|
475
|
+
|
469
476
|
console.write(
|
470
477
|
self.format(compact=True, use_inferred=True),
|
471
|
-
title=f'[{call_counter}] LM INFO:',
|
478
|
+
title=f'[{call_counter}] LM INFO{title_suffix}:',
|
472
479
|
color='magenta',
|
473
480
|
)
|
474
481
|
|
475
|
-
def _debug_prompt(
|
482
|
+
def _debug_prompt(
|
483
|
+
self,
|
484
|
+
prompt: message_lib.Message,
|
485
|
+
call_counter: int,
|
486
|
+
usage: LMSamplingUsage | None,
|
487
|
+
) -> None:
|
476
488
|
"""Outputs debugging information about the prompt."""
|
489
|
+
title_suffix = ''
|
490
|
+
if usage and usage.prompt_tokens != 0:
|
491
|
+
title_suffix = console.colored(f' ({usage.prompt_tokens} tokens)', 'red')
|
492
|
+
|
477
493
|
console.write(
|
478
494
|
prompt,
|
479
|
-
title=f'\n[{call_counter}] PROMPT SENT TO LM:',
|
495
|
+
title=f'\n[{call_counter}] PROMPT SENT TO LM{title_suffix}:',
|
480
496
|
color='green',
|
481
497
|
)
|
482
498
|
referred_modalities = prompt.referred_modalities()
|
@@ -490,12 +506,22 @@ class LanguageModel(component.Component):
|
|
490
506
|
)
|
491
507
|
|
492
508
|
def _debug_response(
|
493
|
-
self,
|
494
|
-
|
509
|
+
self,
|
510
|
+
response: message_lib.Message,
|
511
|
+
call_counter: int,
|
512
|
+
usage: LMSamplingUsage | None,
|
513
|
+
elapse: float
|
514
|
+
) -> None:
|
495
515
|
"""Outputs debugging information about the response."""
|
516
|
+
title_suffix = ' ('
|
517
|
+
if usage and usage.completion_tokens != 0:
|
518
|
+
title_suffix += f'{usage.completion_tokens} tokens '
|
519
|
+
title_suffix += f'in {elapse:.2f} seconds)'
|
520
|
+
title_suffix = console.colored(title_suffix, 'red')
|
521
|
+
|
496
522
|
console.write(
|
497
523
|
str(response) + '\n',
|
498
|
-
title=f'\n[{call_counter}] LM RESPONSE
|
524
|
+
title=f'\n[{call_counter}] LM RESPONSE{title_suffix}:',
|
499
525
|
color='blue',
|
500
526
|
)
|
501
527
|
|
@@ -542,7 +568,7 @@ class LanguageModel(component.Component):
|
|
542
568
|
debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
|
543
569
|
|
544
570
|
if debug & LMDebugMode.INFO:
|
545
|
-
self._debug_model_info(call_counter)
|
571
|
+
self._debug_model_info(call_counter, None)
|
546
572
|
|
547
573
|
if debug & LMDebugMode.PROMPT:
|
548
574
|
console.write(
|
langfun/core/llms/__init__.py
CHANGED
@@ -35,8 +35,12 @@ from langfun.core.llms.google_genai import Palm2_IT
|
|
35
35
|
from langfun.core.llms.openai import OpenAI
|
36
36
|
|
37
37
|
from langfun.core.llms.openai import Gpt4Turbo
|
38
|
-
from langfun.core.llms.openai import
|
39
|
-
from langfun.core.llms.openai import
|
38
|
+
from langfun.core.llms.openai import Gpt4Turbo_20240409
|
39
|
+
from langfun.core.llms.openai import Gpt4TurboPreview
|
40
|
+
from langfun.core.llms.openai import Gpt4TurboPreview_0125
|
41
|
+
from langfun.core.llms.openai import Gpt4TurboPreview_1106
|
42
|
+
from langfun.core.llms.openai import Gpt4VisionPreview
|
43
|
+
from langfun.core.llms.openai import Gpt4VisionPreview_1106
|
40
44
|
from langfun.core.llms.openai import Gpt4
|
41
45
|
from langfun.core.llms.openai import Gpt4_0613
|
42
46
|
from langfun.core.llms.openai import Gpt4_32K
|
@@ -57,6 +61,12 @@ from langfun.core.llms.openai import Gpt3Curie
|
|
57
61
|
from langfun.core.llms.openai import Gpt3Babbage
|
58
62
|
from langfun.core.llms.openai import Gpt3Ada
|
59
63
|
|
64
|
+
from langfun.core.llms.anthropic import Anthropic
|
65
|
+
from langfun.core.llms.anthropic import Claude3Opus
|
66
|
+
from langfun.core.llms.anthropic import Claude3Sonnet
|
67
|
+
from langfun.core.llms.anthropic import Claude3Haiku
|
68
|
+
|
69
|
+
|
60
70
|
# LLaMA C++ models.
|
61
71
|
from langfun.core.llms.llama_cpp import LlamaCppRemote
|
62
72
|
|
@@ -0,0 +1,249 @@
|
|
1
|
+
# Copyright 2023 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Language models from Anthropic."""
|
15
|
+
|
16
|
+
import base64
|
17
|
+
import functools
|
18
|
+
import os
|
19
|
+
from typing import Annotated, Any
|
20
|
+
|
21
|
+
import langfun.core as lf
|
22
|
+
from langfun.core import modalities as lf_modalities
|
23
|
+
import pyglove as pg
|
24
|
+
import requests
|
25
|
+
|
26
|
+
|
27
|
+
SUPPORTED_MODELS_AND_SETTINGS = {
|
28
|
+
# See https://docs.anthropic.com/claude/docs/models-overview
|
29
|
+
'claude-3-opus-20240229': pg.Dict(max_tokens=4096, max_concurrency=16),
|
30
|
+
'claude-3-sonnet-20240229': pg.Dict(max_tokens=4096, max_concurrency=16),
|
31
|
+
'claude-3-haiku-20240307': pg.Dict(max_tokens=4096, max_concurrency=16),
|
32
|
+
'claude-2.1': pg.Dict(max_tokens=4096, max_concurrency=16),
|
33
|
+
'claude-2.0': pg.Dict(max_tokens=4096, max_concurrency=16),
|
34
|
+
'claude-instant-1.2': pg.Dict(max_tokens=4096, max_concurrency=16),
|
35
|
+
}
|
36
|
+
|
37
|
+
|
38
|
+
class AnthropicError(Exception): # pylint: disable=g-bad-exception-name
|
39
|
+
"""Base class for Anthropic errors."""
|
40
|
+
|
41
|
+
|
42
|
+
class RateLimitError(AnthropicError):
|
43
|
+
"""Error for rate limit reached."""
|
44
|
+
|
45
|
+
|
46
|
+
class OverloadedError(AnthropicError):
|
47
|
+
"""Anthropic's server is temporarily overloaded."""
|
48
|
+
|
49
|
+
|
50
|
+
_ANTHROPIC_MESSAGE_API_ENDPOINT = 'https://api.anthropic.com/v1/messages'
|
51
|
+
_ANTHROPIC_API_VERSION = '2023-06-01'
|
52
|
+
|
53
|
+
|
54
|
+
@lf.use_init_args(['model'])
|
55
|
+
class Anthropic(lf.LanguageModel):
|
56
|
+
"""Anthropic LLMs (Claude) through REST APIs.
|
57
|
+
|
58
|
+
See https://docs.anthropic.com/claude/reference/messages_post
|
59
|
+
"""
|
60
|
+
|
61
|
+
model: pg.typing.Annotated[
|
62
|
+
pg.typing.Enum(
|
63
|
+
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
|
64
|
+
),
|
65
|
+
'The name of the model to use.',
|
66
|
+
]
|
67
|
+
|
68
|
+
multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
|
69
|
+
True
|
70
|
+
)
|
71
|
+
|
72
|
+
api_key: Annotated[
|
73
|
+
str | None,
|
74
|
+
(
|
75
|
+
'API key. If None, the key will be read from environment variable '
|
76
|
+
"'ANTHROPIC_API_KEY'."
|
77
|
+
),
|
78
|
+
] = None
|
79
|
+
|
80
|
+
def _on_bound(self):
|
81
|
+
super()._on_bound()
|
82
|
+
self._api_key = None
|
83
|
+
self.__dict__.pop('_api_initialized', None)
|
84
|
+
|
85
|
+
@functools.cached_property
|
86
|
+
def _api_initialized(self):
|
87
|
+
api_key = self.api_key or os.environ.get('ANTHROPIC_API_KEY', None)
|
88
|
+
if not api_key:
|
89
|
+
raise ValueError(
|
90
|
+
'Please specify `api_key` during `__init__` or set environment '
|
91
|
+
'variable `ANTHROPIC_API_KEY` with your Anthropic API key.'
|
92
|
+
)
|
93
|
+
self._api_key = api_key
|
94
|
+
return True
|
95
|
+
|
96
|
+
@property
|
97
|
+
def model_id(self) -> str:
|
98
|
+
"""Returns a string to identify the model."""
|
99
|
+
return self.model
|
100
|
+
|
101
|
+
@property
|
102
|
+
def max_concurrency(self) -> int:
|
103
|
+
return SUPPORTED_MODELS_AND_SETTINGS[self.model].max_concurrency
|
104
|
+
|
105
|
+
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
106
|
+
assert self._api_initialized
|
107
|
+
return self._parallel_execute_with_currency_control(
|
108
|
+
self._sample_single, prompts, retry_on_errors=(RateLimitError)
|
109
|
+
)
|
110
|
+
|
111
|
+
def _get_request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
112
|
+
"""Returns a dict as request arguments."""
|
113
|
+
# Authropic requires `max_tokens` to be specified.
|
114
|
+
max_tokens = (
|
115
|
+
options.max_tokens
|
116
|
+
or SUPPORTED_MODELS_AND_SETTINGS[self.model].max_tokens
|
117
|
+
)
|
118
|
+
args = dict(
|
119
|
+
model=self.model,
|
120
|
+
max_tokens=max_tokens,
|
121
|
+
stream=False,
|
122
|
+
)
|
123
|
+
if options.stop:
|
124
|
+
args['stop_sequences'] = options.stop
|
125
|
+
if options.temperature is not None:
|
126
|
+
args['temperature'] = options.temperature
|
127
|
+
if options.top_k is not None:
|
128
|
+
args['top_k'] = options.top_k
|
129
|
+
if options.top_p is not None:
|
130
|
+
args['top_p'] = options.top_p
|
131
|
+
return args
|
132
|
+
|
133
|
+
def _content_from_message(self, prompt: lf.Message) -> list[dict[str, Any]]:
|
134
|
+
"""Converts an message to Anthropic's content protocol (list of dicts)."""
|
135
|
+
# Refer: https://docs.anthropic.com/claude/reference/messages-examples
|
136
|
+
if self.multimodal:
|
137
|
+
content = []
|
138
|
+
for chunk in prompt.chunk():
|
139
|
+
if isinstance(chunk, str):
|
140
|
+
item = dict(type='text', text=chunk)
|
141
|
+
elif isinstance(chunk, lf_modalities.Image):
|
142
|
+
# NOTE(daiyip): Anthropic only support image content instead of URL.
|
143
|
+
item = dict(
|
144
|
+
type='image',
|
145
|
+
source=dict(
|
146
|
+
type='base64',
|
147
|
+
media_type=chunk.mime_type,
|
148
|
+
data=base64.b64encode(chunk.to_bytes()).decode(),
|
149
|
+
),
|
150
|
+
)
|
151
|
+
else:
|
152
|
+
raise ValueError(f'Unsupported modality object: {chunk!r}.')
|
153
|
+
content.append(item)
|
154
|
+
return content
|
155
|
+
else:
|
156
|
+
return [dict(type='text', text=prompt.text)]
|
157
|
+
|
158
|
+
def _message_from_content(self, content: list[dict[str, Any]]) -> lf.Message:
|
159
|
+
"""Converts Anthropic's content protocol to message."""
|
160
|
+
# Refer: https://docs.anthropic.com/claude/reference/messages-examples
|
161
|
+
return lf.AIMessage.from_chunks(
|
162
|
+
[x['text'] for x in content if x['type'] == 'text']
|
163
|
+
)
|
164
|
+
|
165
|
+
def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
|
166
|
+
"""Parses Anthropic's response."""
|
167
|
+
# NOTE(daiyip): Refer https://docs.anthropic.com/claude/reference/errors
|
168
|
+
output = response.json()
|
169
|
+
if response.status_code == 200:
|
170
|
+
message = self._message_from_content(output['content'])
|
171
|
+
input_tokens = output['usage']['input_tokens']
|
172
|
+
output_tokens = output['usage']['output_tokens']
|
173
|
+
return lf.LMSamplingResult(
|
174
|
+
[lf.LMSample(message)],
|
175
|
+
usage=lf.LMSamplingUsage(
|
176
|
+
prompt_tokens=input_tokens,
|
177
|
+
completion_tokens=output_tokens,
|
178
|
+
total_tokens=input_tokens + output_tokens,
|
179
|
+
),
|
180
|
+
)
|
181
|
+
else:
|
182
|
+
if response.status_code == 429:
|
183
|
+
error_cls = RateLimitError
|
184
|
+
elif response.status_code == 529:
|
185
|
+
error_cls = OverloadedError
|
186
|
+
else:
|
187
|
+
error_cls = AnthropicError
|
188
|
+
error = output['error']
|
189
|
+
raise error_cls(f'{error["type"]}: {error["message"]}')
|
190
|
+
|
191
|
+
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
192
|
+
request = dict()
|
193
|
+
request.update(self._get_request_args(self.sampling_options))
|
194
|
+
request.update(
|
195
|
+
dict(
|
196
|
+
messages=[
|
197
|
+
dict(role='user', content=self._content_from_message(prompt))
|
198
|
+
]
|
199
|
+
)
|
200
|
+
)
|
201
|
+
response = requests.post(
|
202
|
+
_ANTHROPIC_MESSAGE_API_ENDPOINT,
|
203
|
+
json=request,
|
204
|
+
headers={
|
205
|
+
'x-api-key': self._api_key,
|
206
|
+
'anthropic-version': _ANTHROPIC_API_VERSION,
|
207
|
+
'content-type': 'application/json',
|
208
|
+
},
|
209
|
+
timeout=self.timeout,
|
210
|
+
)
|
211
|
+
return self._parse_response(response)
|
212
|
+
|
213
|
+
|
214
|
+
class Claude3(Anthropic):
|
215
|
+
"""Base class for Claude 3 models. 200K input tokens and 4K output tokens."""
|
216
|
+
multimodal = True
|
217
|
+
|
218
|
+
|
219
|
+
class Claude3Opus(Claude3):
|
220
|
+
"""Anthropic's most powerful model."""
|
221
|
+
|
222
|
+
model = 'claude-3-opus-20240229'
|
223
|
+
|
224
|
+
|
225
|
+
class Claude3Sonnet(Claude3):
|
226
|
+
"""A balance between between Opus and Haiku."""
|
227
|
+
|
228
|
+
model = 'claude-3-sonnet-20240229'
|
229
|
+
|
230
|
+
|
231
|
+
class Claude3Haiku(Claude3):
|
232
|
+
"""Anthropic's most compact model."""
|
233
|
+
|
234
|
+
model = 'claude-3-haiku-20240307'
|
235
|
+
|
236
|
+
|
237
|
+
class Claude2(Anthropic):
|
238
|
+
"""Predecessor to Claude 3 with 100K context window.."""
|
239
|
+
model = 'claude-2.0'
|
240
|
+
|
241
|
+
|
242
|
+
class Claude21(Anthropic):
|
243
|
+
"""Updated Claude 2 model with improved accuracy and 200K context window."""
|
244
|
+
model = 'claude-2.1'
|
245
|
+
|
246
|
+
|
247
|
+
class ClaudeInstant(Anthropic):
|
248
|
+
"""Cheapest small and fast model, 100K context window."""
|
249
|
+
model = 'claude-instant-1.2'
|
@@ -0,0 +1,167 @@
|
|
1
|
+
# Copyright 2023 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tests for Anthropic models."""
|
15
|
+
|
16
|
+
import base64
|
17
|
+
import os
|
18
|
+
from typing import Any
|
19
|
+
import unittest
|
20
|
+
from unittest import mock
|
21
|
+
from langfun.core import modalities as lf_modalities
|
22
|
+
from langfun.core.llms import anthropic
|
23
|
+
import pyglove as pg
|
24
|
+
import requests
|
25
|
+
|
26
|
+
|
27
|
+
def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
|
28
|
+
del url, kwargs
|
29
|
+
|
30
|
+
response = requests.Response()
|
31
|
+
response.status_code = 200
|
32
|
+
response._content = pg.to_json_str({
|
33
|
+
'content': [{
|
34
|
+
'type': 'text',
|
35
|
+
'text': (
|
36
|
+
f'hello with temperature={json.get("temperature")}, '
|
37
|
+
f'top_k={json.get("top_k")}, '
|
38
|
+
f'top_p={json.get("top_p")}, '
|
39
|
+
f'max_tokens={json.get("max_tokens")}, '
|
40
|
+
f'stop={json.get("stop_sequences")}.'
|
41
|
+
),
|
42
|
+
}],
|
43
|
+
'usage': {
|
44
|
+
'input_tokens': 2,
|
45
|
+
'output_tokens': 1,
|
46
|
+
},
|
47
|
+
}).encode()
|
48
|
+
return response
|
49
|
+
|
50
|
+
|
51
|
+
image_content = (
|
52
|
+
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
|
53
|
+
b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
|
54
|
+
b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
|
55
|
+
b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
|
56
|
+
b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
|
57
|
+
b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
|
58
|
+
b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
|
59
|
+
b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
|
60
|
+
)
|
61
|
+
|
62
|
+
|
63
|
+
def mock_mm_requests_post(url: str, json: dict[str, Any], **kwargs):
|
64
|
+
del url, kwargs
|
65
|
+
v = json['messages'][0]['content'][0]
|
66
|
+
image = lf_modalities.Image.from_bytes(base64.b64decode(v['source']['data']))
|
67
|
+
|
68
|
+
response = requests.Response()
|
69
|
+
response.status_code = 200
|
70
|
+
response._content = pg.to_json_str({
|
71
|
+
'content': [{
|
72
|
+
'type': 'text',
|
73
|
+
'text': f'{v["type"]}: {image.mime_type}',
|
74
|
+
}],
|
75
|
+
'usage': {
|
76
|
+
'input_tokens': 2,
|
77
|
+
'output_tokens': 1,
|
78
|
+
},
|
79
|
+
}).encode()
|
80
|
+
return response
|
81
|
+
|
82
|
+
|
83
|
+
def mock_requests_post_error(status_code, error_type, error_message):
|
84
|
+
def _mock_requests(url: str, json: dict[str, Any], **kwargs):
|
85
|
+
del url, json, kwargs
|
86
|
+
response = requests.Response()
|
87
|
+
response.status_code = status_code
|
88
|
+
response._content = pg.to_json_str(
|
89
|
+
{
|
90
|
+
'error': {
|
91
|
+
'type': error_type,
|
92
|
+
'message': error_message,
|
93
|
+
}
|
94
|
+
}
|
95
|
+
).encode()
|
96
|
+
return response
|
97
|
+
|
98
|
+
return _mock_requests
|
99
|
+
|
100
|
+
|
101
|
+
class AuthropicTest(unittest.TestCase):
|
102
|
+
|
103
|
+
def test_basics(self):
|
104
|
+
self.assertEqual(
|
105
|
+
anthropic.Claude3Haiku().model_id, 'claude-3-haiku-20240307'
|
106
|
+
)
|
107
|
+
self.assertEqual(anthropic.Claude3Haiku().max_concurrency, 16)
|
108
|
+
|
109
|
+
def test_api_key(self):
|
110
|
+
lm = anthropic.Claude3Haiku()
|
111
|
+
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
112
|
+
lm('hi')
|
113
|
+
|
114
|
+
with mock.patch('requests.post') as mock_request:
|
115
|
+
mock_request.side_effect = mock_requests_post
|
116
|
+
|
117
|
+
lm = anthropic.Claude3Haiku(api_key='fake key')
|
118
|
+
self.assertRegex(lm('hi').text, 'hello.*')
|
119
|
+
|
120
|
+
os.environ['ANTHROPIC_API_KEY'] = 'abc'
|
121
|
+
lm = anthropic.Claude3Haiku()
|
122
|
+
self.assertRegex(lm('hi').text, 'hello.*')
|
123
|
+
del os.environ['ANTHROPIC_API_KEY']
|
124
|
+
|
125
|
+
def test_call(self):
|
126
|
+
with mock.patch('requests.post') as mock_request:
|
127
|
+
mock_request.side_effect = mock_requests_post
|
128
|
+
lm = anthropic.Claude3Haiku(api_key='fake_key')
|
129
|
+
response = lm('hello', temperature=0.0, top_k=0.1, top_p=0.2, stop=['\n'])
|
130
|
+
self.assertEqual(
|
131
|
+
response.text,
|
132
|
+
(
|
133
|
+
'hello with temperature=0.0, top_k=0.1, top_p=0.2, '
|
134
|
+
"max_tokens=4096, stop=['\\n']."
|
135
|
+
),
|
136
|
+
)
|
137
|
+
self.assertIsNotNone(response.usage)
|
138
|
+
self.assertIsNotNone(response.usage.prompt_tokens, 2)
|
139
|
+
self.assertIsNotNone(response.usage.completion_tokens, 1)
|
140
|
+
self.assertIsNotNone(response.usage.total_tokens, 3)
|
141
|
+
|
142
|
+
def test_mm_call(self):
|
143
|
+
with mock.patch('requests.post') as mock_mm_request:
|
144
|
+
mock_mm_request.side_effect = mock_mm_requests_post
|
145
|
+
lm = anthropic.Claude3Haiku(api_key='fake_key')
|
146
|
+
response = lm(lf_modalities.Image.from_bytes(image_content), lm=lm)
|
147
|
+
self.assertEqual(response.text, 'image: image/png')
|
148
|
+
|
149
|
+
def test_call_errors(self):
|
150
|
+
for status_code, error_type, error_message in [
|
151
|
+
(429, 'rate_limit', 'Rate limit exceeded.'),
|
152
|
+
(529, 'service_unavailable', 'Service unavailable.'),
|
153
|
+
(500, 'bad_request', 'Bad request.'),
|
154
|
+
]:
|
155
|
+
with mock.patch('requests.post') as mock_mm_request:
|
156
|
+
mock_mm_request.side_effect = mock_requests_post_error(
|
157
|
+
status_code, error_type, error_message
|
158
|
+
)
|
159
|
+
lm = anthropic.Claude3Haiku(api_key='fake_key')
|
160
|
+
with self.assertRaisesRegex(
|
161
|
+
Exception, f'{error_type}: {error_message}'
|
162
|
+
):
|
163
|
+
lm('hello', lm=lm, max_attempts=1)
|
164
|
+
|
165
|
+
|
166
|
+
if __name__ == '__main__':
|
167
|
+
unittest.main()
|
langfun/core/llms/fake_test.py
CHANGED
@@ -39,8 +39,8 @@ class EchoTest(unittest.TestCase):
|
|
39
39
|
with contextlib.redirect_stdout(string_io):
|
40
40
|
self.assertEqual(lm('hi'), 'hi')
|
41
41
|
debug_info = string_io.getvalue()
|
42
|
-
self.assertIn('[0] LM INFO
|
43
|
-
self.assertIn('[0] PROMPT SENT TO LM
|
42
|
+
self.assertIn('[0] LM INFO', debug_info)
|
43
|
+
self.assertIn('[0] PROMPT SENT TO LM', debug_info)
|
44
44
|
self.assertIn('[0] LM RESPONSE', debug_info)
|
45
45
|
|
46
46
|
def test_score(self):
|
@@ -84,8 +84,8 @@ class StaticResponseTest(unittest.TestCase):
|
|
84
84
|
self.assertEqual(lm('hi'), canned_response)
|
85
85
|
|
86
86
|
debug_info = string_io.getvalue()
|
87
|
-
self.assertIn('[0] LM INFO
|
88
|
-
self.assertIn('[0] PROMPT SENT TO LM
|
87
|
+
self.assertIn('[0] LM INFO', debug_info)
|
88
|
+
self.assertIn('[0] PROMPT SENT TO LM', debug_info)
|
89
89
|
self.assertIn('[0] LM RESPONSE', debug_info)
|
90
90
|
|
91
91
|
|
langfun/core/llms/openai.py
CHANGED
@@ -31,10 +31,13 @@ SUPPORTED_MODELS_AND_SETTINGS = [
|
|
31
31
|
# The concurrent requests is estimated by TPM/RPM from
|
32
32
|
# https://platform.openai.com/account/limits
|
33
33
|
# GPT-4 Turbo models.
|
34
|
-
('gpt-4-turbo
|
35
|
-
('gpt-4-
|
36
|
-
('gpt-4-
|
37
|
-
('gpt-4-
|
34
|
+
('gpt-4-turbo', 8), # GPT-4 Turbo with Vision
|
35
|
+
('gpt-4-turbo-2024-04-09', 8), # GPT-4-Turbo with Vision, 04/09/2024
|
36
|
+
('gpt-4-turbo-preview', 8), # GPT-4 Turbo Preview
|
37
|
+
('gpt-4-0125-preview', 8), # GPT-4 Turbo Preview, 01/25/2024
|
38
|
+
('gpt-4-1106-preview', 8), # GPT-4 Turbo Preview, 11/06/2023
|
39
|
+
('gpt-4-vision-preview', 8), # GPT-4 Turbo Vision Preview.
|
40
|
+
('gpt-4-1106-vision-preview', 8), # GPT-4 Turbo Vision Preview, 11/06/2023
|
38
41
|
# GPT-4 models.
|
39
42
|
('gpt-4', 4),
|
40
43
|
('gpt-4-0613', 4),
|
@@ -284,26 +287,43 @@ class Gpt4(OpenAI):
|
|
284
287
|
|
285
288
|
|
286
289
|
class Gpt4Turbo(Gpt4):
|
287
|
-
"""GPT-4 Turbo with 128K context window
|
288
|
-
model = 'gpt-4-turbo
|
290
|
+
"""GPT-4 Turbo with 128K context window. Knowledge up to Dec. 2023."""
|
291
|
+
model = 'gpt-4-turbo'
|
292
|
+
multimodal = True
|
289
293
|
|
290
294
|
|
291
|
-
class
|
292
|
-
"""GPT-4 Turbo with
|
293
|
-
model = 'gpt-4-
|
295
|
+
class Gpt4Turbo_20240409(Gpt4Turbo): # pylint:disable=invalid-name
|
296
|
+
"""GPT-4 Turbo with 128K context window. Knowledge up to Dec. 2023."""
|
297
|
+
model = 'gpt-4-turbo-2024-04-09'
|
294
298
|
multimodal = True
|
295
299
|
|
296
300
|
|
297
|
-
class
|
298
|
-
"""GPT-4 Turbo with
|
301
|
+
class Gpt4TurboPreview(Gpt4):
|
302
|
+
"""GPT-4 Turbo Preview with 128k context window. Knowledge up to Dec. 2023."""
|
303
|
+
model = 'gpt-4-turbo-preview'
|
304
|
+
|
305
|
+
|
306
|
+
class Gpt4TurboPreview_0125(Gpt4TurboPreview): # pylint: disable=invalid-name
|
307
|
+
"""GPT-4 Turbo Preview with 128k context window. Knowledge up to Dec. 2023."""
|
299
308
|
model = 'gpt-4-0125-preview'
|
300
309
|
|
301
310
|
|
302
|
-
class
|
303
|
-
"""GPT-4 Turbo
|
311
|
+
class Gpt4TurboPreview_1106(Gpt4TurboPreview): # pylint: disable=invalid-name
|
312
|
+
"""GPT-4 Turbo Preview with 128k context window. Knowledge up to Apr. 2023."""
|
304
313
|
model = 'gpt-4-1106-preview'
|
305
314
|
|
306
315
|
|
316
|
+
class Gpt4VisionPreview(Gpt4):
|
317
|
+
"""GPT-4 Turbo vision preview. 128k context window. Knowledge to Apr. 2023."""
|
318
|
+
model = 'gpt-4-vision-preview'
|
319
|
+
multimodal = True
|
320
|
+
|
321
|
+
|
322
|
+
class Gpt4VisionPreview_1106(Gpt4): # pylint: disable=invalid-name
|
323
|
+
"""GPT-4 Turbo vision preview. 128k context window. Knowledge to Apr. 2023."""
|
324
|
+
model = 'gpt-4-1106-vision-preview'
|
325
|
+
|
326
|
+
|
307
327
|
class Gpt4_0613(Gpt4): # pylint:disable=invalid-name
|
308
328
|
"""GPT-4 @20230613. 8K context window. Knowledge up to 9-2021."""
|
309
329
|
model = 'gpt-4-0613'
|
langfun/core/llms/openai_test.py
CHANGED
@@ -157,17 +157,19 @@ class OpenaiTest(unittest.TestCase):
|
|
157
157
|
def test_call_chat_completion_vision(self):
|
158
158
|
with mock.patch('openai.ChatCompletion.create') as mock_chat_completion:
|
159
159
|
mock_chat_completion.side_effect = mock_chat_completion_query_vision
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
160
|
+
lm_1 = openai.Gpt4Turbo(api_key='test_key')
|
161
|
+
lm_2 = openai.Gpt4VisionPreview(api_key='test_key')
|
162
|
+
for lm in (lm_1, lm_2):
|
163
|
+
self.assertEqual(
|
164
|
+
lm(
|
165
|
+
lf.UserMessage(
|
166
|
+
'hello {{image}}',
|
167
|
+
image=lf_modalities.Image.from_uri('https://fake/image')
|
168
|
+
),
|
169
|
+
sampling_options=lf.LMSamplingOptions(n=2)
|
170
|
+
),
|
171
|
+
'Sample 0 for message: https://fake/image',
|
172
|
+
)
|
171
173
|
|
172
174
|
def test_sample_completion(self):
|
173
175
|
with mock.patch('openai.Completion.create') as mock_completion:
|
@@ -8,7 +8,7 @@ langfun/core/console.py,sha256=bk5rNPNm9rMGW5YT2HixxU04p2umnoabn5SDz6Dqe88,2317
|
|
8
8
|
langfun/core/console_test.py,sha256=5SYJdxpJGLgdSSQqqMPoA1X6jpsLD8rgcyk-EgI65oE,1077
|
9
9
|
langfun/core/langfunc.py,sha256=WXdTc3QsmGD_n80KD9dFRr5MHpGZ9E_y_Rhtk4t9-3w,11852
|
10
10
|
langfun/core/langfunc_test.py,sha256=sQaKuZpGGmG80GRifhbxkj7nfzQLJKj4Vuw5y1s1K3U,8378
|
11
|
-
langfun/core/language_model.py,sha256=
|
11
|
+
langfun/core/language_model.py,sha256=1_GO6oEm0wXnE7aRRLOdT-A4j_6YvRanS5oMgfobcIs,18331
|
12
12
|
langfun/core/language_model_test.py,sha256=KvXXOr64TsSs3WkEALCLLZSlz09i7hBiHDOZ_8Eq8_o,13047
|
13
13
|
langfun/core/memory.py,sha256=f-asN1F7Vehgdn_fK84v73GrEUOxRtaW934keutTKjk,2416
|
14
14
|
langfun/core/message.py,sha256=QhvV9t5qaryPcruyxxcXi3gm9QDInkSldwTtK6sVJ3c,15734
|
@@ -46,15 +46,17 @@ langfun/core/eval/matching.py,sha256=aqNlYrlav7YmsB7rUlsdfoi1RLA5CYqn2RGPxRlPc78
|
|
46
46
|
langfun/core/eval/matching_test.py,sha256=FFHYD7IDuKe5RMjkx74ksukiwUhO5a_SS340JaIPMws,4898
|
47
47
|
langfun/core/eval/scoring.py,sha256=aKeanBJf1yO3Q9JEtgPWoiZk_3M_GiqwXVXX7x_g22w,6172
|
48
48
|
langfun/core/eval/scoring_test.py,sha256=YH1cIxBWtfdKcAV9Fh10vLkV5J-gxk8b6nxW4Z2u5pk,4024
|
49
|
-
langfun/core/llms/__init__.py,sha256=
|
49
|
+
langfun/core/llms/__init__.py,sha256=c_9lVKzFjnxHKgRjY_dUiJzBmW1jWALy3mtYv0uMyl0,2953
|
50
|
+
langfun/core/llms/anthropic.py,sha256=p-tjttvithBg2b4tgxIS2F-Zk5AYAh5e-lW-8e1p4wc,7865
|
51
|
+
langfun/core/llms/anthropic_test.py,sha256=OuLDxeiPRdqsfKILS0R6jJLTRs3-1KCIotPPr7IbIDU,5502
|
50
52
|
langfun/core/llms/fake.py,sha256=b-Xk5IPTbUt-elsyzd_i3n1tqzc_kgETXrEvgJruSMk,2824
|
51
|
-
langfun/core/llms/fake_test.py,sha256=
|
53
|
+
langfun/core/llms/fake_test.py,sha256=ZlDQgL41EX3eYTfBQNp2nB2LciqCmtoHgCsGvW4XhwI,4184
|
52
54
|
langfun/core/llms/google_genai.py,sha256=n8zyJwh9UCTgb6-8LyvmjVNFGZQ4-zfzZ0ulkhHAnR8,8624
|
53
55
|
langfun/core/llms/google_genai_test.py,sha256=_UcGTfl16-aDUlEWFC2W2F8y9jPUs53RBYA6MOCpGXw,7525
|
54
56
|
langfun/core/llms/llama_cpp.py,sha256=Y_KkMUf3Xfac49koMUtUslKl3h-HWp3-ntq7Jaa3bdo,2385
|
55
57
|
langfun/core/llms/llama_cpp_test.py,sha256=ZxC6defGd_HX9SFRU9U4cJiQnBKundbOrchbXuC1Z2M,1683
|
56
|
-
langfun/core/llms/openai.py,sha256=
|
57
|
-
langfun/core/llms/openai_test.py,sha256=
|
58
|
+
langfun/core/llms/openai.py,sha256=Z_pujF3B2QMzWBgOdV67DKAfZ8Wmyeb_6F9BkcGHyaE,12344
|
59
|
+
langfun/core/llms/openai_test.py,sha256=S83nVUq1Za15-rq-tCGOZPGPGByVgk0YdamoO7gnNpw,8270
|
58
60
|
langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
|
59
61
|
langfun/core/llms/cache/base.py,sha256=cFfYvOIUae842pncqCAsRvqXCk2AnAsRYVx0mcIoAeY,3338
|
60
62
|
langfun/core/llms/cache/in_memory.py,sha256=YfFyJEhLs73cUiB0ZfhMxYpdE8Iuxxw-dvMFwGHTSHw,4742
|
@@ -97,8 +99,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
|
|
97
99
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
98
100
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
99
101
|
langfun/core/templates/selfplay_test.py,sha256=DYVrkk7uNKCqJGEHH31HssU2BPuMItU1vJLzfcXIlYg,2156
|
100
|
-
langfun-0.0.2.
|
101
|
-
langfun-0.0.2.
|
102
|
-
langfun-0.0.2.
|
103
|
-
langfun-0.0.2.
|
104
|
-
langfun-0.0.2.
|
102
|
+
langfun-0.0.2.dev20240420.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
103
|
+
langfun-0.0.2.dev20240420.dist-info/METADATA,sha256=R4bRp7OO2PSjDyKe48YvIbMptLTkeqesP98ZxJ17woc,3405
|
104
|
+
langfun-0.0.2.dev20240420.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
105
|
+
langfun-0.0.2.dev20240420.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
106
|
+
langfun-0.0.2.dev20240420.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|