langfun 0.0.2.dev20240330__py3-none-any.whl → 0.0.2.dev20240511__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/__init__.py +7 -0
- langfun/core/__init__.py +1 -0
- langfun/core/coding/python/correction.py +0 -7
- langfun/core/component.py +6 -0
- langfun/core/component_test.py +1 -0
- langfun/core/eval/__init__.py +15 -0
- langfun/core/eval/base.py +665 -95
- langfun/core/eval/base_test.py +224 -53
- langfun/core/eval/matching.py +48 -30
- langfun/core/eval/matching_test.py +25 -3
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +19 -10
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/langfunc.py +1 -22
- langfun/core/langfunc_test.py +10 -4
- langfun/core/language_model.py +130 -24
- langfun/core/language_model_test.py +249 -26
- langfun/core/llms/__init__.py +27 -2
- langfun/core/llms/anthropic.py +263 -0
- langfun/core/llms/anthropic_test.py +167 -0
- langfun/core/llms/cache/in_memory_test.py +37 -28
- langfun/core/llms/fake.py +34 -25
- langfun/core/llms/fake_test.py +122 -11
- langfun/core/llms/google_genai.py +8 -0
- langfun/core/llms/google_genai_test.py +8 -3
- langfun/core/llms/groq.py +260 -0
- langfun/core/llms/groq_test.py +170 -0
- langfun/core/llms/llama_cpp.py +3 -1
- langfun/core/llms/openai.py +100 -81
- langfun/core/llms/openai_test.py +287 -60
- langfun/core/llms/vertexai.py +291 -0
- langfun/core/llms/vertexai_test.py +233 -0
- langfun/core/modalities/image.py +1 -3
- langfun/core/modalities/mime.py +6 -0
- langfun/core/modalities/video.py +6 -5
- langfun/core/structured/__init__.py +5 -0
- langfun/core/structured/completion_test.py +2 -2
- langfun/core/structured/function_generation.py +245 -0
- langfun/core/structured/function_generation_test.py +329 -0
- langfun/core/structured/mapping.py +61 -3
- langfun/core/structured/mapping_test.py +17 -0
- langfun/core/structured/parsing_test.py +18 -13
- langfun/core/structured/prompting.py +61 -12
- langfun/core/structured/prompting_test.py +122 -12
- langfun/core/structured/schema.py +38 -6
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +36 -7
- langfun/core/structured/scoring.py +4 -1
- langfun/core/structured/scoring_test.py +6 -0
- langfun/core/template.py +147 -11
- langfun/core/template_test.py +75 -0
- langfun/core/templates/selfplay_test.py +6 -2
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/METADATA +3 -2
- langfun-0.0.2.dev20240511.dist-info/RECORD +112 -0
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/top_level.txt +0 -0
langfun/core/llms/fake_test.py
CHANGED
@@ -25,7 +25,24 @@ class EchoTest(unittest.TestCase):
|
|
25
25
|
def test_sample(self):
|
26
26
|
lm = fakelm.Echo()
|
27
27
|
self.assertEqual(
|
28
|
-
lm.sample(['hi']),
|
28
|
+
lm.sample(['hi']),
|
29
|
+
[
|
30
|
+
lf.LMSamplingResult(
|
31
|
+
[
|
32
|
+
lf.LMSample(
|
33
|
+
lf.AIMessage(
|
34
|
+
'hi',
|
35
|
+
score=1.0,
|
36
|
+
logprobs=None,
|
37
|
+
usage=lf.LMSamplingUsage(2, 2, 4),
|
38
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
39
|
+
),
|
40
|
+
score=1.0,
|
41
|
+
logprobs=None,
|
42
|
+
)
|
43
|
+
],
|
44
|
+
lf.LMSamplingUsage(2, 2, 4))
|
45
|
+
]
|
29
46
|
)
|
30
47
|
|
31
48
|
def test_call(self):
|
@@ -34,8 +51,8 @@ class EchoTest(unittest.TestCase):
|
|
34
51
|
with contextlib.redirect_stdout(string_io):
|
35
52
|
self.assertEqual(lm('hi'), 'hi')
|
36
53
|
debug_info = string_io.getvalue()
|
37
|
-
self.assertIn('[0] LM INFO
|
38
|
-
self.assertIn('[0] PROMPT SENT TO LM
|
54
|
+
self.assertIn('[0] LM INFO', debug_info)
|
55
|
+
self.assertIn('[0] PROMPT SENT TO LM', debug_info)
|
39
56
|
self.assertIn('[0] LM RESPONSE', debug_info)
|
40
57
|
|
41
58
|
def test_score(self):
|
@@ -53,11 +70,45 @@ class StaticResponseTest(unittest.TestCase):
|
|
53
70
|
lm = fakelm.StaticResponse(canned_response)
|
54
71
|
self.assertEqual(
|
55
72
|
lm.sample(['hi']),
|
56
|
-
[
|
73
|
+
[
|
74
|
+
lf.LMSamplingResult(
|
75
|
+
[
|
76
|
+
lf.LMSample(
|
77
|
+
lf.AIMessage(
|
78
|
+
canned_response,
|
79
|
+
score=1.0,
|
80
|
+
logprobs=None,
|
81
|
+
usage=lf.LMSamplingUsage(2, 38, 40),
|
82
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
83
|
+
),
|
84
|
+
score=1.0,
|
85
|
+
logprobs=None,
|
86
|
+
)
|
87
|
+
],
|
88
|
+
usage=lf.LMSamplingUsage(2, 38, 40)
|
89
|
+
)
|
90
|
+
],
|
57
91
|
)
|
58
92
|
self.assertEqual(
|
59
93
|
lm.sample(['Tell me a joke.']),
|
60
|
-
[
|
94
|
+
[
|
95
|
+
lf.LMSamplingResult(
|
96
|
+
[
|
97
|
+
lf.LMSample(
|
98
|
+
lf.AIMessage(
|
99
|
+
canned_response,
|
100
|
+
score=1.0,
|
101
|
+
logprobs=None,
|
102
|
+
usage=lf.LMSamplingUsage(15, 38, 53),
|
103
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
104
|
+
),
|
105
|
+
score=1.0,
|
106
|
+
logprobs=None,
|
107
|
+
)
|
108
|
+
],
|
109
|
+
usage=lf.LMSamplingUsage(15, 38, 53)
|
110
|
+
)
|
111
|
+
],
|
61
112
|
)
|
62
113
|
|
63
114
|
def test_call(self):
|
@@ -69,8 +120,8 @@ class StaticResponseTest(unittest.TestCase):
|
|
69
120
|
self.assertEqual(lm('hi'), canned_response)
|
70
121
|
|
71
122
|
debug_info = string_io.getvalue()
|
72
|
-
self.assertIn('[0] LM INFO
|
73
|
-
self.assertIn('[0] PROMPT SENT TO LM
|
123
|
+
self.assertIn('[0] LM INFO', debug_info)
|
124
|
+
self.assertIn('[0] PROMPT SENT TO LM', debug_info)
|
74
125
|
self.assertIn('[0] LM RESPONSE', debug_info)
|
75
126
|
|
76
127
|
|
@@ -85,8 +136,38 @@ class StaticMappingTest(unittest.TestCase):
|
|
85
136
|
self.assertEqual(
|
86
137
|
lm.sample(['Hi', 'How are you?']),
|
87
138
|
[
|
88
|
-
lf.LMSamplingResult(
|
89
|
-
|
139
|
+
lf.LMSamplingResult(
|
140
|
+
[
|
141
|
+
lf.LMSample(
|
142
|
+
lf.AIMessage(
|
143
|
+
'Hello',
|
144
|
+
score=1.0,
|
145
|
+
logprobs=None,
|
146
|
+
usage=lf.LMSamplingUsage(2, 5, 7),
|
147
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
148
|
+
),
|
149
|
+
score=1.0,
|
150
|
+
logprobs=None,
|
151
|
+
)
|
152
|
+
],
|
153
|
+
usage=lf.LMSamplingUsage(2, 5, 7)
|
154
|
+
),
|
155
|
+
lf.LMSamplingResult(
|
156
|
+
[
|
157
|
+
lf.LMSample(
|
158
|
+
lf.AIMessage(
|
159
|
+
'I am fine, how about you?',
|
160
|
+
score=1.0,
|
161
|
+
logprobs=None,
|
162
|
+
usage=lf.LMSamplingUsage(12, 25, 37),
|
163
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
164
|
+
),
|
165
|
+
score=1.0,
|
166
|
+
logprobs=None,
|
167
|
+
)
|
168
|
+
],
|
169
|
+
usage=lf.LMSamplingUsage(12, 25, 37)
|
170
|
+
)
|
90
171
|
]
|
91
172
|
)
|
92
173
|
with self.assertRaises(KeyError):
|
@@ -104,8 +185,38 @@ class StaticSequenceTest(unittest.TestCase):
|
|
104
185
|
self.assertEqual(
|
105
186
|
lm.sample(['Hi', 'How are you?']),
|
106
187
|
[
|
107
|
-
lf.LMSamplingResult(
|
108
|
-
|
188
|
+
lf.LMSamplingResult(
|
189
|
+
[
|
190
|
+
lf.LMSample(
|
191
|
+
lf.AIMessage(
|
192
|
+
'Hello',
|
193
|
+
score=1.0,
|
194
|
+
logprobs=None,
|
195
|
+
usage=lf.LMSamplingUsage(2, 5, 7),
|
196
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
197
|
+
),
|
198
|
+
score=1.0,
|
199
|
+
logprobs=None,
|
200
|
+
)
|
201
|
+
],
|
202
|
+
usage=lf.LMSamplingUsage(2, 5, 7)
|
203
|
+
),
|
204
|
+
lf.LMSamplingResult(
|
205
|
+
[
|
206
|
+
lf.LMSample(
|
207
|
+
lf.AIMessage(
|
208
|
+
'I am fine, how about you?',
|
209
|
+
score=1.0,
|
210
|
+
logprobs=None,
|
211
|
+
usage=lf.LMSamplingUsage(12, 25, 37),
|
212
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
213
|
+
),
|
214
|
+
score=1.0,
|
215
|
+
logprobs=None,
|
216
|
+
)
|
217
|
+
],
|
218
|
+
usage=lf.LMSamplingUsage(12, 25, 37)
|
219
|
+
)
|
109
220
|
]
|
110
221
|
)
|
111
222
|
with self.assertRaises(IndexError):
|
@@ -34,6 +34,7 @@ class GenAI(lf.LanguageModel):
|
|
34
34
|
'gemini-pro-vision',
|
35
35
|
'text-bison-001',
|
36
36
|
'chat-bison-001',
|
37
|
+
'gemini-1.5-pro-latest',
|
37
38
|
],
|
38
39
|
'Model name.',
|
39
40
|
]
|
@@ -262,6 +263,13 @@ _GOOGLE_GENAI_MODEL_HUB = _ModelHub()
|
|
262
263
|
#
|
263
264
|
|
264
265
|
|
266
|
+
class GeminiPro1_5(GenAI): # pylint: disable=invalid-name
|
267
|
+
"""Gemini Pro latest model."""
|
268
|
+
|
269
|
+
model = 'gemini-1.5-pro-latest'
|
270
|
+
multimodal = True
|
271
|
+
|
272
|
+
|
265
273
|
class GeminiPro(GenAI):
|
266
274
|
"""Gemini Pro model."""
|
267
275
|
|
@@ -152,10 +152,15 @@ class GenAITest(unittest.TestCase):
|
|
152
152
|
)
|
153
153
|
|
154
154
|
def test_model_hub(self):
|
155
|
+
orig_get_model = genai.get_model
|
156
|
+
genai.get_model = mock_get_model
|
157
|
+
|
155
158
|
model = google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro')
|
156
159
|
self.assertIsNotNone(model)
|
157
160
|
self.assertIs(google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model)
|
158
161
|
|
162
|
+
genai.get_model = orig_get_model
|
163
|
+
|
159
164
|
def test_api_key_check(self):
|
160
165
|
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
161
166
|
_ = google_genai.GeminiPro()._api_initialized
|
@@ -167,7 +172,7 @@ class GenAITest(unittest.TestCase):
|
|
167
172
|
|
168
173
|
def test_call(self):
|
169
174
|
with mock.patch(
|
170
|
-
'google.generativeai.
|
175
|
+
'google.generativeai.GenerativeModel.generate_content',
|
171
176
|
) as mock_generate:
|
172
177
|
orig_get_model = genai.get_model
|
173
178
|
genai.get_model = mock_get_model
|
@@ -176,7 +181,7 @@ class GenAITest(unittest.TestCase):
|
|
176
181
|
lm = google_genai.GeminiPro(api_key='test_key')
|
177
182
|
self.maxDiff = None
|
178
183
|
self.assertEqual(
|
179
|
-
lm('hello', temperature=2.0, top_k=20).text,
|
184
|
+
lm('hello', temperature=2.0, top_k=20, max_tokens=1024).text,
|
180
185
|
(
|
181
186
|
'This is a response to hello with n=1, temperature=2.0, '
|
182
187
|
'top_p=None, top_k=20, max_tokens=1024, stop=None.'
|
@@ -197,7 +202,7 @@ class GenAITest(unittest.TestCase):
|
|
197
202
|
(
|
198
203
|
"hello to models/text-bison-001 with {'temperature': 2.0, "
|
199
204
|
"'top_k': 20, 'top_p': None, 'candidate_count': 1, "
|
200
|
-
"'max_output_tokens':
|
205
|
+
"'max_output_tokens': None, 'stop_sequences': None}"
|
201
206
|
),
|
202
207
|
)
|
203
208
|
genai.get_model = orig_get_model
|
@@ -0,0 +1,260 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Language models from Groq."""
|
15
|
+
|
16
|
+
import functools
|
17
|
+
import os
|
18
|
+
from typing import Annotated, Any
|
19
|
+
|
20
|
+
import langfun.core as lf
|
21
|
+
from langfun.core import modalities as lf_modalities
|
22
|
+
import pyglove as pg
|
23
|
+
import requests
|
24
|
+
|
25
|
+
|
26
|
+
SUPPORTED_MODELS_AND_SETTINGS = {
|
27
|
+
# Refer https://console.groq.com/docs/models
|
28
|
+
'llama3-8b-8192': pg.Dict(max_tokens=8192, max_concurrency=16),
|
29
|
+
'llama3-70b-8192': pg.Dict(max_tokens=8192, max_concurrency=16),
|
30
|
+
'llama2-70b-4096': pg.Dict(max_tokens=4096, max_concurrency=16),
|
31
|
+
'mixtral-8x7b-32768': pg.Dict(max_tokens=32768, max_concurrency=16),
|
32
|
+
'gemma-7b-it': pg.Dict(max_tokens=8192, max_concurrency=16),
|
33
|
+
}
|
34
|
+
|
35
|
+
|
36
|
+
class GroqError(Exception): # pylint: disable=g-bad-exception-name
|
37
|
+
"""Base class for Groq errors."""
|
38
|
+
|
39
|
+
|
40
|
+
class RateLimitError(GroqError):
|
41
|
+
"""Error for rate limit reached."""
|
42
|
+
|
43
|
+
|
44
|
+
class OverloadedError(GroqError):
|
45
|
+
"""Groq's server is temporarily overloaded."""
|
46
|
+
|
47
|
+
|
48
|
+
_CHAT_COMPLETE_API_ENDPOINT = 'https://api.groq.com/openai/v1/chat/completions'
|
49
|
+
|
50
|
+
|
51
|
+
@lf.use_init_args(['model'])
|
52
|
+
class Groq(lf.LanguageModel):
|
53
|
+
"""Groq LLMs through REST APIs (OpenAI compatible).
|
54
|
+
|
55
|
+
See https://platform.openai.com/docs/api-reference/chat
|
56
|
+
"""
|
57
|
+
|
58
|
+
model: pg.typing.Annotated[
|
59
|
+
pg.typing.Enum(
|
60
|
+
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
|
61
|
+
),
|
62
|
+
'The name of the model to use.',
|
63
|
+
]
|
64
|
+
|
65
|
+
multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
|
66
|
+
False
|
67
|
+
)
|
68
|
+
|
69
|
+
api_key: Annotated[
|
70
|
+
str | None,
|
71
|
+
(
|
72
|
+
'API key. If None, the key will be read from environment variable '
|
73
|
+
"'GROQ_API_KEY'."
|
74
|
+
),
|
75
|
+
] = None
|
76
|
+
|
77
|
+
def _on_bound(self):
|
78
|
+
super()._on_bound()
|
79
|
+
self._api_key = None
|
80
|
+
self.__dict__.pop('_api_initialized', None)
|
81
|
+
self.__dict__.pop('_session', None)
|
82
|
+
|
83
|
+
@functools.cached_property
|
84
|
+
def _api_initialized(self):
|
85
|
+
api_key = self.api_key or os.environ.get('GROQ_API_KEY', None)
|
86
|
+
if not api_key:
|
87
|
+
raise ValueError(
|
88
|
+
'Please specify `api_key` during `__init__` or set environment '
|
89
|
+
'variable `GROQ_API_KEY` with your Groq API key.'
|
90
|
+
)
|
91
|
+
self._api_key = api_key
|
92
|
+
return True
|
93
|
+
|
94
|
+
@functools.cached_property
|
95
|
+
def _session(self) -> requests.Session:
|
96
|
+
assert self._api_initialized
|
97
|
+
s = requests.Session()
|
98
|
+
s.headers.update({
|
99
|
+
'Authorization': f'Bearer {self._api_key}',
|
100
|
+
'Content-Type': 'application/json',
|
101
|
+
})
|
102
|
+
return s
|
103
|
+
|
104
|
+
@property
|
105
|
+
def model_id(self) -> str:
|
106
|
+
"""Returns a string to identify the model."""
|
107
|
+
return self.model
|
108
|
+
|
109
|
+
@property
|
110
|
+
def max_concurrency(self) -> int:
|
111
|
+
return SUPPORTED_MODELS_AND_SETTINGS[self.model].max_concurrency
|
112
|
+
|
113
|
+
def _get_request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
114
|
+
"""Returns a dict as request arguments."""
|
115
|
+
# `logprobs` and `top_logprobs` flags are not supported on Groq yet.
|
116
|
+
args = dict(
|
117
|
+
model=self.model,
|
118
|
+
n=options.n,
|
119
|
+
stream=False,
|
120
|
+
)
|
121
|
+
|
122
|
+
if options.temperature is not None:
|
123
|
+
args['temperature'] = options.temperature
|
124
|
+
if options.max_tokens is not None:
|
125
|
+
args['max_tokens'] = options.max_tokens
|
126
|
+
if options.top_p is not None:
|
127
|
+
args['top_p'] = options.top_p
|
128
|
+
if options.stop:
|
129
|
+
args['stop'] = options.stop
|
130
|
+
return args
|
131
|
+
|
132
|
+
def _content_from_message(self, prompt: lf.Message) -> list[dict[str, Any]]:
|
133
|
+
"""Converts an message to Groq's content protocol (list of dicts)."""
|
134
|
+
# Refer: https://platform.openai.com/docs/api-reference/chat/create
|
135
|
+
content = []
|
136
|
+
for chunk in prompt.chunk():
|
137
|
+
if isinstance(chunk, str):
|
138
|
+
item = dict(type='text', text=chunk)
|
139
|
+
elif (
|
140
|
+
self.multimodal
|
141
|
+
and isinstance(chunk, lf_modalities.Image)
|
142
|
+
and chunk.uri
|
143
|
+
):
|
144
|
+
# NOTE(daiyip): Groq only support image URL.
|
145
|
+
item = dict(type='image_url', image_url=chunk.uri)
|
146
|
+
else:
|
147
|
+
raise ValueError(f'Unsupported modality object: {chunk!r}.')
|
148
|
+
content.append(item)
|
149
|
+
return content
|
150
|
+
|
151
|
+
def _message_from_choice(self, choice: dict[str, Any]) -> lf.Message:
|
152
|
+
"""Converts Groq's content protocol to message."""
|
153
|
+
# Refer: https://platform.openai.com/docs/api-reference/chat/create
|
154
|
+
content = choice['message']['content']
|
155
|
+
if isinstance(content, str):
|
156
|
+
return lf.AIMessage(content)
|
157
|
+
return lf.AIMessage.from_chunks(
|
158
|
+
[x['text'] for x in content if x['type'] == 'text']
|
159
|
+
)
|
160
|
+
|
161
|
+
def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
|
162
|
+
"""Parses Groq's response."""
|
163
|
+
# Refer: https://platform.openai.com/docs/api-reference/chat/object
|
164
|
+
if response.status_code == 200:
|
165
|
+
output = response.json()
|
166
|
+
samples = [
|
167
|
+
lf.LMSample(self._message_from_choice(choice), score=0.0)
|
168
|
+
for choice in output['choices']
|
169
|
+
]
|
170
|
+
usage = output['usage']
|
171
|
+
return lf.LMSamplingResult(
|
172
|
+
samples,
|
173
|
+
usage=lf.LMSamplingUsage(
|
174
|
+
prompt_tokens=usage['prompt_tokens'],
|
175
|
+
completion_tokens=usage['completion_tokens'],
|
176
|
+
total_tokens=usage['total_tokens'],
|
177
|
+
),
|
178
|
+
)
|
179
|
+
else:
|
180
|
+
# https://platform.openai.com/docs/guides/error-codes/api-errors
|
181
|
+
if response.status_code == 429:
|
182
|
+
error_cls = RateLimitError
|
183
|
+
elif response.status_code in (500, 502, 503):
|
184
|
+
error_cls = OverloadedError
|
185
|
+
else:
|
186
|
+
error_cls = GroqError
|
187
|
+
raise error_cls(f'{response.status_code}: {response.content}')
|
188
|
+
|
189
|
+
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
190
|
+
assert self._api_initialized
|
191
|
+
return self._parallel_execute_with_currency_control(
|
192
|
+
self._sample_single,
|
193
|
+
prompts,
|
194
|
+
retry_on_errors=(RateLimitError, OverloadedError),
|
195
|
+
)
|
196
|
+
|
197
|
+
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
198
|
+
request = dict()
|
199
|
+
request.update(self._get_request_args(self.sampling_options))
|
200
|
+
request.update(
|
201
|
+
dict(
|
202
|
+
messages=[
|
203
|
+
dict(role='user', content=self._content_from_message(prompt))
|
204
|
+
]
|
205
|
+
)
|
206
|
+
)
|
207
|
+
try:
|
208
|
+
response = self._session.post(
|
209
|
+
_CHAT_COMPLETE_API_ENDPOINT,
|
210
|
+
json=request,
|
211
|
+
timeout=self.timeout,
|
212
|
+
)
|
213
|
+
return self._parse_response(response)
|
214
|
+
except ConnectionError as e:
|
215
|
+
raise OverloadedError(str(e)) from e
|
216
|
+
|
217
|
+
|
218
|
+
class GroqLlama3_8B(Groq): # pylint: disable=invalid-name
|
219
|
+
"""Llama3-8B with 8K context window.
|
220
|
+
|
221
|
+
See: https://huggingface.co/meta-llama/Meta-Llama-3-8B
|
222
|
+
"""
|
223
|
+
|
224
|
+
model = 'llama3-8b-8192'
|
225
|
+
|
226
|
+
|
227
|
+
class GroqLlama3_70B(Groq): # pylint: disable=invalid-name
|
228
|
+
"""Llama3-70B with 8K context window.
|
229
|
+
|
230
|
+
See: https://huggingface.co/meta-llama/Meta-Llama-3-70B
|
231
|
+
"""
|
232
|
+
|
233
|
+
model = 'llama3-70b-8192'
|
234
|
+
|
235
|
+
|
236
|
+
class GroqLlama2_70B(Groq): # pylint: disable=invalid-name
|
237
|
+
"""Llama2-70B with 4K context window.
|
238
|
+
|
239
|
+
See: https://huggingface.co/meta-llama/Llama-2-70b
|
240
|
+
"""
|
241
|
+
|
242
|
+
model = 'llama2-70b-4096'
|
243
|
+
|
244
|
+
|
245
|
+
class GroqMistral_8x7B(Groq): # pylint: disable=invalid-name
|
246
|
+
"""Mixtral 8x7B with 32K context window.
|
247
|
+
|
248
|
+
See: https://huggingface.co/meta-llama/Llama-2-70b
|
249
|
+
"""
|
250
|
+
|
251
|
+
model = 'mixtral-8x7b-32768'
|
252
|
+
|
253
|
+
|
254
|
+
class GroqGemma7B_IT(Groq): # pylint: disable=invalid-name
|
255
|
+
"""Gemma 7B with 8K context window.
|
256
|
+
|
257
|
+
See: https://huggingface.co/google/gemma-1.1-7b-it
|
258
|
+
"""
|
259
|
+
|
260
|
+
model = 'gemma-7b-it'
|
@@ -0,0 +1,170 @@
|
|
1
|
+
# Copyright 2023 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tests for Groq models."""
|
15
|
+
|
16
|
+
import os
|
17
|
+
from typing import Any
|
18
|
+
import unittest
|
19
|
+
from unittest import mock
|
20
|
+
from langfun.core import modalities as lf_modalities
|
21
|
+
from langfun.core.llms import groq
|
22
|
+
import pyglove as pg
|
23
|
+
import requests
|
24
|
+
|
25
|
+
|
26
|
+
def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
|
27
|
+
del url, kwargs
|
28
|
+
|
29
|
+
response = requests.Response()
|
30
|
+
response.status_code = 200
|
31
|
+
response._content = pg.to_json_str({
|
32
|
+
'choices': [{
|
33
|
+
'message': {
|
34
|
+
'content': [{
|
35
|
+
'type': 'text',
|
36
|
+
'text': (
|
37
|
+
f'hello with temperature={json.get("temperature")}, '
|
38
|
+
f'top_p={json.get("top_p")}, '
|
39
|
+
f'max_tokens={json.get("max_tokens")}, '
|
40
|
+
f'stop={json.get("stop")}.'
|
41
|
+
),
|
42
|
+
}],
|
43
|
+
}
|
44
|
+
}],
|
45
|
+
'usage': {
|
46
|
+
'prompt_tokens': 2,
|
47
|
+
'completion_tokens': 1,
|
48
|
+
'total_tokens': 3,
|
49
|
+
},
|
50
|
+
}).encode()
|
51
|
+
return response
|
52
|
+
|
53
|
+
|
54
|
+
def mock_mm_requests_post(url: str, json: dict[str, Any], **kwargs):
|
55
|
+
del url, kwargs
|
56
|
+
v = json['messages'][0]['content'][0]
|
57
|
+
image = lf_modalities.Image.from_uri(v['image_url'])
|
58
|
+
|
59
|
+
response = requests.Response()
|
60
|
+
response.status_code = 200
|
61
|
+
response._content = pg.to_json_str({
|
62
|
+
'choices': [
|
63
|
+
{
|
64
|
+
'message': {
|
65
|
+
'content': [{
|
66
|
+
'type': 'text',
|
67
|
+
'text': image.uri,
|
68
|
+
}],
|
69
|
+
}
|
70
|
+
}
|
71
|
+
],
|
72
|
+
'usage': {
|
73
|
+
'prompt_tokens': 2,
|
74
|
+
'completion_tokens': 1,
|
75
|
+
'total_tokens': 3,
|
76
|
+
},
|
77
|
+
}).encode()
|
78
|
+
return response
|
79
|
+
|
80
|
+
|
81
|
+
def mock_requests_post_error(status_code, error_type, error_message):
|
82
|
+
def _mock_requests(url: str, json: dict[str, Any], **kwargs):
|
83
|
+
del url, json, kwargs
|
84
|
+
response = requests.Response()
|
85
|
+
response.status_code = status_code
|
86
|
+
response._content = pg.to_json_str(
|
87
|
+
{
|
88
|
+
'error': {
|
89
|
+
'type': error_type,
|
90
|
+
'message': error_message,
|
91
|
+
}
|
92
|
+
}
|
93
|
+
).encode()
|
94
|
+
return response
|
95
|
+
|
96
|
+
return _mock_requests
|
97
|
+
|
98
|
+
|
99
|
+
class AuthropicTest(unittest.TestCase):
|
100
|
+
|
101
|
+
def test_basics(self):
|
102
|
+
self.assertEqual(groq.GroqMistral_8x7B().model_id, 'mixtral-8x7b-32768')
|
103
|
+
self.assertEqual(groq.GroqMistral_8x7B().max_concurrency, 16)
|
104
|
+
|
105
|
+
def test_api_key(self):
|
106
|
+
lm = groq.GroqMistral_8x7B()
|
107
|
+
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
108
|
+
lm('hi')
|
109
|
+
|
110
|
+
with mock.patch('requests.Session.post') as mock_request:
|
111
|
+
mock_request.side_effect = mock_requests_post
|
112
|
+
|
113
|
+
lm = groq.GroqMistral_8x7B(api_key='fake key')
|
114
|
+
self.assertRegex(lm('hi').text, 'hello.*')
|
115
|
+
|
116
|
+
os.environ['GROQ_API_KEY'] = 'abc'
|
117
|
+
lm = groq.GroqMistral_8x7B()
|
118
|
+
self.assertRegex(lm('hi').text, 'hello.*')
|
119
|
+
del os.environ['GROQ_API_KEY']
|
120
|
+
|
121
|
+
def test_call(self):
|
122
|
+
with mock.patch('requests.Session.post') as mock_request:
|
123
|
+
mock_request.side_effect = mock_requests_post
|
124
|
+
lm = groq.GroqLlama3_70B(api_key='fake_key')
|
125
|
+
response = lm(
|
126
|
+
'hello',
|
127
|
+
temperature=0.0,
|
128
|
+
max_tokens=1024,
|
129
|
+
top_k=0.1,
|
130
|
+
top_p=0.2,
|
131
|
+
stop=['\n'],
|
132
|
+
)
|
133
|
+
self.assertEqual(
|
134
|
+
response.text,
|
135
|
+
(
|
136
|
+
'hello with temperature=0.0, top_p=0.2, '
|
137
|
+
"max_tokens=1024, stop=['\\n']."
|
138
|
+
),
|
139
|
+
)
|
140
|
+
self.assertIsNotNone(response.usage)
|
141
|
+
self.assertIsNotNone(response.usage.prompt_tokens, 2)
|
142
|
+
self.assertIsNotNone(response.usage.completion_tokens, 1)
|
143
|
+
self.assertIsNotNone(response.usage.total_tokens, 3)
|
144
|
+
|
145
|
+
def test_mm_call(self):
|
146
|
+
with mock.patch('requests.Session.post') as mock_mm_request:
|
147
|
+
mock_mm_request.side_effect = mock_mm_requests_post
|
148
|
+
lm = groq.GroqLlama3_70B(multimodal=True, api_key='fake_key')
|
149
|
+
response = lm(lf_modalities.Image.from_uri('https://fake/image.jpg'))
|
150
|
+
self.assertEqual(response.text, 'https://fake/image.jpg')
|
151
|
+
|
152
|
+
def test_call_errors(self):
|
153
|
+
for status_code, error_type, error_message in [
|
154
|
+
(429, 'rate_limit', 'Rate limit exceeded.'),
|
155
|
+
(503, 'service_unavailable', 'Service unavailable.'),
|
156
|
+
(500, 'bad_request', 'Bad request.'),
|
157
|
+
]:
|
158
|
+
with mock.patch('requests.Session.post') as mock_mm_request:
|
159
|
+
mock_mm_request.side_effect = mock_requests_post_error(
|
160
|
+
status_code, error_type, error_message
|
161
|
+
)
|
162
|
+
lm = groq.GroqLlama3_70B(api_key='fake_key')
|
163
|
+
with self.assertRaisesRegex(
|
164
|
+
Exception, f'{status_code}:.*{error_type}'
|
165
|
+
):
|
166
|
+
lm('hello', lm=lm, max_attempts=1)
|
167
|
+
|
168
|
+
|
169
|
+
if __name__ == '__main__':
|
170
|
+
unittest.main()
|
langfun/core/llms/llama_cpp.py
CHANGED
@@ -51,10 +51,12 @@ class LlamaCppRemote(lf.LanguageModel):
|
|
51
51
|
data = {
|
52
52
|
"prompt": prompt.text,
|
53
53
|
"n_predict": self.sampling_options.max_tokens,
|
54
|
-
"temperature": self.sampling_options.temperature,
|
55
54
|
"top_k": self.sampling_options.top_k or 50,
|
56
55
|
"top_p": self.sampling_options.top_p or 0.95,
|
57
56
|
}
|
57
|
+
if self.sampling_options.temperature is not None:
|
58
|
+
data["temperature"] = self.sampling_options.temperature
|
59
|
+
|
58
60
|
response = requests.post(
|
59
61
|
f"{self.url}/completion",
|
60
62
|
json=data,
|