langfun 0.0.2.dev20240420__py3-none-any.whl → 0.0.2.dev20240421__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -66,6 +66,13 @@ from langfun.core.llms.anthropic import Claude3Opus
66
66
  from langfun.core.llms.anthropic import Claude3Sonnet
67
67
  from langfun.core.llms.anthropic import Claude3Haiku
68
68
 
69
+ from langfun.core.llms.groq import Groq
70
+ from langfun.core.llms.groq import GroqLlama3_70B
71
+ from langfun.core.llms.groq import GroqLlama3_8B
72
+ from langfun.core.llms.groq import GroqLlama2_70B
73
+ from langfun.core.llms.groq import GroqMistral_8x7B
74
+ from langfun.core.llms.groq import GroqGemma7B_IT
75
+
69
76
 
70
77
  # LLaMA C++ models.
71
78
  from langfun.core.llms.llama_cpp import LlamaCppRemote
@@ -0,0 +1,251 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Language models from Groq."""
15
+
16
+ import functools
17
+ import os
18
+ from typing import Annotated, Any
19
+
20
+ import langfun.core as lf
21
+ from langfun.core import modalities as lf_modalities
22
+ import pyglove as pg
23
+ import requests
24
+
25
+
26
+ SUPPORTED_MODELS_AND_SETTINGS = {
27
+ # Refer https://console.groq.com/docs/models
28
+ 'llama3-8b-8192': pg.Dict(max_tokens=8192, max_concurrency=16),
29
+ 'llama3-70b-8192': pg.Dict(max_tokens=8192, max_concurrency=16),
30
+ 'llama2-70b-4096': pg.Dict(max_tokens=4096, max_concurrency=16),
31
+ 'mixtral-8x7b-32768': pg.Dict(max_tokens=32768, max_concurrency=16),
32
+ 'gemma-7b-it': pg.Dict(max_tokens=8192, max_concurrency=16),
33
+ }
34
+
35
+
36
+ class GroqError(Exception): # pylint: disable=g-bad-exception-name
37
+ """Base class for Groq errors."""
38
+
39
+
40
+ class RateLimitError(GroqError):
41
+ """Error for rate limit reached."""
42
+
43
+
44
+ class OverloadedError(GroqError):
45
+ """Groq's server is temporarily overloaded."""
46
+
47
+
48
+ _CHAT_COMPLETE_API_ENDPOINT = 'https://api.groq.com/openai/v1/chat/completions'
49
+
50
+
51
+ @lf.use_init_args(['model'])
52
+ class Groq(lf.LanguageModel):
53
+ """Groq LLMs through REST APIs (OpenAI compatible).
54
+
55
+ See https://platform.openai.com/docs/api-reference/chat
56
+ """
57
+
58
+ model: pg.typing.Annotated[
59
+ pg.typing.Enum(
60
+ pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
61
+ ),
62
+ 'The name of the model to use.',
63
+ ]
64
+
65
+ multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
66
+ False
67
+ )
68
+
69
+ api_key: Annotated[
70
+ str | None,
71
+ (
72
+ 'API key. If None, the key will be read from environment variable '
73
+ "'GROQ_API_KEY'."
74
+ ),
75
+ ] = None
76
+
77
+ def _on_bound(self):
78
+ super()._on_bound()
79
+ self._api_key = None
80
+ self.__dict__.pop('_api_initialized', None)
81
+
82
+ @functools.cached_property
83
+ def _api_initialized(self):
84
+ api_key = self.api_key or os.environ.get('GROQ_API_KEY', None)
85
+ if not api_key:
86
+ raise ValueError(
87
+ 'Please specify `api_key` during `__init__` or set environment '
88
+ 'variable `GROQ_API_KEY` with your Anthropic API key.'
89
+ )
90
+ self._api_key = api_key
91
+ return True
92
+
93
+ @property
94
+ def model_id(self) -> str:
95
+ """Returns a string to identify the model."""
96
+ return self.model
97
+
98
+ @property
99
+ def max_concurrency(self) -> int:
100
+ return SUPPORTED_MODELS_AND_SETTINGS[self.model].max_concurrency
101
+
102
+ def _get_request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
103
+ """Returns a dict as request arguments."""
104
+ # `logprobs` and `top_logprobs` flags are not supported on Groq yet.
105
+ args = dict(
106
+ model=self.model,
107
+ n=options.n,
108
+ stream=False,
109
+ )
110
+
111
+ if options.temperature is not None:
112
+ args['temperature'] = options.temperature
113
+ if options.max_tokens is not None:
114
+ args['max_tokens'] = options.max_tokens
115
+ if options.top_p is not None:
116
+ args['top_p'] = options.top_p
117
+ if options.stop:
118
+ args['stop'] = options.stop
119
+ return args
120
+
121
+ def _content_from_message(self, prompt: lf.Message) -> list[dict[str, Any]]:
122
+ """Converts an message to Anthropic's content protocol (list of dicts)."""
123
+ # Refer: https://platform.openai.com/docs/api-reference/chat/create
124
+ content = []
125
+ for chunk in prompt.chunk():
126
+ if isinstance(chunk, str):
127
+ item = dict(type='text', text=chunk)
128
+ elif (
129
+ self.multimodal
130
+ and isinstance(chunk, lf_modalities.Image)
131
+ and chunk.uri
132
+ ):
133
+ # NOTE(daiyip): Groq only support image URL.
134
+ item = dict(type='image_url', image_url=chunk.uri)
135
+ else:
136
+ raise ValueError(f'Unsupported modality object: {chunk!r}.')
137
+ content.append(item)
138
+ return content
139
+
140
+ def _message_from_choice(self, choice: dict[str, Any]) -> lf.Message:
141
+ """Converts Anthropic's content protocol to message."""
142
+ # Refer: https://platform.openai.com/docs/api-reference/chat/create
143
+ content = choice['message']['content']
144
+ if isinstance(content, str):
145
+ return lf.AIMessage(content)
146
+ return lf.AIMessage.from_chunks(
147
+ [x['text'] for x in content if x['type'] == 'text']
148
+ )
149
+
150
+ def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
151
+ """Parses Anthropic's response."""
152
+ # Refer: https://platform.openai.com/docs/api-reference/chat/object
153
+ output = response.json()
154
+ if response.status_code == 200:
155
+ samples = [
156
+ lf.LMSample(self._message_from_choice(choice), score=0.0)
157
+ for choice in output['choices']
158
+ ]
159
+ usage = output['usage']
160
+ return lf.LMSamplingResult(
161
+ samples,
162
+ usage=lf.LMSamplingUsage(
163
+ prompt_tokens=usage['prompt_tokens'],
164
+ completion_tokens=usage['completion_tokens'],
165
+ total_tokens=usage['total_tokens'],
166
+ ),
167
+ )
168
+ else:
169
+ # https://platform.openai.com/docs/guides/error-codes/api-errors
170
+ if response.status_code == 429:
171
+ error_cls = RateLimitError
172
+ elif response.status_code in (500, 503):
173
+ error_cls = OverloadedError
174
+ else:
175
+ error_cls = GroqError
176
+ error = output['error']
177
+ raise error_cls(f'{error["type"]}: {error["message"]}')
178
+
179
+ def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
180
+ assert self._api_initialized
181
+ return self._parallel_execute_with_currency_control(
182
+ self._sample_single,
183
+ prompts,
184
+ retry_on_errors=(RateLimitError, OverloadedError),
185
+ )
186
+
187
+ def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
188
+ request = dict()
189
+ request.update(self._get_request_args(self.sampling_options))
190
+ request.update(
191
+ dict(
192
+ messages=[
193
+ dict(role='user', content=self._content_from_message(prompt))
194
+ ]
195
+ )
196
+ )
197
+ response = requests.post(
198
+ _CHAT_COMPLETE_API_ENDPOINT,
199
+ json=request,
200
+ headers={
201
+ 'Authorization': f'Bearer {self._api_key}',
202
+ 'Content-Type': 'application/json',
203
+ },
204
+ timeout=self.timeout,
205
+ )
206
+ return self._parse_response(response)
207
+
208
+
209
+ class GroqLlama3_8B(Groq): # pylint: disable=invalid-name
210
+ """Llama3-8B with 8K context window.
211
+
212
+ See: https://huggingface.co/meta-llama/Meta-Llama-3-8B
213
+ """
214
+
215
+ model = 'llama3-8b-8192'
216
+
217
+
218
+ class GroqLlama3_70B(Groq): # pylint: disable=invalid-name
219
+ """Llama3-70B with 8K context window.
220
+
221
+ See: https://huggingface.co/meta-llama/Meta-Llama-3-70B
222
+ """
223
+
224
+ model = 'llama3-70b-8192'
225
+
226
+
227
+ class GroqLlama2_70B(Groq): # pylint: disable=invalid-name
228
+ """Llama2-70B with 4K context window.
229
+
230
+ See: https://huggingface.co/meta-llama/Llama-2-70b
231
+ """
232
+
233
+ model = 'llama2-70b-4096'
234
+
235
+
236
+ class GroqMistral_8x7B(Groq): # pylint: disable=invalid-name
237
+ """Mixtral 8x7B with 32K context window.
238
+
239
+ See: https://huggingface.co/meta-llama/Llama-2-70b
240
+ """
241
+
242
+ model = 'mixtral-8x7b-32768'
243
+
244
+
245
+ class GroqGemma7B_IT(Groq): # pylint: disable=invalid-name
246
+ """Gemma 7B with 8K context window.
247
+
248
+ See: https://huggingface.co/google/gemma-1.1-7b-it
249
+ """
250
+
251
+ model = 'gemma-7b-it'
@@ -0,0 +1,170 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for Groq models."""
15
+
16
+ import os
17
+ from typing import Any
18
+ import unittest
19
+ from unittest import mock
20
+ from langfun.core import modalities as lf_modalities
21
+ from langfun.core.llms import groq
22
+ import pyglove as pg
23
+ import requests
24
+
25
+
26
+ def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
27
+ del url, kwargs
28
+
29
+ response = requests.Response()
30
+ response.status_code = 200
31
+ response._content = pg.to_json_str({
32
+ 'choices': [{
33
+ 'message': {
34
+ 'content': [{
35
+ 'type': 'text',
36
+ 'text': (
37
+ f'hello with temperature={json.get("temperature")}, '
38
+ f'top_p={json.get("top_p")}, '
39
+ f'max_tokens={json.get("max_tokens")}, '
40
+ f'stop={json.get("stop")}.'
41
+ ),
42
+ }],
43
+ }
44
+ }],
45
+ 'usage': {
46
+ 'prompt_tokens': 2,
47
+ 'completion_tokens': 1,
48
+ 'total_tokens': 3,
49
+ },
50
+ }).encode()
51
+ return response
52
+
53
+
54
+ def mock_mm_requests_post(url: str, json: dict[str, Any], **kwargs):
55
+ del url, kwargs
56
+ v = json['messages'][0]['content'][0]
57
+ image = lf_modalities.Image.from_uri(v['image_url'])
58
+
59
+ response = requests.Response()
60
+ response.status_code = 200
61
+ response._content = pg.to_json_str({
62
+ 'choices': [
63
+ {
64
+ 'message': {
65
+ 'content': [{
66
+ 'type': 'text',
67
+ 'text': image.uri,
68
+ }],
69
+ }
70
+ }
71
+ ],
72
+ 'usage': {
73
+ 'prompt_tokens': 2,
74
+ 'completion_tokens': 1,
75
+ 'total_tokens': 3,
76
+ },
77
+ }).encode()
78
+ return response
79
+
80
+
81
+ def mock_requests_post_error(status_code, error_type, error_message):
82
+ def _mock_requests(url: str, json: dict[str, Any], **kwargs):
83
+ del url, json, kwargs
84
+ response = requests.Response()
85
+ response.status_code = status_code
86
+ response._content = pg.to_json_str(
87
+ {
88
+ 'error': {
89
+ 'type': error_type,
90
+ 'message': error_message,
91
+ }
92
+ }
93
+ ).encode()
94
+ return response
95
+
96
+ return _mock_requests
97
+
98
+
99
+ class AuthropicTest(unittest.TestCase):
100
+
101
+ def test_basics(self):
102
+ self.assertEqual(groq.GroqMistral_8x7B().model_id, 'mixtral-8x7b-32768')
103
+ self.assertEqual(groq.GroqMistral_8x7B().max_concurrency, 16)
104
+
105
+ def test_api_key(self):
106
+ lm = groq.GroqMistral_8x7B()
107
+ with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
108
+ lm('hi')
109
+
110
+ with mock.patch('requests.post') as mock_request:
111
+ mock_request.side_effect = mock_requests_post
112
+
113
+ lm = groq.GroqMistral_8x7B(api_key='fake key')
114
+ self.assertRegex(lm('hi').text, 'hello.*')
115
+
116
+ os.environ['GROQ_API_KEY'] = 'abc'
117
+ lm = groq.GroqMistral_8x7B()
118
+ self.assertRegex(lm('hi').text, 'hello.*')
119
+ del os.environ['GROQ_API_KEY']
120
+
121
+ def test_call(self):
122
+ with mock.patch('requests.post') as mock_request:
123
+ mock_request.side_effect = mock_requests_post
124
+ lm = groq.GroqLlama3_70B(api_key='fake_key')
125
+ response = lm(
126
+ 'hello',
127
+ temperature=0.0,
128
+ max_tokens=1024,
129
+ top_k=0.1,
130
+ top_p=0.2,
131
+ stop=['\n'],
132
+ )
133
+ self.assertEqual(
134
+ response.text,
135
+ (
136
+ 'hello with temperature=0.0, top_p=0.2, '
137
+ "max_tokens=1024, stop=['\\n']."
138
+ ),
139
+ )
140
+ self.assertIsNotNone(response.usage)
141
+ self.assertIsNotNone(response.usage.prompt_tokens, 2)
142
+ self.assertIsNotNone(response.usage.completion_tokens, 1)
143
+ self.assertIsNotNone(response.usage.total_tokens, 3)
144
+
145
+ def test_mm_call(self):
146
+ with mock.patch('requests.post') as mock_mm_request:
147
+ mock_mm_request.side_effect = mock_mm_requests_post
148
+ lm = groq.GroqLlama3_70B(multimodal=True, api_key='fake_key')
149
+ response = lm(lf_modalities.Image.from_uri('https://fake/image.jpg'))
150
+ self.assertEqual(response.text, 'https://fake/image.jpg')
151
+
152
+ def test_call_errors(self):
153
+ for status_code, error_type, error_message in [
154
+ (429, 'rate_limit', 'Rate limit exceeded.'),
155
+ (503, 'service_unavailable', 'Service unavailable.'),
156
+ (500, 'bad_request', 'Bad request.'),
157
+ ]:
158
+ with mock.patch('requests.post') as mock_mm_request:
159
+ mock_mm_request.side_effect = mock_requests_post_error(
160
+ status_code, error_type, error_message
161
+ )
162
+ lm = groq.GroqLlama3_70B(api_key='fake_key')
163
+ with self.assertRaisesRegex(
164
+ Exception, f'{error_type}: {error_message}'
165
+ ):
166
+ lm('hello', lm=lm, max_attempts=1)
167
+
168
+
169
+ if __name__ == '__main__':
170
+ unittest.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.0.2.dev20240420
3
+ Version: 0.0.2.dev20240421
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -46,13 +46,15 @@ langfun/core/eval/matching.py,sha256=aqNlYrlav7YmsB7rUlsdfoi1RLA5CYqn2RGPxRlPc78
46
46
  langfun/core/eval/matching_test.py,sha256=FFHYD7IDuKe5RMjkx74ksukiwUhO5a_SS340JaIPMws,4898
47
47
  langfun/core/eval/scoring.py,sha256=aKeanBJf1yO3Q9JEtgPWoiZk_3M_GiqwXVXX7x_g22w,6172
48
48
  langfun/core/eval/scoring_test.py,sha256=YH1cIxBWtfdKcAV9Fh10vLkV5J-gxk8b6nxW4Z2u5pk,4024
49
- langfun/core/llms/__init__.py,sha256=c_9lVKzFjnxHKgRjY_dUiJzBmW1jWALy3mtYv0uMyl0,2953
49
+ langfun/core/llms/__init__.py,sha256=1bPg1QI8duOZCYINm-jWi094x0JtLmsk4KX60qIC_gs,3245
50
50
  langfun/core/llms/anthropic.py,sha256=p-tjttvithBg2b4tgxIS2F-Zk5AYAh5e-lW-8e1p4wc,7865
51
51
  langfun/core/llms/anthropic_test.py,sha256=OuLDxeiPRdqsfKILS0R6jJLTRs3-1KCIotPPr7IbIDU,5502
52
52
  langfun/core/llms/fake.py,sha256=b-Xk5IPTbUt-elsyzd_i3n1tqzc_kgETXrEvgJruSMk,2824
53
53
  langfun/core/llms/fake_test.py,sha256=ZlDQgL41EX3eYTfBQNp2nB2LciqCmtoHgCsGvW4XhwI,4184
54
54
  langfun/core/llms/google_genai.py,sha256=n8zyJwh9UCTgb6-8LyvmjVNFGZQ4-zfzZ0ulkhHAnR8,8624
55
55
  langfun/core/llms/google_genai_test.py,sha256=_UcGTfl16-aDUlEWFC2W2F8y9jPUs53RBYA6MOCpGXw,7525
56
+ langfun/core/llms/groq.py,sha256=ZULexLJoU_IJ6vjQimMsmv0xnCOTPGrJVkPLbjfqC5w,7600
57
+ langfun/core/llms/groq_test.py,sha256=o95z76qwOwmsOxC2WhHJ4roFzxFRoVjkC7KETlfsVis,5250
56
58
  langfun/core/llms/llama_cpp.py,sha256=Y_KkMUf3Xfac49koMUtUslKl3h-HWp3-ntq7Jaa3bdo,2385
57
59
  langfun/core/llms/llama_cpp_test.py,sha256=ZxC6defGd_HX9SFRU9U4cJiQnBKundbOrchbXuC1Z2M,1683
58
60
  langfun/core/llms/openai.py,sha256=Z_pujF3B2QMzWBgOdV67DKAfZ8Wmyeb_6F9BkcGHyaE,12344
@@ -99,8 +101,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
99
101
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
100
102
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
101
103
  langfun/core/templates/selfplay_test.py,sha256=DYVrkk7uNKCqJGEHH31HssU2BPuMItU1vJLzfcXIlYg,2156
102
- langfun-0.0.2.dev20240420.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
103
- langfun-0.0.2.dev20240420.dist-info/METADATA,sha256=R4bRp7OO2PSjDyKe48YvIbMptLTkeqesP98ZxJ17woc,3405
104
- langfun-0.0.2.dev20240420.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
105
- langfun-0.0.2.dev20240420.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
106
- langfun-0.0.2.dev20240420.dist-info/RECORD,,
104
+ langfun-0.0.2.dev20240421.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
105
+ langfun-0.0.2.dev20240421.dist-info/METADATA,sha256=-43mHV_1OLXkg277zBFm6nRAcaS8VMfePW1QzyrIc2o,3405
106
+ langfun-0.0.2.dev20240421.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
107
+ langfun-0.0.2.dev20240421.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
108
+ langfun-0.0.2.dev20240421.dist-info/RECORD,,