langfun 0.0.2.dev20240330__py3-none-any.whl → 0.0.2.dev20240429__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +2 -0
- langfun/core/__init__.py +1 -0
- langfun/core/coding/python/correction.py +0 -7
- langfun/core/component.py +6 -0
- langfun/core/component_test.py +1 -0
- langfun/core/eval/__init__.py +2 -0
- langfun/core/eval/base.py +202 -23
- langfun/core/eval/base_test.py +49 -10
- langfun/core/eval/matching.py +26 -9
- langfun/core/eval/matching_test.py +2 -1
- langfun/core/eval/scoring.py +15 -6
- langfun/core/eval/scoring_test.py +2 -1
- langfun/core/langfunc.py +0 -5
- langfun/core/langfunc_test.py +6 -4
- langfun/core/language_model.py +124 -24
- langfun/core/language_model_test.py +249 -26
- langfun/core/llms/__init__.py +19 -2
- langfun/core/llms/anthropic.py +263 -0
- langfun/core/llms/anthropic_test.py +167 -0
- langfun/core/llms/cache/in_memory_test.py +37 -28
- langfun/core/llms/fake.py +31 -22
- langfun/core/llms/fake_test.py +122 -11
- langfun/core/llms/google_genai_test.py +8 -3
- langfun/core/llms/groq.py +260 -0
- langfun/core/llms/groq_test.py +170 -0
- langfun/core/llms/llama_cpp.py +3 -1
- langfun/core/llms/openai.py +97 -79
- langfun/core/llms/openai_test.py +285 -59
- langfun/core/modalities/video.py +5 -2
- langfun/core/structured/__init__.py +3 -0
- langfun/core/structured/completion_test.py +2 -2
- langfun/core/structured/function_generation.py +245 -0
- langfun/core/structured/function_generation_test.py +329 -0
- langfun/core/structured/mapping.py +56 -2
- langfun/core/structured/mapping_test.py +17 -0
- langfun/core/structured/parsing_test.py +18 -13
- langfun/core/structured/prompting.py +27 -6
- langfun/core/structured/prompting_test.py +79 -12
- langfun/core/structured/schema.py +4 -2
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +4 -6
- langfun/core/template.py +125 -10
- langfun/core/template_test.py +75 -0
- langfun/core/templates/selfplay_test.py +6 -2
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/METADATA +3 -2
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/RECORD +49 -43
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/top_level.txt +0 -0
@@ -38,9 +38,19 @@ class MockModel(lm_lib.LanguageModel):
|
|
38
38
|
def fake_sample(prompts):
|
39
39
|
if context.attempt >= self.failures_before_attempt:
|
40
40
|
return [
|
41
|
-
lm_lib.LMSamplingResult(
|
42
|
-
|
43
|
-
|
41
|
+
lm_lib.LMSamplingResult(
|
42
|
+
[
|
43
|
+
lm_lib.LMSample( # pylint: disable=g-complex-comprehension
|
44
|
+
response=prompt.text * self.sampling_options.top_k,
|
45
|
+
score=self.sampling_options.temperature or -1.0,
|
46
|
+
)
|
47
|
+
],
|
48
|
+
usage=lm_lib.LMSamplingUsage(
|
49
|
+
prompt_tokens=100,
|
50
|
+
completion_tokens=100,
|
51
|
+
total_tokens=200,
|
52
|
+
),
|
53
|
+
)
|
44
54
|
for prompt in prompts
|
45
55
|
]
|
46
56
|
context.attempt += 1
|
@@ -73,13 +83,13 @@ class LMSamplingOptionsTest(unittest.TestCase):
|
|
73
83
|
def test_cache_key(self):
|
74
84
|
options = lm_lib.LMSamplingOptions()
|
75
85
|
key1 = options.cache_key()
|
76
|
-
self.assertEqual(key1, (
|
86
|
+
self.assertEqual(key1, (None, None, 1, 40, None, None))
|
77
87
|
with options.override(temperature=1.0, max_tokens=256):
|
78
88
|
key2 = options.cache_key()
|
79
89
|
self.assertEqual(key2, (1.0, 256, 1, 40, None, None))
|
80
90
|
|
81
91
|
# Make sure key1 does not change upon override.
|
82
|
-
self.assertEqual(key1, (
|
92
|
+
self.assertEqual(key1, (None, None, 1, 40, None, None))
|
83
93
|
|
84
94
|
|
85
95
|
class LanguageModelTest(unittest.TestCase):
|
@@ -100,8 +110,38 @@ class LanguageModelTest(unittest.TestCase):
|
|
100
110
|
self.assertEqual(
|
101
111
|
lm.sample(prompts=['foo', 'bar']),
|
102
112
|
[
|
103
|
-
lm_lib.LMSamplingResult(
|
104
|
-
|
113
|
+
lm_lib.LMSamplingResult(
|
114
|
+
[
|
115
|
+
lm_lib.LMSample(
|
116
|
+
message_lib.AIMessage(
|
117
|
+
'foo',
|
118
|
+
score=-1.0,
|
119
|
+
logprobs=None,
|
120
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
121
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
122
|
+
),
|
123
|
+
score=-1.0,
|
124
|
+
logprobs=None,
|
125
|
+
)
|
126
|
+
],
|
127
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
128
|
+
),
|
129
|
+
lm_lib.LMSamplingResult(
|
130
|
+
[
|
131
|
+
lm_lib.LMSample(
|
132
|
+
message_lib.AIMessage(
|
133
|
+
'bar',
|
134
|
+
score=-1.0,
|
135
|
+
logprobs=None,
|
136
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
137
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
138
|
+
),
|
139
|
+
score=-1.0,
|
140
|
+
logprobs=None,
|
141
|
+
)
|
142
|
+
],
|
143
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
144
|
+
),
|
105
145
|
],
|
106
146
|
)
|
107
147
|
# Test override sampling_options.
|
@@ -112,38 +152,128 @@ class LanguageModelTest(unittest.TestCase):
|
|
112
152
|
),
|
113
153
|
[
|
114
154
|
lm_lib.LMSamplingResult(
|
115
|
-
[
|
155
|
+
[
|
156
|
+
lm_lib.LMSample(
|
157
|
+
message_lib.AIMessage(
|
158
|
+
'foo' * 2,
|
159
|
+
score=0.5,
|
160
|
+
logprobs=None,
|
161
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
162
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
163
|
+
),
|
164
|
+
score=0.5,
|
165
|
+
logprobs=None,
|
166
|
+
),
|
167
|
+
],
|
168
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
116
169
|
),
|
117
170
|
lm_lib.LMSamplingResult(
|
118
|
-
[
|
171
|
+
[
|
172
|
+
lm_lib.LMSample(
|
173
|
+
message_lib.AIMessage(
|
174
|
+
'bar' * 2,
|
175
|
+
score=0.5,
|
176
|
+
logprobs=None,
|
177
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
178
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
179
|
+
),
|
180
|
+
score=0.5,
|
181
|
+
logprobs=None,
|
182
|
+
),
|
183
|
+
],
|
184
|
+
usage=lm_lib.LMSamplingUsage(
|
185
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
186
|
+
),
|
119
187
|
),
|
120
|
-
]
|
188
|
+
]
|
121
189
|
)
|
122
190
|
# Test override individual flags within sampling_options.
|
123
191
|
self.assertEqual(
|
124
192
|
lm.sample(prompts=['foo', 'bar'], temperature=1.0),
|
125
193
|
[
|
126
|
-
lm_lib.LMSamplingResult(
|
127
|
-
|
128
|
-
|
194
|
+
lm_lib.LMSamplingResult(
|
195
|
+
[
|
196
|
+
lm_lib.LMSample(
|
197
|
+
message_lib.AIMessage(
|
198
|
+
'foo',
|
199
|
+
score=1.0,
|
200
|
+
logprobs=None,
|
201
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
202
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
203
|
+
),
|
204
|
+
score=1.0,
|
205
|
+
logprobs=None,
|
206
|
+
),
|
207
|
+
],
|
208
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
209
|
+
),
|
210
|
+
lm_lib.LMSamplingResult(
|
211
|
+
[
|
212
|
+
lm_lib.LMSample(
|
213
|
+
message_lib.AIMessage(
|
214
|
+
'bar',
|
215
|
+
score=1.0,
|
216
|
+
logprobs=None,
|
217
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
218
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
219
|
+
),
|
220
|
+
score=1.0,
|
221
|
+
logprobs=None,
|
222
|
+
),
|
223
|
+
],
|
224
|
+
usage=lm_lib.LMSamplingUsage(
|
225
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
226
|
+
),
|
227
|
+
),
|
228
|
+
]
|
129
229
|
)
|
130
230
|
self.assertEqual(
|
131
231
|
lm.sample(prompts=['foo', 'bar'], top_k=2, temperature=0.7),
|
132
232
|
[
|
133
233
|
lm_lib.LMSamplingResult(
|
134
|
-
[
|
234
|
+
[
|
235
|
+
lm_lib.LMSample(
|
236
|
+
message_lib.AIMessage(
|
237
|
+
'foo' * 2,
|
238
|
+
score=0.7,
|
239
|
+
logprobs=None,
|
240
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
241
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
242
|
+
),
|
243
|
+
score=0.7,
|
244
|
+
logprobs=None,
|
245
|
+
),
|
246
|
+
],
|
247
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
135
248
|
),
|
136
249
|
lm_lib.LMSamplingResult(
|
137
|
-
[
|
250
|
+
[
|
251
|
+
lm_lib.LMSample(
|
252
|
+
message_lib.AIMessage(
|
253
|
+
'bar' * 2,
|
254
|
+
score=0.7,
|
255
|
+
logprobs=None,
|
256
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
257
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
258
|
+
),
|
259
|
+
score=0.7,
|
260
|
+
logprobs=None,
|
261
|
+
),
|
262
|
+
],
|
263
|
+
usage=lm_lib.LMSamplingUsage(
|
264
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
265
|
+
),
|
138
266
|
),
|
139
|
-
]
|
267
|
+
]
|
140
268
|
)
|
141
269
|
|
142
270
|
def test_call(self):
|
143
271
|
lm = MockModel(sampling_options=lm_lib.LMSamplingOptions(top_k=1))
|
144
272
|
response = lm(prompt='foo')
|
145
273
|
self.assertEqual(response.text, 'foo')
|
146
|
-
self.assertEqual(response.score,
|
274
|
+
self.assertEqual(response.score, -1.0)
|
275
|
+
self.assertIsNone(response.logprobs)
|
276
|
+
self.assertEqual(response.usage, lm_lib.LMSamplingUsage(100, 100, 200))
|
147
277
|
|
148
278
|
# Test override sampling_options.
|
149
279
|
self.assertEqual(
|
@@ -158,11 +288,42 @@ class LanguageModelTest(unittest.TestCase):
|
|
158
288
|
self.assertEqual(
|
159
289
|
lm.sample(prompts=['foo', 'bar']),
|
160
290
|
[
|
161
|
-
lm_lib.LMSamplingResult(
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
291
|
+
lm_lib.LMSamplingResult(
|
292
|
+
[
|
293
|
+
lm_lib.LMSample(
|
294
|
+
message_lib.AIMessage(
|
295
|
+
'foo',
|
296
|
+
cache_seed=0,
|
297
|
+
score=-1.0,
|
298
|
+
logprobs=None,
|
299
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
300
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
301
|
+
),
|
302
|
+
score=-1.0,
|
303
|
+
logprobs=None,
|
304
|
+
)
|
305
|
+
],
|
306
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
307
|
+
),
|
308
|
+
lm_lib.LMSamplingResult(
|
309
|
+
[
|
310
|
+
lm_lib.LMSample(
|
311
|
+
message_lib.AIMessage(
|
312
|
+
'bar',
|
313
|
+
cache_seed=0,
|
314
|
+
score=-1.0,
|
315
|
+
logprobs=None,
|
316
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
317
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
318
|
+
),
|
319
|
+
score=-1.0,
|
320
|
+
logprobs=None,
|
321
|
+
)
|
322
|
+
],
|
323
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
324
|
+
),
|
325
|
+
],
|
326
|
+
)
|
166
327
|
self.assertEqual(cache.stats.num_queries, 2)
|
167
328
|
self.assertEqual(cache.stats.num_hits, 0)
|
168
329
|
self.assertEqual(cache.stats.num_updates, 2)
|
@@ -181,10 +342,40 @@ class LanguageModelTest(unittest.TestCase):
|
|
181
342
|
self.assertEqual(
|
182
343
|
lm.sample(prompts=['foo', 'baz'], temperature=1.0),
|
183
344
|
[
|
184
|
-
lm_lib.LMSamplingResult(
|
185
|
-
|
186
|
-
|
187
|
-
|
345
|
+
lm_lib.LMSamplingResult(
|
346
|
+
[
|
347
|
+
lm_lib.LMSample(
|
348
|
+
message_lib.AIMessage(
|
349
|
+
'foo',
|
350
|
+
cache_seed=0,
|
351
|
+
score=1.0,
|
352
|
+
logprobs=None,
|
353
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
354
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
355
|
+
),
|
356
|
+
score=1.0,
|
357
|
+
logprobs=None,
|
358
|
+
)
|
359
|
+
],
|
360
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
361
|
+
),
|
362
|
+
lm_lib.LMSamplingResult(
|
363
|
+
[
|
364
|
+
lm_lib.LMSample(
|
365
|
+
message_lib.AIMessage(
|
366
|
+
'baz',
|
367
|
+
cache_seed=0,
|
368
|
+
score=1.0,
|
369
|
+
logprobs=None,
|
370
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
371
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
372
|
+
),
|
373
|
+
score=1.0,
|
374
|
+
logprobs=None,
|
375
|
+
)
|
376
|
+
],
|
377
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
378
|
+
),
|
188
379
|
],
|
189
380
|
)
|
190
381
|
self.assertEqual(cache.stats.num_queries, 6)
|
@@ -341,6 +532,38 @@ class LanguageModelTest(unittest.TestCase):
|
|
341
532
|
with self.assertRaises(NotImplementedError):
|
342
533
|
MockModel().score('hi', ['1', '2'])
|
343
534
|
|
535
|
+
def test_rate_to_max_concurrency_no_rpm_no_tpm(self) -> None:
|
536
|
+
lm = MockModel()
|
537
|
+
self.assertEqual(
|
538
|
+
lm_lib.DEFAULT_MAX_CONCURRENCY,
|
539
|
+
lm.rate_to_max_concurrency(requests_per_min=0, tokens_per_min=0),
|
540
|
+
)
|
541
|
+
self.assertEqual(
|
542
|
+
lm_lib.DEFAULT_MAX_CONCURRENCY,
|
543
|
+
lm.rate_to_max_concurrency(requests_per_min=-1, tokens_per_min=-1),
|
544
|
+
)
|
545
|
+
|
546
|
+
def test_rate_to_max_concurrency_only_rpm_specified_uses_rpm(self) -> None:
|
547
|
+
lm = MockModel()
|
548
|
+
test_rpm = 1e4
|
549
|
+
self.assertEqual(
|
550
|
+
lm.rate_to_max_concurrency(requests_per_min=test_rpm),
|
551
|
+
int(test_rpm / 60)
|
552
|
+
)
|
553
|
+
|
554
|
+
def test_rate_to_max_concurrency_tpm_specified_uses_tpm(self) -> None:
|
555
|
+
lm = MockModel()
|
556
|
+
test_tpm = 1e7
|
557
|
+
self.assertEqual(
|
558
|
+
lm.rate_to_max_concurrency(requests_per_min=1, tokens_per_min=test_tpm),
|
559
|
+
int(test_tpm / lm_lib.TOKENS_PER_REQUEST / 60),
|
560
|
+
)
|
561
|
+
|
562
|
+
def test_rate_to_max_concurrency_small_rate_returns_one(self) -> None:
|
563
|
+
lm = MockModel()
|
564
|
+
self.assertEqual(lm.rate_to_max_concurrency(requests_per_min=1), 1)
|
565
|
+
self.assertEqual(lm.rate_to_max_concurrency(tokens_per_min=1), 1)
|
566
|
+
|
344
567
|
|
345
568
|
if __name__ == '__main__':
|
346
569
|
unittest.main()
|
langfun/core/llms/__init__.py
CHANGED
@@ -35,8 +35,12 @@ from langfun.core.llms.google_genai import Palm2_IT
|
|
35
35
|
from langfun.core.llms.openai import OpenAI
|
36
36
|
|
37
37
|
from langfun.core.llms.openai import Gpt4Turbo
|
38
|
-
from langfun.core.llms.openai import
|
39
|
-
from langfun.core.llms.openai import
|
38
|
+
from langfun.core.llms.openai import Gpt4Turbo_20240409
|
39
|
+
from langfun.core.llms.openai import Gpt4TurboPreview
|
40
|
+
from langfun.core.llms.openai import Gpt4TurboPreview_0125
|
41
|
+
from langfun.core.llms.openai import Gpt4TurboPreview_1106
|
42
|
+
from langfun.core.llms.openai import Gpt4VisionPreview
|
43
|
+
from langfun.core.llms.openai import Gpt4VisionPreview_1106
|
40
44
|
from langfun.core.llms.openai import Gpt4
|
41
45
|
from langfun.core.llms.openai import Gpt4_0613
|
42
46
|
from langfun.core.llms.openai import Gpt4_32K
|
@@ -57,6 +61,19 @@ from langfun.core.llms.openai import Gpt3Curie
|
|
57
61
|
from langfun.core.llms.openai import Gpt3Babbage
|
58
62
|
from langfun.core.llms.openai import Gpt3Ada
|
59
63
|
|
64
|
+
from langfun.core.llms.anthropic import Anthropic
|
65
|
+
from langfun.core.llms.anthropic import Claude3Opus
|
66
|
+
from langfun.core.llms.anthropic import Claude3Sonnet
|
67
|
+
from langfun.core.llms.anthropic import Claude3Haiku
|
68
|
+
|
69
|
+
from langfun.core.llms.groq import Groq
|
70
|
+
from langfun.core.llms.groq import GroqLlama3_70B
|
71
|
+
from langfun.core.llms.groq import GroqLlama3_8B
|
72
|
+
from langfun.core.llms.groq import GroqLlama2_70B
|
73
|
+
from langfun.core.llms.groq import GroqMistral_8x7B
|
74
|
+
from langfun.core.llms.groq import GroqGemma7B_IT
|
75
|
+
|
76
|
+
|
60
77
|
# LLaMA C++ models.
|
61
78
|
from langfun.core.llms.llama_cpp import LlamaCppRemote
|
62
79
|
|
@@ -0,0 +1,263 @@
|
|
1
|
+
# Copyright 2023 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Language models from Anthropic."""
|
15
|
+
|
16
|
+
import base64
|
17
|
+
import functools
|
18
|
+
import os
|
19
|
+
from typing import Annotated, Any
|
20
|
+
|
21
|
+
import langfun.core as lf
|
22
|
+
from langfun.core import modalities as lf_modalities
|
23
|
+
import pyglove as pg
|
24
|
+
import requests
|
25
|
+
|
26
|
+
|
27
|
+
SUPPORTED_MODELS_AND_SETTINGS = {
|
28
|
+
# See https://docs.anthropic.com/claude/docs/models-overview
|
29
|
+
# Rate limits from https://docs.anthropic.com/claude/reference/rate-limits
|
30
|
+
# RPM/TPM for Claude-2.1, Claude-2.0, and Claude-Instant-1.2 estimated
|
31
|
+
# as RPM/TPM of the largest-available model (Claude-3-Opus).
|
32
|
+
'claude-3-opus-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
33
|
+
'claude-3-sonnet-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
34
|
+
'claude-3-haiku-20240307': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
35
|
+
'claude-2.1': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
36
|
+
'claude-2.0': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
37
|
+
'claude-instant-1.2': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
38
|
+
}
|
39
|
+
|
40
|
+
|
41
|
+
class AnthropicError(Exception): # pylint: disable=g-bad-exception-name
|
42
|
+
"""Base class for Anthropic errors."""
|
43
|
+
|
44
|
+
|
45
|
+
class RateLimitError(AnthropicError):
|
46
|
+
"""Error for rate limit reached."""
|
47
|
+
|
48
|
+
|
49
|
+
class OverloadedError(AnthropicError):
|
50
|
+
"""Anthropic's server is temporarily overloaded."""
|
51
|
+
|
52
|
+
|
53
|
+
_ANTHROPIC_MESSAGE_API_ENDPOINT = 'https://api.anthropic.com/v1/messages'
|
54
|
+
_ANTHROPIC_API_VERSION = '2023-06-01'
|
55
|
+
|
56
|
+
|
57
|
+
@lf.use_init_args(['model'])
|
58
|
+
class Anthropic(lf.LanguageModel):
|
59
|
+
"""Anthropic LLMs (Claude) through REST APIs.
|
60
|
+
|
61
|
+
See https://docs.anthropic.com/claude/reference/messages_post
|
62
|
+
"""
|
63
|
+
|
64
|
+
model: pg.typing.Annotated[
|
65
|
+
pg.typing.Enum(
|
66
|
+
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
|
67
|
+
),
|
68
|
+
'The name of the model to use.',
|
69
|
+
]
|
70
|
+
|
71
|
+
multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
|
72
|
+
True
|
73
|
+
)
|
74
|
+
|
75
|
+
api_key: Annotated[
|
76
|
+
str | None,
|
77
|
+
(
|
78
|
+
'API key. If None, the key will be read from environment variable '
|
79
|
+
"'ANTHROPIC_API_KEY'."
|
80
|
+
),
|
81
|
+
] = None
|
82
|
+
|
83
|
+
def _on_bound(self):
|
84
|
+
super()._on_bound()
|
85
|
+
self._api_key = None
|
86
|
+
self.__dict__.pop('_api_initialized', None)
|
87
|
+
self.__dict__.pop('_session', None)
|
88
|
+
|
89
|
+
@functools.cached_property
|
90
|
+
def _api_initialized(self):
|
91
|
+
api_key = self.api_key or os.environ.get('ANTHROPIC_API_KEY', None)
|
92
|
+
if not api_key:
|
93
|
+
raise ValueError(
|
94
|
+
'Please specify `api_key` during `__init__` or set environment '
|
95
|
+
'variable `ANTHROPIC_API_KEY` with your Anthropic API key.'
|
96
|
+
)
|
97
|
+
self._api_key = api_key
|
98
|
+
return True
|
99
|
+
|
100
|
+
@functools.cached_property
|
101
|
+
def _session(self) -> requests.Session:
|
102
|
+
assert self._api_initialized
|
103
|
+
s = requests.Session()
|
104
|
+
s.headers.update({
|
105
|
+
'x-api-key': self._api_key,
|
106
|
+
'anthropic-version': _ANTHROPIC_API_VERSION,
|
107
|
+
'content-type': 'application/json',
|
108
|
+
})
|
109
|
+
return s
|
110
|
+
|
111
|
+
@property
|
112
|
+
def model_id(self) -> str:
|
113
|
+
"""Returns a string to identify the model."""
|
114
|
+
return self.model
|
115
|
+
|
116
|
+
@property
|
117
|
+
def max_concurrency(self) -> int:
|
118
|
+
rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
|
119
|
+
tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
|
120
|
+
return self.rate_to_max_concurrency(
|
121
|
+
requests_per_min=rpm, tokens_per_min=tpm
|
122
|
+
)
|
123
|
+
|
124
|
+
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
125
|
+
assert self._api_initialized
|
126
|
+
return self._parallel_execute_with_currency_control(
|
127
|
+
self._sample_single, prompts, retry_on_errors=(RateLimitError)
|
128
|
+
)
|
129
|
+
|
130
|
+
def _get_request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
131
|
+
"""Returns a dict as request arguments."""
|
132
|
+
# Authropic requires `max_tokens` to be specified.
|
133
|
+
max_tokens = (
|
134
|
+
options.max_tokens
|
135
|
+
or SUPPORTED_MODELS_AND_SETTINGS[self.model].max_tokens
|
136
|
+
)
|
137
|
+
args = dict(
|
138
|
+
model=self.model,
|
139
|
+
max_tokens=max_tokens,
|
140
|
+
stream=False,
|
141
|
+
)
|
142
|
+
if options.stop:
|
143
|
+
args['stop_sequences'] = options.stop
|
144
|
+
if options.temperature is not None:
|
145
|
+
args['temperature'] = options.temperature
|
146
|
+
if options.top_k is not None:
|
147
|
+
args['top_k'] = options.top_k
|
148
|
+
if options.top_p is not None:
|
149
|
+
args['top_p'] = options.top_p
|
150
|
+
return args
|
151
|
+
|
152
|
+
def _content_from_message(self, prompt: lf.Message) -> list[dict[str, Any]]:
|
153
|
+
"""Converts an message to Anthropic's content protocol (list of dicts)."""
|
154
|
+
# Refer: https://docs.anthropic.com/claude/reference/messages-examples
|
155
|
+
if self.multimodal:
|
156
|
+
content = []
|
157
|
+
for chunk in prompt.chunk():
|
158
|
+
if isinstance(chunk, str):
|
159
|
+
item = dict(type='text', text=chunk)
|
160
|
+
elif isinstance(chunk, lf_modalities.Image):
|
161
|
+
# NOTE(daiyip): Anthropic only support image content instead of URL.
|
162
|
+
item = dict(
|
163
|
+
type='image',
|
164
|
+
source=dict(
|
165
|
+
type='base64',
|
166
|
+
media_type=chunk.mime_type,
|
167
|
+
data=base64.b64encode(chunk.to_bytes()).decode(),
|
168
|
+
),
|
169
|
+
)
|
170
|
+
else:
|
171
|
+
raise ValueError(f'Unsupported modality object: {chunk!r}.')
|
172
|
+
content.append(item)
|
173
|
+
return content
|
174
|
+
else:
|
175
|
+
return [dict(type='text', text=prompt.text)]
|
176
|
+
|
177
|
+
def _message_from_content(self, content: list[dict[str, Any]]) -> lf.Message:
|
178
|
+
"""Converts Anthropic's content protocol to message."""
|
179
|
+
# Refer: https://docs.anthropic.com/claude/reference/messages-examples
|
180
|
+
return lf.AIMessage.from_chunks(
|
181
|
+
[x['text'] for x in content if x['type'] == 'text']
|
182
|
+
)
|
183
|
+
|
184
|
+
def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
|
185
|
+
"""Parses Anthropic's response."""
|
186
|
+
# NOTE(daiyip): Refer https://docs.anthropic.com/claude/reference/errors
|
187
|
+
if response.status_code == 200:
|
188
|
+
output = response.json()
|
189
|
+
message = self._message_from_content(output['content'])
|
190
|
+
input_tokens = output['usage']['input_tokens']
|
191
|
+
output_tokens = output['usage']['output_tokens']
|
192
|
+
return lf.LMSamplingResult(
|
193
|
+
[lf.LMSample(message)],
|
194
|
+
usage=lf.LMSamplingUsage(
|
195
|
+
prompt_tokens=input_tokens,
|
196
|
+
completion_tokens=output_tokens,
|
197
|
+
total_tokens=input_tokens + output_tokens,
|
198
|
+
),
|
199
|
+
)
|
200
|
+
else:
|
201
|
+
if response.status_code == 429:
|
202
|
+
error_cls = RateLimitError
|
203
|
+
elif response.status_code in (502, 529):
|
204
|
+
error_cls = OverloadedError
|
205
|
+
else:
|
206
|
+
error_cls = AnthropicError
|
207
|
+
raise error_cls(f'{response.status_code}: {response.content}')
|
208
|
+
|
209
|
+
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
210
|
+
request = dict()
|
211
|
+
request.update(self._get_request_args(self.sampling_options))
|
212
|
+
request.update(
|
213
|
+
dict(
|
214
|
+
messages=[
|
215
|
+
dict(role='user', content=self._content_from_message(prompt))
|
216
|
+
]
|
217
|
+
)
|
218
|
+
)
|
219
|
+
try:
|
220
|
+
response = self._session.post(
|
221
|
+
_ANTHROPIC_MESSAGE_API_ENDPOINT, json=request, timeout=self.timeout,
|
222
|
+
)
|
223
|
+
return self._parse_response(response)
|
224
|
+
except ConnectionError as e:
|
225
|
+
raise OverloadedError(str(e)) from e
|
226
|
+
|
227
|
+
|
228
|
+
class Claude3(Anthropic):
|
229
|
+
"""Base class for Claude 3 models. 200K input tokens and 4K output tokens."""
|
230
|
+
multimodal = True
|
231
|
+
|
232
|
+
|
233
|
+
class Claude3Opus(Claude3):
|
234
|
+
"""Anthropic's most powerful model."""
|
235
|
+
|
236
|
+
model = 'claude-3-opus-20240229'
|
237
|
+
|
238
|
+
|
239
|
+
class Claude3Sonnet(Claude3):
|
240
|
+
"""A balance between between Opus and Haiku."""
|
241
|
+
|
242
|
+
model = 'claude-3-sonnet-20240229'
|
243
|
+
|
244
|
+
|
245
|
+
class Claude3Haiku(Claude3):
|
246
|
+
"""Anthropic's most compact model."""
|
247
|
+
|
248
|
+
model = 'claude-3-haiku-20240307'
|
249
|
+
|
250
|
+
|
251
|
+
class Claude2(Anthropic):
|
252
|
+
"""Predecessor to Claude 3 with 100K context window.."""
|
253
|
+
model = 'claude-2.0'
|
254
|
+
|
255
|
+
|
256
|
+
class Claude21(Anthropic):
|
257
|
+
"""Updated Claude 2 model with improved accuracy and 200K context window."""
|
258
|
+
model = 'claude-2.1'
|
259
|
+
|
260
|
+
|
261
|
+
class ClaudeInstant(Anthropic):
|
262
|
+
"""Cheapest small and fast model, 100K context window."""
|
263
|
+
model = 'claude-instant-1.2'
|