langfun 0.0.2.dev20240330__py3-none-any.whl → 0.0.2.dev20240429__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. langfun/__init__.py +2 -0
  2. langfun/core/__init__.py +1 -0
  3. langfun/core/coding/python/correction.py +0 -7
  4. langfun/core/component.py +6 -0
  5. langfun/core/component_test.py +1 -0
  6. langfun/core/eval/__init__.py +2 -0
  7. langfun/core/eval/base.py +202 -23
  8. langfun/core/eval/base_test.py +49 -10
  9. langfun/core/eval/matching.py +26 -9
  10. langfun/core/eval/matching_test.py +2 -1
  11. langfun/core/eval/scoring.py +15 -6
  12. langfun/core/eval/scoring_test.py +2 -1
  13. langfun/core/langfunc.py +0 -5
  14. langfun/core/langfunc_test.py +6 -4
  15. langfun/core/language_model.py +124 -24
  16. langfun/core/language_model_test.py +249 -26
  17. langfun/core/llms/__init__.py +19 -2
  18. langfun/core/llms/anthropic.py +263 -0
  19. langfun/core/llms/anthropic_test.py +167 -0
  20. langfun/core/llms/cache/in_memory_test.py +37 -28
  21. langfun/core/llms/fake.py +31 -22
  22. langfun/core/llms/fake_test.py +122 -11
  23. langfun/core/llms/google_genai_test.py +8 -3
  24. langfun/core/llms/groq.py +260 -0
  25. langfun/core/llms/groq_test.py +170 -0
  26. langfun/core/llms/llama_cpp.py +3 -1
  27. langfun/core/llms/openai.py +97 -79
  28. langfun/core/llms/openai_test.py +285 -59
  29. langfun/core/modalities/video.py +5 -2
  30. langfun/core/structured/__init__.py +3 -0
  31. langfun/core/structured/completion_test.py +2 -2
  32. langfun/core/structured/function_generation.py +245 -0
  33. langfun/core/structured/function_generation_test.py +329 -0
  34. langfun/core/structured/mapping.py +56 -2
  35. langfun/core/structured/mapping_test.py +17 -0
  36. langfun/core/structured/parsing_test.py +18 -13
  37. langfun/core/structured/prompting.py +27 -6
  38. langfun/core/structured/prompting_test.py +79 -12
  39. langfun/core/structured/schema.py +4 -2
  40. langfun/core/structured/schema_generation_test.py +2 -2
  41. langfun/core/structured/schema_test.py +4 -6
  42. langfun/core/template.py +125 -10
  43. langfun/core/template_test.py +75 -0
  44. langfun/core/templates/selfplay_test.py +6 -2
  45. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/METADATA +3 -2
  46. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/RECORD +49 -43
  47. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/LICENSE +0 -0
  48. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/WHEEL +0 -0
  49. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/top_level.txt +0 -0
@@ -38,9 +38,19 @@ class MockModel(lm_lib.LanguageModel):
38
38
  def fake_sample(prompts):
39
39
  if context.attempt >= self.failures_before_attempt:
40
40
  return [
41
- lm_lib.LMSamplingResult([lm_lib.LMSample( # pylint: disable=g-complex-comprehension
42
- response=prompt.text * self.sampling_options.top_k,
43
- score=self.sampling_options.temperature)])
41
+ lm_lib.LMSamplingResult(
42
+ [
43
+ lm_lib.LMSample( # pylint: disable=g-complex-comprehension
44
+ response=prompt.text * self.sampling_options.top_k,
45
+ score=self.sampling_options.temperature or -1.0,
46
+ )
47
+ ],
48
+ usage=lm_lib.LMSamplingUsage(
49
+ prompt_tokens=100,
50
+ completion_tokens=100,
51
+ total_tokens=200,
52
+ ),
53
+ )
44
54
  for prompt in prompts
45
55
  ]
46
56
  context.attempt += 1
@@ -73,13 +83,13 @@ class LMSamplingOptionsTest(unittest.TestCase):
73
83
  def test_cache_key(self):
74
84
  options = lm_lib.LMSamplingOptions()
75
85
  key1 = options.cache_key()
76
- self.assertEqual(key1, (0.0, 1024, 1, 40, None, None))
86
+ self.assertEqual(key1, (None, None, 1, 40, None, None))
77
87
  with options.override(temperature=1.0, max_tokens=256):
78
88
  key2 = options.cache_key()
79
89
  self.assertEqual(key2, (1.0, 256, 1, 40, None, None))
80
90
 
81
91
  # Make sure key1 does not change upon override.
82
- self.assertEqual(key1, (0.0, 1024, 1, 40, None, None))
92
+ self.assertEqual(key1, (None, None, 1, 40, None, None))
83
93
 
84
94
 
85
95
  class LanguageModelTest(unittest.TestCase):
@@ -100,8 +110,38 @@ class LanguageModelTest(unittest.TestCase):
100
110
  self.assertEqual(
101
111
  lm.sample(prompts=['foo', 'bar']),
102
112
  [
103
- lm_lib.LMSamplingResult([lm_lib.LMSample('foo', score=0.0)]),
104
- lm_lib.LMSamplingResult([lm_lib.LMSample('bar', score=0.0)]),
113
+ lm_lib.LMSamplingResult(
114
+ [
115
+ lm_lib.LMSample(
116
+ message_lib.AIMessage(
117
+ 'foo',
118
+ score=-1.0,
119
+ logprobs=None,
120
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
121
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
122
+ ),
123
+ score=-1.0,
124
+ logprobs=None,
125
+ )
126
+ ],
127
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
128
+ ),
129
+ lm_lib.LMSamplingResult(
130
+ [
131
+ lm_lib.LMSample(
132
+ message_lib.AIMessage(
133
+ 'bar',
134
+ score=-1.0,
135
+ logprobs=None,
136
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
137
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
138
+ ),
139
+ score=-1.0,
140
+ logprobs=None,
141
+ )
142
+ ],
143
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
144
+ ),
105
145
  ],
106
146
  )
107
147
  # Test override sampling_options.
@@ -112,38 +152,128 @@ class LanguageModelTest(unittest.TestCase):
112
152
  ),
113
153
  [
114
154
  lm_lib.LMSamplingResult(
115
- [lm_lib.LMSample('foo' * 2, score=0.5)]
155
+ [
156
+ lm_lib.LMSample(
157
+ message_lib.AIMessage(
158
+ 'foo' * 2,
159
+ score=0.5,
160
+ logprobs=None,
161
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
162
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
163
+ ),
164
+ score=0.5,
165
+ logprobs=None,
166
+ ),
167
+ ],
168
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
116
169
  ),
117
170
  lm_lib.LMSamplingResult(
118
- [lm_lib.LMSample('bar' * 2, score=0.5)]
171
+ [
172
+ lm_lib.LMSample(
173
+ message_lib.AIMessage(
174
+ 'bar' * 2,
175
+ score=0.5,
176
+ logprobs=None,
177
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
178
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
179
+ ),
180
+ score=0.5,
181
+ logprobs=None,
182
+ ),
183
+ ],
184
+ usage=lm_lib.LMSamplingUsage(
185
+ prompt_tokens=100, completion_tokens=100, total_tokens=200
186
+ ),
119
187
  ),
120
- ],
188
+ ]
121
189
  )
122
190
  # Test override individual flags within sampling_options.
123
191
  self.assertEqual(
124
192
  lm.sample(prompts=['foo', 'bar'], temperature=1.0),
125
193
  [
126
- lm_lib.LMSamplingResult([lm_lib.LMSample('foo', score=1.0)]),
127
- lm_lib.LMSamplingResult([lm_lib.LMSample('bar', score=1.0)]),
128
- ],
194
+ lm_lib.LMSamplingResult(
195
+ [
196
+ lm_lib.LMSample(
197
+ message_lib.AIMessage(
198
+ 'foo',
199
+ score=1.0,
200
+ logprobs=None,
201
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
202
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
203
+ ),
204
+ score=1.0,
205
+ logprobs=None,
206
+ ),
207
+ ],
208
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
209
+ ),
210
+ lm_lib.LMSamplingResult(
211
+ [
212
+ lm_lib.LMSample(
213
+ message_lib.AIMessage(
214
+ 'bar',
215
+ score=1.0,
216
+ logprobs=None,
217
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
218
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
219
+ ),
220
+ score=1.0,
221
+ logprobs=None,
222
+ ),
223
+ ],
224
+ usage=lm_lib.LMSamplingUsage(
225
+ prompt_tokens=100, completion_tokens=100, total_tokens=200
226
+ ),
227
+ ),
228
+ ]
129
229
  )
130
230
  self.assertEqual(
131
231
  lm.sample(prompts=['foo', 'bar'], top_k=2, temperature=0.7),
132
232
  [
133
233
  lm_lib.LMSamplingResult(
134
- [lm_lib.LMSample('foo' * 2, score=0.7)]
234
+ [
235
+ lm_lib.LMSample(
236
+ message_lib.AIMessage(
237
+ 'foo' * 2,
238
+ score=0.7,
239
+ logprobs=None,
240
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
241
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
242
+ ),
243
+ score=0.7,
244
+ logprobs=None,
245
+ ),
246
+ ],
247
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
135
248
  ),
136
249
  lm_lib.LMSamplingResult(
137
- [lm_lib.LMSample('bar' * 2, score=0.7)]
250
+ [
251
+ lm_lib.LMSample(
252
+ message_lib.AIMessage(
253
+ 'bar' * 2,
254
+ score=0.7,
255
+ logprobs=None,
256
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
257
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
258
+ ),
259
+ score=0.7,
260
+ logprobs=None,
261
+ ),
262
+ ],
263
+ usage=lm_lib.LMSamplingUsage(
264
+ prompt_tokens=100, completion_tokens=100, total_tokens=200
265
+ ),
138
266
  ),
139
- ],
267
+ ]
140
268
  )
141
269
 
142
270
  def test_call(self):
143
271
  lm = MockModel(sampling_options=lm_lib.LMSamplingOptions(top_k=1))
144
272
  response = lm(prompt='foo')
145
273
  self.assertEqual(response.text, 'foo')
146
- self.assertEqual(response.score, 0.0)
274
+ self.assertEqual(response.score, -1.0)
275
+ self.assertIsNone(response.logprobs)
276
+ self.assertEqual(response.usage, lm_lib.LMSamplingUsage(100, 100, 200))
147
277
 
148
278
  # Test override sampling_options.
149
279
  self.assertEqual(
@@ -158,11 +288,42 @@ class LanguageModelTest(unittest.TestCase):
158
288
  self.assertEqual(
159
289
  lm.sample(prompts=['foo', 'bar']),
160
290
  [
161
- lm_lib.LMSamplingResult([lm_lib.LMSample(
162
- message_lib.AIMessage('foo', cache_seed=0), score=0.0)]),
163
- lm_lib.LMSamplingResult([lm_lib.LMSample(
164
- message_lib.AIMessage('bar', cache_seed=0), score=0.0)]),
165
- ])
291
+ lm_lib.LMSamplingResult(
292
+ [
293
+ lm_lib.LMSample(
294
+ message_lib.AIMessage(
295
+ 'foo',
296
+ cache_seed=0,
297
+ score=-1.0,
298
+ logprobs=None,
299
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
300
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
301
+ ),
302
+ score=-1.0,
303
+ logprobs=None,
304
+ )
305
+ ],
306
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
307
+ ),
308
+ lm_lib.LMSamplingResult(
309
+ [
310
+ lm_lib.LMSample(
311
+ message_lib.AIMessage(
312
+ 'bar',
313
+ cache_seed=0,
314
+ score=-1.0,
315
+ logprobs=None,
316
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
317
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
318
+ ),
319
+ score=-1.0,
320
+ logprobs=None,
321
+ )
322
+ ],
323
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
324
+ ),
325
+ ],
326
+ )
166
327
  self.assertEqual(cache.stats.num_queries, 2)
167
328
  self.assertEqual(cache.stats.num_hits, 0)
168
329
  self.assertEqual(cache.stats.num_updates, 2)
@@ -181,10 +342,40 @@ class LanguageModelTest(unittest.TestCase):
181
342
  self.assertEqual(
182
343
  lm.sample(prompts=['foo', 'baz'], temperature=1.0),
183
344
  [
184
- lm_lib.LMSamplingResult([lm_lib.LMSample(
185
- message_lib.AIMessage('foo', cache_seed=0), score=1.0)]),
186
- lm_lib.LMSamplingResult([lm_lib.LMSample(
187
- message_lib.AIMessage('baz', cache_seed=0), score=1.0)]),
345
+ lm_lib.LMSamplingResult(
346
+ [
347
+ lm_lib.LMSample(
348
+ message_lib.AIMessage(
349
+ 'foo',
350
+ cache_seed=0,
351
+ score=1.0,
352
+ logprobs=None,
353
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
354
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
355
+ ),
356
+ score=1.0,
357
+ logprobs=None,
358
+ )
359
+ ],
360
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
361
+ ),
362
+ lm_lib.LMSamplingResult(
363
+ [
364
+ lm_lib.LMSample(
365
+ message_lib.AIMessage(
366
+ 'baz',
367
+ cache_seed=0,
368
+ score=1.0,
369
+ logprobs=None,
370
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
371
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
372
+ ),
373
+ score=1.0,
374
+ logprobs=None,
375
+ )
376
+ ],
377
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
378
+ ),
188
379
  ],
189
380
  )
190
381
  self.assertEqual(cache.stats.num_queries, 6)
@@ -341,6 +532,38 @@ class LanguageModelTest(unittest.TestCase):
341
532
  with self.assertRaises(NotImplementedError):
342
533
  MockModel().score('hi', ['1', '2'])
343
534
 
535
+ def test_rate_to_max_concurrency_no_rpm_no_tpm(self) -> None:
536
+ lm = MockModel()
537
+ self.assertEqual(
538
+ lm_lib.DEFAULT_MAX_CONCURRENCY,
539
+ lm.rate_to_max_concurrency(requests_per_min=0, tokens_per_min=0),
540
+ )
541
+ self.assertEqual(
542
+ lm_lib.DEFAULT_MAX_CONCURRENCY,
543
+ lm.rate_to_max_concurrency(requests_per_min=-1, tokens_per_min=-1),
544
+ )
545
+
546
+ def test_rate_to_max_concurrency_only_rpm_specified_uses_rpm(self) -> None:
547
+ lm = MockModel()
548
+ test_rpm = 1e4
549
+ self.assertEqual(
550
+ lm.rate_to_max_concurrency(requests_per_min=test_rpm),
551
+ int(test_rpm / 60)
552
+ )
553
+
554
+ def test_rate_to_max_concurrency_tpm_specified_uses_tpm(self) -> None:
555
+ lm = MockModel()
556
+ test_tpm = 1e7
557
+ self.assertEqual(
558
+ lm.rate_to_max_concurrency(requests_per_min=1, tokens_per_min=test_tpm),
559
+ int(test_tpm / lm_lib.TOKENS_PER_REQUEST / 60),
560
+ )
561
+
562
+ def test_rate_to_max_concurrency_small_rate_returns_one(self) -> None:
563
+ lm = MockModel()
564
+ self.assertEqual(lm.rate_to_max_concurrency(requests_per_min=1), 1)
565
+ self.assertEqual(lm.rate_to_max_concurrency(tokens_per_min=1), 1)
566
+
344
567
 
345
568
  if __name__ == '__main__':
346
569
  unittest.main()
@@ -35,8 +35,12 @@ from langfun.core.llms.google_genai import Palm2_IT
35
35
  from langfun.core.llms.openai import OpenAI
36
36
 
37
37
  from langfun.core.llms.openai import Gpt4Turbo
38
- from langfun.core.llms.openai import Gpt4Turbo_0125
39
- from langfun.core.llms.openai import Gpt4TurboVision
38
+ from langfun.core.llms.openai import Gpt4Turbo_20240409
39
+ from langfun.core.llms.openai import Gpt4TurboPreview
40
+ from langfun.core.llms.openai import Gpt4TurboPreview_0125
41
+ from langfun.core.llms.openai import Gpt4TurboPreview_1106
42
+ from langfun.core.llms.openai import Gpt4VisionPreview
43
+ from langfun.core.llms.openai import Gpt4VisionPreview_1106
40
44
  from langfun.core.llms.openai import Gpt4
41
45
  from langfun.core.llms.openai import Gpt4_0613
42
46
  from langfun.core.llms.openai import Gpt4_32K
@@ -57,6 +61,19 @@ from langfun.core.llms.openai import Gpt3Curie
57
61
  from langfun.core.llms.openai import Gpt3Babbage
58
62
  from langfun.core.llms.openai import Gpt3Ada
59
63
 
64
+ from langfun.core.llms.anthropic import Anthropic
65
+ from langfun.core.llms.anthropic import Claude3Opus
66
+ from langfun.core.llms.anthropic import Claude3Sonnet
67
+ from langfun.core.llms.anthropic import Claude3Haiku
68
+
69
+ from langfun.core.llms.groq import Groq
70
+ from langfun.core.llms.groq import GroqLlama3_70B
71
+ from langfun.core.llms.groq import GroqLlama3_8B
72
+ from langfun.core.llms.groq import GroqLlama2_70B
73
+ from langfun.core.llms.groq import GroqMistral_8x7B
74
+ from langfun.core.llms.groq import GroqGemma7B_IT
75
+
76
+
60
77
  # LLaMA C++ models.
61
78
  from langfun.core.llms.llama_cpp import LlamaCppRemote
62
79
 
@@ -0,0 +1,263 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Language models from Anthropic."""
15
+
16
+ import base64
17
+ import functools
18
+ import os
19
+ from typing import Annotated, Any
20
+
21
+ import langfun.core as lf
22
+ from langfun.core import modalities as lf_modalities
23
+ import pyglove as pg
24
+ import requests
25
+
26
+
27
+ SUPPORTED_MODELS_AND_SETTINGS = {
28
+ # See https://docs.anthropic.com/claude/docs/models-overview
29
+ # Rate limits from https://docs.anthropic.com/claude/reference/rate-limits
30
+ # RPM/TPM for Claude-2.1, Claude-2.0, and Claude-Instant-1.2 estimated
31
+ # as RPM/TPM of the largest-available model (Claude-3-Opus).
32
+ 'claude-3-opus-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
33
+ 'claude-3-sonnet-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
34
+ 'claude-3-haiku-20240307': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
35
+ 'claude-2.1': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
36
+ 'claude-2.0': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
37
+ 'claude-instant-1.2': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
38
+ }
39
+
40
+
41
+ class AnthropicError(Exception): # pylint: disable=g-bad-exception-name
42
+ """Base class for Anthropic errors."""
43
+
44
+
45
+ class RateLimitError(AnthropicError):
46
+ """Error for rate limit reached."""
47
+
48
+
49
+ class OverloadedError(AnthropicError):
50
+ """Anthropic's server is temporarily overloaded."""
51
+
52
+
53
+ _ANTHROPIC_MESSAGE_API_ENDPOINT = 'https://api.anthropic.com/v1/messages'
54
+ _ANTHROPIC_API_VERSION = '2023-06-01'
55
+
56
+
57
+ @lf.use_init_args(['model'])
58
+ class Anthropic(lf.LanguageModel):
59
+ """Anthropic LLMs (Claude) through REST APIs.
60
+
61
+ See https://docs.anthropic.com/claude/reference/messages_post
62
+ """
63
+
64
+ model: pg.typing.Annotated[
65
+ pg.typing.Enum(
66
+ pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
67
+ ),
68
+ 'The name of the model to use.',
69
+ ]
70
+
71
+ multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
72
+ True
73
+ )
74
+
75
+ api_key: Annotated[
76
+ str | None,
77
+ (
78
+ 'API key. If None, the key will be read from environment variable '
79
+ "'ANTHROPIC_API_KEY'."
80
+ ),
81
+ ] = None
82
+
83
+ def _on_bound(self):
84
+ super()._on_bound()
85
+ self._api_key = None
86
+ self.__dict__.pop('_api_initialized', None)
87
+ self.__dict__.pop('_session', None)
88
+
89
+ @functools.cached_property
90
+ def _api_initialized(self):
91
+ api_key = self.api_key or os.environ.get('ANTHROPIC_API_KEY', None)
92
+ if not api_key:
93
+ raise ValueError(
94
+ 'Please specify `api_key` during `__init__` or set environment '
95
+ 'variable `ANTHROPIC_API_KEY` with your Anthropic API key.'
96
+ )
97
+ self._api_key = api_key
98
+ return True
99
+
100
+ @functools.cached_property
101
+ def _session(self) -> requests.Session:
102
+ assert self._api_initialized
103
+ s = requests.Session()
104
+ s.headers.update({
105
+ 'x-api-key': self._api_key,
106
+ 'anthropic-version': _ANTHROPIC_API_VERSION,
107
+ 'content-type': 'application/json',
108
+ })
109
+ return s
110
+
111
+ @property
112
+ def model_id(self) -> str:
113
+ """Returns a string to identify the model."""
114
+ return self.model
115
+
116
+ @property
117
+ def max_concurrency(self) -> int:
118
+ rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
119
+ tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
120
+ return self.rate_to_max_concurrency(
121
+ requests_per_min=rpm, tokens_per_min=tpm
122
+ )
123
+
124
+ def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
125
+ assert self._api_initialized
126
+ return self._parallel_execute_with_currency_control(
127
+ self._sample_single, prompts, retry_on_errors=(RateLimitError)
128
+ )
129
+
130
+ def _get_request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
131
+ """Returns a dict as request arguments."""
132
+ # Authropic requires `max_tokens` to be specified.
133
+ max_tokens = (
134
+ options.max_tokens
135
+ or SUPPORTED_MODELS_AND_SETTINGS[self.model].max_tokens
136
+ )
137
+ args = dict(
138
+ model=self.model,
139
+ max_tokens=max_tokens,
140
+ stream=False,
141
+ )
142
+ if options.stop:
143
+ args['stop_sequences'] = options.stop
144
+ if options.temperature is not None:
145
+ args['temperature'] = options.temperature
146
+ if options.top_k is not None:
147
+ args['top_k'] = options.top_k
148
+ if options.top_p is not None:
149
+ args['top_p'] = options.top_p
150
+ return args
151
+
152
+ def _content_from_message(self, prompt: lf.Message) -> list[dict[str, Any]]:
153
+ """Converts an message to Anthropic's content protocol (list of dicts)."""
154
+ # Refer: https://docs.anthropic.com/claude/reference/messages-examples
155
+ if self.multimodal:
156
+ content = []
157
+ for chunk in prompt.chunk():
158
+ if isinstance(chunk, str):
159
+ item = dict(type='text', text=chunk)
160
+ elif isinstance(chunk, lf_modalities.Image):
161
+ # NOTE(daiyip): Anthropic only support image content instead of URL.
162
+ item = dict(
163
+ type='image',
164
+ source=dict(
165
+ type='base64',
166
+ media_type=chunk.mime_type,
167
+ data=base64.b64encode(chunk.to_bytes()).decode(),
168
+ ),
169
+ )
170
+ else:
171
+ raise ValueError(f'Unsupported modality object: {chunk!r}.')
172
+ content.append(item)
173
+ return content
174
+ else:
175
+ return [dict(type='text', text=prompt.text)]
176
+
177
+ def _message_from_content(self, content: list[dict[str, Any]]) -> lf.Message:
178
+ """Converts Anthropic's content protocol to message."""
179
+ # Refer: https://docs.anthropic.com/claude/reference/messages-examples
180
+ return lf.AIMessage.from_chunks(
181
+ [x['text'] for x in content if x['type'] == 'text']
182
+ )
183
+
184
+ def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
185
+ """Parses Anthropic's response."""
186
+ # NOTE(daiyip): Refer https://docs.anthropic.com/claude/reference/errors
187
+ if response.status_code == 200:
188
+ output = response.json()
189
+ message = self._message_from_content(output['content'])
190
+ input_tokens = output['usage']['input_tokens']
191
+ output_tokens = output['usage']['output_tokens']
192
+ return lf.LMSamplingResult(
193
+ [lf.LMSample(message)],
194
+ usage=lf.LMSamplingUsage(
195
+ prompt_tokens=input_tokens,
196
+ completion_tokens=output_tokens,
197
+ total_tokens=input_tokens + output_tokens,
198
+ ),
199
+ )
200
+ else:
201
+ if response.status_code == 429:
202
+ error_cls = RateLimitError
203
+ elif response.status_code in (502, 529):
204
+ error_cls = OverloadedError
205
+ else:
206
+ error_cls = AnthropicError
207
+ raise error_cls(f'{response.status_code}: {response.content}')
208
+
209
+ def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
210
+ request = dict()
211
+ request.update(self._get_request_args(self.sampling_options))
212
+ request.update(
213
+ dict(
214
+ messages=[
215
+ dict(role='user', content=self._content_from_message(prompt))
216
+ ]
217
+ )
218
+ )
219
+ try:
220
+ response = self._session.post(
221
+ _ANTHROPIC_MESSAGE_API_ENDPOINT, json=request, timeout=self.timeout,
222
+ )
223
+ return self._parse_response(response)
224
+ except ConnectionError as e:
225
+ raise OverloadedError(str(e)) from e
226
+
227
+
228
+ class Claude3(Anthropic):
229
+ """Base class for Claude 3 models. 200K input tokens and 4K output tokens."""
230
+ multimodal = True
231
+
232
+
233
+ class Claude3Opus(Claude3):
234
+ """Anthropic's most powerful model."""
235
+
236
+ model = 'claude-3-opus-20240229'
237
+
238
+
239
+ class Claude3Sonnet(Claude3):
240
+ """A balance between between Opus and Haiku."""
241
+
242
+ model = 'claude-3-sonnet-20240229'
243
+
244
+
245
+ class Claude3Haiku(Claude3):
246
+ """Anthropic's most compact model."""
247
+
248
+ model = 'claude-3-haiku-20240307'
249
+
250
+
251
+ class Claude2(Anthropic):
252
+ """Predecessor to Claude 3 with 100K context window.."""
253
+ model = 'claude-2.0'
254
+
255
+
256
+ class Claude21(Anthropic):
257
+ """Updated Claude 2 model with improved accuracy and 200K context window."""
258
+ model = 'claude-2.1'
259
+
260
+
261
+ class ClaudeInstant(Anthropic):
262
+ """Cheapest small and fast model, 100K context window."""
263
+ model = 'claude-instant-1.2'