langfun 0.0.2.dev20240418__py3-none-any.whl → 0.0.2.dev20240420__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -440,7 +440,7 @@ class LanguageModel(component.Component):
440
440
  response.metadata.usage = result.usage
441
441
 
442
442
  elapse = time.time() - request_start
443
- self._debug(prompt, response, call_counter, elapse)
443
+ self._debug(prompt, response, call_counter, result.usage, elapse)
444
444
  return response
445
445
 
446
446
  def _debug(
@@ -448,35 +448,51 @@ class LanguageModel(component.Component):
448
448
  prompt: message_lib.Message,
449
449
  response: message_lib.Message,
450
450
  call_counter: int,
451
+ usage: LMSamplingUsage | None,
451
452
  elapse: float,
452
- ):
453
+ ) -> None:
453
454
  """Outputs debugging information."""
454
455
  debug = self.debug
455
456
  if isinstance(debug, bool):
456
457
  debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
457
458
 
458
459
  if debug & LMDebugMode.INFO:
459
- self._debug_model_info(call_counter)
460
+ self._debug_model_info(call_counter, usage)
460
461
 
461
462
  if debug & LMDebugMode.PROMPT:
462
- self._debug_prompt(prompt, call_counter)
463
+ self._debug_prompt(prompt, call_counter, usage)
463
464
 
464
465
  if debug & LMDebugMode.RESPONSE:
465
- self._debug_response(response, call_counter, elapse)
466
+ self._debug_response(response, call_counter, usage, elapse)
466
467
 
467
- def _debug_model_info(self, call_counter: int):
468
+ def _debug_model_info(
469
+ self, call_counter: int, usage: LMSamplingUsage | None) -> None:
468
470
  """Outputs debugging information about the model."""
471
+ title_suffix = ''
472
+ if usage and usage.total_tokens != 0:
473
+ title_suffix = console.colored(
474
+ f' (total {usage.total_tokens} tokens)', 'red')
475
+
469
476
  console.write(
470
477
  self.format(compact=True, use_inferred=True),
471
- title=f'[{call_counter}] LM INFO:',
478
+ title=f'[{call_counter}] LM INFO{title_suffix}:',
472
479
  color='magenta',
473
480
  )
474
481
 
475
- def _debug_prompt(self, prompt: message_lib.Message, call_counter: int):
482
+ def _debug_prompt(
483
+ self,
484
+ prompt: message_lib.Message,
485
+ call_counter: int,
486
+ usage: LMSamplingUsage | None,
487
+ ) -> None:
476
488
  """Outputs debugging information about the prompt."""
489
+ title_suffix = ''
490
+ if usage and usage.prompt_tokens != 0:
491
+ title_suffix = console.colored(f' ({usage.prompt_tokens} tokens)', 'red')
492
+
477
493
  console.write(
478
494
  prompt,
479
- title=f'\n[{call_counter}] PROMPT SENT TO LM:',
495
+ title=f'\n[{call_counter}] PROMPT SENT TO LM{title_suffix}:',
480
496
  color='green',
481
497
  )
482
498
  referred_modalities = prompt.referred_modalities()
@@ -490,12 +506,22 @@ class LanguageModel(component.Component):
490
506
  )
491
507
 
492
508
  def _debug_response(
493
- self, response: message_lib.Message, call_counter: int, elapse: float
494
- ):
509
+ self,
510
+ response: message_lib.Message,
511
+ call_counter: int,
512
+ usage: LMSamplingUsage | None,
513
+ elapse: float
514
+ ) -> None:
495
515
  """Outputs debugging information about the response."""
516
+ title_suffix = ' ('
517
+ if usage and usage.completion_tokens != 0:
518
+ title_suffix += f'{usage.completion_tokens} tokens '
519
+ title_suffix += f'in {elapse:.2f} seconds)'
520
+ title_suffix = console.colored(title_suffix, 'red')
521
+
496
522
  console.write(
497
523
  str(response) + '\n',
498
- title=f'\n[{call_counter}] LM RESPONSE (in {elapse:.2f} seconds):',
524
+ title=f'\n[{call_counter}] LM RESPONSE{title_suffix}:',
499
525
  color='blue',
500
526
  )
501
527
 
@@ -542,7 +568,7 @@ class LanguageModel(component.Component):
542
568
  debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
543
569
 
544
570
  if debug & LMDebugMode.INFO:
545
- self._debug_model_info(call_counter)
571
+ self._debug_model_info(call_counter, None)
546
572
 
547
573
  if debug & LMDebugMode.PROMPT:
548
574
  console.write(
@@ -35,8 +35,12 @@ from langfun.core.llms.google_genai import Palm2_IT
35
35
  from langfun.core.llms.openai import OpenAI
36
36
 
37
37
  from langfun.core.llms.openai import Gpt4Turbo
38
- from langfun.core.llms.openai import Gpt4Turbo_0125
39
- from langfun.core.llms.openai import Gpt4TurboVision
38
+ from langfun.core.llms.openai import Gpt4Turbo_20240409
39
+ from langfun.core.llms.openai import Gpt4TurboPreview
40
+ from langfun.core.llms.openai import Gpt4TurboPreview_0125
41
+ from langfun.core.llms.openai import Gpt4TurboPreview_1106
42
+ from langfun.core.llms.openai import Gpt4VisionPreview
43
+ from langfun.core.llms.openai import Gpt4VisionPreview_1106
40
44
  from langfun.core.llms.openai import Gpt4
41
45
  from langfun.core.llms.openai import Gpt4_0613
42
46
  from langfun.core.llms.openai import Gpt4_32K
@@ -57,6 +61,12 @@ from langfun.core.llms.openai import Gpt3Curie
57
61
  from langfun.core.llms.openai import Gpt3Babbage
58
62
  from langfun.core.llms.openai import Gpt3Ada
59
63
 
64
+ from langfun.core.llms.anthropic import Anthropic
65
+ from langfun.core.llms.anthropic import Claude3Opus
66
+ from langfun.core.llms.anthropic import Claude3Sonnet
67
+ from langfun.core.llms.anthropic import Claude3Haiku
68
+
69
+
60
70
  # LLaMA C++ models.
61
71
  from langfun.core.llms.llama_cpp import LlamaCppRemote
62
72
 
@@ -0,0 +1,249 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Language models from Anthropic."""
15
+
16
+ import base64
17
+ import functools
18
+ import os
19
+ from typing import Annotated, Any
20
+
21
+ import langfun.core as lf
22
+ from langfun.core import modalities as lf_modalities
23
+ import pyglove as pg
24
+ import requests
25
+
26
+
27
+ SUPPORTED_MODELS_AND_SETTINGS = {
28
+ # See https://docs.anthropic.com/claude/docs/models-overview
29
+ 'claude-3-opus-20240229': pg.Dict(max_tokens=4096, max_concurrency=16),
30
+ 'claude-3-sonnet-20240229': pg.Dict(max_tokens=4096, max_concurrency=16),
31
+ 'claude-3-haiku-20240307': pg.Dict(max_tokens=4096, max_concurrency=16),
32
+ 'claude-2.1': pg.Dict(max_tokens=4096, max_concurrency=16),
33
+ 'claude-2.0': pg.Dict(max_tokens=4096, max_concurrency=16),
34
+ 'claude-instant-1.2': pg.Dict(max_tokens=4096, max_concurrency=16),
35
+ }
36
+
37
+
38
+ class AnthropicError(Exception): # pylint: disable=g-bad-exception-name
39
+ """Base class for Anthropic errors."""
40
+
41
+
42
+ class RateLimitError(AnthropicError):
43
+ """Error for rate limit reached."""
44
+
45
+
46
+ class OverloadedError(AnthropicError):
47
+ """Anthropic's server is temporarily overloaded."""
48
+
49
+
50
+ _ANTHROPIC_MESSAGE_API_ENDPOINT = 'https://api.anthropic.com/v1/messages'
51
+ _ANTHROPIC_API_VERSION = '2023-06-01'
52
+
53
+
54
+ @lf.use_init_args(['model'])
55
+ class Anthropic(lf.LanguageModel):
56
+ """Anthropic LLMs (Claude) through REST APIs.
57
+
58
+ See https://docs.anthropic.com/claude/reference/messages_post
59
+ """
60
+
61
+ model: pg.typing.Annotated[
62
+ pg.typing.Enum(
63
+ pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
64
+ ),
65
+ 'The name of the model to use.',
66
+ ]
67
+
68
+ multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
69
+ True
70
+ )
71
+
72
+ api_key: Annotated[
73
+ str | None,
74
+ (
75
+ 'API key. If None, the key will be read from environment variable '
76
+ "'ANTHROPIC_API_KEY'."
77
+ ),
78
+ ] = None
79
+
80
+ def _on_bound(self):
81
+ super()._on_bound()
82
+ self._api_key = None
83
+ self.__dict__.pop('_api_initialized', None)
84
+
85
+ @functools.cached_property
86
+ def _api_initialized(self):
87
+ api_key = self.api_key or os.environ.get('ANTHROPIC_API_KEY', None)
88
+ if not api_key:
89
+ raise ValueError(
90
+ 'Please specify `api_key` during `__init__` or set environment '
91
+ 'variable `ANTHROPIC_API_KEY` with your Anthropic API key.'
92
+ )
93
+ self._api_key = api_key
94
+ return True
95
+
96
+ @property
97
+ def model_id(self) -> str:
98
+ """Returns a string to identify the model."""
99
+ return self.model
100
+
101
+ @property
102
+ def max_concurrency(self) -> int:
103
+ return SUPPORTED_MODELS_AND_SETTINGS[self.model].max_concurrency
104
+
105
+ def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
106
+ assert self._api_initialized
107
+ return self._parallel_execute_with_currency_control(
108
+ self._sample_single, prompts, retry_on_errors=(RateLimitError)
109
+ )
110
+
111
+ def _get_request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
112
+ """Returns a dict as request arguments."""
113
+ # Authropic requires `max_tokens` to be specified.
114
+ max_tokens = (
115
+ options.max_tokens
116
+ or SUPPORTED_MODELS_AND_SETTINGS[self.model].max_tokens
117
+ )
118
+ args = dict(
119
+ model=self.model,
120
+ max_tokens=max_tokens,
121
+ stream=False,
122
+ )
123
+ if options.stop:
124
+ args['stop_sequences'] = options.stop
125
+ if options.temperature is not None:
126
+ args['temperature'] = options.temperature
127
+ if options.top_k is not None:
128
+ args['top_k'] = options.top_k
129
+ if options.top_p is not None:
130
+ args['top_p'] = options.top_p
131
+ return args
132
+
133
+ def _content_from_message(self, prompt: lf.Message) -> list[dict[str, Any]]:
134
+ """Converts an message to Anthropic's content protocol (list of dicts)."""
135
+ # Refer: https://docs.anthropic.com/claude/reference/messages-examples
136
+ if self.multimodal:
137
+ content = []
138
+ for chunk in prompt.chunk():
139
+ if isinstance(chunk, str):
140
+ item = dict(type='text', text=chunk)
141
+ elif isinstance(chunk, lf_modalities.Image):
142
+ # NOTE(daiyip): Anthropic only support image content instead of URL.
143
+ item = dict(
144
+ type='image',
145
+ source=dict(
146
+ type='base64',
147
+ media_type=chunk.mime_type,
148
+ data=base64.b64encode(chunk.to_bytes()).decode(),
149
+ ),
150
+ )
151
+ else:
152
+ raise ValueError(f'Unsupported modality object: {chunk!r}.')
153
+ content.append(item)
154
+ return content
155
+ else:
156
+ return [dict(type='text', text=prompt.text)]
157
+
158
+ def _message_from_content(self, content: list[dict[str, Any]]) -> lf.Message:
159
+ """Converts Anthropic's content protocol to message."""
160
+ # Refer: https://docs.anthropic.com/claude/reference/messages-examples
161
+ return lf.AIMessage.from_chunks(
162
+ [x['text'] for x in content if x['type'] == 'text']
163
+ )
164
+
165
+ def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
166
+ """Parses Anthropic's response."""
167
+ # NOTE(daiyip): Refer https://docs.anthropic.com/claude/reference/errors
168
+ output = response.json()
169
+ if response.status_code == 200:
170
+ message = self._message_from_content(output['content'])
171
+ input_tokens = output['usage']['input_tokens']
172
+ output_tokens = output['usage']['output_tokens']
173
+ return lf.LMSamplingResult(
174
+ [lf.LMSample(message)],
175
+ usage=lf.LMSamplingUsage(
176
+ prompt_tokens=input_tokens,
177
+ completion_tokens=output_tokens,
178
+ total_tokens=input_tokens + output_tokens,
179
+ ),
180
+ )
181
+ else:
182
+ if response.status_code == 429:
183
+ error_cls = RateLimitError
184
+ elif response.status_code == 529:
185
+ error_cls = OverloadedError
186
+ else:
187
+ error_cls = AnthropicError
188
+ error = output['error']
189
+ raise error_cls(f'{error["type"]}: {error["message"]}')
190
+
191
+ def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
192
+ request = dict()
193
+ request.update(self._get_request_args(self.sampling_options))
194
+ request.update(
195
+ dict(
196
+ messages=[
197
+ dict(role='user', content=self._content_from_message(prompt))
198
+ ]
199
+ )
200
+ )
201
+ response = requests.post(
202
+ _ANTHROPIC_MESSAGE_API_ENDPOINT,
203
+ json=request,
204
+ headers={
205
+ 'x-api-key': self._api_key,
206
+ 'anthropic-version': _ANTHROPIC_API_VERSION,
207
+ 'content-type': 'application/json',
208
+ },
209
+ timeout=self.timeout,
210
+ )
211
+ return self._parse_response(response)
212
+
213
+
214
+ class Claude3(Anthropic):
215
+ """Base class for Claude 3 models. 200K input tokens and 4K output tokens."""
216
+ multimodal = True
217
+
218
+
219
+ class Claude3Opus(Claude3):
220
+ """Anthropic's most powerful model."""
221
+
222
+ model = 'claude-3-opus-20240229'
223
+
224
+
225
+ class Claude3Sonnet(Claude3):
226
+ """A balance between between Opus and Haiku."""
227
+
228
+ model = 'claude-3-sonnet-20240229'
229
+
230
+
231
+ class Claude3Haiku(Claude3):
232
+ """Anthropic's most compact model."""
233
+
234
+ model = 'claude-3-haiku-20240307'
235
+
236
+
237
+ class Claude2(Anthropic):
238
+ """Predecessor to Claude 3 with 100K context window.."""
239
+ model = 'claude-2.0'
240
+
241
+
242
+ class Claude21(Anthropic):
243
+ """Updated Claude 2 model with improved accuracy and 200K context window."""
244
+ model = 'claude-2.1'
245
+
246
+
247
+ class ClaudeInstant(Anthropic):
248
+ """Cheapest small and fast model, 100K context window."""
249
+ model = 'claude-instant-1.2'
@@ -0,0 +1,167 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for Anthropic models."""
15
+
16
+ import base64
17
+ import os
18
+ from typing import Any
19
+ import unittest
20
+ from unittest import mock
21
+ from langfun.core import modalities as lf_modalities
22
+ from langfun.core.llms import anthropic
23
+ import pyglove as pg
24
+ import requests
25
+
26
+
27
+ def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
28
+ del url, kwargs
29
+
30
+ response = requests.Response()
31
+ response.status_code = 200
32
+ response._content = pg.to_json_str({
33
+ 'content': [{
34
+ 'type': 'text',
35
+ 'text': (
36
+ f'hello with temperature={json.get("temperature")}, '
37
+ f'top_k={json.get("top_k")}, '
38
+ f'top_p={json.get("top_p")}, '
39
+ f'max_tokens={json.get("max_tokens")}, '
40
+ f'stop={json.get("stop_sequences")}.'
41
+ ),
42
+ }],
43
+ 'usage': {
44
+ 'input_tokens': 2,
45
+ 'output_tokens': 1,
46
+ },
47
+ }).encode()
48
+ return response
49
+
50
+
51
+ image_content = (
52
+ b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
53
+ b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
54
+ b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
55
+ b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
56
+ b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
57
+ b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
58
+ b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
59
+ b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
60
+ )
61
+
62
+
63
+ def mock_mm_requests_post(url: str, json: dict[str, Any], **kwargs):
64
+ del url, kwargs
65
+ v = json['messages'][0]['content'][0]
66
+ image = lf_modalities.Image.from_bytes(base64.b64decode(v['source']['data']))
67
+
68
+ response = requests.Response()
69
+ response.status_code = 200
70
+ response._content = pg.to_json_str({
71
+ 'content': [{
72
+ 'type': 'text',
73
+ 'text': f'{v["type"]}: {image.mime_type}',
74
+ }],
75
+ 'usage': {
76
+ 'input_tokens': 2,
77
+ 'output_tokens': 1,
78
+ },
79
+ }).encode()
80
+ return response
81
+
82
+
83
+ def mock_requests_post_error(status_code, error_type, error_message):
84
+ def _mock_requests(url: str, json: dict[str, Any], **kwargs):
85
+ del url, json, kwargs
86
+ response = requests.Response()
87
+ response.status_code = status_code
88
+ response._content = pg.to_json_str(
89
+ {
90
+ 'error': {
91
+ 'type': error_type,
92
+ 'message': error_message,
93
+ }
94
+ }
95
+ ).encode()
96
+ return response
97
+
98
+ return _mock_requests
99
+
100
+
101
+ class AuthropicTest(unittest.TestCase):
102
+
103
+ def test_basics(self):
104
+ self.assertEqual(
105
+ anthropic.Claude3Haiku().model_id, 'claude-3-haiku-20240307'
106
+ )
107
+ self.assertEqual(anthropic.Claude3Haiku().max_concurrency, 16)
108
+
109
+ def test_api_key(self):
110
+ lm = anthropic.Claude3Haiku()
111
+ with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
112
+ lm('hi')
113
+
114
+ with mock.patch('requests.post') as mock_request:
115
+ mock_request.side_effect = mock_requests_post
116
+
117
+ lm = anthropic.Claude3Haiku(api_key='fake key')
118
+ self.assertRegex(lm('hi').text, 'hello.*')
119
+
120
+ os.environ['ANTHROPIC_API_KEY'] = 'abc'
121
+ lm = anthropic.Claude3Haiku()
122
+ self.assertRegex(lm('hi').text, 'hello.*')
123
+ del os.environ['ANTHROPIC_API_KEY']
124
+
125
+ def test_call(self):
126
+ with mock.patch('requests.post') as mock_request:
127
+ mock_request.side_effect = mock_requests_post
128
+ lm = anthropic.Claude3Haiku(api_key='fake_key')
129
+ response = lm('hello', temperature=0.0, top_k=0.1, top_p=0.2, stop=['\n'])
130
+ self.assertEqual(
131
+ response.text,
132
+ (
133
+ 'hello with temperature=0.0, top_k=0.1, top_p=0.2, '
134
+ "max_tokens=4096, stop=['\\n']."
135
+ ),
136
+ )
137
+ self.assertIsNotNone(response.usage)
138
+ self.assertIsNotNone(response.usage.prompt_tokens, 2)
139
+ self.assertIsNotNone(response.usage.completion_tokens, 1)
140
+ self.assertIsNotNone(response.usage.total_tokens, 3)
141
+
142
+ def test_mm_call(self):
143
+ with mock.patch('requests.post') as mock_mm_request:
144
+ mock_mm_request.side_effect = mock_mm_requests_post
145
+ lm = anthropic.Claude3Haiku(api_key='fake_key')
146
+ response = lm(lf_modalities.Image.from_bytes(image_content), lm=lm)
147
+ self.assertEqual(response.text, 'image: image/png')
148
+
149
+ def test_call_errors(self):
150
+ for status_code, error_type, error_message in [
151
+ (429, 'rate_limit', 'Rate limit exceeded.'),
152
+ (529, 'service_unavailable', 'Service unavailable.'),
153
+ (500, 'bad_request', 'Bad request.'),
154
+ ]:
155
+ with mock.patch('requests.post') as mock_mm_request:
156
+ mock_mm_request.side_effect = mock_requests_post_error(
157
+ status_code, error_type, error_message
158
+ )
159
+ lm = anthropic.Claude3Haiku(api_key='fake_key')
160
+ with self.assertRaisesRegex(
161
+ Exception, f'{error_type}: {error_message}'
162
+ ):
163
+ lm('hello', lm=lm, max_attempts=1)
164
+
165
+
166
+ if __name__ == '__main__':
167
+ unittest.main()
@@ -39,8 +39,8 @@ class EchoTest(unittest.TestCase):
39
39
  with contextlib.redirect_stdout(string_io):
40
40
  self.assertEqual(lm('hi'), 'hi')
41
41
  debug_info = string_io.getvalue()
42
- self.assertIn('[0] LM INFO:', debug_info)
43
- self.assertIn('[0] PROMPT SENT TO LM:', debug_info)
42
+ self.assertIn('[0] LM INFO', debug_info)
43
+ self.assertIn('[0] PROMPT SENT TO LM', debug_info)
44
44
  self.assertIn('[0] LM RESPONSE', debug_info)
45
45
 
46
46
  def test_score(self):
@@ -84,8 +84,8 @@ class StaticResponseTest(unittest.TestCase):
84
84
  self.assertEqual(lm('hi'), canned_response)
85
85
 
86
86
  debug_info = string_io.getvalue()
87
- self.assertIn('[0] LM INFO:', debug_info)
88
- self.assertIn('[0] PROMPT SENT TO LM:', debug_info)
87
+ self.assertIn('[0] LM INFO', debug_info)
88
+ self.assertIn('[0] PROMPT SENT TO LM', debug_info)
89
89
  self.assertIn('[0] LM RESPONSE', debug_info)
90
90
 
91
91
 
@@ -31,10 +31,13 @@ SUPPORTED_MODELS_AND_SETTINGS = [
31
31
  # The concurrent requests is estimated by TPM/RPM from
32
32
  # https://platform.openai.com/account/limits
33
33
  # GPT-4 Turbo models.
34
- ('gpt-4-turbo-preview', 1), # GPT-4 Turbo.
35
- ('gpt-4-0125-preview', 1), # GPT-4 Turbo
36
- ('gpt-4-1106-preview', 1), # GPT-4 Turbo
37
- ('gpt-4-vision-preview', 1), # GPT-4 Turbo with Vision.
34
+ ('gpt-4-turbo', 8), # GPT-4 Turbo with Vision
35
+ ('gpt-4-turbo-2024-04-09', 8), # GPT-4-Turbo with Vision, 04/09/2024
36
+ ('gpt-4-turbo-preview', 8), # GPT-4 Turbo Preview
37
+ ('gpt-4-0125-preview', 8), # GPT-4 Turbo Preview, 01/25/2024
38
+ ('gpt-4-1106-preview', 8), # GPT-4 Turbo Preview, 11/06/2023
39
+ ('gpt-4-vision-preview', 8), # GPT-4 Turbo Vision Preview.
40
+ ('gpt-4-1106-vision-preview', 8), # GPT-4 Turbo Vision Preview, 11/06/2023
38
41
  # GPT-4 models.
39
42
  ('gpt-4', 4),
40
43
  ('gpt-4-0613', 4),
@@ -284,26 +287,43 @@ class Gpt4(OpenAI):
284
287
 
285
288
 
286
289
  class Gpt4Turbo(Gpt4):
287
- """GPT-4 Turbo with 128K context window size. Knowledge up to 4-2023."""
288
- model = 'gpt-4-turbo-preview'
290
+ """GPT-4 Turbo with 128K context window. Knowledge up to Dec. 2023."""
291
+ model = 'gpt-4-turbo'
292
+ multimodal = True
289
293
 
290
294
 
291
- class Gpt4TurboVision(Gpt4Turbo):
292
- """GPT-4 Turbo with vision."""
293
- model = 'gpt-4-vision-preview'
295
+ class Gpt4Turbo_20240409(Gpt4Turbo): # pylint:disable=invalid-name
296
+ """GPT-4 Turbo with 128K context window. Knowledge up to Dec. 2023."""
297
+ model = 'gpt-4-turbo-2024-04-09'
294
298
  multimodal = True
295
299
 
296
300
 
297
- class Gpt4Turbo_0125(Gpt4Turbo): # pylint:disable=invalid-name
298
- """GPT-4 Turbo with 128K context window size. Knowledge up to 4-2023."""
301
+ class Gpt4TurboPreview(Gpt4):
302
+ """GPT-4 Turbo Preview with 128k context window. Knowledge up to Dec. 2023."""
303
+ model = 'gpt-4-turbo-preview'
304
+
305
+
306
+ class Gpt4TurboPreview_0125(Gpt4TurboPreview): # pylint: disable=invalid-name
307
+ """GPT-4 Turbo Preview with 128k context window. Knowledge up to Dec. 2023."""
299
308
  model = 'gpt-4-0125-preview'
300
309
 
301
310
 
302
- class Gpt4Turbo_1106(Gpt4Turbo): # pylint:disable=invalid-name
303
- """GPT-4 Turbo @20231106. 128K context window. Knowledge up to 4-2023."""
311
+ class Gpt4TurboPreview_1106(Gpt4TurboPreview): # pylint: disable=invalid-name
312
+ """GPT-4 Turbo Preview with 128k context window. Knowledge up to Apr. 2023."""
304
313
  model = 'gpt-4-1106-preview'
305
314
 
306
315
 
316
+ class Gpt4VisionPreview(Gpt4):
317
+ """GPT-4 Turbo vision preview. 128k context window. Knowledge to Apr. 2023."""
318
+ model = 'gpt-4-vision-preview'
319
+ multimodal = True
320
+
321
+
322
+ class Gpt4VisionPreview_1106(Gpt4): # pylint: disable=invalid-name
323
+ """GPT-4 Turbo vision preview. 128k context window. Knowledge to Apr. 2023."""
324
+ model = 'gpt-4-1106-vision-preview'
325
+
326
+
307
327
  class Gpt4_0613(Gpt4): # pylint:disable=invalid-name
308
328
  """GPT-4 @20230613. 8K context window. Knowledge up to 9-2021."""
309
329
  model = 'gpt-4-0613'
@@ -157,17 +157,19 @@ class OpenaiTest(unittest.TestCase):
157
157
  def test_call_chat_completion_vision(self):
158
158
  with mock.patch('openai.ChatCompletion.create') as mock_chat_completion:
159
159
  mock_chat_completion.side_effect = mock_chat_completion_query_vision
160
- lm = openai.Gpt4TurboVision(api_key='test_key')
161
- self.assertEqual(
162
- lm(
163
- lf.UserMessage(
164
- 'hello {{image}}',
165
- image=lf_modalities.Image.from_uri('https://fake/image')
166
- ),
167
- sampling_options=lf.LMSamplingOptions(n=2)
168
- ),
169
- 'Sample 0 for message: https://fake/image',
170
- )
160
+ lm_1 = openai.Gpt4Turbo(api_key='test_key')
161
+ lm_2 = openai.Gpt4VisionPreview(api_key='test_key')
162
+ for lm in (lm_1, lm_2):
163
+ self.assertEqual(
164
+ lm(
165
+ lf.UserMessage(
166
+ 'hello {{image}}',
167
+ image=lf_modalities.Image.from_uri('https://fake/image')
168
+ ),
169
+ sampling_options=lf.LMSamplingOptions(n=2)
170
+ ),
171
+ 'Sample 0 for message: https://fake/image',
172
+ )
171
173
 
172
174
  def test_sample_completion(self):
173
175
  with mock.patch('openai.Completion.create') as mock_completion:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.0.2.dev20240418
3
+ Version: 0.0.2.dev20240420
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -8,7 +8,7 @@ langfun/core/console.py,sha256=bk5rNPNm9rMGW5YT2HixxU04p2umnoabn5SDz6Dqe88,2317
8
8
  langfun/core/console_test.py,sha256=5SYJdxpJGLgdSSQqqMPoA1X6jpsLD8rgcyk-EgI65oE,1077
9
9
  langfun/core/langfunc.py,sha256=WXdTc3QsmGD_n80KD9dFRr5MHpGZ9E_y_Rhtk4t9-3w,11852
10
10
  langfun/core/langfunc_test.py,sha256=sQaKuZpGGmG80GRifhbxkj7nfzQLJKj4Vuw5y1s1K3U,8378
11
- langfun/core/language_model.py,sha256=Tzswu0hyXOQOZ3fZ_Mz_Cc0ei7tVj8rTay9jJEgM6mI,17510
11
+ langfun/core/language_model.py,sha256=1_GO6oEm0wXnE7aRRLOdT-A4j_6YvRanS5oMgfobcIs,18331
12
12
  langfun/core/language_model_test.py,sha256=KvXXOr64TsSs3WkEALCLLZSlz09i7hBiHDOZ_8Eq8_o,13047
13
13
  langfun/core/memory.py,sha256=f-asN1F7Vehgdn_fK84v73GrEUOxRtaW934keutTKjk,2416
14
14
  langfun/core/message.py,sha256=QhvV9t5qaryPcruyxxcXi3gm9QDInkSldwTtK6sVJ3c,15734
@@ -46,15 +46,17 @@ langfun/core/eval/matching.py,sha256=aqNlYrlav7YmsB7rUlsdfoi1RLA5CYqn2RGPxRlPc78
46
46
  langfun/core/eval/matching_test.py,sha256=FFHYD7IDuKe5RMjkx74ksukiwUhO5a_SS340JaIPMws,4898
47
47
  langfun/core/eval/scoring.py,sha256=aKeanBJf1yO3Q9JEtgPWoiZk_3M_GiqwXVXX7x_g22w,6172
48
48
  langfun/core/eval/scoring_test.py,sha256=YH1cIxBWtfdKcAV9Fh10vLkV5J-gxk8b6nxW4Z2u5pk,4024
49
- langfun/core/llms/__init__.py,sha256=gROJ8AjMq_ebXFcEfsyzYGCS6NsGfzf9d43nLu_TIdw,2504
49
+ langfun/core/llms/__init__.py,sha256=c_9lVKzFjnxHKgRjY_dUiJzBmW1jWALy3mtYv0uMyl0,2953
50
+ langfun/core/llms/anthropic.py,sha256=p-tjttvithBg2b4tgxIS2F-Zk5AYAh5e-lW-8e1p4wc,7865
51
+ langfun/core/llms/anthropic_test.py,sha256=OuLDxeiPRdqsfKILS0R6jJLTRs3-1KCIotPPr7IbIDU,5502
50
52
  langfun/core/llms/fake.py,sha256=b-Xk5IPTbUt-elsyzd_i3n1tqzc_kgETXrEvgJruSMk,2824
51
- langfun/core/llms/fake_test.py,sha256=AThvNyhZbkpsn-YO798uLgqB6TSw5XP2SKpKvcXEytw,4188
53
+ langfun/core/llms/fake_test.py,sha256=ZlDQgL41EX3eYTfBQNp2nB2LciqCmtoHgCsGvW4XhwI,4184
52
54
  langfun/core/llms/google_genai.py,sha256=n8zyJwh9UCTgb6-8LyvmjVNFGZQ4-zfzZ0ulkhHAnR8,8624
53
55
  langfun/core/llms/google_genai_test.py,sha256=_UcGTfl16-aDUlEWFC2W2F8y9jPUs53RBYA6MOCpGXw,7525
54
56
  langfun/core/llms/llama_cpp.py,sha256=Y_KkMUf3Xfac49koMUtUslKl3h-HWp3-ntq7Jaa3bdo,2385
55
57
  langfun/core/llms/llama_cpp_test.py,sha256=ZxC6defGd_HX9SFRU9U4cJiQnBKundbOrchbXuC1Z2M,1683
56
- langfun/core/llms/openai.py,sha256=1EUd8WTI6EpcU_fzD90-4M11RdL9Mj4S9zfrzUZIyGM,11463
57
- langfun/core/llms/openai_test.py,sha256=hiByS95g3pXtjB2XfIdVCKiAZDb_-Qirb2_LsSyskpY,8166
58
+ langfun/core/llms/openai.py,sha256=Z_pujF3B2QMzWBgOdV67DKAfZ8Wmyeb_6F9BkcGHyaE,12344
59
+ langfun/core/llms/openai_test.py,sha256=S83nVUq1Za15-rq-tCGOZPGPGByVgk0YdamoO7gnNpw,8270
58
60
  langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
59
61
  langfun/core/llms/cache/base.py,sha256=cFfYvOIUae842pncqCAsRvqXCk2AnAsRYVx0mcIoAeY,3338
60
62
  langfun/core/llms/cache/in_memory.py,sha256=YfFyJEhLs73cUiB0ZfhMxYpdE8Iuxxw-dvMFwGHTSHw,4742
@@ -97,8 +99,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
97
99
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
98
100
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
99
101
  langfun/core/templates/selfplay_test.py,sha256=DYVrkk7uNKCqJGEHH31HssU2BPuMItU1vJLzfcXIlYg,2156
100
- langfun-0.0.2.dev20240418.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
101
- langfun-0.0.2.dev20240418.dist-info/METADATA,sha256=lZtyyTUeLg3CBgl4K0U82O_5YWUSh6TDpYHG893cXA8,3405
102
- langfun-0.0.2.dev20240418.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
103
- langfun-0.0.2.dev20240418.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
104
- langfun-0.0.2.dev20240418.dist-info/RECORD,,
102
+ langfun-0.0.2.dev20240420.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
103
+ langfun-0.0.2.dev20240420.dist-info/METADATA,sha256=R4bRp7OO2PSjDyKe48YvIbMptLTkeqesP98ZxJ17woc,3405
104
+ langfun-0.0.2.dev20240420.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
105
+ langfun-0.0.2.dev20240420.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
106
+ langfun-0.0.2.dev20240420.dist-info/RECORD,,