langfun 0.0.2.dev20240601__tar.gz → 0.0.2.dev20240604__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/PKG-INFO +1 -1
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/__init__.py +1 -1
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/__init__.py +5 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/language_model.py +27 -1
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/__init__.py +3 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/anthropic.py +44 -80
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/anthropic_test.py +1 -1
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/groq.py +42 -87
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/groq_test.py +1 -1
- langfun-0.0.2.dev20240604/langfun/core/llms/llama_cpp.py +84 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/llama_cpp_test.py +14 -8
- langfun-0.0.2.dev20240604/langfun/core/llms/rest.py +112 -0
- langfun-0.0.2.dev20240604/langfun/core/llms/rest_test.py +111 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun.egg-info/PKG-INFO +1 -1
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun.egg-info/SOURCES.txt +2 -0
- langfun-0.0.2.dev20240601/langfun/core/llms/llama_cpp.py +0 -74
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/LICENSE +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/README.md +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/__init__.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/__init__.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/correction.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/correction_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/errors.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/errors_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/execution.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/execution_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/generation.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/generation_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/parsing.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/parsing_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/permissions.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/permissions_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/component.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/component_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/concurrent.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/concurrent_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/console.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/console_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/eval/__init__.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/eval/base.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/eval/base_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/eval/matching.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/eval/matching_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/eval/patching.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/eval/patching_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/eval/scoring.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/eval/scoring_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/langfunc.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/langfunc_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/language_model_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/cache/__init__.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/cache/base.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/cache/in_memory.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/cache/in_memory_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/fake.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/fake_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/google_genai.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/google_genai_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/openai.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/openai_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/vertexai.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/vertexai_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/memories/__init__.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/memories/conversation_history.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/memories/conversation_history_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/memory.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/message.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/message_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/__init__.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/audio.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/audio_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/image.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/image_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/mime.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/mime_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/ms_office.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/ms_office_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/pdf.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/pdf_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/video.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/video_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modality.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modality_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/natural_language.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/natural_language_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/sampling.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/sampling_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/__init__.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/completion.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/completion_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/description.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/description_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/function_generation.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/function_generation_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/mapping.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/mapping_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/parsing.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/parsing_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/prompting.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/prompting_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/schema.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/schema_generation.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/schema_generation_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/schema_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/scoring.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/scoring_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/subscription.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/subscription_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/template.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/template_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/templates/__init__.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/templates/completion.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/templates/completion_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/templates/conversation.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/templates/conversation_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/templates/demonstration.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/templates/demonstration_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/templates/selfplay.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/templates/selfplay_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/text_formatting.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/text_formatting_test.py +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun.egg-info/dependency_links.txt +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun.egg-info/requires.txt +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun.egg-info/top_level.txt +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/setup.cfg +0 -0
- {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/setup.py +0 -0
@@ -106,6 +106,11 @@ from langfun.core.language_model import LMScoringResult
|
|
106
106
|
from langfun.core.language_model import LMCache
|
107
107
|
from langfun.core.language_model import LMDebugMode
|
108
108
|
|
109
|
+
from langfun.core.language_model import LMError
|
110
|
+
from langfun.core.language_model import RetryableLMError
|
111
|
+
from langfun.core.language_model import RateLimitError
|
112
|
+
from langfun.core.language_model import TemporaryLMError
|
113
|
+
|
109
114
|
# Components for building agents.
|
110
115
|
from langfun.core.memory import Memory
|
111
116
|
|
@@ -29,6 +29,32 @@ TOKENS_PER_REQUEST = 250 # Estimated num tokens for a single request
|
|
29
29
|
DEFAULT_MAX_CONCURRENCY = 1 # Use this as max concurrency if no RPM or TPM data
|
30
30
|
|
31
31
|
|
32
|
+
#
|
33
|
+
# Common errors during calling language models.
|
34
|
+
#
|
35
|
+
|
36
|
+
|
37
|
+
class LMError(RuntimeError):
|
38
|
+
"""Base class for language model errors."""
|
39
|
+
|
40
|
+
|
41
|
+
class RetryableLMError(LMError):
|
42
|
+
"""Base class for LLM errors that can be solved by retrying."""
|
43
|
+
|
44
|
+
|
45
|
+
class RateLimitError(RetryableLMError):
|
46
|
+
"""Error for rate limit reached."""
|
47
|
+
|
48
|
+
|
49
|
+
class TemporaryLMError(RetryableLMError):
|
50
|
+
"""Error for temporary service issues that can be retried."""
|
51
|
+
|
52
|
+
|
53
|
+
#
|
54
|
+
# Language model input/output interfaces.
|
55
|
+
#
|
56
|
+
|
57
|
+
|
32
58
|
class LMSample(pg.Object):
|
33
59
|
"""Response candidate."""
|
34
60
|
|
@@ -445,7 +471,7 @@ class LanguageModel(component.Component):
|
|
445
471
|
None,
|
446
472
|
Union[Type[Exception], Tuple[Type[Exception], str]],
|
447
473
|
Sequence[Union[Type[Exception], Tuple[Type[Exception], str]]],
|
448
|
-
] =
|
474
|
+
] = RetryableLMError,
|
449
475
|
) -> Any:
|
450
476
|
"""Helper method for subclasses for implementing _sample."""
|
451
477
|
return concurrent.concurrent_execute(
|
@@ -24,6 +24,9 @@ from langfun.core.llms.fake import StaticMapping
|
|
24
24
|
from langfun.core.llms.fake import StaticResponse
|
25
25
|
from langfun.core.llms.fake import StaticSequence
|
26
26
|
|
27
|
+
# REST-based models.
|
28
|
+
from langfun.core.llms.rest import REST
|
29
|
+
|
27
30
|
# Gemini models.
|
28
31
|
from langfun.core.llms.google_genai import GenAI
|
29
32
|
from langfun.core.llms.google_genai import GeminiPro
|
@@ -14,14 +14,13 @@
|
|
14
14
|
"""Language models from Anthropic."""
|
15
15
|
|
16
16
|
import base64
|
17
|
-
import functools
|
18
17
|
import os
|
19
18
|
from typing import Annotated, Any
|
20
19
|
|
21
20
|
import langfun.core as lf
|
22
21
|
from langfun.core import modalities as lf_modalities
|
22
|
+
from langfun.core.llms import rest
|
23
23
|
import pyglove as pg
|
24
|
-
import requests
|
25
24
|
|
26
25
|
|
27
26
|
SUPPORTED_MODELS_AND_SETTINGS = {
|
@@ -38,24 +37,8 @@ SUPPORTED_MODELS_AND_SETTINGS = {
|
|
38
37
|
}
|
39
38
|
|
40
39
|
|
41
|
-
class AnthropicError(Exception): # pylint: disable=g-bad-exception-name
|
42
|
-
"""Base class for Anthropic errors."""
|
43
|
-
|
44
|
-
|
45
|
-
class RateLimitError(AnthropicError):
|
46
|
-
"""Error for rate limit reached."""
|
47
|
-
|
48
|
-
|
49
|
-
class OverloadedError(AnthropicError):
|
50
|
-
"""Anthropic's server is temporarily overloaded."""
|
51
|
-
|
52
|
-
|
53
|
-
_ANTHROPIC_MESSAGE_API_ENDPOINT = 'https://api.anthropic.com/v1/messages'
|
54
|
-
_ANTHROPIC_API_VERSION = '2023-06-01'
|
55
|
-
|
56
|
-
|
57
40
|
@lf.use_init_args(['model'])
|
58
|
-
class Anthropic(
|
41
|
+
class Anthropic(rest.REST):
|
59
42
|
"""Anthropic LLMs (Claude) through REST APIs.
|
60
43
|
|
61
44
|
See https://docs.anthropic.com/claude/reference/messages_post
|
@@ -80,14 +63,18 @@ class Anthropic(lf.LanguageModel):
|
|
80
63
|
),
|
81
64
|
] = None
|
82
65
|
|
66
|
+
api_endpoint: str = 'https://api.anthropic.com/v1/messages'
|
67
|
+
|
68
|
+
api_version: Annotated[
|
69
|
+
str,
|
70
|
+
'Anthropic API version.'
|
71
|
+
] = '2023-06-01'
|
72
|
+
|
83
73
|
def _on_bound(self):
|
84
74
|
super()._on_bound()
|
85
75
|
self._api_key = None
|
86
|
-
self.__dict__.pop('_api_initialized', None)
|
87
|
-
self.__dict__.pop('_session', None)
|
88
76
|
|
89
|
-
|
90
|
-
def _api_initialized(self):
|
77
|
+
def _initialize(self):
|
91
78
|
api_key = self.api_key or os.environ.get('ANTHROPIC_API_KEY', None)
|
92
79
|
if not api_key:
|
93
80
|
raise ValueError(
|
@@ -95,18 +82,14 @@ class Anthropic(lf.LanguageModel):
|
|
95
82
|
'variable `ANTHROPIC_API_KEY` with your Anthropic API key.'
|
96
83
|
)
|
97
84
|
self._api_key = api_key
|
98
|
-
return True
|
99
85
|
|
100
|
-
@
|
101
|
-
def
|
102
|
-
|
103
|
-
s = requests.Session()
|
104
|
-
s.headers.update({
|
86
|
+
@property
|
87
|
+
def headers(self) -> dict[str, Any]:
|
88
|
+
return {
|
105
89
|
'x-api-key': self._api_key,
|
106
|
-
'anthropic-version':
|
90
|
+
'anthropic-version': self.api_version,
|
107
91
|
'content-type': 'application/json',
|
108
|
-
}
|
109
|
-
return s
|
92
|
+
}
|
110
93
|
|
111
94
|
@property
|
112
95
|
def model_id(self) -> str:
|
@@ -121,13 +104,24 @@ class Anthropic(lf.LanguageModel):
|
|
121
104
|
requests_per_min=rpm, tokens_per_min=tpm
|
122
105
|
)
|
123
106
|
|
124
|
-
def
|
125
|
-
|
126
|
-
|
127
|
-
|
107
|
+
def request(
|
108
|
+
self,
|
109
|
+
prompt: lf.Message,
|
110
|
+
sampling_options: lf.LMSamplingOptions
|
111
|
+
) -> dict[str, Any]:
|
112
|
+
"""Returns the JSON input for a message."""
|
113
|
+
request = dict()
|
114
|
+
request.update(self._request_args(sampling_options))
|
115
|
+
request.update(
|
116
|
+
dict(
|
117
|
+
messages=[
|
118
|
+
dict(role='user', content=self._content_from_message(prompt))
|
119
|
+
]
|
120
|
+
)
|
128
121
|
)
|
122
|
+
return request
|
129
123
|
|
130
|
-
def
|
124
|
+
def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
131
125
|
"""Returns a dict as request arguments."""
|
132
126
|
# Authropic requires `max_tokens` to be specified.
|
133
127
|
max_tokens = (
|
@@ -174,6 +168,19 @@ class Anthropic(lf.LanguageModel):
|
|
174
168
|
else:
|
175
169
|
return [dict(type='text', text=prompt.text)]
|
176
170
|
|
171
|
+
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
172
|
+
message = self._message_from_content(json['content'])
|
173
|
+
input_tokens = json['usage']['input_tokens']
|
174
|
+
output_tokens = json['usage']['output_tokens']
|
175
|
+
return lf.LMSamplingResult(
|
176
|
+
[lf.LMSample(message)],
|
177
|
+
usage=lf.LMSamplingUsage(
|
178
|
+
prompt_tokens=input_tokens,
|
179
|
+
completion_tokens=output_tokens,
|
180
|
+
total_tokens=input_tokens + output_tokens,
|
181
|
+
),
|
182
|
+
)
|
183
|
+
|
177
184
|
def _message_from_content(self, content: list[dict[str, Any]]) -> lf.Message:
|
178
185
|
"""Converts Anthropic's content protocol to message."""
|
179
186
|
# Refer: https://docs.anthropic.com/claude/reference/messages-examples
|
@@ -181,49 +188,6 @@ class Anthropic(lf.LanguageModel):
|
|
181
188
|
[x['text'] for x in content if x['type'] == 'text']
|
182
189
|
)
|
183
190
|
|
184
|
-
def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
|
185
|
-
"""Parses Anthropic's response."""
|
186
|
-
# NOTE(daiyip): Refer https://docs.anthropic.com/claude/reference/errors
|
187
|
-
if response.status_code == 200:
|
188
|
-
output = response.json()
|
189
|
-
message = self._message_from_content(output['content'])
|
190
|
-
input_tokens = output['usage']['input_tokens']
|
191
|
-
output_tokens = output['usage']['output_tokens']
|
192
|
-
return lf.LMSamplingResult(
|
193
|
-
[lf.LMSample(message)],
|
194
|
-
usage=lf.LMSamplingUsage(
|
195
|
-
prompt_tokens=input_tokens,
|
196
|
-
completion_tokens=output_tokens,
|
197
|
-
total_tokens=input_tokens + output_tokens,
|
198
|
-
),
|
199
|
-
)
|
200
|
-
else:
|
201
|
-
if response.status_code == 429:
|
202
|
-
error_cls = RateLimitError
|
203
|
-
elif response.status_code in (502, 529):
|
204
|
-
error_cls = OverloadedError
|
205
|
-
else:
|
206
|
-
error_cls = AnthropicError
|
207
|
-
raise error_cls(f'{response.status_code}: {response.content}')
|
208
|
-
|
209
|
-
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
210
|
-
request = dict()
|
211
|
-
request.update(self._get_request_args(self.sampling_options))
|
212
|
-
request.update(
|
213
|
-
dict(
|
214
|
-
messages=[
|
215
|
-
dict(role='user', content=self._content_from_message(prompt))
|
216
|
-
]
|
217
|
-
)
|
218
|
-
)
|
219
|
-
try:
|
220
|
-
response = self._session.post(
|
221
|
-
_ANTHROPIC_MESSAGE_API_ENDPOINT, json=request, timeout=self.timeout,
|
222
|
-
)
|
223
|
-
return self._parse_response(response)
|
224
|
-
except ConnectionError as e:
|
225
|
-
raise OverloadedError(str(e)) from e
|
226
|
-
|
227
191
|
|
228
192
|
class Claude3(Anthropic):
|
229
193
|
"""Base class for Claude 3 models. 200K input tokens and 4K output tokens."""
|
@@ -13,14 +13,13 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Language models from Groq."""
|
15
15
|
|
16
|
-
import functools
|
17
16
|
import os
|
18
17
|
from typing import Annotated, Any
|
19
18
|
|
20
19
|
import langfun.core as lf
|
21
20
|
from langfun.core import modalities as lf_modalities
|
21
|
+
from langfun.core.llms import rest
|
22
22
|
import pyglove as pg
|
23
|
-
import requests
|
24
23
|
|
25
24
|
|
26
25
|
SUPPORTED_MODELS_AND_SETTINGS = {
|
@@ -33,23 +32,8 @@ SUPPORTED_MODELS_AND_SETTINGS = {
|
|
33
32
|
}
|
34
33
|
|
35
34
|
|
36
|
-
class GroqError(Exception): # pylint: disable=g-bad-exception-name
|
37
|
-
"""Base class for Groq errors."""
|
38
|
-
|
39
|
-
|
40
|
-
class RateLimitError(GroqError):
|
41
|
-
"""Error for rate limit reached."""
|
42
|
-
|
43
|
-
|
44
|
-
class OverloadedError(GroqError):
|
45
|
-
"""Groq's server is temporarily overloaded."""
|
46
|
-
|
47
|
-
|
48
|
-
_CHAT_COMPLETE_API_ENDPOINT = 'https://api.groq.com/openai/v1/chat/completions'
|
49
|
-
|
50
|
-
|
51
35
|
@lf.use_init_args(['model'])
|
52
|
-
class Groq(
|
36
|
+
class Groq(rest.REST):
|
53
37
|
"""Groq LLMs through REST APIs (OpenAI compatible).
|
54
38
|
|
55
39
|
See https://platform.openai.com/docs/api-reference/chat
|
@@ -74,14 +58,13 @@ class Groq(lf.LanguageModel):
|
|
74
58
|
),
|
75
59
|
] = None
|
76
60
|
|
61
|
+
api_endpoint: str = 'https://api.groq.com/openai/v1/chat/completions'
|
62
|
+
|
77
63
|
def _on_bound(self):
|
78
64
|
super()._on_bound()
|
79
65
|
self._api_key = None
|
80
|
-
self.__dict__.pop('_api_initialized', None)
|
81
|
-
self.__dict__.pop('_session', None)
|
82
66
|
|
83
|
-
|
84
|
-
def _api_initialized(self):
|
67
|
+
def _initialize(self):
|
85
68
|
api_key = self.api_key or os.environ.get('GROQ_API_KEY', None)
|
86
69
|
if not api_key:
|
87
70
|
raise ValueError(
|
@@ -89,17 +72,13 @@ class Groq(lf.LanguageModel):
|
|
89
72
|
'variable `GROQ_API_KEY` with your Groq API key.'
|
90
73
|
)
|
91
74
|
self._api_key = api_key
|
92
|
-
return True
|
93
75
|
|
94
|
-
@
|
95
|
-
def
|
96
|
-
|
97
|
-
s = requests.Session()
|
98
|
-
s.headers.update({
|
76
|
+
@property
|
77
|
+
def headers(self) -> dict[str, Any]:
|
78
|
+
return {
|
99
79
|
'Authorization': f'Bearer {self._api_key}',
|
100
80
|
'Content-Type': 'application/json',
|
101
|
-
}
|
102
|
-
return s
|
81
|
+
}
|
103
82
|
|
104
83
|
@property
|
105
84
|
def model_id(self) -> str:
|
@@ -110,7 +89,24 @@ class Groq(lf.LanguageModel):
|
|
110
89
|
def max_concurrency(self) -> int:
|
111
90
|
return SUPPORTED_MODELS_AND_SETTINGS[self.model].max_concurrency
|
112
91
|
|
113
|
-
def
|
92
|
+
def request(
|
93
|
+
self,
|
94
|
+
prompt: lf.Message,
|
95
|
+
sampling_options: lf.LMSamplingOptions
|
96
|
+
) -> dict[str, Any]:
|
97
|
+
"""Returns the JSON input for a message."""
|
98
|
+
request = dict()
|
99
|
+
request.update(self._request_args(sampling_options))
|
100
|
+
request.update(
|
101
|
+
dict(
|
102
|
+
messages=[
|
103
|
+
dict(role='user', content=self._content_from_message(prompt))
|
104
|
+
]
|
105
|
+
)
|
106
|
+
)
|
107
|
+
return request
|
108
|
+
|
109
|
+
def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
114
110
|
"""Returns a dict as request arguments."""
|
115
111
|
# `logprobs` and `top_logprobs` flags are not supported on Groq yet.
|
116
112
|
args = dict(
|
@@ -148,6 +144,21 @@ class Groq(lf.LanguageModel):
|
|
148
144
|
content.append(item)
|
149
145
|
return content
|
150
146
|
|
147
|
+
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
148
|
+
samples = [
|
149
|
+
lf.LMSample(self._message_from_choice(choice), score=0.0)
|
150
|
+
for choice in json['choices']
|
151
|
+
]
|
152
|
+
usage = json['usage']
|
153
|
+
return lf.LMSamplingResult(
|
154
|
+
samples,
|
155
|
+
usage=lf.LMSamplingUsage(
|
156
|
+
prompt_tokens=usage['prompt_tokens'],
|
157
|
+
completion_tokens=usage['completion_tokens'],
|
158
|
+
total_tokens=usage['total_tokens'],
|
159
|
+
),
|
160
|
+
)
|
161
|
+
|
151
162
|
def _message_from_choice(self, choice: dict[str, Any]) -> lf.Message:
|
152
163
|
"""Converts Groq's content protocol to message."""
|
153
164
|
# Refer: https://platform.openai.com/docs/api-reference/chat/create
|
@@ -158,62 +169,6 @@ class Groq(lf.LanguageModel):
|
|
158
169
|
[x['text'] for x in content if x['type'] == 'text']
|
159
170
|
)
|
160
171
|
|
161
|
-
def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
|
162
|
-
"""Parses Groq's response."""
|
163
|
-
# Refer: https://platform.openai.com/docs/api-reference/chat/object
|
164
|
-
if response.status_code == 200:
|
165
|
-
output = response.json()
|
166
|
-
samples = [
|
167
|
-
lf.LMSample(self._message_from_choice(choice), score=0.0)
|
168
|
-
for choice in output['choices']
|
169
|
-
]
|
170
|
-
usage = output['usage']
|
171
|
-
return lf.LMSamplingResult(
|
172
|
-
samples,
|
173
|
-
usage=lf.LMSamplingUsage(
|
174
|
-
prompt_tokens=usage['prompt_tokens'],
|
175
|
-
completion_tokens=usage['completion_tokens'],
|
176
|
-
total_tokens=usage['total_tokens'],
|
177
|
-
),
|
178
|
-
)
|
179
|
-
else:
|
180
|
-
# https://platform.openai.com/docs/guides/error-codes/api-errors
|
181
|
-
if response.status_code == 429:
|
182
|
-
error_cls = RateLimitError
|
183
|
-
elif response.status_code in (500, 502, 503):
|
184
|
-
error_cls = OverloadedError
|
185
|
-
else:
|
186
|
-
error_cls = GroqError
|
187
|
-
raise error_cls(f'{response.status_code}: {response.content}')
|
188
|
-
|
189
|
-
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
190
|
-
assert self._api_initialized
|
191
|
-
return self._parallel_execute_with_currency_control(
|
192
|
-
self._sample_single,
|
193
|
-
prompts,
|
194
|
-
retry_on_errors=(RateLimitError, OverloadedError),
|
195
|
-
)
|
196
|
-
|
197
|
-
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
198
|
-
request = dict()
|
199
|
-
request.update(self._get_request_args(self.sampling_options))
|
200
|
-
request.update(
|
201
|
-
dict(
|
202
|
-
messages=[
|
203
|
-
dict(role='user', content=self._content_from_message(prompt))
|
204
|
-
]
|
205
|
-
)
|
206
|
-
)
|
207
|
-
try:
|
208
|
-
response = self._session.post(
|
209
|
-
_CHAT_COMPLETE_API_ENDPOINT,
|
210
|
-
json=request,
|
211
|
-
timeout=self.timeout,
|
212
|
-
)
|
213
|
-
return self._parse_response(response)
|
214
|
-
except ConnectionError as e:
|
215
|
-
raise OverloadedError(str(e)) from e
|
216
|
-
|
217
172
|
|
218
173
|
class GroqLlama3_8B(Groq): # pylint: disable=invalid-name
|
219
174
|
"""Llama3-8B with 8K context window.
|
@@ -0,0 +1,84 @@
|
|
1
|
+
# Copyright 2023 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Language models from llama.cpp."""
|
15
|
+
|
16
|
+
from typing import Any
|
17
|
+
|
18
|
+
import langfun.core as lf
|
19
|
+
from langfun.core.llms import rest
|
20
|
+
import pyglove as pg
|
21
|
+
|
22
|
+
|
23
|
+
class LlamaCppRemote(rest.REST):
|
24
|
+
"""The remote LLaMA C++ model.
|
25
|
+
|
26
|
+
The Remote LLaMA C++ models can be launched via
|
27
|
+
https://github.com/ggerganov/llama.cpp/tree/master/examples/server
|
28
|
+
"""
|
29
|
+
|
30
|
+
@pg.explicit_method_override
|
31
|
+
def __init__(self, url: str, model: str | None = None, **kwargs):
|
32
|
+
super().__init__(api_endpoint=f'{url}/completion', model=model, **kwargs)
|
33
|
+
|
34
|
+
@property
|
35
|
+
def model_id(self) -> str:
|
36
|
+
"""Returns a string to identify the model."""
|
37
|
+
return f'LLaMAC++({self.model or ""})'
|
38
|
+
|
39
|
+
def request(
|
40
|
+
self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
|
41
|
+
) -> dict[str, Any]:
|
42
|
+
"""Returns the JSON input for a message."""
|
43
|
+
request = dict()
|
44
|
+
request.update(self._request_args(sampling_options))
|
45
|
+
# NOTE(daiyip): multi-modal is current not supported.
|
46
|
+
request['prompt'] = prompt.text
|
47
|
+
return request
|
48
|
+
|
49
|
+
def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
50
|
+
"""Returns a dict as request arguments."""
|
51
|
+
args = dict(
|
52
|
+
n_predict=options.max_tokens or 1024,
|
53
|
+
top_k=options.top_k or 50,
|
54
|
+
top_p=options.top_p or 0.95,
|
55
|
+
)
|
56
|
+
if options.temperature is not None:
|
57
|
+
args['temperature'] = options.temperature
|
58
|
+
return args
|
59
|
+
|
60
|
+
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
61
|
+
return lf.LMSamplingResult(
|
62
|
+
[lf.LMSample(item['content'], score=0.0) for item in json['items']]
|
63
|
+
)
|
64
|
+
|
65
|
+
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
66
|
+
request = self.request(prompt, self.sampling_options)
|
67
|
+
|
68
|
+
def _sample_one_example(request):
|
69
|
+
response = self._session.post(
|
70
|
+
self.api_endpoint,
|
71
|
+
json=request,
|
72
|
+
timeout=self.timeout,
|
73
|
+
)
|
74
|
+
if response.status_code == 200:
|
75
|
+
return response.json()
|
76
|
+
else:
|
77
|
+
error_cls = self._error_cls_from_status(response.status_code)
|
78
|
+
raise error_cls(f'{response.status_code}: {response.content}')
|
79
|
+
|
80
|
+
items = self._parallel_execute_with_currency_control(
|
81
|
+
_sample_one_example,
|
82
|
+
[request] * (self.sampling_options.n or 1),
|
83
|
+
)
|
84
|
+
return self.result(dict(items=items))
|
@@ -17,7 +17,6 @@ import typing
|
|
17
17
|
import unittest
|
18
18
|
from unittest import mock
|
19
19
|
|
20
|
-
import langfun.core as lf
|
21
20
|
from langfun.core.llms import llama_cpp
|
22
21
|
|
23
22
|
|
@@ -25,6 +24,9 @@ def mock_requests_post(url: str, json: typing.Dict[str, typing.Any], **kwargs):
|
|
25
24
|
del kwargs
|
26
25
|
|
27
26
|
class TEMP:
|
27
|
+
@property
|
28
|
+
def status_code(self):
|
29
|
+
return 200
|
28
30
|
|
29
31
|
def json(self):
|
30
32
|
return {"content": json["prompt"] + "\n" + url}
|
@@ -36,19 +38,23 @@ class LlamaCppRemoteTest(unittest.TestCase):
|
|
36
38
|
"""Tests for the LlamaCppRemote model."""
|
37
39
|
|
38
40
|
def test_call_completion(self):
|
39
|
-
with mock.patch("requests.post") as mock_request:
|
41
|
+
with mock.patch("requests.Session.post") as mock_request:
|
40
42
|
mock_request.side_effect = mock_requests_post
|
41
|
-
lm = llama_cpp.LlamaCppRemote(
|
42
|
-
|
43
|
+
lm = llama_cpp.LlamaCppRemote("http://127.0.0.1:8080")
|
44
|
+
[result] = lm.sample(["hello"], n=2)
|
43
45
|
self.assertEqual(
|
44
|
-
|
46
|
+
len(result.samples),
|
47
|
+
2
|
48
|
+
)
|
49
|
+
self.assertEqual(
|
50
|
+
str(result.samples[0].response),
|
45
51
|
"hello\nhttp://127.0.0.1:8080/completion",
|
46
52
|
)
|
47
53
|
|
48
|
-
def
|
49
|
-
lm = llama_cpp.LlamaCppRemote()
|
54
|
+
def test_model_id(self):
|
55
|
+
lm = llama_cpp.LlamaCppRemote("http://127.0.0.1:8080")
|
50
56
|
self.assertEqual(lm.model_id, "LLaMAC++()")
|
51
|
-
lm = llama_cpp.LlamaCppRemote(
|
57
|
+
lm = llama_cpp.LlamaCppRemote("xxx", model="x")
|
52
58
|
self.assertEqual(lm.model_id, "LLaMAC++(x)")
|
53
59
|
|
54
60
|
|