langfun 0.0.2.dev20240420__tar.gz → 0.0.2.dev20240423__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/PKG-INFO +1 -1
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/component.py +6 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/component_test.py +1 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/language_model.py +14 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/language_model_test.py +32 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/__init__.py +7 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/anthropic.py +36 -22
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/anthropic_test.py +7 -7
- langfun-0.0.2.dev20240423/langfun/core/llms/groq.py +260 -0
- langfun-0.0.2.dev20240423/langfun/core/llms/groq_test.py +170 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/openai.py +55 -50
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/openai_test.py +3 -3
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/template.py +26 -8
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/template_test.py +9 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun.egg-info/PKG-INFO +1 -1
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun.egg-info/SOURCES.txt +2 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/LICENSE +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/README.md +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/__init__.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/__init__.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/__init__.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/__init__.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/correction.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/correction_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/errors.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/errors_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/execution.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/execution_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/generation.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/generation_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/parsing.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/parsing_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/permissions.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/permissions_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/concurrent.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/concurrent_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/console.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/console_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/eval/__init__.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/eval/base.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/eval/base_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/eval/matching.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/eval/matching_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/eval/scoring.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/eval/scoring_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/langfunc.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/langfunc_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/cache/__init__.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/cache/base.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/cache/in_memory.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/cache/in_memory_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/fake.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/fake_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/google_genai.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/google_genai_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/llama_cpp.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/llama_cpp_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/memories/__init__.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/memories/conversation_history.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/memories/conversation_history_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/memory.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/message.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/message_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/modalities/__init__.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/modalities/image.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/modalities/image_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/modalities/mime.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/modalities/mime_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/modalities/video.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/modalities/video_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/modality.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/modality_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/natural_language.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/natural_language_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/sampling.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/sampling_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/__init__.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/completion.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/completion_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/description.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/description_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/function_generation.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/function_generation_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/mapping.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/mapping_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/parsing.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/parsing_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/prompting.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/prompting_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/schema.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/schema_generation.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/schema_generation_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/schema_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/scoring.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/scoring_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/subscription.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/subscription_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/templates/__init__.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/templates/completion.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/templates/completion_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/templates/conversation.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/templates/conversation_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/templates/demonstration.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/templates/demonstration_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/templates/selfplay.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/templates/selfplay_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/text_formatting.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/text_formatting_test.py +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun.egg-info/dependency_links.txt +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun.egg-info/requires.txt +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun.egg-info/top_level.txt +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/setup.cfg +0 -0
- {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/setup.py +0 -0
@@ -210,6 +210,12 @@ def get_contextual_override(var_name: str) -> ContextualOverride | None:
|
|
210
210
|
return _get_scoped_value(_global_tls, _CONTEXT_OVERRIDES, var_name)
|
211
211
|
|
212
212
|
|
213
|
+
def all_contextual_values() -> dict[str, Any]:
|
214
|
+
"""Returns all contextual values provided from `lf.context` in scope."""
|
215
|
+
overrides = getattr(_global_tls, _CONTEXT_OVERRIDES, {})
|
216
|
+
return {k: v.value for k, v in overrides.items()}
|
217
|
+
|
218
|
+
|
213
219
|
@contextlib.contextmanager
|
214
220
|
def _contextual_scope(
|
215
221
|
tls: threading.local, tls_key, **variables
|
@@ -84,6 +84,7 @@ class ComponentContextTest(unittest.TestCase):
|
|
84
84
|
lf.get_contextual_override('y'),
|
85
85
|
lf.ContextualOverride(3, cascade=False, override_attrs=False),
|
86
86
|
)
|
87
|
+
self.assertEqual(lf.all_contextual_values(), dict(x=3, y=3, z=3))
|
87
88
|
|
88
89
|
# Member attributes take precedence over `lf.context`.
|
89
90
|
self.assertEqual(a1.x, 1)
|
@@ -24,6 +24,9 @@ from langfun.core import console
|
|
24
24
|
from langfun.core import message as message_lib
|
25
25
|
import pyglove as pg
|
26
26
|
|
27
|
+
TOKENS_PER_REQUEST = 250 # Estimated num tokens for a single request
|
28
|
+
DEFAULT_MAX_CONCURRENCY = 1 # Use this as max concurrency if no RPM or TPM data
|
29
|
+
|
27
30
|
|
28
31
|
class LMSample(pg.Object):
|
29
32
|
"""Response candidate."""
|
@@ -604,3 +607,14 @@ class LanguageModel(component.Component):
|
|
604
607
|
f'score: {r.score}',
|
605
608
|
color='blue',
|
606
609
|
)
|
610
|
+
|
611
|
+
def rate_to_max_concurrency(
|
612
|
+
self, requests_per_min: float = 0, tokens_per_min: float = 0
|
613
|
+
) -> int:
|
614
|
+
"""Converts a rate to a max concurrency."""
|
615
|
+
if tokens_per_min > 0:
|
616
|
+
return max(int(tokens_per_min / TOKENS_PER_REQUEST / 60), 1)
|
617
|
+
elif requests_per_min > 0:
|
618
|
+
return max(int(requests_per_min / 60), 1) # Max concurrency can't be zero
|
619
|
+
else:
|
620
|
+
return DEFAULT_MAX_CONCURRENCY # Default of 1
|
@@ -394,6 +394,38 @@ class LanguageModelTest(unittest.TestCase):
|
|
394
394
|
with self.assertRaises(NotImplementedError):
|
395
395
|
MockModel().score('hi', ['1', '2'])
|
396
396
|
|
397
|
+
def test_rate_to_max_concurrency_no_rpm_no_tpm(self) -> None:
|
398
|
+
lm = MockModel()
|
399
|
+
self.assertEqual(
|
400
|
+
lm_lib.DEFAULT_MAX_CONCURRENCY,
|
401
|
+
lm.rate_to_max_concurrency(requests_per_min=0, tokens_per_min=0),
|
402
|
+
)
|
403
|
+
self.assertEqual(
|
404
|
+
lm_lib.DEFAULT_MAX_CONCURRENCY,
|
405
|
+
lm.rate_to_max_concurrency(requests_per_min=-1, tokens_per_min=-1),
|
406
|
+
)
|
407
|
+
|
408
|
+
def test_rate_to_max_concurrency_only_rpm_specified_uses_rpm(self) -> None:
|
409
|
+
lm = MockModel()
|
410
|
+
test_rpm = 1e4
|
411
|
+
self.assertEqual(
|
412
|
+
lm.rate_to_max_concurrency(requests_per_min=test_rpm),
|
413
|
+
int(test_rpm / 60)
|
414
|
+
)
|
415
|
+
|
416
|
+
def test_rate_to_max_concurrency_tpm_specified_uses_tpm(self) -> None:
|
417
|
+
lm = MockModel()
|
418
|
+
test_tpm = 1e7
|
419
|
+
self.assertEqual(
|
420
|
+
lm.rate_to_max_concurrency(requests_per_min=1, tokens_per_min=test_tpm),
|
421
|
+
int(test_tpm / lm_lib.TOKENS_PER_REQUEST / 60),
|
422
|
+
)
|
423
|
+
|
424
|
+
def test_rate_to_max_concurrency_small_rate_returns_one(self) -> None:
|
425
|
+
lm = MockModel()
|
426
|
+
self.assertEqual(lm.rate_to_max_concurrency(requests_per_min=1), 1)
|
427
|
+
self.assertEqual(lm.rate_to_max_concurrency(tokens_per_min=1), 1)
|
428
|
+
|
397
429
|
|
398
430
|
if __name__ == '__main__':
|
399
431
|
unittest.main()
|
@@ -66,6 +66,13 @@ from langfun.core.llms.anthropic import Claude3Opus
|
|
66
66
|
from langfun.core.llms.anthropic import Claude3Sonnet
|
67
67
|
from langfun.core.llms.anthropic import Claude3Haiku
|
68
68
|
|
69
|
+
from langfun.core.llms.groq import Groq
|
70
|
+
from langfun.core.llms.groq import GroqLlama3_70B
|
71
|
+
from langfun.core.llms.groq import GroqLlama3_8B
|
72
|
+
from langfun.core.llms.groq import GroqLlama2_70B
|
73
|
+
from langfun.core.llms.groq import GroqMistral_8x7B
|
74
|
+
from langfun.core.llms.groq import GroqGemma7B_IT
|
75
|
+
|
69
76
|
|
70
77
|
# LLaMA C++ models.
|
71
78
|
from langfun.core.llms.llama_cpp import LlamaCppRemote
|
@@ -26,12 +26,15 @@ import requests
|
|
26
26
|
|
27
27
|
SUPPORTED_MODELS_AND_SETTINGS = {
|
28
28
|
# See https://docs.anthropic.com/claude/docs/models-overview
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
'claude-
|
33
|
-
'claude-
|
34
|
-
'claude-
|
29
|
+
# Rate limits from https://docs.anthropic.com/claude/reference/rate-limits
|
30
|
+
# RPM/TPM for Claude-2.1, Claude-2.0, and Claude-Instant-1.2 estimated
|
31
|
+
# as RPM/TPM of the largest-available model (Claude-3-Opus).
|
32
|
+
'claude-3-opus-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
33
|
+
'claude-3-sonnet-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
34
|
+
'claude-3-haiku-20240307': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
35
|
+
'claude-2.1': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
36
|
+
'claude-2.0': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
37
|
+
'claude-instant-1.2': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
35
38
|
}
|
36
39
|
|
37
40
|
|
@@ -81,6 +84,7 @@ class Anthropic(lf.LanguageModel):
|
|
81
84
|
super()._on_bound()
|
82
85
|
self._api_key = None
|
83
86
|
self.__dict__.pop('_api_initialized', None)
|
87
|
+
self.__dict__.pop('_session', None)
|
84
88
|
|
85
89
|
@functools.cached_property
|
86
90
|
def _api_initialized(self):
|
@@ -93,6 +97,17 @@ class Anthropic(lf.LanguageModel):
|
|
93
97
|
self._api_key = api_key
|
94
98
|
return True
|
95
99
|
|
100
|
+
@functools.cached_property
|
101
|
+
def _session(self) -> requests.Session:
|
102
|
+
assert self._api_initialized
|
103
|
+
s = requests.Session()
|
104
|
+
s.headers.update({
|
105
|
+
'x-api-key': self._api_key,
|
106
|
+
'anthropic-version': _ANTHROPIC_API_VERSION,
|
107
|
+
'content-type': 'application/json',
|
108
|
+
})
|
109
|
+
return s
|
110
|
+
|
96
111
|
@property
|
97
112
|
def model_id(self) -> str:
|
98
113
|
"""Returns a string to identify the model."""
|
@@ -100,7 +115,11 @@ class Anthropic(lf.LanguageModel):
|
|
100
115
|
|
101
116
|
@property
|
102
117
|
def max_concurrency(self) -> int:
|
103
|
-
|
118
|
+
rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
|
119
|
+
tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
|
120
|
+
return self.rate_to_max_concurrency(
|
121
|
+
requests_per_min=rpm, tokens_per_min=tpm
|
122
|
+
)
|
104
123
|
|
105
124
|
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
106
125
|
assert self._api_initialized
|
@@ -165,8 +184,8 @@ class Anthropic(lf.LanguageModel):
|
|
165
184
|
def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
|
166
185
|
"""Parses Anthropic's response."""
|
167
186
|
# NOTE(daiyip): Refer https://docs.anthropic.com/claude/reference/errors
|
168
|
-
output = response.json()
|
169
187
|
if response.status_code == 200:
|
188
|
+
output = response.json()
|
170
189
|
message = self._message_from_content(output['content'])
|
171
190
|
input_tokens = output['usage']['input_tokens']
|
172
191
|
output_tokens = output['usage']['output_tokens']
|
@@ -181,12 +200,11 @@ class Anthropic(lf.LanguageModel):
|
|
181
200
|
else:
|
182
201
|
if response.status_code == 429:
|
183
202
|
error_cls = RateLimitError
|
184
|
-
elif response.status_code
|
203
|
+
elif response.status_code in (502, 529):
|
185
204
|
error_cls = OverloadedError
|
186
205
|
else:
|
187
206
|
error_cls = AnthropicError
|
188
|
-
|
189
|
-
raise error_cls(f'{error["type"]}: {error["message"]}')
|
207
|
+
raise error_cls(f'{response.status_code}: {response.content}')
|
190
208
|
|
191
209
|
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
192
210
|
request = dict()
|
@@ -198,17 +216,13 @@ class Anthropic(lf.LanguageModel):
|
|
198
216
|
]
|
199
217
|
)
|
200
218
|
)
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
},
|
209
|
-
timeout=self.timeout,
|
210
|
-
)
|
211
|
-
return self._parse_response(response)
|
219
|
+
try:
|
220
|
+
response = self._session.post(
|
221
|
+
_ANTHROPIC_MESSAGE_API_ENDPOINT, json=request, timeout=self.timeout,
|
222
|
+
)
|
223
|
+
return self._parse_response(response)
|
224
|
+
except ConnectionError as e:
|
225
|
+
raise OverloadedError(str(e)) from e
|
212
226
|
|
213
227
|
|
214
228
|
class Claude3(Anthropic):
|
@@ -98,20 +98,20 @@ def mock_requests_post_error(status_code, error_type, error_message):
|
|
98
98
|
return _mock_requests
|
99
99
|
|
100
100
|
|
101
|
-
class
|
101
|
+
class AnthropicTest(unittest.TestCase):
|
102
102
|
|
103
103
|
def test_basics(self):
|
104
104
|
self.assertEqual(
|
105
105
|
anthropic.Claude3Haiku().model_id, 'claude-3-haiku-20240307'
|
106
106
|
)
|
107
|
-
self.
|
107
|
+
self.assertGreater(anthropic.Claude3Haiku().max_concurrency, 0)
|
108
108
|
|
109
109
|
def test_api_key(self):
|
110
110
|
lm = anthropic.Claude3Haiku()
|
111
111
|
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
112
112
|
lm('hi')
|
113
113
|
|
114
|
-
with mock.patch('requests.post') as mock_request:
|
114
|
+
with mock.patch('requests.Session.post') as mock_request:
|
115
115
|
mock_request.side_effect = mock_requests_post
|
116
116
|
|
117
117
|
lm = anthropic.Claude3Haiku(api_key='fake key')
|
@@ -123,7 +123,7 @@ class AuthropicTest(unittest.TestCase):
|
|
123
123
|
del os.environ['ANTHROPIC_API_KEY']
|
124
124
|
|
125
125
|
def test_call(self):
|
126
|
-
with mock.patch('requests.post') as mock_request:
|
126
|
+
with mock.patch('requests.Session.post') as mock_request:
|
127
127
|
mock_request.side_effect = mock_requests_post
|
128
128
|
lm = anthropic.Claude3Haiku(api_key='fake_key')
|
129
129
|
response = lm('hello', temperature=0.0, top_k=0.1, top_p=0.2, stop=['\n'])
|
@@ -140,7 +140,7 @@ class AuthropicTest(unittest.TestCase):
|
|
140
140
|
self.assertIsNotNone(response.usage.total_tokens, 3)
|
141
141
|
|
142
142
|
def test_mm_call(self):
|
143
|
-
with mock.patch('requests.post') as mock_mm_request:
|
143
|
+
with mock.patch('requests.Session.post') as mock_mm_request:
|
144
144
|
mock_mm_request.side_effect = mock_mm_requests_post
|
145
145
|
lm = anthropic.Claude3Haiku(api_key='fake_key')
|
146
146
|
response = lm(lf_modalities.Image.from_bytes(image_content), lm=lm)
|
@@ -152,13 +152,13 @@ class AuthropicTest(unittest.TestCase):
|
|
152
152
|
(529, 'service_unavailable', 'Service unavailable.'),
|
153
153
|
(500, 'bad_request', 'Bad request.'),
|
154
154
|
]:
|
155
|
-
with mock.patch('requests.post') as mock_mm_request:
|
155
|
+
with mock.patch('requests.Session.post') as mock_mm_request:
|
156
156
|
mock_mm_request.side_effect = mock_requests_post_error(
|
157
157
|
status_code, error_type, error_message
|
158
158
|
)
|
159
159
|
lm = anthropic.Claude3Haiku(api_key='fake_key')
|
160
160
|
with self.assertRaisesRegex(
|
161
|
-
Exception, f'{
|
161
|
+
Exception, f'.*{status_code}: .*{error_message}'
|
162
162
|
):
|
163
163
|
lm('hello', lm=lm, max_attempts=1)
|
164
164
|
|
@@ -0,0 +1,260 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Language models from Groq."""
|
15
|
+
|
16
|
+
import functools
|
17
|
+
import os
|
18
|
+
from typing import Annotated, Any
|
19
|
+
|
20
|
+
import langfun.core as lf
|
21
|
+
from langfun.core import modalities as lf_modalities
|
22
|
+
import pyglove as pg
|
23
|
+
import requests
|
24
|
+
|
25
|
+
|
26
|
+
SUPPORTED_MODELS_AND_SETTINGS = {
|
27
|
+
# Refer https://console.groq.com/docs/models
|
28
|
+
'llama3-8b-8192': pg.Dict(max_tokens=8192, max_concurrency=16),
|
29
|
+
'llama3-70b-8192': pg.Dict(max_tokens=8192, max_concurrency=16),
|
30
|
+
'llama2-70b-4096': pg.Dict(max_tokens=4096, max_concurrency=16),
|
31
|
+
'mixtral-8x7b-32768': pg.Dict(max_tokens=32768, max_concurrency=16),
|
32
|
+
'gemma-7b-it': pg.Dict(max_tokens=8192, max_concurrency=16),
|
33
|
+
}
|
34
|
+
|
35
|
+
|
36
|
+
class GroqError(Exception): # pylint: disable=g-bad-exception-name
|
37
|
+
"""Base class for Groq errors."""
|
38
|
+
|
39
|
+
|
40
|
+
class RateLimitError(GroqError):
|
41
|
+
"""Error for rate limit reached."""
|
42
|
+
|
43
|
+
|
44
|
+
class OverloadedError(GroqError):
|
45
|
+
"""Groq's server is temporarily overloaded."""
|
46
|
+
|
47
|
+
|
48
|
+
_CHAT_COMPLETE_API_ENDPOINT = 'https://api.groq.com/openai/v1/chat/completions'
|
49
|
+
|
50
|
+
|
51
|
+
@lf.use_init_args(['model'])
|
52
|
+
class Groq(lf.LanguageModel):
|
53
|
+
"""Groq LLMs through REST APIs (OpenAI compatible).
|
54
|
+
|
55
|
+
See https://platform.openai.com/docs/api-reference/chat
|
56
|
+
"""
|
57
|
+
|
58
|
+
model: pg.typing.Annotated[
|
59
|
+
pg.typing.Enum(
|
60
|
+
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
|
61
|
+
),
|
62
|
+
'The name of the model to use.',
|
63
|
+
]
|
64
|
+
|
65
|
+
multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
|
66
|
+
False
|
67
|
+
)
|
68
|
+
|
69
|
+
api_key: Annotated[
|
70
|
+
str | None,
|
71
|
+
(
|
72
|
+
'API key. If None, the key will be read from environment variable '
|
73
|
+
"'GROQ_API_KEY'."
|
74
|
+
),
|
75
|
+
] = None
|
76
|
+
|
77
|
+
def _on_bound(self):
|
78
|
+
super()._on_bound()
|
79
|
+
self._api_key = None
|
80
|
+
self.__dict__.pop('_api_initialized', None)
|
81
|
+
self.__dict__.pop('_session', None)
|
82
|
+
|
83
|
+
@functools.cached_property
|
84
|
+
def _api_initialized(self):
|
85
|
+
api_key = self.api_key or os.environ.get('GROQ_API_KEY', None)
|
86
|
+
if not api_key:
|
87
|
+
raise ValueError(
|
88
|
+
'Please specify `api_key` during `__init__` or set environment '
|
89
|
+
'variable `GROQ_API_KEY` with your Groq API key.'
|
90
|
+
)
|
91
|
+
self._api_key = api_key
|
92
|
+
return True
|
93
|
+
|
94
|
+
@functools.cached_property
|
95
|
+
def _session(self) -> requests.Session:
|
96
|
+
assert self._api_initialized
|
97
|
+
s = requests.Session()
|
98
|
+
s.headers.update({
|
99
|
+
'Authorization': f'Bearer {self._api_key}',
|
100
|
+
'Content-Type': 'application/json',
|
101
|
+
})
|
102
|
+
return s
|
103
|
+
|
104
|
+
@property
|
105
|
+
def model_id(self) -> str:
|
106
|
+
"""Returns a string to identify the model."""
|
107
|
+
return self.model
|
108
|
+
|
109
|
+
@property
|
110
|
+
def max_concurrency(self) -> int:
|
111
|
+
return SUPPORTED_MODELS_AND_SETTINGS[self.model].max_concurrency
|
112
|
+
|
113
|
+
def _get_request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
114
|
+
"""Returns a dict as request arguments."""
|
115
|
+
# `logprobs` and `top_logprobs` flags are not supported on Groq yet.
|
116
|
+
args = dict(
|
117
|
+
model=self.model,
|
118
|
+
n=options.n,
|
119
|
+
stream=False,
|
120
|
+
)
|
121
|
+
|
122
|
+
if options.temperature is not None:
|
123
|
+
args['temperature'] = options.temperature
|
124
|
+
if options.max_tokens is not None:
|
125
|
+
args['max_tokens'] = options.max_tokens
|
126
|
+
if options.top_p is not None:
|
127
|
+
args['top_p'] = options.top_p
|
128
|
+
if options.stop:
|
129
|
+
args['stop'] = options.stop
|
130
|
+
return args
|
131
|
+
|
132
|
+
def _content_from_message(self, prompt: lf.Message) -> list[dict[str, Any]]:
|
133
|
+
"""Converts an message to Groq's content protocol (list of dicts)."""
|
134
|
+
# Refer: https://platform.openai.com/docs/api-reference/chat/create
|
135
|
+
content = []
|
136
|
+
for chunk in prompt.chunk():
|
137
|
+
if isinstance(chunk, str):
|
138
|
+
item = dict(type='text', text=chunk)
|
139
|
+
elif (
|
140
|
+
self.multimodal
|
141
|
+
and isinstance(chunk, lf_modalities.Image)
|
142
|
+
and chunk.uri
|
143
|
+
):
|
144
|
+
# NOTE(daiyip): Groq only support image URL.
|
145
|
+
item = dict(type='image_url', image_url=chunk.uri)
|
146
|
+
else:
|
147
|
+
raise ValueError(f'Unsupported modality object: {chunk!r}.')
|
148
|
+
content.append(item)
|
149
|
+
return content
|
150
|
+
|
151
|
+
def _message_from_choice(self, choice: dict[str, Any]) -> lf.Message:
|
152
|
+
"""Converts Groq's content protocol to message."""
|
153
|
+
# Refer: https://platform.openai.com/docs/api-reference/chat/create
|
154
|
+
content = choice['message']['content']
|
155
|
+
if isinstance(content, str):
|
156
|
+
return lf.AIMessage(content)
|
157
|
+
return lf.AIMessage.from_chunks(
|
158
|
+
[x['text'] for x in content if x['type'] == 'text']
|
159
|
+
)
|
160
|
+
|
161
|
+
def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
|
162
|
+
"""Parses Groq's response."""
|
163
|
+
# Refer: https://platform.openai.com/docs/api-reference/chat/object
|
164
|
+
if response.status_code == 200:
|
165
|
+
output = response.json()
|
166
|
+
samples = [
|
167
|
+
lf.LMSample(self._message_from_choice(choice), score=0.0)
|
168
|
+
for choice in output['choices']
|
169
|
+
]
|
170
|
+
usage = output['usage']
|
171
|
+
return lf.LMSamplingResult(
|
172
|
+
samples,
|
173
|
+
usage=lf.LMSamplingUsage(
|
174
|
+
prompt_tokens=usage['prompt_tokens'],
|
175
|
+
completion_tokens=usage['completion_tokens'],
|
176
|
+
total_tokens=usage['total_tokens'],
|
177
|
+
),
|
178
|
+
)
|
179
|
+
else:
|
180
|
+
# https://platform.openai.com/docs/guides/error-codes/api-errors
|
181
|
+
if response.status_code == 429:
|
182
|
+
error_cls = RateLimitError
|
183
|
+
elif response.status_code in (500, 502, 503):
|
184
|
+
error_cls = OverloadedError
|
185
|
+
else:
|
186
|
+
error_cls = GroqError
|
187
|
+
raise error_cls(f'{response.status_code}: {response.content}')
|
188
|
+
|
189
|
+
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
190
|
+
assert self._api_initialized
|
191
|
+
return self._parallel_execute_with_currency_control(
|
192
|
+
self._sample_single,
|
193
|
+
prompts,
|
194
|
+
retry_on_errors=(RateLimitError, OverloadedError),
|
195
|
+
)
|
196
|
+
|
197
|
+
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
198
|
+
request = dict()
|
199
|
+
request.update(self._get_request_args(self.sampling_options))
|
200
|
+
request.update(
|
201
|
+
dict(
|
202
|
+
messages=[
|
203
|
+
dict(role='user', content=self._content_from_message(prompt))
|
204
|
+
]
|
205
|
+
)
|
206
|
+
)
|
207
|
+
try:
|
208
|
+
response = self._session.post(
|
209
|
+
_CHAT_COMPLETE_API_ENDPOINT,
|
210
|
+
json=request,
|
211
|
+
timeout=self.timeout,
|
212
|
+
)
|
213
|
+
return self._parse_response(response)
|
214
|
+
except ConnectionError as e:
|
215
|
+
raise OverloadedError(str(e)) from e
|
216
|
+
|
217
|
+
|
218
|
+
class GroqLlama3_8B(Groq): # pylint: disable=invalid-name
|
219
|
+
"""Llama3-8B with 8K context window.
|
220
|
+
|
221
|
+
See: https://huggingface.co/meta-llama/Meta-Llama-3-8B
|
222
|
+
"""
|
223
|
+
|
224
|
+
model = 'llama3-8b-8192'
|
225
|
+
|
226
|
+
|
227
|
+
class GroqLlama3_70B(Groq): # pylint: disable=invalid-name
|
228
|
+
"""Llama3-70B with 8K context window.
|
229
|
+
|
230
|
+
See: https://huggingface.co/meta-llama/Meta-Llama-3-70B
|
231
|
+
"""
|
232
|
+
|
233
|
+
model = 'llama3-70b-8192'
|
234
|
+
|
235
|
+
|
236
|
+
class GroqLlama2_70B(Groq): # pylint: disable=invalid-name
|
237
|
+
"""Llama2-70B with 4K context window.
|
238
|
+
|
239
|
+
See: https://huggingface.co/meta-llama/Llama-2-70b
|
240
|
+
"""
|
241
|
+
|
242
|
+
model = 'llama2-70b-4096'
|
243
|
+
|
244
|
+
|
245
|
+
class GroqMistral_8x7B(Groq): # pylint: disable=invalid-name
|
246
|
+
"""Mixtral 8x7B with 32K context window.
|
247
|
+
|
248
|
+
See: https://huggingface.co/meta-llama/Llama-2-70b
|
249
|
+
"""
|
250
|
+
|
251
|
+
model = 'mixtral-8x7b-32768'
|
252
|
+
|
253
|
+
|
254
|
+
class GroqGemma7B_IT(Groq): # pylint: disable=invalid-name
|
255
|
+
"""Gemma 7B with 8K context window.
|
256
|
+
|
257
|
+
See: https://huggingface.co/google/gemma-1.1-7b-it
|
258
|
+
"""
|
259
|
+
|
260
|
+
model = 'gemma-7b-it'
|