langfun 0.0.2.dev20240603__py3-none-any.whl → 0.0.2.dev20240605__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

langfun/__init__.py CHANGED
@@ -63,7 +63,7 @@ Image = modalities.Image
63
63
  Video = modalities.Video
64
64
  PDF = modalities.PDF
65
65
 
66
- # Error types.
66
+ # Additional error types.
67
67
  MappingError = structured.MappingError
68
68
  SchemaError = structured.SchemaError
69
69
  JsonError = structured.JsonError
langfun/core/__init__.py CHANGED
@@ -106,6 +106,11 @@ from langfun.core.language_model import LMScoringResult
106
106
  from langfun.core.language_model import LMCache
107
107
  from langfun.core.language_model import LMDebugMode
108
108
 
109
+ from langfun.core.language_model import LMError
110
+ from langfun.core.language_model import RetryableLMError
111
+ from langfun.core.language_model import RateLimitError
112
+ from langfun.core.language_model import TemporaryLMError
113
+
109
114
  # Components for building agents.
110
115
  from langfun.core.memory import Memory
111
116
 
@@ -29,6 +29,32 @@ TOKENS_PER_REQUEST = 250 # Estimated num tokens for a single request
29
29
  DEFAULT_MAX_CONCURRENCY = 1 # Use this as max concurrency if no RPM or TPM data
30
30
 
31
31
 
32
+ #
33
+ # Common errors during calling language models.
34
+ #
35
+
36
+
37
+ class LMError(RuntimeError):
38
+ """Base class for language model errors."""
39
+
40
+
41
+ class RetryableLMError(LMError):
42
+ """Base class for LLM errors that can be solved by retrying."""
43
+
44
+
45
+ class RateLimitError(RetryableLMError):
46
+ """Error for rate limit reached."""
47
+
48
+
49
+ class TemporaryLMError(RetryableLMError):
50
+ """Error for temporary service issues that can be retried."""
51
+
52
+
53
+ #
54
+ # Language model input/output interfaces.
55
+ #
56
+
57
+
32
58
  class LMSample(pg.Object):
33
59
  """Response candidate."""
34
60
 
@@ -445,7 +471,7 @@ class LanguageModel(component.Component):
445
471
  None,
446
472
  Union[Type[Exception], Tuple[Type[Exception], str]],
447
473
  Sequence[Union[Type[Exception], Tuple[Type[Exception], str]]],
448
- ] = None,
474
+ ] = RetryableLMError,
449
475
  ) -> Any:
450
476
  """Helper method for subclasses for implementing _sample."""
451
477
  return concurrent.concurrent_execute(
@@ -24,6 +24,9 @@ from langfun.core.llms.fake import StaticMapping
24
24
  from langfun.core.llms.fake import StaticResponse
25
25
  from langfun.core.llms.fake import StaticSequence
26
26
 
27
+ # REST-based models.
28
+ from langfun.core.llms.rest import REST
29
+
27
30
  # Gemini models.
28
31
  from langfun.core.llms.google_genai import GenAI
29
32
  from langfun.core.llms.google_genai import GeminiPro
@@ -14,14 +14,13 @@
14
14
  """Language models from Anthropic."""
15
15
 
16
16
  import base64
17
- import functools
18
17
  import os
19
18
  from typing import Annotated, Any
20
19
 
21
20
  import langfun.core as lf
22
21
  from langfun.core import modalities as lf_modalities
22
+ from langfun.core.llms import rest
23
23
  import pyglove as pg
24
- import requests
25
24
 
26
25
 
27
26
  SUPPORTED_MODELS_AND_SETTINGS = {
@@ -38,24 +37,8 @@ SUPPORTED_MODELS_AND_SETTINGS = {
38
37
  }
39
38
 
40
39
 
41
- class AnthropicError(Exception): # pylint: disable=g-bad-exception-name
42
- """Base class for Anthropic errors."""
43
-
44
-
45
- class RateLimitError(AnthropicError):
46
- """Error for rate limit reached."""
47
-
48
-
49
- class OverloadedError(AnthropicError):
50
- """Anthropic's server is temporarily overloaded."""
51
-
52
-
53
- _ANTHROPIC_MESSAGE_API_ENDPOINT = 'https://api.anthropic.com/v1/messages'
54
- _ANTHROPIC_API_VERSION = '2023-06-01'
55
-
56
-
57
40
  @lf.use_init_args(['model'])
58
- class Anthropic(lf.LanguageModel):
41
+ class Anthropic(rest.REST):
59
42
  """Anthropic LLMs (Claude) through REST APIs.
60
43
 
61
44
  See https://docs.anthropic.com/claude/reference/messages_post
@@ -80,14 +63,18 @@ class Anthropic(lf.LanguageModel):
80
63
  ),
81
64
  ] = None
82
65
 
66
+ api_endpoint: str = 'https://api.anthropic.com/v1/messages'
67
+
68
+ api_version: Annotated[
69
+ str,
70
+ 'Anthropic API version.'
71
+ ] = '2023-06-01'
72
+
83
73
  def _on_bound(self):
84
74
  super()._on_bound()
85
75
  self._api_key = None
86
- self.__dict__.pop('_api_initialized', None)
87
- self.__dict__.pop('_session', None)
88
76
 
89
- @functools.cached_property
90
- def _api_initialized(self):
77
+ def _initialize(self):
91
78
  api_key = self.api_key or os.environ.get('ANTHROPIC_API_KEY', None)
92
79
  if not api_key:
93
80
  raise ValueError(
@@ -95,18 +82,14 @@ class Anthropic(lf.LanguageModel):
95
82
  'variable `ANTHROPIC_API_KEY` with your Anthropic API key.'
96
83
  )
97
84
  self._api_key = api_key
98
- return True
99
85
 
100
- @functools.cached_property
101
- def _session(self) -> requests.Session:
102
- assert self._api_initialized
103
- s = requests.Session()
104
- s.headers.update({
86
+ @property
87
+ def headers(self) -> dict[str, Any]:
88
+ return {
105
89
  'x-api-key': self._api_key,
106
- 'anthropic-version': _ANTHROPIC_API_VERSION,
90
+ 'anthropic-version': self.api_version,
107
91
  'content-type': 'application/json',
108
- })
109
- return s
92
+ }
110
93
 
111
94
  @property
112
95
  def model_id(self) -> str:
@@ -121,13 +104,24 @@ class Anthropic(lf.LanguageModel):
121
104
  requests_per_min=rpm, tokens_per_min=tpm
122
105
  )
123
106
 
124
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
125
- assert self._api_initialized
126
- return self._parallel_execute_with_currency_control(
127
- self._sample_single, prompts, retry_on_errors=(RateLimitError)
107
+ def request(
108
+ self,
109
+ prompt: lf.Message,
110
+ sampling_options: lf.LMSamplingOptions
111
+ ) -> dict[str, Any]:
112
+ """Returns the JSON input for a message."""
113
+ request = dict()
114
+ request.update(self._request_args(sampling_options))
115
+ request.update(
116
+ dict(
117
+ messages=[
118
+ dict(role='user', content=self._content_from_message(prompt))
119
+ ]
120
+ )
128
121
  )
122
+ return request
129
123
 
130
- def _get_request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
124
+ def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
131
125
  """Returns a dict as request arguments."""
132
126
  # Authropic requires `max_tokens` to be specified.
133
127
  max_tokens = (
@@ -174,6 +168,19 @@ class Anthropic(lf.LanguageModel):
174
168
  else:
175
169
  return [dict(type='text', text=prompt.text)]
176
170
 
171
+ def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
172
+ message = self._message_from_content(json['content'])
173
+ input_tokens = json['usage']['input_tokens']
174
+ output_tokens = json['usage']['output_tokens']
175
+ return lf.LMSamplingResult(
176
+ [lf.LMSample(message)],
177
+ usage=lf.LMSamplingUsage(
178
+ prompt_tokens=input_tokens,
179
+ completion_tokens=output_tokens,
180
+ total_tokens=input_tokens + output_tokens,
181
+ ),
182
+ )
183
+
177
184
  def _message_from_content(self, content: list[dict[str, Any]]) -> lf.Message:
178
185
  """Converts Anthropic's content protocol to message."""
179
186
  # Refer: https://docs.anthropic.com/claude/reference/messages-examples
@@ -181,49 +188,6 @@ class Anthropic(lf.LanguageModel):
181
188
  [x['text'] for x in content if x['type'] == 'text']
182
189
  )
183
190
 
184
- def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
185
- """Parses Anthropic's response."""
186
- # NOTE(daiyip): Refer https://docs.anthropic.com/claude/reference/errors
187
- if response.status_code == 200:
188
- output = response.json()
189
- message = self._message_from_content(output['content'])
190
- input_tokens = output['usage']['input_tokens']
191
- output_tokens = output['usage']['output_tokens']
192
- return lf.LMSamplingResult(
193
- [lf.LMSample(message)],
194
- usage=lf.LMSamplingUsage(
195
- prompt_tokens=input_tokens,
196
- completion_tokens=output_tokens,
197
- total_tokens=input_tokens + output_tokens,
198
- ),
199
- )
200
- else:
201
- if response.status_code == 429:
202
- error_cls = RateLimitError
203
- elif response.status_code in (502, 529):
204
- error_cls = OverloadedError
205
- else:
206
- error_cls = AnthropicError
207
- raise error_cls(f'{response.status_code}: {response.content}')
208
-
209
- def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
210
- request = dict()
211
- request.update(self._get_request_args(self.sampling_options))
212
- request.update(
213
- dict(
214
- messages=[
215
- dict(role='user', content=self._content_from_message(prompt))
216
- ]
217
- )
218
- )
219
- try:
220
- response = self._session.post(
221
- _ANTHROPIC_MESSAGE_API_ENDPOINT, json=request, timeout=self.timeout,
222
- )
223
- return self._parse_response(response)
224
- except ConnectionError as e:
225
- raise OverloadedError(str(e)) from e
226
-
227
191
 
228
192
  class Claude3(Anthropic):
229
193
  """Base class for Claude 3 models. 200K input tokens and 4K output tokens."""
@@ -160,7 +160,7 @@ class AnthropicTest(unittest.TestCase):
160
160
  with self.assertRaisesRegex(
161
161
  Exception, f'.*{status_code}: .*{error_message}'
162
162
  ):
163
- lm('hello', lm=lm, max_attempts=1)
163
+ lm('hello', max_attempts=1)
164
164
 
165
165
 
166
166
  if __name__ == '__main__':
langfun/core/llms/groq.py CHANGED
@@ -13,14 +13,13 @@
13
13
  # limitations under the License.
14
14
  """Language models from Groq."""
15
15
 
16
- import functools
17
16
  import os
18
17
  from typing import Annotated, Any
19
18
 
20
19
  import langfun.core as lf
21
20
  from langfun.core import modalities as lf_modalities
21
+ from langfun.core.llms import rest
22
22
  import pyglove as pg
23
- import requests
24
23
 
25
24
 
26
25
  SUPPORTED_MODELS_AND_SETTINGS = {
@@ -33,23 +32,8 @@ SUPPORTED_MODELS_AND_SETTINGS = {
33
32
  }
34
33
 
35
34
 
36
- class GroqError(Exception): # pylint: disable=g-bad-exception-name
37
- """Base class for Groq errors."""
38
-
39
-
40
- class RateLimitError(GroqError):
41
- """Error for rate limit reached."""
42
-
43
-
44
- class OverloadedError(GroqError):
45
- """Groq's server is temporarily overloaded."""
46
-
47
-
48
- _CHAT_COMPLETE_API_ENDPOINT = 'https://api.groq.com/openai/v1/chat/completions'
49
-
50
-
51
35
  @lf.use_init_args(['model'])
52
- class Groq(lf.LanguageModel):
36
+ class Groq(rest.REST):
53
37
  """Groq LLMs through REST APIs (OpenAI compatible).
54
38
 
55
39
  See https://platform.openai.com/docs/api-reference/chat
@@ -74,14 +58,13 @@ class Groq(lf.LanguageModel):
74
58
  ),
75
59
  ] = None
76
60
 
61
+ api_endpoint: str = 'https://api.groq.com/openai/v1/chat/completions'
62
+
77
63
  def _on_bound(self):
78
64
  super()._on_bound()
79
65
  self._api_key = None
80
- self.__dict__.pop('_api_initialized', None)
81
- self.__dict__.pop('_session', None)
82
66
 
83
- @functools.cached_property
84
- def _api_initialized(self):
67
+ def _initialize(self):
85
68
  api_key = self.api_key or os.environ.get('GROQ_API_KEY', None)
86
69
  if not api_key:
87
70
  raise ValueError(
@@ -89,17 +72,13 @@ class Groq(lf.LanguageModel):
89
72
  'variable `GROQ_API_KEY` with your Groq API key.'
90
73
  )
91
74
  self._api_key = api_key
92
- return True
93
75
 
94
- @functools.cached_property
95
- def _session(self) -> requests.Session:
96
- assert self._api_initialized
97
- s = requests.Session()
98
- s.headers.update({
76
+ @property
77
+ def headers(self) -> dict[str, Any]:
78
+ return {
99
79
  'Authorization': f'Bearer {self._api_key}',
100
80
  'Content-Type': 'application/json',
101
- })
102
- return s
81
+ }
103
82
 
104
83
  @property
105
84
  def model_id(self) -> str:
@@ -110,7 +89,24 @@ class Groq(lf.LanguageModel):
110
89
  def max_concurrency(self) -> int:
111
90
  return SUPPORTED_MODELS_AND_SETTINGS[self.model].max_concurrency
112
91
 
113
- def _get_request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
92
+ def request(
93
+ self,
94
+ prompt: lf.Message,
95
+ sampling_options: lf.LMSamplingOptions
96
+ ) -> dict[str, Any]:
97
+ """Returns the JSON input for a message."""
98
+ request = dict()
99
+ request.update(self._request_args(sampling_options))
100
+ request.update(
101
+ dict(
102
+ messages=[
103
+ dict(role='user', content=self._content_from_message(prompt))
104
+ ]
105
+ )
106
+ )
107
+ return request
108
+
109
+ def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
114
110
  """Returns a dict as request arguments."""
115
111
  # `logprobs` and `top_logprobs` flags are not supported on Groq yet.
116
112
  args = dict(
@@ -148,6 +144,21 @@ class Groq(lf.LanguageModel):
148
144
  content.append(item)
149
145
  return content
150
146
 
147
+ def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
148
+ samples = [
149
+ lf.LMSample(self._message_from_choice(choice), score=0.0)
150
+ for choice in json['choices']
151
+ ]
152
+ usage = json['usage']
153
+ return lf.LMSamplingResult(
154
+ samples,
155
+ usage=lf.LMSamplingUsage(
156
+ prompt_tokens=usage['prompt_tokens'],
157
+ completion_tokens=usage['completion_tokens'],
158
+ total_tokens=usage['total_tokens'],
159
+ ),
160
+ )
161
+
151
162
  def _message_from_choice(self, choice: dict[str, Any]) -> lf.Message:
152
163
  """Converts Groq's content protocol to message."""
153
164
  # Refer: https://platform.openai.com/docs/api-reference/chat/create
@@ -158,62 +169,6 @@ class Groq(lf.LanguageModel):
158
169
  [x['text'] for x in content if x['type'] == 'text']
159
170
  )
160
171
 
161
- def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
162
- """Parses Groq's response."""
163
- # Refer: https://platform.openai.com/docs/api-reference/chat/object
164
- if response.status_code == 200:
165
- output = response.json()
166
- samples = [
167
- lf.LMSample(self._message_from_choice(choice), score=0.0)
168
- for choice in output['choices']
169
- ]
170
- usage = output['usage']
171
- return lf.LMSamplingResult(
172
- samples,
173
- usage=lf.LMSamplingUsage(
174
- prompt_tokens=usage['prompt_tokens'],
175
- completion_tokens=usage['completion_tokens'],
176
- total_tokens=usage['total_tokens'],
177
- ),
178
- )
179
- else:
180
- # https://platform.openai.com/docs/guides/error-codes/api-errors
181
- if response.status_code == 429:
182
- error_cls = RateLimitError
183
- elif response.status_code in (500, 502, 503):
184
- error_cls = OverloadedError
185
- else:
186
- error_cls = GroqError
187
- raise error_cls(f'{response.status_code}: {response.content}')
188
-
189
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
190
- assert self._api_initialized
191
- return self._parallel_execute_with_currency_control(
192
- self._sample_single,
193
- prompts,
194
- retry_on_errors=(RateLimitError, OverloadedError),
195
- )
196
-
197
- def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
198
- request = dict()
199
- request.update(self._get_request_args(self.sampling_options))
200
- request.update(
201
- dict(
202
- messages=[
203
- dict(role='user', content=self._content_from_message(prompt))
204
- ]
205
- )
206
- )
207
- try:
208
- response = self._session.post(
209
- _CHAT_COMPLETE_API_ENDPOINT,
210
- json=request,
211
- timeout=self.timeout,
212
- )
213
- return self._parse_response(response)
214
- except ConnectionError as e:
215
- raise OverloadedError(str(e)) from e
216
-
217
172
 
218
173
  class GroqLlama3_8B(Groq): # pylint: disable=invalid-name
219
174
  """Llama3-8B with 8K context window.
@@ -163,7 +163,7 @@ class AuthropicTest(unittest.TestCase):
163
163
  with self.assertRaisesRegex(
164
164
  Exception, f'{status_code}:.*{error_type}'
165
165
  ):
166
- lm('hello', lm=lm, max_attempts=1)
166
+ lm('hello', max_attempts=1)
167
167
 
168
168
 
169
169
  if __name__ == '__main__':
@@ -13,62 +13,72 @@
13
13
  # limitations under the License.
14
14
  """Language models from llama.cpp."""
15
15
 
16
- from typing import Annotated
16
+ from typing import Any
17
17
 
18
18
  import langfun.core as lf
19
- import requests
19
+ from langfun.core.llms import rest
20
+ import pyglove as pg
20
21
 
21
22
 
22
- @lf.use_init_args(["url"])
23
- class LlamaCppRemote(lf.LanguageModel):
23
+ class LlamaCppRemote(rest.REST):
24
24
  """The remote LLaMA C++ model.
25
25
 
26
26
  The Remote LLaMA C++ models can be launched via
27
27
  https://github.com/ggerganov/llama.cpp/tree/master/examples/server
28
28
  """
29
29
 
30
- url: Annotated[
31
- str,
32
- "The name of the model to use.",
33
- ] = ""
34
-
35
- name: Annotated[
36
- str,
37
- "The abbreviation for the LLaMA CPP-based model name.",
38
- ] = ""
30
+ @pg.explicit_method_override
31
+ def __init__(self, url: str, model: str | None = None, **kwargs):
32
+ super().__init__(api_endpoint=f'{url}/completion', model=model, **kwargs)
39
33
 
40
34
  @property
41
35
  def model_id(self) -> str:
42
36
  """Returns a string to identify the model."""
43
- return f"LLaMAC++({self.name})"
37
+ return f'LLaMAC++({self.model or ""})'
38
+
39
+ def request(
40
+ self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
41
+ ) -> dict[str, Any]:
42
+ """Returns the JSON input for a message."""
43
+ request = dict()
44
+ request.update(self._request_args(sampling_options))
45
+ # NOTE(daiyip): multi-modal is current not supported.
46
+ request['prompt'] = prompt.text
47
+ return request
48
+
49
+ def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
50
+ """Returns a dict as request arguments."""
51
+ args = dict(
52
+ n_predict=options.max_tokens or 1024,
53
+ top_k=options.top_k or 50,
54
+ top_p=options.top_p or 0.95,
55
+ )
56
+ if options.temperature is not None:
57
+ args['temperature'] = options.temperature
58
+ return args
59
+
60
+ def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
61
+ return lf.LMSamplingResult(
62
+ [lf.LMSample(item['content'], score=0.0) for item in json['items']]
63
+ )
44
64
 
45
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
46
- def _complete_fn(cur_prompts):
47
- results = []
48
- for prompt in cur_prompts:
49
- result = lf.LMSamplingResult()
50
- for _ in range(self.sampling_options.n or 1):
51
- data = {
52
- "prompt": prompt.text,
53
- "n_predict": self.sampling_options.max_tokens,
54
- "top_k": self.sampling_options.top_k or 50,
55
- "top_p": self.sampling_options.top_p or 0.95,
56
- }
57
- if self.sampling_options.temperature is not None:
58
- data["temperature"] = self.sampling_options.temperature
65
+ def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
66
+ request = self.request(prompt, self.sampling_options)
59
67
 
60
- response = requests.post(
61
- f"{self.url}/completion",
62
- json=data,
63
- headers={"Content-Type": "application/json"},
64
- timeout=self.timeout,
65
- )
66
- decoded_response = response.json()
67
- response = decoded_response["content"]
68
- result.samples.append(lf.LMSample(response, score=0.0))
69
- results.append(result)
70
- return results
68
+ def _sample_one_example(request):
69
+ response = self._session.post(
70
+ self.api_endpoint,
71
+ json=request,
72
+ timeout=self.timeout,
73
+ )
74
+ if response.status_code == 200:
75
+ return response.json()
76
+ else:
77
+ error_cls = self._error_cls_from_status(response.status_code)
78
+ raise error_cls(f'{response.status_code}: {response.content}')
71
79
 
72
- return self._parallel_execute_with_currency_control(
73
- _complete_fn, [prompts]
74
- )[0]
80
+ items = self._parallel_execute_with_currency_control(
81
+ _sample_one_example,
82
+ [request] * (self.sampling_options.n or 1),
83
+ )
84
+ return self.result(dict(items=items))
@@ -17,7 +17,6 @@ import typing
17
17
  import unittest
18
18
  from unittest import mock
19
19
 
20
- import langfun.core as lf
21
20
  from langfun.core.llms import llama_cpp
22
21
 
23
22
 
@@ -25,6 +24,9 @@ def mock_requests_post(url: str, json: typing.Dict[str, typing.Any], **kwargs):
25
24
  del kwargs
26
25
 
27
26
  class TEMP:
27
+ @property
28
+ def status_code(self):
29
+ return 200
28
30
 
29
31
  def json(self):
30
32
  return {"content": json["prompt"] + "\n" + url}
@@ -36,19 +38,23 @@ class LlamaCppRemoteTest(unittest.TestCase):
36
38
  """Tests for the LlamaCppRemote model."""
37
39
 
38
40
  def test_call_completion(self):
39
- with mock.patch("requests.post") as mock_request:
41
+ with mock.patch("requests.Session.post") as mock_request:
40
42
  mock_request.side_effect = mock_requests_post
41
- lm = llama_cpp.LlamaCppRemote(url="http://127.0.0.1:8080")
42
- response = lm("hello", sampling_options=lf.LMSamplingOptions(n=1))
43
+ lm = llama_cpp.LlamaCppRemote("http://127.0.0.1:8080")
44
+ [result] = lm.sample(["hello"], n=2)
43
45
  self.assertEqual(
44
- response.text,
46
+ len(result.samples),
47
+ 2
48
+ )
49
+ self.assertEqual(
50
+ str(result.samples[0].response),
45
51
  "hello\nhttp://127.0.0.1:8080/completion",
46
52
  )
47
53
 
48
- def test_name(self):
49
- lm = llama_cpp.LlamaCppRemote()
54
+ def test_model_id(self):
55
+ lm = llama_cpp.LlamaCppRemote("http://127.0.0.1:8080")
50
56
  self.assertEqual(lm.model_id, "LLaMAC++()")
51
- lm = llama_cpp.LlamaCppRemote(url="xxx", name="x")
57
+ lm = llama_cpp.LlamaCppRemote("xxx", model="x")
52
58
  self.assertEqual(lm.model_id, "LLaMAC++(x)")
53
59
 
54
60
 
@@ -0,0 +1,112 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Base class for language models through REST APIs."""
15
+
16
+ import functools
17
+ from typing import Annotated, Any, Callable
18
+
19
+ import langfun.core as lf
20
+ import requests
21
+
22
+
23
+ class REST(lf.LanguageModel):
24
+ """REST-based language model."""
25
+
26
+ api_endpoint: Annotated[
27
+ str,
28
+ 'The endpoint of the REST API.'
29
+ ]
30
+
31
+ request: Annotated[
32
+ Callable[[lf.Message, lf.LMSamplingOptions], dict[str, Any]],
33
+ 'A function to convert a Langfun message to a JSON request.'
34
+ ]
35
+
36
+ result: Annotated[
37
+ Callable[[dict[str, Any]], lf.LMSamplingResult],
38
+ 'A function to convert a JSON response to an LMSamplingResult.'
39
+ ]
40
+
41
+ model: Annotated[
42
+ str | None,
43
+ 'Model ID.'
44
+ ] = None
45
+
46
+ headers: Annotated[
47
+ dict[str, Any] | None,
48
+ 'The headers for the REST API.'
49
+ ] = None
50
+
51
+ @property
52
+ def model_id(self) -> str:
53
+ """Returns a string to identify the model."""
54
+ return self.model or 'unknown'
55
+
56
+ @functools.cached_property
57
+ def _api_initialized(self) -> bool:
58
+ """Returns whether the API is initialized."""
59
+ self._initialize()
60
+ return True
61
+
62
+ def _initialize(self) -> None:
63
+ """Initializes the API. Subclasses can override."""
64
+
65
+ @functools.cached_property
66
+ def _session(self) -> requests.Session:
67
+ assert self._api_initialized
68
+ s = requests.Session()
69
+ s.headers.update(self.headers or {})
70
+ return s
71
+
72
+ def _on_bound(self):
73
+ super()._on_bound()
74
+ self.__dict__.pop('_session', None)
75
+ self.__dict__.pop('_api_initialized', None)
76
+
77
+ def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
78
+ assert self._api_initialized
79
+ return self._parallel_execute_with_currency_control(
80
+ self._sample_single, prompts
81
+ )
82
+
83
+ def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
84
+ try:
85
+ response = self._session.post(
86
+ self.api_endpoint,
87
+ json=self.request(prompt, self.sampling_options),
88
+ timeout=self.timeout,
89
+ )
90
+ return self._parse_response(response)
91
+ except ConnectionError as e:
92
+ raise lf.LMError(str(e)) from e
93
+
94
+ def _error(self, status_code: int, content: str) -> lf.LMError:
95
+ if status_code == 429:
96
+ error_cls = lf.RateLimitError
97
+ elif status_code in (
98
+ 500, # Server side issue (might be bug).
99
+ 502, # Bad gateway (upstream issue, might retry).
100
+ 503, # Servers currently under load, retry after a brief wait.
101
+ ):
102
+ error_cls = lf.TemporaryLMError
103
+ else:
104
+ error_cls = lf.LMError
105
+ return error_cls(f'{status_code}: {content}')
106
+
107
+ def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
108
+ """Parses Anthropic's response."""
109
+ if response.status_code == 200:
110
+ return self.result(response.json())
111
+ else:
112
+ raise self._error(response.status_code, response.content)
@@ -0,0 +1,111 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for REST models."""
15
+
16
+ from typing import Any
17
+ import unittest
18
+ from unittest import mock
19
+ import langfun.core as lf
20
+ from langfun.core.llms import rest
21
+ import pyglove as pg
22
+ import requests
23
+
24
+
25
+ def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
26
+ del url, kwargs
27
+ response = requests.Response()
28
+ response.status_code = 200
29
+ response._content = pg.to_json_str({
30
+ 'content': [(
31
+ f'hello with temperature={json.get("temperature")}, '
32
+ f'top_k={json.get("top_k")}, '
33
+ f'top_p={json.get("top_p")}, '
34
+ f'max_tokens={json.get("max_tokens")}, '
35
+ f'stop={json.get("stop_sequences")}.'
36
+ )],
37
+ }).encode()
38
+ return response
39
+
40
+
41
+ def mock_requests_post_error(status_code, error_type, error_message):
42
+ def _mock_requests(url: str, json: dict[str, Any], **kwargs):
43
+ del url, json, kwargs
44
+ response = requests.Response()
45
+ response.status_code = status_code
46
+ response._content = pg.to_json_str(
47
+ {
48
+ 'error': {
49
+ 'type': error_type,
50
+ 'message': error_message,
51
+ }
52
+ }
53
+ ).encode()
54
+ return response
55
+
56
+ return _mock_requests
57
+
58
+
59
+ class RestTest(unittest.TestCase):
60
+
61
+ def setUp(self):
62
+ super().setUp()
63
+ self._lm = rest.REST(
64
+ api_endpoint='https://fake-api.com',
65
+ request=lambda x, o: dict(
66
+ model='test-model',
67
+ prompt=x.text,
68
+ temperature=0.0,
69
+ top_k=0.1,
70
+ top_p=0.2,
71
+ stop_sequences=['\n'],
72
+ max_tokens=4096,
73
+ ),
74
+ result=lambda x: lf.LMSamplingResult(
75
+ [lf.LMSample(c) for c in x['content']]),
76
+ headers=dict(api_key='fake_key'),
77
+ )
78
+
79
+ def test_call(self):
80
+ with mock.patch('requests.Session.post') as mock_request:
81
+ mock_request.side_effect = mock_requests_post
82
+ self.assertEqual(self._lm.model_id, 'unknown')
83
+ response = self._lm(
84
+ 'hello', temperature=0.0, top_k=0.1, top_p=0.2, stop=['\n'])
85
+ self.assertEqual(
86
+ response.text,
87
+ (
88
+ 'hello with temperature=0.0, top_k=0.1, top_p=0.2, '
89
+ "max_tokens=4096, stop=['\\n']."
90
+ ),
91
+ )
92
+ self.assertIsNone(response.usage)
93
+
94
+ def test_call_errors(self):
95
+ for status_code, error_type, error_message in [
96
+ (429, 'rate_limit', 'Rate limit exceeded.'),
97
+ (529, 'service_unavailable', 'Service unavailable.'),
98
+ (500, 'bad_request', 'Bad request.'),
99
+ ]:
100
+ with mock.patch('requests.Session.post') as mock_mm_request:
101
+ mock_mm_request.side_effect = mock_requests_post_error(
102
+ status_code, error_type, error_message
103
+ )
104
+ with self.assertRaisesRegex(
105
+ Exception, f'.*{status_code}: .*{error_message}'
106
+ ):
107
+ self._lm('hello', max_attempts=1)
108
+
109
+
110
+ if __name__ == '__main__':
111
+ unittest.main()
@@ -179,7 +179,16 @@ class VertexAI(lf.LanguageModel):
179
179
  assert self._api_initialized, 'Vertex AI API is not initialized.'
180
180
  # TODO(yifenglu): It seems this exception is due to the instability of the
181
181
  # API. We should revisit this later.
182
- retry_on_errors = [(Exception, 'InternalServerError')]
182
+ retry_on_errors = [
183
+ (Exception, 'InternalServerError'),
184
+ (
185
+ Exception,
186
+ (
187
+ 'ValueError: Response candidate content has no parts (and thus'
188
+ ' no text).'
189
+ ),
190
+ ),
191
+ ]
183
192
 
184
193
  return lf.concurrent_execute(
185
194
  self._sample_single,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.0.2.dev20240603
3
+ Version: 0.0.2.dev20240605
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -1,5 +1,5 @@
1
- langfun/__init__.py,sha256=LFsDp22pTeJHmzzKEg2OLmSVOPAym00DyF38LmrL2n4,2263
2
- langfun/core/__init__.py,sha256=nFJx6X7oB7IIWsAQqjbgZ_ScH-gsKg53YgAkuDvY0cw,4296
1
+ langfun/__init__.py,sha256=P62MnqA6-f0h8iYfQ3MT6Yg7a4qRnQeb4GrIn6dcSnY,2274
2
+ langfun/core/__init__.py,sha256=ZheiCpop_GAZbVpnSS-uPBJaEEM15Td5xFGGizSGqko,4514
3
3
  langfun/core/component.py,sha256=oxesbC0BoE_TbtxwW5x-BAZWxZyyJbuPiX5S38RqCv0,9909
4
4
  langfun/core/component_test.py,sha256=uR-_Sz_42Jxc5qzLIB-f5_pXmNwnC01Xlbv5NOQSeSU,8021
5
5
  langfun/core/concurrent.py,sha256=TRc49pJ3HQro2kb5FtcWkHjhBm8UcgE8RJybU5cU3-0,24537
@@ -8,7 +8,7 @@ langfun/core/console.py,sha256=bk5rNPNm9rMGW5YT2HixxU04p2umnoabn5SDz6Dqe88,2317
8
8
  langfun/core/console_test.py,sha256=5SYJdxpJGLgdSSQqqMPoA1X6jpsLD8rgcyk-EgI65oE,1077
9
9
  langfun/core/langfunc.py,sha256=RvIcRjIq0jWYRu1xim-FYe4HSrt97r3GMBO_PuagUmw,11060
10
10
  langfun/core/langfunc_test.py,sha256=_mfARnakX3oji5HDigFSLMd6yQ2wma-2Mgbztwqn73g,8501
11
- langfun/core/language_model.py,sha256=owNCgefGoPeRCHrxBhMtNdOj3orbeVml4eqLf1n211o,20760
11
+ langfun/core/language_model.py,sha256=PocBg1t3uB0a_bJntLW5aagHhNbZsVdp2iduSBEW6ro,21240
12
12
  langfun/core/language_model_test.py,sha256=NZaSUls6cZdtxiqkqumWbtkx9zgNiJlsviYZOWkuHig,20137
13
13
  langfun/core/memory.py,sha256=f-asN1F7Vehgdn_fK84v73GrEUOxRtaW934keutTKjk,2416
14
14
  langfun/core/message.py,sha256=Rw3yC9HyGRjMhfDgyNjGlSCALEyDDbJ0_o6qTXeeDiQ,15738
@@ -48,20 +48,22 @@ langfun/core/eval/patching.py,sha256=R0s2eAd1m97exQt06dmUL0V_MBG0W2Hxg7fhNB7cXW0
48
48
  langfun/core/eval/patching_test.py,sha256=8kCd54Egjju22FMgtJuxEsrXkW8ifs-UUBHtrCG1L6w,4775
49
49
  langfun/core/eval/scoring.py,sha256=1J7IATo-8FXUR0SBqk9icztHiM0lWkBFcWUo-vUURgQ,6376
50
50
  langfun/core/eval/scoring_test.py,sha256=O8olHbrUEg60gMxwOkWzKBJZpZoUlmVnBANX5Se2SXM,4546
51
- langfun/core/llms/__init__.py,sha256=XHK_ZpfEppCF-ixfpIvmrOvH2P6XgkjMhS7zBa8yYk4,4302
52
- langfun/core/llms/anthropic.py,sha256=7W9YdPN3SlAFhAIQlihMkrpo7tTY_4NvD0KIlCrqcsk,8505
53
- langfun/core/llms/anthropic_test.py,sha256=TMM30myyEhwF99Le4RvJEXOn8RYl0q1FRkt9Q9nl1jk,5540
51
+ langfun/core/llms/__init__.py,sha256=3G7pJISeClgHGV34Gy2t_Nih4N08UhGbWe6uAff8TnA,4364
52
+ langfun/core/llms/anthropic.py,sha256=pBYe8dVwswxKaqhNjA_jtZbyfvOaXtEo399Zty242iA,7097
53
+ langfun/core/llms/anthropic_test.py,sha256=T-swuMkfnlgs8Fpif4rtXs579exGk0TsbLMirXDZCkg,5533
54
54
  langfun/core/llms/fake.py,sha256=Dd7-6ka9pFf3fcWZyczamjOqQ91MOI-m7We3Oc9Ffmo,2927
55
55
  langfun/core/llms/fake_test.py,sha256=ipKfdOcuqVcJ8lDXVpnBVb9HHG0hAVkFkMoHpWjC2cI,7212
56
56
  langfun/core/llms/google_genai.py,sha256=Rl5a5CyF_6Y0BYYArKk8yMaenv1rH3MUQLy6b3dfMRI,10202
57
57
  langfun/core/llms/google_genai_test.py,sha256=iTISk3tJ4-3gjWmzcKQhEbH3ke4AkEiCu8rAGtB7SvU,7535
58
- langfun/core/llms/groq.py,sha256=NaGItVL_pkOpqPpI4bPGU27xLFRoaeizZ49v2s-4ERs,7844
59
- langfun/core/llms/groq_test.py,sha256=M6GtlrsOvDun_j-sR8cPh4W_moHWZNSTiThu3kuwbbc,5281
60
- langfun/core/llms/llama_cpp.py,sha256=Y_KkMUf3Xfac49koMUtUslKl3h-HWp3-ntq7Jaa3bdo,2385
61
- langfun/core/llms/llama_cpp_test.py,sha256=ZxC6defGd_HX9SFRU9U4cJiQnBKundbOrchbXuC1Z2M,1683
58
+ langfun/core/llms/groq.py,sha256=pqtyOZ_1_OJMOg8xATWT_B_SVbuT9nMRf4VkH9GzW8g,6308
59
+ langfun/core/llms/groq_test.py,sha256=GYF_Qtq5S1H1TrKH38t6_lkdroqT7v-joYLDKnmS9e0,5274
60
+ langfun/core/llms/llama_cpp.py,sha256=9tXQntSCDtjTF3bnyJrAPCr4N6wycy5nXYvp9uduygE,2843
61
+ langfun/core/llms/llama_cpp_test.py,sha256=MWO_qaOeKjRniGjcaWPDScd7HPaIJemqUZoslrt4FPs,1806
62
62
  langfun/core/llms/openai.py,sha256=IN46gIqfY6aEEfxCPNmyH1hrep3oWBhJDwVFilfqNkM,13657
63
63
  langfun/core/llms/openai_test.py,sha256=QWDzTgi8F2Z9u9ip6alK4rDEp_YraVTxWlDX5XOsKJk,14858
64
- langfun/core/llms/vertexai.py,sha256=eILbXoMSza5r4FLGlIdH6-eD8Ggy9Z4PdjLaBDxy29A,11162
64
+ langfun/core/llms/rest.py,sha256=laopuq-zD8V-3Y6eFDngftHEbE66VlUkCD2-rvvRaLU,3388
65
+ langfun/core/llms/rest_test.py,sha256=Zw08Xbl_O2OQuUglLXHsPsY5KW2VEcPGl1gR8PRHuFY,3449
66
+ langfun/core/llms/vertexai.py,sha256=wIpckH-rMHUBA1vhauQk4LVrSsPQEsVntz7kLDKwm9g,11359
65
67
  langfun/core/llms/vertexai_test.py,sha256=G18BG36h5KvmX2zutDTLjtYCRjTuP_nWIFm4FMnLnyY,7651
66
68
  langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
67
69
  langfun/core/llms/cache/base.py,sha256=cFfYvOIUae842pncqCAsRvqXCk2AnAsRYVx0mcIoAeY,3338
@@ -111,8 +113,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
111
113
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
112
114
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
113
115
  langfun/core/templates/selfplay_test.py,sha256=DYVrkk7uNKCqJGEHH31HssU2BPuMItU1vJLzfcXIlYg,2156
114
- langfun-0.0.2.dev20240603.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
115
- langfun-0.0.2.dev20240603.dist-info/METADATA,sha256=G7sGJIdQQ5xbDndwUtXcIw1m-xHA1taiAHngx0mamkk,3550
116
- langfun-0.0.2.dev20240603.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
117
- langfun-0.0.2.dev20240603.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
118
- langfun-0.0.2.dev20240603.dist-info/RECORD,,
116
+ langfun-0.0.2.dev20240605.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
117
+ langfun-0.0.2.dev20240605.dist-info/METADATA,sha256=NMWv4oYMcRuZXl22coMTShqGC8hj_Y2PGfTk7n-Alt0,3550
118
+ langfun-0.0.2.dev20240605.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
119
+ langfun-0.0.2.dev20240605.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
120
+ langfun-0.0.2.dev20240605.dist-info/RECORD,,