langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +20 -2
- langfun/core/__init__.py +16 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -21
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +63 -2
- langfun/core/component_test.py +53 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +16 -1
- langfun/core/eval/base.py +622 -174
- langfun/core/eval/base_test.py +200 -54
- langfun/core/eval/matching.py +63 -76
- langfun/core/eval/matching_test.py +17 -8
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +26 -26
- langfun/core/eval/scoring_test.py +19 -2
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +4 -17
- langfun/core/langfunc_test.py +22 -6
- langfun/core/language_model.py +577 -39
- langfun/core/language_model_test.py +470 -56
- langfun/core/llms/__init__.py +87 -16
- langfun/core/llms/anthropic.py +312 -87
- langfun/core/llms/anthropic_test.py +71 -3
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +53 -2
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +11 -7
- langfun/core/llms/fake_test.py +14 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -202
- langfun/core/llms/groq.py +160 -144
- langfun/core/llms/groq_test.py +31 -137
- langfun/core/llms/llama_cpp.py +15 -42
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +395 -203
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +30 -395
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -26
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +12 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +7 -6
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +60 -27
- langfun/core/structured/function_generation_test.py +72 -2
- langfun/core/structured/mapping.py +97 -47
- langfun/core/structured/mapping_test.py +90 -2
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +53 -9
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
- langfun/core/structured/schema.py +204 -97
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_test.py +130 -29
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +115 -1
- langfun/core/template_test.py +71 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +10 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -238
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
- langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -11,221 +11,28 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for
|
14
|
+
"""Tests for Google GenAI models."""
|
15
15
|
|
16
16
|
import os
|
17
17
|
import unittest
|
18
|
-
from unittest import mock
|
19
|
-
|
20
|
-
from google import generativeai as genai
|
21
|
-
import langfun.core as lf
|
22
|
-
from langfun.core import modalities as lf_modalities
|
23
18
|
from langfun.core.llms import google_genai
|
24
|
-
import pyglove as pg
|
25
|
-
|
26
|
-
|
27
|
-
example_image = (
|
28
|
-
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
|
29
|
-
b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
|
30
|
-
b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
|
31
|
-
b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
|
32
|
-
b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
|
33
|
-
b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
|
34
|
-
b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
|
35
|
-
b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
|
36
|
-
)
|
37
|
-
|
38
|
-
|
39
|
-
def mock_get_model(model_name, *args, **kwargs):
|
40
|
-
del args, kwargs
|
41
|
-
if 'gemini' in model_name:
|
42
|
-
method = 'generateContent'
|
43
|
-
elif 'chat' in model_name:
|
44
|
-
method = 'generateMessage'
|
45
|
-
else:
|
46
|
-
method = 'generateText'
|
47
|
-
return pg.Dict(supported_generation_methods=[method])
|
48
|
-
|
49
|
-
|
50
|
-
def mock_generate_text(*, model, prompt, **kwargs):
|
51
|
-
return pg.Dict(
|
52
|
-
candidates=[pg.Dict(output=f'{prompt} to {model} with {kwargs}')]
|
53
|
-
)
|
54
|
-
|
55
|
-
|
56
|
-
def mock_chat(*, model, messages, **kwargs):
|
57
|
-
return pg.Dict(
|
58
|
-
candidates=[pg.Dict(content=f'{messages} to {model} with {kwargs}')]
|
59
|
-
)
|
60
|
-
|
61
|
-
|
62
|
-
def mock_generate_content(content, generation_config, **kwargs):
|
63
|
-
del kwargs
|
64
|
-
c = generation_config
|
65
|
-
return genai.types.GenerateContentResponse(
|
66
|
-
done=True,
|
67
|
-
iterator=None,
|
68
|
-
chunks=[],
|
69
|
-
result=pg.Dict(
|
70
|
-
prompt_feedback=pg.Dict(block_reason=None),
|
71
|
-
candidates=[
|
72
|
-
pg.Dict(
|
73
|
-
content=pg.Dict(
|
74
|
-
parts=[
|
75
|
-
pg.Dict(
|
76
|
-
text=(
|
77
|
-
f'This is a response to {content[0]} with '
|
78
|
-
f'n={c.candidate_count}, '
|
79
|
-
f'temperature={c.temperature}, '
|
80
|
-
f'top_p={c.top_p}, '
|
81
|
-
f'top_k={c.top_k}, '
|
82
|
-
f'max_tokens={c.max_output_tokens}, '
|
83
|
-
f'stop={c.stop_sequences}.'
|
84
|
-
)
|
85
|
-
)
|
86
|
-
]
|
87
|
-
),
|
88
|
-
),
|
89
|
-
],
|
90
|
-
),
|
91
|
-
)
|
92
19
|
|
93
20
|
|
94
21
|
class GenAITest(unittest.TestCase):
|
95
|
-
"""Tests for
|
96
|
-
|
97
|
-
def test_content_from_message_text_only(self):
|
98
|
-
text = 'This is a beautiful day'
|
99
|
-
model = google_genai.GeminiPro()
|
100
|
-
chunks = model._content_from_message(lf.UserMessage(text))
|
101
|
-
self.assertEqual(chunks, [text])
|
102
|
-
|
103
|
-
def test_content_from_message_mm(self):
|
104
|
-
message = lf.UserMessage(
|
105
|
-
'This is an {{image}}, what is it?',
|
106
|
-
image=lf_modalities.Image.from_bytes(example_image),
|
107
|
-
)
|
22
|
+
"""Tests for GenAI model."""
|
108
23
|
|
109
|
-
|
110
|
-
with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
|
111
|
-
google_genai.GeminiPro()._content_from_message(message)
|
112
|
-
|
113
|
-
model = google_genai.GeminiProVision()
|
114
|
-
chunks = model._content_from_message(message)
|
115
|
-
self.maxDiff = None
|
116
|
-
self.assertEqual(
|
117
|
-
chunks,
|
118
|
-
[
|
119
|
-
'This is an',
|
120
|
-
genai.types.BlobDict(mime_type='image/png', data=example_image),
|
121
|
-
', what is it?',
|
122
|
-
],
|
123
|
-
)
|
124
|
-
|
125
|
-
def test_response_to_result_text_only(self):
|
126
|
-
response = genai.types.GenerateContentResponse(
|
127
|
-
done=True,
|
128
|
-
iterator=None,
|
129
|
-
chunks=[],
|
130
|
-
result=pg.Dict(
|
131
|
-
prompt_feedback=pg.Dict(block_reason=None),
|
132
|
-
candidates=[
|
133
|
-
pg.Dict(
|
134
|
-
content=pg.Dict(
|
135
|
-
parts=[pg.Dict(text='This is response 1.')]
|
136
|
-
),
|
137
|
-
),
|
138
|
-
pg.Dict(
|
139
|
-
content=pg.Dict(parts=[pg.Dict(text='This is response 2.')])
|
140
|
-
),
|
141
|
-
],
|
142
|
-
),
|
143
|
-
)
|
144
|
-
model = google_genai.GeminiProVision()
|
145
|
-
result = model._response_to_result(response)
|
146
|
-
self.assertEqual(
|
147
|
-
result,
|
148
|
-
lf.LMSamplingResult([
|
149
|
-
lf.LMSample(lf.AIMessage('This is response 1.'), score=0.0),
|
150
|
-
lf.LMSample(lf.AIMessage('This is response 2.'), score=0.0),
|
151
|
-
]),
|
152
|
-
)
|
153
|
-
|
154
|
-
def test_model_hub(self):
|
155
|
-
orig_get_model = genai.get_model
|
156
|
-
genai.get_model = mock_get_model
|
157
|
-
|
158
|
-
model = google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro')
|
159
|
-
self.assertIsNotNone(model)
|
160
|
-
self.assertIs(google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model)
|
161
|
-
|
162
|
-
genai.get_model = orig_get_model
|
163
|
-
|
164
|
-
def test_api_key_check(self):
|
24
|
+
def test_basics(self):
|
165
25
|
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
166
|
-
_ = google_genai.
|
26
|
+
_ = google_genai.GeminiPro1_5().api_endpoint
|
27
|
+
|
28
|
+
self.assertIsNotNone(google_genai.GeminiPro1_5(api_key='abc').api_endpoint)
|
167
29
|
|
168
|
-
self.assertTrue(google_genai.GeminiPro(api_key='abc')._api_initialized)
|
169
30
|
os.environ['GOOGLE_API_KEY'] = 'abc'
|
170
|
-
|
31
|
+
lm = google_genai.GeminiPro1_5()
|
32
|
+
self.assertIsNotNone(lm.api_endpoint)
|
33
|
+
self.assertTrue(lm.model_id.startswith('GenAI('))
|
171
34
|
del os.environ['GOOGLE_API_KEY']
|
172
35
|
|
173
|
-
def test_call(self):
|
174
|
-
with mock.patch(
|
175
|
-
'google.generativeai.GenerativeModel.generate_content',
|
176
|
-
) as mock_generate:
|
177
|
-
orig_get_model = genai.get_model
|
178
|
-
genai.get_model = mock_get_model
|
179
|
-
mock_generate.side_effect = mock_generate_content
|
180
|
-
|
181
|
-
lm = google_genai.GeminiPro(api_key='test_key')
|
182
|
-
self.maxDiff = None
|
183
|
-
self.assertEqual(
|
184
|
-
lm('hello', temperature=2.0, top_k=20, max_tokens=1024).text,
|
185
|
-
(
|
186
|
-
'This is a response to hello with n=1, temperature=2.0, '
|
187
|
-
'top_p=None, top_k=20, max_tokens=1024, stop=None.'
|
188
|
-
),
|
189
|
-
)
|
190
|
-
genai.get_model = orig_get_model
|
191
|
-
|
192
|
-
def test_call_with_legacy_completion_model(self):
|
193
|
-
orig_get_model = genai.get_model
|
194
|
-
genai.get_model = mock_get_model
|
195
|
-
orig_generate_text = genai.generate_text
|
196
|
-
genai.generate_text = mock_generate_text
|
197
|
-
|
198
|
-
lm = google_genai.Palm2(api_key='test_key')
|
199
|
-
self.maxDiff = None
|
200
|
-
self.assertEqual(
|
201
|
-
lm('hello', temperature=2.0, top_k=20).text,
|
202
|
-
(
|
203
|
-
"hello to models/text-bison-001 with {'temperature': 2.0, "
|
204
|
-
"'top_k': 20, 'top_p': None, 'candidate_count': 1, "
|
205
|
-
"'max_output_tokens': None, 'stop_sequences': None}"
|
206
|
-
),
|
207
|
-
)
|
208
|
-
genai.get_model = orig_get_model
|
209
|
-
genai.generate_text = orig_generate_text
|
210
|
-
|
211
|
-
def test_call_with_legacy_chat_model(self):
|
212
|
-
orig_get_model = genai.get_model
|
213
|
-
genai.get_model = mock_get_model
|
214
|
-
orig_chat = genai.chat
|
215
|
-
genai.chat = mock_chat
|
216
|
-
|
217
|
-
lm = google_genai.Palm2_IT(api_key='test_key')
|
218
|
-
self.maxDiff = None
|
219
|
-
self.assertEqual(
|
220
|
-
lm('hello', temperature=2.0, top_k=20).text,
|
221
|
-
(
|
222
|
-
"hello to models/chat-bison-001 with {'temperature': 2.0, "
|
223
|
-
"'top_k': 20, 'top_p': None, 'candidate_count': 1}"
|
224
|
-
),
|
225
|
-
)
|
226
|
-
genai.get_model = orig_get_model
|
227
|
-
genai.chat = orig_chat
|
228
|
-
|
229
36
|
|
230
37
|
if __name__ == '__main__':
|
231
38
|
unittest.main()
|
langfun/core/llms/groq.py
CHANGED
@@ -13,43 +13,88 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Language models from Groq."""
|
15
15
|
|
16
|
-
import functools
|
17
16
|
import os
|
18
17
|
from typing import Annotated, Any
|
19
18
|
|
20
19
|
import langfun.core as lf
|
21
|
-
from langfun.core import
|
20
|
+
from langfun.core.llms import openai_compatible
|
22
21
|
import pyglove as pg
|
23
|
-
import requests
|
24
22
|
|
25
23
|
|
26
24
|
SUPPORTED_MODELS_AND_SETTINGS = {
|
27
25
|
# Refer https://console.groq.com/docs/models
|
28
|
-
|
29
|
-
'
|
30
|
-
|
31
|
-
|
32
|
-
|
26
|
+
# Price in US dollars at https://groq.com/pricing/ as of 2024-10-10.
|
27
|
+
'llama-3.2-3b-preview': pg.Dict(
|
28
|
+
max_tokens=8192,
|
29
|
+
max_concurrency=64,
|
30
|
+
cost_per_1k_input_tokens=0.00006,
|
31
|
+
cost_per_1k_output_tokens=0.00006,
|
32
|
+
),
|
33
|
+
'llama-3.2-1b-preview': pg.Dict(
|
34
|
+
max_tokens=8192,
|
35
|
+
max_concurrency=64,
|
36
|
+
cost_per_1k_input_tokens=0.00004,
|
37
|
+
cost_per_1k_output_tokens=0.00004,
|
38
|
+
),
|
39
|
+
'llama-3.1-70b-versatile': pg.Dict(
|
40
|
+
max_tokens=8192,
|
41
|
+
max_concurrency=16,
|
42
|
+
cost_per_1k_input_tokens=0.00059,
|
43
|
+
cost_per_1k_output_tokens=0.00079,
|
44
|
+
),
|
45
|
+
'llama-3.1-8b-instant': pg.Dict(
|
46
|
+
max_tokens=8192,
|
47
|
+
max_concurrency=32,
|
48
|
+
cost_per_1k_input_tokens=0.00005,
|
49
|
+
cost_per_1k_output_tokens=0.00008,
|
50
|
+
),
|
51
|
+
'llama3-70b-8192': pg.Dict(
|
52
|
+
max_tokens=8192,
|
53
|
+
max_concurrency=16,
|
54
|
+
cost_per_1k_input_tokens=0.00059,
|
55
|
+
cost_per_1k_output_tokens=0.00079,
|
56
|
+
),
|
57
|
+
'llama3-8b-8192': pg.Dict(
|
58
|
+
max_tokens=8192,
|
59
|
+
max_concurrency=32,
|
60
|
+
cost_per_1k_input_tokens=0.00005,
|
61
|
+
cost_per_1k_output_tokens=0.00008,
|
62
|
+
),
|
63
|
+
'llama2-70b-4096': pg.Dict(
|
64
|
+
max_tokens=4096,
|
65
|
+
max_concurrency=16,
|
66
|
+
),
|
67
|
+
'mixtral-8x7b-32768': pg.Dict(
|
68
|
+
max_tokens=32768,
|
69
|
+
max_concurrency=16,
|
70
|
+
cost_per_1k_input_tokens=0.00024,
|
71
|
+
cost_per_1k_output_tokens=0.00024,
|
72
|
+
),
|
73
|
+
'gemma2-9b-it': pg.Dict(
|
74
|
+
max_tokens=8192,
|
75
|
+
max_concurrency=32,
|
76
|
+
cost_per_1k_input_tokens=0.0002,
|
77
|
+
cost_per_1k_output_tokens=0.0002,
|
78
|
+
),
|
79
|
+
'gemma-7b-it': pg.Dict(
|
80
|
+
max_tokens=8192,
|
81
|
+
max_concurrency=32,
|
82
|
+
cost_per_1k_input_tokens=0.00007,
|
83
|
+
cost_per_1k_output_tokens=0.00007,
|
84
|
+
),
|
85
|
+
'whisper-large-v3': pg.Dict(
|
86
|
+
max_tokens=8192,
|
87
|
+
max_concurrency=16,
|
88
|
+
),
|
89
|
+
'whisper-large-v3-turbo': pg.Dict(
|
90
|
+
max_tokens=8192,
|
91
|
+
max_concurrency=16,
|
92
|
+
)
|
33
93
|
}
|
34
94
|
|
35
95
|
|
36
|
-
class GroqError(Exception): # pylint: disable=g-bad-exception-name
|
37
|
-
"""Base class for Groq errors."""
|
38
|
-
|
39
|
-
|
40
|
-
class RateLimitError(GroqError):
|
41
|
-
"""Error for rate limit reached."""
|
42
|
-
|
43
|
-
|
44
|
-
class OverloadedError(GroqError):
|
45
|
-
"""Groq's server is temporarily overloaded."""
|
46
|
-
|
47
|
-
|
48
|
-
_CHAT_COMPLETE_API_ENDPOINT = 'https://api.groq.com/openai/v1/chat/completions'
|
49
|
-
|
50
|
-
|
51
96
|
@lf.use_init_args(['model'])
|
52
|
-
class Groq(
|
97
|
+
class Groq(openai_compatible.OpenAICompatible):
|
53
98
|
"""Groq LLMs through REST APIs (OpenAI compatible).
|
54
99
|
|
55
100
|
See https://platform.openai.com/docs/api-reference/chat
|
@@ -62,10 +107,6 @@ class Groq(lf.LanguageModel):
|
|
62
107
|
'The name of the model to use.',
|
63
108
|
]
|
64
109
|
|
65
|
-
multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
|
66
|
-
False
|
67
|
-
)
|
68
|
-
|
69
110
|
api_key: Annotated[
|
70
111
|
str | None,
|
71
112
|
(
|
@@ -74,32 +115,21 @@ class Groq(lf.LanguageModel):
|
|
74
115
|
),
|
75
116
|
] = None
|
76
117
|
|
77
|
-
|
78
|
-
super()._on_bound()
|
79
|
-
self._api_key = None
|
80
|
-
self.__dict__.pop('_api_initialized', None)
|
81
|
-
self.__dict__.pop('_session', None)
|
118
|
+
api_endpoint: str = 'https://api.groq.com/openai/v1/chat/completions'
|
82
119
|
|
83
|
-
@
|
84
|
-
def
|
120
|
+
@property
|
121
|
+
def headers(self) -> dict[str, Any]:
|
85
122
|
api_key = self.api_key or os.environ.get('GROQ_API_KEY', None)
|
86
123
|
if not api_key:
|
87
124
|
raise ValueError(
|
88
125
|
'Please specify `api_key` during `__init__` or set environment '
|
89
126
|
'variable `GROQ_API_KEY` with your Groq API key.'
|
90
127
|
)
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
@functools.cached_property
|
95
|
-
def _session(self) -> requests.Session:
|
96
|
-
assert self._api_initialized
|
97
|
-
s = requests.Session()
|
98
|
-
s.headers.update({
|
99
|
-
'Authorization': f'Bearer {self._api_key}',
|
100
|
-
'Content-Type': 'application/json',
|
128
|
+
headers = super().headers
|
129
|
+
headers.update({
|
130
|
+
'Authorization': f'Bearer {api_key}',
|
101
131
|
})
|
102
|
-
return
|
132
|
+
return headers
|
103
133
|
|
104
134
|
@property
|
105
135
|
def model_id(self) -> str:
|
@@ -110,109 +140,50 @@ class Groq(lf.LanguageModel):
|
|
110
140
|
def max_concurrency(self) -> int:
|
111
141
|
return SUPPORTED_MODELS_AND_SETTINGS[self.model].max_concurrency
|
112
142
|
|
113
|
-
def
|
143
|
+
def estimate_cost(
|
144
|
+
self,
|
145
|
+
num_input_tokens: int,
|
146
|
+
num_output_tokens: int
|
147
|
+
) -> float | None:
|
148
|
+
"""Estimate the cost based on usage."""
|
149
|
+
cost_per_1k_input_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
150
|
+
'cost_per_1k_input_tokens', None
|
151
|
+
)
|
152
|
+
cost_per_1k_output_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
153
|
+
'cost_per_1k_output_tokens', None
|
154
|
+
)
|
155
|
+
if cost_per_1k_input_tokens is None or cost_per_1k_output_tokens is None:
|
156
|
+
return None
|
157
|
+
return (
|
158
|
+
cost_per_1k_input_tokens * num_input_tokens
|
159
|
+
+ cost_per_1k_output_tokens * num_output_tokens
|
160
|
+
) / 1000
|
161
|
+
|
162
|
+
def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
114
163
|
"""Returns a dict as request arguments."""
|
115
164
|
# `logprobs` and `top_logprobs` flags are not supported on Groq yet.
|
116
|
-
args =
|
117
|
-
|
118
|
-
|
119
|
-
stream=False,
|
120
|
-
)
|
121
|
-
|
122
|
-
if options.temperature is not None:
|
123
|
-
args['temperature'] = options.temperature
|
124
|
-
if options.max_tokens is not None:
|
125
|
-
args['max_tokens'] = options.max_tokens
|
126
|
-
if options.top_p is not None:
|
127
|
-
args['top_p'] = options.top_p
|
128
|
-
if options.stop:
|
129
|
-
args['stop'] = options.stop
|
165
|
+
args = super()._request_args(options)
|
166
|
+
args.pop('logprobs', None)
|
167
|
+
args.pop('top_logprobs', None)
|
130
168
|
return args
|
131
169
|
|
132
|
-
def _content_from_message(self, prompt: lf.Message) -> list[dict[str, Any]]:
|
133
|
-
"""Converts an message to Groq's content protocol (list of dicts)."""
|
134
|
-
# Refer: https://platform.openai.com/docs/api-reference/chat/create
|
135
|
-
content = []
|
136
|
-
for chunk in prompt.chunk():
|
137
|
-
if isinstance(chunk, str):
|
138
|
-
item = dict(type='text', text=chunk)
|
139
|
-
elif (
|
140
|
-
self.multimodal
|
141
|
-
and isinstance(chunk, lf_modalities.Image)
|
142
|
-
and chunk.uri
|
143
|
-
):
|
144
|
-
# NOTE(daiyip): Groq only support image URL.
|
145
|
-
item = dict(type='image_url', image_url=chunk.uri)
|
146
|
-
else:
|
147
|
-
raise ValueError(f'Unsupported modality object: {chunk!r}.')
|
148
|
-
content.append(item)
|
149
|
-
return content
|
150
|
-
|
151
|
-
def _message_from_choice(self, choice: dict[str, Any]) -> lf.Message:
|
152
|
-
"""Converts Groq's content protocol to message."""
|
153
|
-
# Refer: https://platform.openai.com/docs/api-reference/chat/create
|
154
|
-
content = choice['message']['content']
|
155
|
-
if isinstance(content, str):
|
156
|
-
return lf.AIMessage(content)
|
157
|
-
return lf.AIMessage.from_chunks(
|
158
|
-
[x['text'] for x in content if x['type'] == 'text']
|
159
|
-
)
|
160
170
|
|
161
|
-
|
162
|
-
|
163
|
-
# Refer: https://platform.openai.com/docs/api-reference/chat/object
|
164
|
-
if response.status_code == 200:
|
165
|
-
output = response.json()
|
166
|
-
samples = [
|
167
|
-
lf.LMSample(self._message_from_choice(choice), score=0.0)
|
168
|
-
for choice in output['choices']
|
169
|
-
]
|
170
|
-
usage = output['usage']
|
171
|
-
return lf.LMSamplingResult(
|
172
|
-
samples,
|
173
|
-
usage=lf.LMSamplingUsage(
|
174
|
-
prompt_tokens=usage['prompt_tokens'],
|
175
|
-
completion_tokens=usage['completion_tokens'],
|
176
|
-
total_tokens=usage['total_tokens'],
|
177
|
-
),
|
178
|
-
)
|
179
|
-
else:
|
180
|
-
# https://platform.openai.com/docs/guides/error-codes/api-errors
|
181
|
-
if response.status_code == 429:
|
182
|
-
error_cls = RateLimitError
|
183
|
-
elif response.status_code in (500, 502, 503):
|
184
|
-
error_cls = OverloadedError
|
185
|
-
else:
|
186
|
-
error_cls = GroqError
|
187
|
-
raise error_cls(f'{response.status_code}: {response.content}')
|
188
|
-
|
189
|
-
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
190
|
-
assert self._api_initialized
|
191
|
-
return self._parallel_execute_with_currency_control(
|
192
|
-
self._sample_single,
|
193
|
-
prompts,
|
194
|
-
retry_on_errors=(RateLimitError, OverloadedError),
|
195
|
-
)
|
171
|
+
class GroqLlama3_2_3B(Groq): # pylint: disable=invalid-name
|
172
|
+
"""Llama3.2-3B with 8K context window.
|
196
173
|
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
json=request,
|
211
|
-
timeout=self.timeout,
|
212
|
-
)
|
213
|
-
return self._parse_response(response)
|
214
|
-
except ConnectionError as e:
|
215
|
-
raise OverloadedError(str(e)) from e
|
174
|
+
See: https://huggingface.co/meta-llama/Llama-3.2-3B
|
175
|
+
"""
|
176
|
+
|
177
|
+
model = 'llama-3.2-3b-preview'
|
178
|
+
|
179
|
+
|
180
|
+
class GroqLlama3_2_1B(Groq): # pylint: disable=invalid-name
|
181
|
+
"""Llama3.2-1B with 8K context window.
|
182
|
+
|
183
|
+
See: https://huggingface.co/meta-llama/Llama-3.2-1B
|
184
|
+
"""
|
185
|
+
|
186
|
+
model = 'llama-3.2-3b-preview'
|
216
187
|
|
217
188
|
|
218
189
|
class GroqLlama3_8B(Groq): # pylint: disable=invalid-name
|
@@ -224,6 +195,24 @@ class GroqLlama3_8B(Groq): # pylint: disable=invalid-name
|
|
224
195
|
model = 'llama3-8b-8192'
|
225
196
|
|
226
197
|
|
198
|
+
class GroqLlama3_1_70B(Groq): # pylint: disable=invalid-name
|
199
|
+
"""Llama3.1-70B with 8K context window.
|
200
|
+
|
201
|
+
See: https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/MODEL_CARD.md # pylint: disable=line-too-long
|
202
|
+
"""
|
203
|
+
|
204
|
+
model = 'llama-3.1-70b-versatile'
|
205
|
+
|
206
|
+
|
207
|
+
class GroqLlama3_1_8B(Groq): # pylint: disable=invalid-name
|
208
|
+
"""Llama3.1-8B with 8K context window.
|
209
|
+
|
210
|
+
See: https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/MODEL_CARD.md # pylint: disable=line-too-long
|
211
|
+
"""
|
212
|
+
|
213
|
+
model = 'llama-3.1-8b-instant'
|
214
|
+
|
215
|
+
|
227
216
|
class GroqLlama3_70B(Groq): # pylint: disable=invalid-name
|
228
217
|
"""Llama3-70B with 8K context window.
|
229
218
|
|
@@ -251,10 +240,37 @@ class GroqMistral_8x7B(Groq): # pylint: disable=invalid-name
|
|
251
240
|
model = 'mixtral-8x7b-32768'
|
252
241
|
|
253
242
|
|
254
|
-
class
|
243
|
+
class GroqGemma2_9B_IT(Groq): # pylint: disable=invalid-name
|
244
|
+
"""Gemma2 9B with 8K context window.
|
245
|
+
|
246
|
+
See: https://huggingface.co/google/gemma-2-9b-it
|
247
|
+
"""
|
248
|
+
|
249
|
+
model = 'gemma2-9b-it'
|
250
|
+
|
251
|
+
|
252
|
+
class GroqGemma_7B_IT(Groq): # pylint: disable=invalid-name
|
255
253
|
"""Gemma 7B with 8K context window.
|
256
254
|
|
257
255
|
See: https://huggingface.co/google/gemma-1.1-7b-it
|
258
256
|
"""
|
259
257
|
|
260
258
|
model = 'gemma-7b-it'
|
259
|
+
|
260
|
+
|
261
|
+
class GroqWhisper_Large_v3(Groq): # pylint: disable=invalid-name
|
262
|
+
"""Whisper Large V3 with 8K context window.
|
263
|
+
|
264
|
+
See: https://huggingface.co/openai/whisper-large-v3
|
265
|
+
"""
|
266
|
+
|
267
|
+
model = 'whisper-large-v3'
|
268
|
+
|
269
|
+
|
270
|
+
class GroqWhisper_Large_v3Turbo(Groq): # pylint: disable=invalid-name
|
271
|
+
"""Whisper Large V3 Turbo with 8K context window.
|
272
|
+
|
273
|
+
See: https://huggingface.co/openai/whisper-large-v3-turbo
|
274
|
+
"""
|
275
|
+
|
276
|
+
model = 'whisper-large-v3-turbo'
|