langfun 0.0.2.dev20240603__py3-none-any.whl → 0.0.2.dev20240604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/__init__.py +1 -1
- langfun/core/__init__.py +5 -0
- langfun/core/language_model.py +27 -1
- langfun/core/llms/__init__.py +3 -0
- langfun/core/llms/anthropic.py +44 -80
- langfun/core/llms/anthropic_test.py +1 -1
- langfun/core/llms/groq.py +42 -87
- langfun/core/llms/groq_test.py +1 -1
- langfun/core/llms/llama_cpp.py +52 -42
- langfun/core/llms/llama_cpp_test.py +14 -8
- langfun/core/llms/rest.py +112 -0
- langfun/core/llms/rest_test.py +111 -0
- {langfun-0.0.2.dev20240603.dist-info → langfun-0.0.2.dev20240604.dist-info}/METADATA +1 -1
- {langfun-0.0.2.dev20240603.dist-info → langfun-0.0.2.dev20240604.dist-info}/RECORD +17 -15
- {langfun-0.0.2.dev20240603.dist-info → langfun-0.0.2.dev20240604.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240603.dist-info → langfun-0.0.2.dev20240604.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240603.dist-info → langfun-0.0.2.dev20240604.dist-info}/top_level.txt +0 -0
langfun/__init__.py
CHANGED
langfun/core/__init__.py
CHANGED
@@ -106,6 +106,11 @@ from langfun.core.language_model import LMScoringResult
|
|
106
106
|
from langfun.core.language_model import LMCache
|
107
107
|
from langfun.core.language_model import LMDebugMode
|
108
108
|
|
109
|
+
from langfun.core.language_model import LMError
|
110
|
+
from langfun.core.language_model import RetryableLMError
|
111
|
+
from langfun.core.language_model import RateLimitError
|
112
|
+
from langfun.core.language_model import TemporaryLMError
|
113
|
+
|
109
114
|
# Components for building agents.
|
110
115
|
from langfun.core.memory import Memory
|
111
116
|
|
langfun/core/language_model.py
CHANGED
@@ -29,6 +29,32 @@ TOKENS_PER_REQUEST = 250 # Estimated num tokens for a single request
|
|
29
29
|
DEFAULT_MAX_CONCURRENCY = 1 # Use this as max concurrency if no RPM or TPM data
|
30
30
|
|
31
31
|
|
32
|
+
#
|
33
|
+
# Common errors during calling language models.
|
34
|
+
#
|
35
|
+
|
36
|
+
|
37
|
+
class LMError(RuntimeError):
|
38
|
+
"""Base class for language model errors."""
|
39
|
+
|
40
|
+
|
41
|
+
class RetryableLMError(LMError):
|
42
|
+
"""Base class for LLM errors that can be solved by retrying."""
|
43
|
+
|
44
|
+
|
45
|
+
class RateLimitError(RetryableLMError):
|
46
|
+
"""Error for rate limit reached."""
|
47
|
+
|
48
|
+
|
49
|
+
class TemporaryLMError(RetryableLMError):
|
50
|
+
"""Error for temporary service issues that can be retried."""
|
51
|
+
|
52
|
+
|
53
|
+
#
|
54
|
+
# Language model input/output interfaces.
|
55
|
+
#
|
56
|
+
|
57
|
+
|
32
58
|
class LMSample(pg.Object):
|
33
59
|
"""Response candidate."""
|
34
60
|
|
@@ -445,7 +471,7 @@ class LanguageModel(component.Component):
|
|
445
471
|
None,
|
446
472
|
Union[Type[Exception], Tuple[Type[Exception], str]],
|
447
473
|
Sequence[Union[Type[Exception], Tuple[Type[Exception], str]]],
|
448
|
-
] =
|
474
|
+
] = RetryableLMError,
|
449
475
|
) -> Any:
|
450
476
|
"""Helper method for subclasses for implementing _sample."""
|
451
477
|
return concurrent.concurrent_execute(
|
langfun/core/llms/__init__.py
CHANGED
@@ -24,6 +24,9 @@ from langfun.core.llms.fake import StaticMapping
|
|
24
24
|
from langfun.core.llms.fake import StaticResponse
|
25
25
|
from langfun.core.llms.fake import StaticSequence
|
26
26
|
|
27
|
+
# REST-based models.
|
28
|
+
from langfun.core.llms.rest import REST
|
29
|
+
|
27
30
|
# Gemini models.
|
28
31
|
from langfun.core.llms.google_genai import GenAI
|
29
32
|
from langfun.core.llms.google_genai import GeminiPro
|
langfun/core/llms/anthropic.py
CHANGED
@@ -14,14 +14,13 @@
|
|
14
14
|
"""Language models from Anthropic."""
|
15
15
|
|
16
16
|
import base64
|
17
|
-
import functools
|
18
17
|
import os
|
19
18
|
from typing import Annotated, Any
|
20
19
|
|
21
20
|
import langfun.core as lf
|
22
21
|
from langfun.core import modalities as lf_modalities
|
22
|
+
from langfun.core.llms import rest
|
23
23
|
import pyglove as pg
|
24
|
-
import requests
|
25
24
|
|
26
25
|
|
27
26
|
SUPPORTED_MODELS_AND_SETTINGS = {
|
@@ -38,24 +37,8 @@ SUPPORTED_MODELS_AND_SETTINGS = {
|
|
38
37
|
}
|
39
38
|
|
40
39
|
|
41
|
-
class AnthropicError(Exception): # pylint: disable=g-bad-exception-name
|
42
|
-
"""Base class for Anthropic errors."""
|
43
|
-
|
44
|
-
|
45
|
-
class RateLimitError(AnthropicError):
|
46
|
-
"""Error for rate limit reached."""
|
47
|
-
|
48
|
-
|
49
|
-
class OverloadedError(AnthropicError):
|
50
|
-
"""Anthropic's server is temporarily overloaded."""
|
51
|
-
|
52
|
-
|
53
|
-
_ANTHROPIC_MESSAGE_API_ENDPOINT = 'https://api.anthropic.com/v1/messages'
|
54
|
-
_ANTHROPIC_API_VERSION = '2023-06-01'
|
55
|
-
|
56
|
-
|
57
40
|
@lf.use_init_args(['model'])
|
58
|
-
class Anthropic(
|
41
|
+
class Anthropic(rest.REST):
|
59
42
|
"""Anthropic LLMs (Claude) through REST APIs.
|
60
43
|
|
61
44
|
See https://docs.anthropic.com/claude/reference/messages_post
|
@@ -80,14 +63,18 @@ class Anthropic(lf.LanguageModel):
|
|
80
63
|
),
|
81
64
|
] = None
|
82
65
|
|
66
|
+
api_endpoint: str = 'https://api.anthropic.com/v1/messages'
|
67
|
+
|
68
|
+
api_version: Annotated[
|
69
|
+
str,
|
70
|
+
'Anthropic API version.'
|
71
|
+
] = '2023-06-01'
|
72
|
+
|
83
73
|
def _on_bound(self):
|
84
74
|
super()._on_bound()
|
85
75
|
self._api_key = None
|
86
|
-
self.__dict__.pop('_api_initialized', None)
|
87
|
-
self.__dict__.pop('_session', None)
|
88
76
|
|
89
|
-
|
90
|
-
def _api_initialized(self):
|
77
|
+
def _initialize(self):
|
91
78
|
api_key = self.api_key or os.environ.get('ANTHROPIC_API_KEY', None)
|
92
79
|
if not api_key:
|
93
80
|
raise ValueError(
|
@@ -95,18 +82,14 @@ class Anthropic(lf.LanguageModel):
|
|
95
82
|
'variable `ANTHROPIC_API_KEY` with your Anthropic API key.'
|
96
83
|
)
|
97
84
|
self._api_key = api_key
|
98
|
-
return True
|
99
85
|
|
100
|
-
@
|
101
|
-
def
|
102
|
-
|
103
|
-
s = requests.Session()
|
104
|
-
s.headers.update({
|
86
|
+
@property
|
87
|
+
def headers(self) -> dict[str, Any]:
|
88
|
+
return {
|
105
89
|
'x-api-key': self._api_key,
|
106
|
-
'anthropic-version':
|
90
|
+
'anthropic-version': self.api_version,
|
107
91
|
'content-type': 'application/json',
|
108
|
-
}
|
109
|
-
return s
|
92
|
+
}
|
110
93
|
|
111
94
|
@property
|
112
95
|
def model_id(self) -> str:
|
@@ -121,13 +104,24 @@ class Anthropic(lf.LanguageModel):
|
|
121
104
|
requests_per_min=rpm, tokens_per_min=tpm
|
122
105
|
)
|
123
106
|
|
124
|
-
def
|
125
|
-
|
126
|
-
|
127
|
-
|
107
|
+
def request(
|
108
|
+
self,
|
109
|
+
prompt: lf.Message,
|
110
|
+
sampling_options: lf.LMSamplingOptions
|
111
|
+
) -> dict[str, Any]:
|
112
|
+
"""Returns the JSON input for a message."""
|
113
|
+
request = dict()
|
114
|
+
request.update(self._request_args(sampling_options))
|
115
|
+
request.update(
|
116
|
+
dict(
|
117
|
+
messages=[
|
118
|
+
dict(role='user', content=self._content_from_message(prompt))
|
119
|
+
]
|
120
|
+
)
|
128
121
|
)
|
122
|
+
return request
|
129
123
|
|
130
|
-
def
|
124
|
+
def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
131
125
|
"""Returns a dict as request arguments."""
|
132
126
|
# Authropic requires `max_tokens` to be specified.
|
133
127
|
max_tokens = (
|
@@ -174,6 +168,19 @@ class Anthropic(lf.LanguageModel):
|
|
174
168
|
else:
|
175
169
|
return [dict(type='text', text=prompt.text)]
|
176
170
|
|
171
|
+
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
172
|
+
message = self._message_from_content(json['content'])
|
173
|
+
input_tokens = json['usage']['input_tokens']
|
174
|
+
output_tokens = json['usage']['output_tokens']
|
175
|
+
return lf.LMSamplingResult(
|
176
|
+
[lf.LMSample(message)],
|
177
|
+
usage=lf.LMSamplingUsage(
|
178
|
+
prompt_tokens=input_tokens,
|
179
|
+
completion_tokens=output_tokens,
|
180
|
+
total_tokens=input_tokens + output_tokens,
|
181
|
+
),
|
182
|
+
)
|
183
|
+
|
177
184
|
def _message_from_content(self, content: list[dict[str, Any]]) -> lf.Message:
|
178
185
|
"""Converts Anthropic's content protocol to message."""
|
179
186
|
# Refer: https://docs.anthropic.com/claude/reference/messages-examples
|
@@ -181,49 +188,6 @@ class Anthropic(lf.LanguageModel):
|
|
181
188
|
[x['text'] for x in content if x['type'] == 'text']
|
182
189
|
)
|
183
190
|
|
184
|
-
def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
|
185
|
-
"""Parses Anthropic's response."""
|
186
|
-
# NOTE(daiyip): Refer https://docs.anthropic.com/claude/reference/errors
|
187
|
-
if response.status_code == 200:
|
188
|
-
output = response.json()
|
189
|
-
message = self._message_from_content(output['content'])
|
190
|
-
input_tokens = output['usage']['input_tokens']
|
191
|
-
output_tokens = output['usage']['output_tokens']
|
192
|
-
return lf.LMSamplingResult(
|
193
|
-
[lf.LMSample(message)],
|
194
|
-
usage=lf.LMSamplingUsage(
|
195
|
-
prompt_tokens=input_tokens,
|
196
|
-
completion_tokens=output_tokens,
|
197
|
-
total_tokens=input_tokens + output_tokens,
|
198
|
-
),
|
199
|
-
)
|
200
|
-
else:
|
201
|
-
if response.status_code == 429:
|
202
|
-
error_cls = RateLimitError
|
203
|
-
elif response.status_code in (502, 529):
|
204
|
-
error_cls = OverloadedError
|
205
|
-
else:
|
206
|
-
error_cls = AnthropicError
|
207
|
-
raise error_cls(f'{response.status_code}: {response.content}')
|
208
|
-
|
209
|
-
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
210
|
-
request = dict()
|
211
|
-
request.update(self._get_request_args(self.sampling_options))
|
212
|
-
request.update(
|
213
|
-
dict(
|
214
|
-
messages=[
|
215
|
-
dict(role='user', content=self._content_from_message(prompt))
|
216
|
-
]
|
217
|
-
)
|
218
|
-
)
|
219
|
-
try:
|
220
|
-
response = self._session.post(
|
221
|
-
_ANTHROPIC_MESSAGE_API_ENDPOINT, json=request, timeout=self.timeout,
|
222
|
-
)
|
223
|
-
return self._parse_response(response)
|
224
|
-
except ConnectionError as e:
|
225
|
-
raise OverloadedError(str(e)) from e
|
226
|
-
|
227
191
|
|
228
192
|
class Claude3(Anthropic):
|
229
193
|
"""Base class for Claude 3 models. 200K input tokens and 4K output tokens."""
|
langfun/core/llms/groq.py
CHANGED
@@ -13,14 +13,13 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Language models from Groq."""
|
15
15
|
|
16
|
-
import functools
|
17
16
|
import os
|
18
17
|
from typing import Annotated, Any
|
19
18
|
|
20
19
|
import langfun.core as lf
|
21
20
|
from langfun.core import modalities as lf_modalities
|
21
|
+
from langfun.core.llms import rest
|
22
22
|
import pyglove as pg
|
23
|
-
import requests
|
24
23
|
|
25
24
|
|
26
25
|
SUPPORTED_MODELS_AND_SETTINGS = {
|
@@ -33,23 +32,8 @@ SUPPORTED_MODELS_AND_SETTINGS = {
|
|
33
32
|
}
|
34
33
|
|
35
34
|
|
36
|
-
class GroqError(Exception): # pylint: disable=g-bad-exception-name
|
37
|
-
"""Base class for Groq errors."""
|
38
|
-
|
39
|
-
|
40
|
-
class RateLimitError(GroqError):
|
41
|
-
"""Error for rate limit reached."""
|
42
|
-
|
43
|
-
|
44
|
-
class OverloadedError(GroqError):
|
45
|
-
"""Groq's server is temporarily overloaded."""
|
46
|
-
|
47
|
-
|
48
|
-
_CHAT_COMPLETE_API_ENDPOINT = 'https://api.groq.com/openai/v1/chat/completions'
|
49
|
-
|
50
|
-
|
51
35
|
@lf.use_init_args(['model'])
|
52
|
-
class Groq(
|
36
|
+
class Groq(rest.REST):
|
53
37
|
"""Groq LLMs through REST APIs (OpenAI compatible).
|
54
38
|
|
55
39
|
See https://platform.openai.com/docs/api-reference/chat
|
@@ -74,14 +58,13 @@ class Groq(lf.LanguageModel):
|
|
74
58
|
),
|
75
59
|
] = None
|
76
60
|
|
61
|
+
api_endpoint: str = 'https://api.groq.com/openai/v1/chat/completions'
|
62
|
+
|
77
63
|
def _on_bound(self):
|
78
64
|
super()._on_bound()
|
79
65
|
self._api_key = None
|
80
|
-
self.__dict__.pop('_api_initialized', None)
|
81
|
-
self.__dict__.pop('_session', None)
|
82
66
|
|
83
|
-
|
84
|
-
def _api_initialized(self):
|
67
|
+
def _initialize(self):
|
85
68
|
api_key = self.api_key or os.environ.get('GROQ_API_KEY', None)
|
86
69
|
if not api_key:
|
87
70
|
raise ValueError(
|
@@ -89,17 +72,13 @@ class Groq(lf.LanguageModel):
|
|
89
72
|
'variable `GROQ_API_KEY` with your Groq API key.'
|
90
73
|
)
|
91
74
|
self._api_key = api_key
|
92
|
-
return True
|
93
75
|
|
94
|
-
@
|
95
|
-
def
|
96
|
-
|
97
|
-
s = requests.Session()
|
98
|
-
s.headers.update({
|
76
|
+
@property
|
77
|
+
def headers(self) -> dict[str, Any]:
|
78
|
+
return {
|
99
79
|
'Authorization': f'Bearer {self._api_key}',
|
100
80
|
'Content-Type': 'application/json',
|
101
|
-
}
|
102
|
-
return s
|
81
|
+
}
|
103
82
|
|
104
83
|
@property
|
105
84
|
def model_id(self) -> str:
|
@@ -110,7 +89,24 @@ class Groq(lf.LanguageModel):
|
|
110
89
|
def max_concurrency(self) -> int:
|
111
90
|
return SUPPORTED_MODELS_AND_SETTINGS[self.model].max_concurrency
|
112
91
|
|
113
|
-
def
|
92
|
+
def request(
|
93
|
+
self,
|
94
|
+
prompt: lf.Message,
|
95
|
+
sampling_options: lf.LMSamplingOptions
|
96
|
+
) -> dict[str, Any]:
|
97
|
+
"""Returns the JSON input for a message."""
|
98
|
+
request = dict()
|
99
|
+
request.update(self._request_args(sampling_options))
|
100
|
+
request.update(
|
101
|
+
dict(
|
102
|
+
messages=[
|
103
|
+
dict(role='user', content=self._content_from_message(prompt))
|
104
|
+
]
|
105
|
+
)
|
106
|
+
)
|
107
|
+
return request
|
108
|
+
|
109
|
+
def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
114
110
|
"""Returns a dict as request arguments."""
|
115
111
|
# `logprobs` and `top_logprobs` flags are not supported on Groq yet.
|
116
112
|
args = dict(
|
@@ -148,6 +144,21 @@ class Groq(lf.LanguageModel):
|
|
148
144
|
content.append(item)
|
149
145
|
return content
|
150
146
|
|
147
|
+
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
148
|
+
samples = [
|
149
|
+
lf.LMSample(self._message_from_choice(choice), score=0.0)
|
150
|
+
for choice in json['choices']
|
151
|
+
]
|
152
|
+
usage = json['usage']
|
153
|
+
return lf.LMSamplingResult(
|
154
|
+
samples,
|
155
|
+
usage=lf.LMSamplingUsage(
|
156
|
+
prompt_tokens=usage['prompt_tokens'],
|
157
|
+
completion_tokens=usage['completion_tokens'],
|
158
|
+
total_tokens=usage['total_tokens'],
|
159
|
+
),
|
160
|
+
)
|
161
|
+
|
151
162
|
def _message_from_choice(self, choice: dict[str, Any]) -> lf.Message:
|
152
163
|
"""Converts Groq's content protocol to message."""
|
153
164
|
# Refer: https://platform.openai.com/docs/api-reference/chat/create
|
@@ -158,62 +169,6 @@ class Groq(lf.LanguageModel):
|
|
158
169
|
[x['text'] for x in content if x['type'] == 'text']
|
159
170
|
)
|
160
171
|
|
161
|
-
def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
|
162
|
-
"""Parses Groq's response."""
|
163
|
-
# Refer: https://platform.openai.com/docs/api-reference/chat/object
|
164
|
-
if response.status_code == 200:
|
165
|
-
output = response.json()
|
166
|
-
samples = [
|
167
|
-
lf.LMSample(self._message_from_choice(choice), score=0.0)
|
168
|
-
for choice in output['choices']
|
169
|
-
]
|
170
|
-
usage = output['usage']
|
171
|
-
return lf.LMSamplingResult(
|
172
|
-
samples,
|
173
|
-
usage=lf.LMSamplingUsage(
|
174
|
-
prompt_tokens=usage['prompt_tokens'],
|
175
|
-
completion_tokens=usage['completion_tokens'],
|
176
|
-
total_tokens=usage['total_tokens'],
|
177
|
-
),
|
178
|
-
)
|
179
|
-
else:
|
180
|
-
# https://platform.openai.com/docs/guides/error-codes/api-errors
|
181
|
-
if response.status_code == 429:
|
182
|
-
error_cls = RateLimitError
|
183
|
-
elif response.status_code in (500, 502, 503):
|
184
|
-
error_cls = OverloadedError
|
185
|
-
else:
|
186
|
-
error_cls = GroqError
|
187
|
-
raise error_cls(f'{response.status_code}: {response.content}')
|
188
|
-
|
189
|
-
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
190
|
-
assert self._api_initialized
|
191
|
-
return self._parallel_execute_with_currency_control(
|
192
|
-
self._sample_single,
|
193
|
-
prompts,
|
194
|
-
retry_on_errors=(RateLimitError, OverloadedError),
|
195
|
-
)
|
196
|
-
|
197
|
-
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
198
|
-
request = dict()
|
199
|
-
request.update(self._get_request_args(self.sampling_options))
|
200
|
-
request.update(
|
201
|
-
dict(
|
202
|
-
messages=[
|
203
|
-
dict(role='user', content=self._content_from_message(prompt))
|
204
|
-
]
|
205
|
-
)
|
206
|
-
)
|
207
|
-
try:
|
208
|
-
response = self._session.post(
|
209
|
-
_CHAT_COMPLETE_API_ENDPOINT,
|
210
|
-
json=request,
|
211
|
-
timeout=self.timeout,
|
212
|
-
)
|
213
|
-
return self._parse_response(response)
|
214
|
-
except ConnectionError as e:
|
215
|
-
raise OverloadedError(str(e)) from e
|
216
|
-
|
217
172
|
|
218
173
|
class GroqLlama3_8B(Groq): # pylint: disable=invalid-name
|
219
174
|
"""Llama3-8B with 8K context window.
|
langfun/core/llms/groq_test.py
CHANGED
langfun/core/llms/llama_cpp.py
CHANGED
@@ -13,62 +13,72 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Language models from llama.cpp."""
|
15
15
|
|
16
|
-
from typing import
|
16
|
+
from typing import Any
|
17
17
|
|
18
18
|
import langfun.core as lf
|
19
|
-
import
|
19
|
+
from langfun.core.llms import rest
|
20
|
+
import pyglove as pg
|
20
21
|
|
21
22
|
|
22
|
-
|
23
|
-
class LlamaCppRemote(lf.LanguageModel):
|
23
|
+
class LlamaCppRemote(rest.REST):
|
24
24
|
"""The remote LLaMA C++ model.
|
25
25
|
|
26
26
|
The Remote LLaMA C++ models can be launched via
|
27
27
|
https://github.com/ggerganov/llama.cpp/tree/master/examples/server
|
28
28
|
"""
|
29
29
|
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
] = ""
|
34
|
-
|
35
|
-
name: Annotated[
|
36
|
-
str,
|
37
|
-
"The abbreviation for the LLaMA CPP-based model name.",
|
38
|
-
] = ""
|
30
|
+
@pg.explicit_method_override
|
31
|
+
def __init__(self, url: str, model: str | None = None, **kwargs):
|
32
|
+
super().__init__(api_endpoint=f'{url}/completion', model=model, **kwargs)
|
39
33
|
|
40
34
|
@property
|
41
35
|
def model_id(self) -> str:
|
42
36
|
"""Returns a string to identify the model."""
|
43
|
-
return f
|
37
|
+
return f'LLaMAC++({self.model or ""})'
|
38
|
+
|
39
|
+
def request(
|
40
|
+
self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
|
41
|
+
) -> dict[str, Any]:
|
42
|
+
"""Returns the JSON input for a message."""
|
43
|
+
request = dict()
|
44
|
+
request.update(self._request_args(sampling_options))
|
45
|
+
# NOTE(daiyip): multi-modal is current not supported.
|
46
|
+
request['prompt'] = prompt.text
|
47
|
+
return request
|
48
|
+
|
49
|
+
def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
50
|
+
"""Returns a dict as request arguments."""
|
51
|
+
args = dict(
|
52
|
+
n_predict=options.max_tokens or 1024,
|
53
|
+
top_k=options.top_k or 50,
|
54
|
+
top_p=options.top_p or 0.95,
|
55
|
+
)
|
56
|
+
if options.temperature is not None:
|
57
|
+
args['temperature'] = options.temperature
|
58
|
+
return args
|
59
|
+
|
60
|
+
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
61
|
+
return lf.LMSamplingResult(
|
62
|
+
[lf.LMSample(item['content'], score=0.0) for item in json['items']]
|
63
|
+
)
|
44
64
|
|
45
|
-
def
|
46
|
-
|
47
|
-
results = []
|
48
|
-
for prompt in cur_prompts:
|
49
|
-
result = lf.LMSamplingResult()
|
50
|
-
for _ in range(self.sampling_options.n or 1):
|
51
|
-
data = {
|
52
|
-
"prompt": prompt.text,
|
53
|
-
"n_predict": self.sampling_options.max_tokens,
|
54
|
-
"top_k": self.sampling_options.top_k or 50,
|
55
|
-
"top_p": self.sampling_options.top_p or 0.95,
|
56
|
-
}
|
57
|
-
if self.sampling_options.temperature is not None:
|
58
|
-
data["temperature"] = self.sampling_options.temperature
|
65
|
+
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
66
|
+
request = self.request(prompt, self.sampling_options)
|
59
67
|
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
68
|
+
def _sample_one_example(request):
|
69
|
+
response = self._session.post(
|
70
|
+
self.api_endpoint,
|
71
|
+
json=request,
|
72
|
+
timeout=self.timeout,
|
73
|
+
)
|
74
|
+
if response.status_code == 200:
|
75
|
+
return response.json()
|
76
|
+
else:
|
77
|
+
error_cls = self._error_cls_from_status(response.status_code)
|
78
|
+
raise error_cls(f'{response.status_code}: {response.content}')
|
71
79
|
|
72
|
-
|
73
|
-
|
74
|
-
|
80
|
+
items = self._parallel_execute_with_currency_control(
|
81
|
+
_sample_one_example,
|
82
|
+
[request] * (self.sampling_options.n or 1),
|
83
|
+
)
|
84
|
+
return self.result(dict(items=items))
|
@@ -17,7 +17,6 @@ import typing
|
|
17
17
|
import unittest
|
18
18
|
from unittest import mock
|
19
19
|
|
20
|
-
import langfun.core as lf
|
21
20
|
from langfun.core.llms import llama_cpp
|
22
21
|
|
23
22
|
|
@@ -25,6 +24,9 @@ def mock_requests_post(url: str, json: typing.Dict[str, typing.Any], **kwargs):
|
|
25
24
|
del kwargs
|
26
25
|
|
27
26
|
class TEMP:
|
27
|
+
@property
|
28
|
+
def status_code(self):
|
29
|
+
return 200
|
28
30
|
|
29
31
|
def json(self):
|
30
32
|
return {"content": json["prompt"] + "\n" + url}
|
@@ -36,19 +38,23 @@ class LlamaCppRemoteTest(unittest.TestCase):
|
|
36
38
|
"""Tests for the LlamaCppRemote model."""
|
37
39
|
|
38
40
|
def test_call_completion(self):
|
39
|
-
with mock.patch("requests.post") as mock_request:
|
41
|
+
with mock.patch("requests.Session.post") as mock_request:
|
40
42
|
mock_request.side_effect = mock_requests_post
|
41
|
-
lm = llama_cpp.LlamaCppRemote(
|
42
|
-
|
43
|
+
lm = llama_cpp.LlamaCppRemote("http://127.0.0.1:8080")
|
44
|
+
[result] = lm.sample(["hello"], n=2)
|
43
45
|
self.assertEqual(
|
44
|
-
|
46
|
+
len(result.samples),
|
47
|
+
2
|
48
|
+
)
|
49
|
+
self.assertEqual(
|
50
|
+
str(result.samples[0].response),
|
45
51
|
"hello\nhttp://127.0.0.1:8080/completion",
|
46
52
|
)
|
47
53
|
|
48
|
-
def
|
49
|
-
lm = llama_cpp.LlamaCppRemote()
|
54
|
+
def test_model_id(self):
|
55
|
+
lm = llama_cpp.LlamaCppRemote("http://127.0.0.1:8080")
|
50
56
|
self.assertEqual(lm.model_id, "LLaMAC++()")
|
51
|
-
lm = llama_cpp.LlamaCppRemote(
|
57
|
+
lm = llama_cpp.LlamaCppRemote("xxx", model="x")
|
52
58
|
self.assertEqual(lm.model_id, "LLaMAC++(x)")
|
53
59
|
|
54
60
|
|
@@ -0,0 +1,112 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Base class for language models through REST APIs."""
|
15
|
+
|
16
|
+
import functools
|
17
|
+
from typing import Annotated, Any, Callable
|
18
|
+
|
19
|
+
import langfun.core as lf
|
20
|
+
import requests
|
21
|
+
|
22
|
+
|
23
|
+
class REST(lf.LanguageModel):
|
24
|
+
"""REST-based language model."""
|
25
|
+
|
26
|
+
api_endpoint: Annotated[
|
27
|
+
str,
|
28
|
+
'The endpoint of the REST API.'
|
29
|
+
]
|
30
|
+
|
31
|
+
request: Annotated[
|
32
|
+
Callable[[lf.Message, lf.LMSamplingOptions], dict[str, Any]],
|
33
|
+
'A function to convert a Langfun message to a JSON request.'
|
34
|
+
]
|
35
|
+
|
36
|
+
result: Annotated[
|
37
|
+
Callable[[dict[str, Any]], lf.LMSamplingResult],
|
38
|
+
'A function to convert a JSON response to an LMSamplingResult.'
|
39
|
+
]
|
40
|
+
|
41
|
+
model: Annotated[
|
42
|
+
str | None,
|
43
|
+
'Model ID.'
|
44
|
+
] = None
|
45
|
+
|
46
|
+
headers: Annotated[
|
47
|
+
dict[str, Any] | None,
|
48
|
+
'The headers for the REST API.'
|
49
|
+
] = None
|
50
|
+
|
51
|
+
@property
|
52
|
+
def model_id(self) -> str:
|
53
|
+
"""Returns a string to identify the model."""
|
54
|
+
return self.model or 'unknown'
|
55
|
+
|
56
|
+
@functools.cached_property
|
57
|
+
def _api_initialized(self) -> bool:
|
58
|
+
"""Returns whether the API is initialized."""
|
59
|
+
self._initialize()
|
60
|
+
return True
|
61
|
+
|
62
|
+
def _initialize(self) -> None:
|
63
|
+
"""Initializes the API. Subclasses can override."""
|
64
|
+
|
65
|
+
@functools.cached_property
|
66
|
+
def _session(self) -> requests.Session:
|
67
|
+
assert self._api_initialized
|
68
|
+
s = requests.Session()
|
69
|
+
s.headers.update(self.headers or {})
|
70
|
+
return s
|
71
|
+
|
72
|
+
def _on_bound(self):
|
73
|
+
super()._on_bound()
|
74
|
+
self.__dict__.pop('_session', None)
|
75
|
+
self.__dict__.pop('_api_initialized', None)
|
76
|
+
|
77
|
+
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
78
|
+
assert self._api_initialized
|
79
|
+
return self._parallel_execute_with_currency_control(
|
80
|
+
self._sample_single, prompts
|
81
|
+
)
|
82
|
+
|
83
|
+
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
84
|
+
try:
|
85
|
+
response = self._session.post(
|
86
|
+
self.api_endpoint,
|
87
|
+
json=self.request(prompt, self.sampling_options),
|
88
|
+
timeout=self.timeout,
|
89
|
+
)
|
90
|
+
return self._parse_response(response)
|
91
|
+
except ConnectionError as e:
|
92
|
+
raise lf.LMError(str(e)) from e
|
93
|
+
|
94
|
+
def _error(self, status_code: int, content: str) -> lf.LMError:
|
95
|
+
if status_code == 429:
|
96
|
+
error_cls = lf.RateLimitError
|
97
|
+
elif status_code in (
|
98
|
+
500, # Server side issue (might be bug).
|
99
|
+
502, # Bad gateway (upstream issue, might retry).
|
100
|
+
503, # Servers currently under load, retry after a brief wait.
|
101
|
+
):
|
102
|
+
error_cls = lf.TemporaryLMError
|
103
|
+
else:
|
104
|
+
error_cls = lf.LMError
|
105
|
+
return error_cls(f'{status_code}: {content}')
|
106
|
+
|
107
|
+
def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
|
108
|
+
"""Parses Anthropic's response."""
|
109
|
+
if response.status_code == 200:
|
110
|
+
return self.result(response.json())
|
111
|
+
else:
|
112
|
+
raise self._error(response.status_code, response.content)
|
@@ -0,0 +1,111 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tests for REST models."""
|
15
|
+
|
16
|
+
from typing import Any
|
17
|
+
import unittest
|
18
|
+
from unittest import mock
|
19
|
+
import langfun.core as lf
|
20
|
+
from langfun.core.llms import rest
|
21
|
+
import pyglove as pg
|
22
|
+
import requests
|
23
|
+
|
24
|
+
|
25
|
+
def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
|
26
|
+
del url, kwargs
|
27
|
+
response = requests.Response()
|
28
|
+
response.status_code = 200
|
29
|
+
response._content = pg.to_json_str({
|
30
|
+
'content': [(
|
31
|
+
f'hello with temperature={json.get("temperature")}, '
|
32
|
+
f'top_k={json.get("top_k")}, '
|
33
|
+
f'top_p={json.get("top_p")}, '
|
34
|
+
f'max_tokens={json.get("max_tokens")}, '
|
35
|
+
f'stop={json.get("stop_sequences")}.'
|
36
|
+
)],
|
37
|
+
}).encode()
|
38
|
+
return response
|
39
|
+
|
40
|
+
|
41
|
+
def mock_requests_post_error(status_code, error_type, error_message):
|
42
|
+
def _mock_requests(url: str, json: dict[str, Any], **kwargs):
|
43
|
+
del url, json, kwargs
|
44
|
+
response = requests.Response()
|
45
|
+
response.status_code = status_code
|
46
|
+
response._content = pg.to_json_str(
|
47
|
+
{
|
48
|
+
'error': {
|
49
|
+
'type': error_type,
|
50
|
+
'message': error_message,
|
51
|
+
}
|
52
|
+
}
|
53
|
+
).encode()
|
54
|
+
return response
|
55
|
+
|
56
|
+
return _mock_requests
|
57
|
+
|
58
|
+
|
59
|
+
class RestTest(unittest.TestCase):
|
60
|
+
|
61
|
+
def setUp(self):
|
62
|
+
super().setUp()
|
63
|
+
self._lm = rest.REST(
|
64
|
+
api_endpoint='https://fake-api.com',
|
65
|
+
request=lambda x, o: dict(
|
66
|
+
model='test-model',
|
67
|
+
prompt=x.text,
|
68
|
+
temperature=0.0,
|
69
|
+
top_k=0.1,
|
70
|
+
top_p=0.2,
|
71
|
+
stop_sequences=['\n'],
|
72
|
+
max_tokens=4096,
|
73
|
+
),
|
74
|
+
result=lambda x: lf.LMSamplingResult(
|
75
|
+
[lf.LMSample(c) for c in x['content']]),
|
76
|
+
headers=dict(api_key='fake_key'),
|
77
|
+
)
|
78
|
+
|
79
|
+
def test_call(self):
|
80
|
+
with mock.patch('requests.Session.post') as mock_request:
|
81
|
+
mock_request.side_effect = mock_requests_post
|
82
|
+
self.assertEqual(self._lm.model_id, 'unknown')
|
83
|
+
response = self._lm(
|
84
|
+
'hello', temperature=0.0, top_k=0.1, top_p=0.2, stop=['\n'])
|
85
|
+
self.assertEqual(
|
86
|
+
response.text,
|
87
|
+
(
|
88
|
+
'hello with temperature=0.0, top_k=0.1, top_p=0.2, '
|
89
|
+
"max_tokens=4096, stop=['\\n']."
|
90
|
+
),
|
91
|
+
)
|
92
|
+
self.assertIsNone(response.usage)
|
93
|
+
|
94
|
+
def test_call_errors(self):
|
95
|
+
for status_code, error_type, error_message in [
|
96
|
+
(429, 'rate_limit', 'Rate limit exceeded.'),
|
97
|
+
(529, 'service_unavailable', 'Service unavailable.'),
|
98
|
+
(500, 'bad_request', 'Bad request.'),
|
99
|
+
]:
|
100
|
+
with mock.patch('requests.Session.post') as mock_mm_request:
|
101
|
+
mock_mm_request.side_effect = mock_requests_post_error(
|
102
|
+
status_code, error_type, error_message
|
103
|
+
)
|
104
|
+
with self.assertRaisesRegex(
|
105
|
+
Exception, f'.*{status_code}: .*{error_message}'
|
106
|
+
):
|
107
|
+
self._lm('hello', max_attempts=1)
|
108
|
+
|
109
|
+
|
110
|
+
if __name__ == '__main__':
|
111
|
+
unittest.main()
|
@@ -1,5 +1,5 @@
|
|
1
|
-
langfun/__init__.py,sha256=
|
2
|
-
langfun/core/__init__.py,sha256=
|
1
|
+
langfun/__init__.py,sha256=P62MnqA6-f0h8iYfQ3MT6Yg7a4qRnQeb4GrIn6dcSnY,2274
|
2
|
+
langfun/core/__init__.py,sha256=ZheiCpop_GAZbVpnSS-uPBJaEEM15Td5xFGGizSGqko,4514
|
3
3
|
langfun/core/component.py,sha256=oxesbC0BoE_TbtxwW5x-BAZWxZyyJbuPiX5S38RqCv0,9909
|
4
4
|
langfun/core/component_test.py,sha256=uR-_Sz_42Jxc5qzLIB-f5_pXmNwnC01Xlbv5NOQSeSU,8021
|
5
5
|
langfun/core/concurrent.py,sha256=TRc49pJ3HQro2kb5FtcWkHjhBm8UcgE8RJybU5cU3-0,24537
|
@@ -8,7 +8,7 @@ langfun/core/console.py,sha256=bk5rNPNm9rMGW5YT2HixxU04p2umnoabn5SDz6Dqe88,2317
|
|
8
8
|
langfun/core/console_test.py,sha256=5SYJdxpJGLgdSSQqqMPoA1X6jpsLD8rgcyk-EgI65oE,1077
|
9
9
|
langfun/core/langfunc.py,sha256=RvIcRjIq0jWYRu1xim-FYe4HSrt97r3GMBO_PuagUmw,11060
|
10
10
|
langfun/core/langfunc_test.py,sha256=_mfARnakX3oji5HDigFSLMd6yQ2wma-2Mgbztwqn73g,8501
|
11
|
-
langfun/core/language_model.py,sha256=
|
11
|
+
langfun/core/language_model.py,sha256=PocBg1t3uB0a_bJntLW5aagHhNbZsVdp2iduSBEW6ro,21240
|
12
12
|
langfun/core/language_model_test.py,sha256=NZaSUls6cZdtxiqkqumWbtkx9zgNiJlsviYZOWkuHig,20137
|
13
13
|
langfun/core/memory.py,sha256=f-asN1F7Vehgdn_fK84v73GrEUOxRtaW934keutTKjk,2416
|
14
14
|
langfun/core/message.py,sha256=Rw3yC9HyGRjMhfDgyNjGlSCALEyDDbJ0_o6qTXeeDiQ,15738
|
@@ -48,19 +48,21 @@ langfun/core/eval/patching.py,sha256=R0s2eAd1m97exQt06dmUL0V_MBG0W2Hxg7fhNB7cXW0
|
|
48
48
|
langfun/core/eval/patching_test.py,sha256=8kCd54Egjju22FMgtJuxEsrXkW8ifs-UUBHtrCG1L6w,4775
|
49
49
|
langfun/core/eval/scoring.py,sha256=1J7IATo-8FXUR0SBqk9icztHiM0lWkBFcWUo-vUURgQ,6376
|
50
50
|
langfun/core/eval/scoring_test.py,sha256=O8olHbrUEg60gMxwOkWzKBJZpZoUlmVnBANX5Se2SXM,4546
|
51
|
-
langfun/core/llms/__init__.py,sha256=
|
52
|
-
langfun/core/llms/anthropic.py,sha256=
|
53
|
-
langfun/core/llms/anthropic_test.py,sha256=
|
51
|
+
langfun/core/llms/__init__.py,sha256=3G7pJISeClgHGV34Gy2t_Nih4N08UhGbWe6uAff8TnA,4364
|
52
|
+
langfun/core/llms/anthropic.py,sha256=pBYe8dVwswxKaqhNjA_jtZbyfvOaXtEo399Zty242iA,7097
|
53
|
+
langfun/core/llms/anthropic_test.py,sha256=T-swuMkfnlgs8Fpif4rtXs579exGk0TsbLMirXDZCkg,5533
|
54
54
|
langfun/core/llms/fake.py,sha256=Dd7-6ka9pFf3fcWZyczamjOqQ91MOI-m7We3Oc9Ffmo,2927
|
55
55
|
langfun/core/llms/fake_test.py,sha256=ipKfdOcuqVcJ8lDXVpnBVb9HHG0hAVkFkMoHpWjC2cI,7212
|
56
56
|
langfun/core/llms/google_genai.py,sha256=Rl5a5CyF_6Y0BYYArKk8yMaenv1rH3MUQLy6b3dfMRI,10202
|
57
57
|
langfun/core/llms/google_genai_test.py,sha256=iTISk3tJ4-3gjWmzcKQhEbH3ke4AkEiCu8rAGtB7SvU,7535
|
58
|
-
langfun/core/llms/groq.py,sha256=
|
59
|
-
langfun/core/llms/groq_test.py,sha256=
|
60
|
-
langfun/core/llms/llama_cpp.py,sha256=
|
61
|
-
langfun/core/llms/llama_cpp_test.py,sha256=
|
58
|
+
langfun/core/llms/groq.py,sha256=pqtyOZ_1_OJMOg8xATWT_B_SVbuT9nMRf4VkH9GzW8g,6308
|
59
|
+
langfun/core/llms/groq_test.py,sha256=GYF_Qtq5S1H1TrKH38t6_lkdroqT7v-joYLDKnmS9e0,5274
|
60
|
+
langfun/core/llms/llama_cpp.py,sha256=9tXQntSCDtjTF3bnyJrAPCr4N6wycy5nXYvp9uduygE,2843
|
61
|
+
langfun/core/llms/llama_cpp_test.py,sha256=MWO_qaOeKjRniGjcaWPDScd7HPaIJemqUZoslrt4FPs,1806
|
62
62
|
langfun/core/llms/openai.py,sha256=IN46gIqfY6aEEfxCPNmyH1hrep3oWBhJDwVFilfqNkM,13657
|
63
63
|
langfun/core/llms/openai_test.py,sha256=QWDzTgi8F2Z9u9ip6alK4rDEp_YraVTxWlDX5XOsKJk,14858
|
64
|
+
langfun/core/llms/rest.py,sha256=laopuq-zD8V-3Y6eFDngftHEbE66VlUkCD2-rvvRaLU,3388
|
65
|
+
langfun/core/llms/rest_test.py,sha256=Zw08Xbl_O2OQuUglLXHsPsY5KW2VEcPGl1gR8PRHuFY,3449
|
64
66
|
langfun/core/llms/vertexai.py,sha256=eILbXoMSza5r4FLGlIdH6-eD8Ggy9Z4PdjLaBDxy29A,11162
|
65
67
|
langfun/core/llms/vertexai_test.py,sha256=G18BG36h5KvmX2zutDTLjtYCRjTuP_nWIFm4FMnLnyY,7651
|
66
68
|
langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
|
@@ -111,8 +113,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
|
|
111
113
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
112
114
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
113
115
|
langfun/core/templates/selfplay_test.py,sha256=DYVrkk7uNKCqJGEHH31HssU2BPuMItU1vJLzfcXIlYg,2156
|
114
|
-
langfun-0.0.2.
|
115
|
-
langfun-0.0.2.
|
116
|
-
langfun-0.0.2.
|
117
|
-
langfun-0.0.2.
|
118
|
-
langfun-0.0.2.
|
116
|
+
langfun-0.0.2.dev20240604.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
117
|
+
langfun-0.0.2.dev20240604.dist-info/METADATA,sha256=WGM_N4Nizh0eT4OvlSgopXgcnHy7uset6O7iAjzqUdk,3550
|
118
|
+
langfun-0.0.2.dev20240604.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
119
|
+
langfun-0.0.2.dev20240604.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
120
|
+
langfun-0.0.2.dev20240604.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|