langfun 0.1.2.dev202502110804__py3-none-any.whl → 0.1.2.dev202502120804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/__init__.py +6 -2
- langfun/core/language_model.py +365 -22
- langfun/core/language_model_test.py +123 -35
- langfun/core/llms/__init__.py +50 -57
- langfun/core/llms/anthropic.py +434 -163
- langfun/core/llms/anthropic_test.py +20 -1
- langfun/core/llms/deepseek.py +90 -51
- langfun/core/llms/deepseek_test.py +15 -16
- langfun/core/llms/fake.py +6 -0
- langfun/core/llms/gemini.py +480 -390
- langfun/core/llms/gemini_test.py +27 -7
- langfun/core/llms/google_genai.py +80 -50
- langfun/core/llms/google_genai_test.py +11 -4
- langfun/core/llms/groq.py +268 -167
- langfun/core/llms/groq_test.py +9 -3
- langfun/core/llms/openai.py +839 -328
- langfun/core/llms/openai_compatible.py +3 -18
- langfun/core/llms/openai_compatible_test.py +20 -5
- langfun/core/llms/openai_test.py +14 -4
- langfun/core/llms/rest.py +11 -6
- langfun/core/llms/vertexai.py +238 -240
- langfun/core/llms/vertexai_test.py +35 -8
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502120804.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502120804.dist-info}/RECORD +27 -27
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502120804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502120804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502120804.dist-info}/top_level.txt +0 -0
langfun/core/llms/groq.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The Langfun Authors
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -13,84 +13,249 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Language models from Groq."""
|
15
15
|
|
16
|
+
import datetime
|
17
|
+
import functools
|
16
18
|
import os
|
17
|
-
from typing import Annotated, Any
|
19
|
+
from typing import Annotated, Any, Final
|
18
20
|
|
19
21
|
import langfun.core as lf
|
20
22
|
from langfun.core.llms import openai_compatible
|
21
23
|
import pyglove as pg
|
22
24
|
|
23
25
|
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
26
|
+
class GroqModelInfo(lf.ModelInfo):
|
27
|
+
"""Groq model info."""
|
28
|
+
|
29
|
+
LINKS = dict(
|
30
|
+
models='https://console.groq.com/docs/models',
|
31
|
+
pricing='https://groq.com/pricing/',
|
32
|
+
rate_limits='https://console.groq.com/docs/rate-limits',
|
33
|
+
error_codes='https://console.groq.com/docs/errors',
|
34
|
+
)
|
35
|
+
|
36
|
+
provider: Final[str] = 'Groq' # pylint: disable=invalid-name
|
37
|
+
|
38
|
+
|
39
|
+
SUPPORTED_MODELS = [
|
40
|
+
#
|
41
|
+
# Llama models.
|
42
|
+
#
|
43
|
+
GroqModelInfo(
|
44
|
+
model_id='llama-3.3-70b-versatile',
|
45
|
+
in_service=True,
|
46
|
+
model_type='instruction-tuned',
|
47
|
+
description='Llama 3.3 70B model on Groq (Production)',
|
48
|
+
url='https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct',
|
49
|
+
release_date=datetime.datetime(2024, 12, 6),
|
50
|
+
context_length=lf.ModelInfo.ContextLength(
|
51
|
+
max_input_tokens=128_000,
|
52
|
+
max_output_tokens=8_192,
|
53
|
+
),
|
54
|
+
pricing=lf.ModelInfo.Pricing(
|
55
|
+
cost_per_1m_input_tokens=0.59,
|
56
|
+
cost_per_1m_output_tokens=0.79,
|
57
|
+
),
|
58
|
+
rate_limits=lf.ModelInfo.RateLimits(
|
59
|
+
# Developer tier.
|
60
|
+
max_requests_per_minute=1_000,
|
61
|
+
max_tokens_per_minute=120_000,
|
62
|
+
),
|
32
63
|
),
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
64
|
+
GroqModelInfo(
|
65
|
+
model_id='llama-3.3-70b-specdec',
|
66
|
+
in_service=True,
|
67
|
+
model_type='instruction-tuned',
|
68
|
+
description='Llama 3.3 70B model on Groq (Production)',
|
69
|
+
url='https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct',
|
70
|
+
release_date=datetime.datetime(2024, 12, 6),
|
71
|
+
context_length=lf.ModelInfo.ContextLength(
|
72
|
+
max_input_tokens=8_192,
|
73
|
+
max_output_tokens=None,
|
74
|
+
),
|
75
|
+
pricing=lf.ModelInfo.Pricing(
|
76
|
+
cost_per_1m_input_tokens=0.59,
|
77
|
+
cost_per_1m_output_tokens=0.99,
|
78
|
+
),
|
79
|
+
rate_limits=lf.ModelInfo.RateLimits(
|
80
|
+
# Developer tier.
|
81
|
+
max_requests_per_minute=100,
|
82
|
+
max_tokens_per_minute=30_000,
|
83
|
+
),
|
38
84
|
),
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
85
|
+
GroqModelInfo(
|
86
|
+
model_id='llama-3.2-1b-preview',
|
87
|
+
in_service=True,
|
88
|
+
model_type='instruction-tuned',
|
89
|
+
description='Llama 3.2 1B model on Groq (Preview)',
|
90
|
+
url='https://huggingface.co/meta-llama/Llama-3.2-1B',
|
91
|
+
release_date=datetime.datetime(2024, 12, 6),
|
92
|
+
context_length=lf.ModelInfo.ContextLength(
|
93
|
+
max_input_tokens=128_000,
|
94
|
+
max_output_tokens=None,
|
95
|
+
),
|
96
|
+
pricing=lf.ModelInfo.Pricing(
|
97
|
+
cost_per_1m_input_tokens=0.04,
|
98
|
+
cost_per_1m_output_tokens=0.04,
|
99
|
+
),
|
100
|
+
rate_limits=lf.ModelInfo.RateLimits(
|
101
|
+
# Developer tier.
|
102
|
+
max_requests_per_minute=100,
|
103
|
+
max_tokens_per_minute=30_000,
|
104
|
+
),
|
44
105
|
),
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
106
|
+
GroqModelInfo(
|
107
|
+
model_id='llama-3.2-3b-preview',
|
108
|
+
in_service=True,
|
109
|
+
model_type='instruction-tuned',
|
110
|
+
description='Llama 3.2 3B model on Groq (Preview)',
|
111
|
+
url='https://huggingface.co/meta-llama/Llama-3.2-3B',
|
112
|
+
release_date=datetime.datetime(2024, 12, 6),
|
113
|
+
context_length=lf.ModelInfo.ContextLength(
|
114
|
+
max_input_tokens=128_000,
|
115
|
+
max_output_tokens=None,
|
116
|
+
),
|
117
|
+
pricing=lf.ModelInfo.Pricing(
|
118
|
+
cost_per_1m_input_tokens=0.06,
|
119
|
+
cost_per_1m_output_tokens=0.06,
|
120
|
+
),
|
121
|
+
rate_limits=lf.ModelInfo.RateLimits(
|
122
|
+
# Developer tier.
|
123
|
+
max_requests_per_minute=100,
|
124
|
+
max_tokens_per_minute=30_000,
|
125
|
+
),
|
50
126
|
),
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
127
|
+
GroqModelInfo(
|
128
|
+
model_id='llama-3.2-11b-vision-preview',
|
129
|
+
in_service=True,
|
130
|
+
model_type='instruction-tuned',
|
131
|
+
description='Llama 3.2 11B vision model on Groq (Preview)',
|
132
|
+
url='https://huggingface.co/meta-llama/Llama-3.2-11B-Vision',
|
133
|
+
release_date=datetime.datetime(2024, 12, 6),
|
134
|
+
context_length=lf.ModelInfo.ContextLength(
|
135
|
+
max_input_tokens=128_000,
|
136
|
+
max_output_tokens=None,
|
137
|
+
),
|
138
|
+
pricing=lf.ModelInfo.Pricing(
|
139
|
+
cost_per_1m_input_tokens=0.18,
|
140
|
+
cost_per_1m_output_tokens=0.18,
|
141
|
+
),
|
142
|
+
rate_limits=lf.ModelInfo.RateLimits(
|
143
|
+
# Developer tier.
|
144
|
+
max_requests_per_minute=100,
|
145
|
+
max_tokens_per_minute=30_000,
|
146
|
+
),
|
56
147
|
),
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
148
|
+
GroqModelInfo(
|
149
|
+
model_id='llama-3.2-90b-vision-preview',
|
150
|
+
in_service=True,
|
151
|
+
model_type='instruction-tuned',
|
152
|
+
description='Llama 3.2 90B vision model on Groq (Preview)',
|
153
|
+
url='https://huggingface.co/meta-llama/Llama-3.2-90B-Vision',
|
154
|
+
release_date=datetime.datetime(2024, 12, 6),
|
155
|
+
context_length=lf.ModelInfo.ContextLength(
|
156
|
+
max_input_tokens=128_000,
|
157
|
+
max_output_tokens=None,
|
158
|
+
),
|
159
|
+
pricing=lf.ModelInfo.Pricing(
|
160
|
+
cost_per_1m_input_tokens=0.9,
|
161
|
+
cost_per_1m_output_tokens=0.9,
|
162
|
+
),
|
163
|
+
rate_limits=lf.ModelInfo.RateLimits(
|
164
|
+
# Developer tier.
|
165
|
+
max_requests_per_minute=100,
|
166
|
+
max_tokens_per_minute=30_000,
|
167
|
+
),
|
62
168
|
),
|
63
|
-
|
64
|
-
|
65
|
-
|
169
|
+
#
|
170
|
+
# DeepSeek models
|
171
|
+
#
|
172
|
+
GroqModelInfo(
|
173
|
+
model_id='deepseek-r1-distill-llama-70b',
|
174
|
+
in_service=True,
|
175
|
+
model_type='thinking',
|
176
|
+
description='DeepSeek R1 distilled from Llama 70B (Preview)',
|
177
|
+
url='https://console.groq.com/docs/model/deepseek-r1-distill-llama-70b',
|
178
|
+
context_length=lf.ModelInfo.ContextLength(
|
179
|
+
max_input_tokens=128_000,
|
180
|
+
max_output_tokens=16_384,
|
181
|
+
),
|
182
|
+
# TODO(daiyip): Pricing needs to be computed based on the number of
|
183
|
+
# input/output tokens.
|
184
|
+
pricing=None,
|
185
|
+
rate_limits=lf.ModelInfo.RateLimits(
|
186
|
+
# Developer tier.
|
187
|
+
max_requests_per_minute=1_000,
|
188
|
+
max_tokens_per_minute=120_000,
|
189
|
+
),
|
66
190
|
),
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
191
|
+
GroqModelInfo(
|
192
|
+
model_id='deepseek-r1-distill-llama-70b-specdec',
|
193
|
+
in_service=True,
|
194
|
+
model_type='thinking',
|
195
|
+
description='DeepSeek R1 distilled from Llama 70B (Preview)',
|
196
|
+
url='https://console.groq.com/docs/model/deepseek-r1-distill-llama-70b',
|
197
|
+
context_length=lf.ModelInfo.ContextLength(
|
198
|
+
max_input_tokens=128_000,
|
199
|
+
max_output_tokens=16_384,
|
200
|
+
),
|
201
|
+
# TODO(daiyip): Pricing needs to be computed based on the number of
|
202
|
+
# input/output tokens.
|
203
|
+
pricing=None,
|
204
|
+
rate_limits=lf.ModelInfo.RateLimits(
|
205
|
+
# Developer tier.
|
206
|
+
max_requests_per_minute=100,
|
207
|
+
max_tokens_per_minute=60_000,
|
208
|
+
),
|
72
209
|
),
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
210
|
+
#
|
211
|
+
# Gemma models.
|
212
|
+
#
|
213
|
+
GroqModelInfo(
|
214
|
+
model_id='gemma2-9b-it',
|
215
|
+
in_service=True,
|
216
|
+
model_type='instruction-tuned',
|
217
|
+
description='Google Gemma 2 9B model on Groq.',
|
218
|
+
url='https://huggingface.co/google/gemma-2-9b-it',
|
219
|
+
context_length=lf.ModelInfo.ContextLength(
|
220
|
+
max_input_tokens=8_192,
|
221
|
+
max_output_tokens=None,
|
222
|
+
),
|
223
|
+
pricing=lf.ModelInfo.Pricing(
|
224
|
+
cost_per_1m_input_tokens=0.2,
|
225
|
+
cost_per_1m_output_tokens=0.2,
|
226
|
+
),
|
227
|
+
rate_limits=lf.ModelInfo.RateLimits(
|
228
|
+
# Developer tier.
|
229
|
+
max_requests_per_minute=200,
|
230
|
+
max_tokens_per_minute=30_000,
|
231
|
+
),
|
78
232
|
),
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
233
|
+
#
|
234
|
+
# Mixtral models.
|
235
|
+
#
|
236
|
+
GroqModelInfo(
|
237
|
+
model_id='mixtral-8x7b-32768',
|
238
|
+
in_service=True,
|
239
|
+
model_type='instruction-tuned',
|
240
|
+
description='Mixtral 8x7B model on Groq (Production)',
|
241
|
+
url='https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1',
|
242
|
+
context_length=lf.ModelInfo.ContextLength(
|
243
|
+
max_input_tokens=32_768,
|
244
|
+
max_output_tokens=None,
|
245
|
+
),
|
246
|
+
pricing=lf.ModelInfo.Pricing(
|
247
|
+
cost_per_1m_input_tokens=0.24,
|
248
|
+
cost_per_1m_output_tokens=0.24,
|
249
|
+
),
|
250
|
+
rate_limits=lf.ModelInfo.RateLimits(
|
251
|
+
# Developer tier.
|
252
|
+
max_requests_per_minute=100,
|
253
|
+
max_tokens_per_minute=25_000,
|
254
|
+
),
|
84
255
|
),
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
),
|
89
|
-
'whisper-large-v3-turbo': pg.Dict(
|
90
|
-
max_tokens=8192,
|
91
|
-
max_concurrency=16,
|
92
|
-
)
|
93
|
-
}
|
256
|
+
]
|
257
|
+
|
258
|
+
_SUPPORTED_MODELS_BY_ID = {m.model_id: m for m in SUPPORTED_MODELS}
|
94
259
|
|
95
260
|
|
96
261
|
@lf.use_init_args(['model'])
|
@@ -102,7 +267,7 @@ class Groq(openai_compatible.OpenAICompatible):
|
|
102
267
|
|
103
268
|
model: pg.typing.Annotated[
|
104
269
|
pg.typing.Enum(
|
105
|
-
pg.MISSING_VALUE,
|
270
|
+
pg.MISSING_VALUE, [m.model_id for m in SUPPORTED_MODELS]
|
106
271
|
),
|
107
272
|
'The name of the model to use.',
|
108
273
|
]
|
@@ -117,6 +282,10 @@ class Groq(openai_compatible.OpenAICompatible):
|
|
117
282
|
|
118
283
|
api_endpoint: str = 'https://api.groq.com/openai/v1/chat/completions'
|
119
284
|
|
285
|
+
@functools.cached_property
|
286
|
+
def model_info(self) -> lf.ModelInfo:
|
287
|
+
return _SUPPORTED_MODELS_BY_ID[self.model]
|
288
|
+
|
120
289
|
@property
|
121
290
|
def headers(self) -> dict[str, Any]:
|
122
291
|
api_key = self.api_key or os.environ.get('GROQ_API_KEY', None)
|
@@ -131,34 +300,6 @@ class Groq(openai_compatible.OpenAICompatible):
|
|
131
300
|
})
|
132
301
|
return headers
|
133
302
|
|
134
|
-
@property
|
135
|
-
def model_id(self) -> str:
|
136
|
-
"""Returns a string to identify the model."""
|
137
|
-
return self.model
|
138
|
-
|
139
|
-
@property
|
140
|
-
def max_concurrency(self) -> int:
|
141
|
-
return SUPPORTED_MODELS_AND_SETTINGS[self.model].max_concurrency
|
142
|
-
|
143
|
-
def estimate_cost(
|
144
|
-
self,
|
145
|
-
num_input_tokens: int,
|
146
|
-
num_output_tokens: int
|
147
|
-
) -> float | None:
|
148
|
-
"""Estimate the cost based on usage."""
|
149
|
-
cost_per_1k_input_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
150
|
-
'cost_per_1k_input_tokens', None
|
151
|
-
)
|
152
|
-
cost_per_1k_output_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
153
|
-
'cost_per_1k_output_tokens', None
|
154
|
-
)
|
155
|
-
if cost_per_1k_input_tokens is None or cost_per_1k_output_tokens is None:
|
156
|
-
return None
|
157
|
-
return (
|
158
|
-
cost_per_1k_input_tokens * num_input_tokens
|
159
|
-
+ cost_per_1k_output_tokens * num_output_tokens
|
160
|
-
) / 1000
|
161
|
-
|
162
303
|
def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
163
304
|
"""Returns a dict as request arguments."""
|
164
305
|
# `logprobs` and `top_logprobs` flags are not supported on Groq yet.
|
@@ -168,109 +309,69 @@ class Groq(openai_compatible.OpenAICompatible):
|
|
168
309
|
return args
|
169
310
|
|
170
311
|
|
171
|
-
class
|
172
|
-
"""Llama3.2-3B with
|
312
|
+
class GroqLlama33_70B_Versatile(Groq): # pylint: disable=invalid-name
|
313
|
+
"""Llama3.2-3B with 128K context window."""
|
314
|
+
model = 'llama-3.3-70b-versatile'
|
173
315
|
|
174
|
-
See: https://huggingface.co/meta-llama/Llama-3.2-3B
|
175
|
-
"""
|
176
316
|
|
177
|
-
|
317
|
+
class GroqLlama33_70B_SpecDec(Groq): # pylint: disable=invalid-name
|
318
|
+
"""Llama3.3-70B with 8K context window."""
|
319
|
+
model = 'llama-3.3-70b-specdec'
|
178
320
|
|
179
321
|
|
180
|
-
class
|
181
|
-
"""Llama3.2-1B
|
322
|
+
class GroqLlama32_1B(Groq): # pylint: disable=invalid-name
|
323
|
+
"""Llama3.2-1B."""
|
324
|
+
model = 'llama-3.2-1b-preview'
|
182
325
|
|
183
|
-
See: https://huggingface.co/meta-llama/Llama-3.2-1B
|
184
|
-
"""
|
185
326
|
|
327
|
+
class GroqLlama32_3B(Groq): # pylint: disable=invalid-name
|
328
|
+
"""Llama3.2-3B."""
|
186
329
|
model = 'llama-3.2-3b-preview'
|
187
330
|
|
188
331
|
|
189
|
-
class
|
190
|
-
"""Llama3-
|
191
|
-
|
192
|
-
See: https://huggingface.co/meta-llama/Meta-Llama-3-8B
|
193
|
-
"""
|
194
|
-
|
195
|
-
model = 'llama3-8b-8192'
|
196
|
-
|
332
|
+
class GroqLlama32_11B_Vision(Groq): # pylint: disable=invalid-name
|
333
|
+
"""Llama3.2-11B vision."""
|
334
|
+
model = 'llama-3.2-11b-vision-preview'
|
197
335
|
|
198
|
-
class GroqLlama3_1_70B(Groq): # pylint: disable=invalid-name
|
199
|
-
"""Llama3.1-70B with 8K context window.
|
200
336
|
|
201
|
-
|
202
|
-
"""
|
203
|
-
|
204
|
-
model = 'llama-3.1-70b-versatile'
|
337
|
+
class GroqLlama32_90B_Vision(Groq): # pylint: disable=invalid-name
|
338
|
+
"""Llama3.2-90B vision."""
|
339
|
+
model = 'llama-3.2-90b-vision-preview'
|
205
340
|
|
206
341
|
|
207
|
-
class
|
208
|
-
"""
|
342
|
+
class GroqDeepSeekR1_DistillLlama_70B(Groq): # pylint: disable=invalid-name
|
343
|
+
"""DeepSeek R1 distilled from Llama 70B."""
|
344
|
+
model = 'deepseek-r1-distill-llama-70b'
|
209
345
|
|
210
|
-
See: https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/MODEL_CARD.md # pylint: disable=line-too-long
|
211
|
-
"""
|
212
|
-
|
213
|
-
model = 'llama-3.1-8b-instant'
|
214
|
-
|
215
|
-
|
216
|
-
class GroqLlama3_70B(Groq): # pylint: disable=invalid-name
|
217
|
-
"""Llama3-70B with 8K context window.
|
218
|
-
|
219
|
-
See: https://huggingface.co/meta-llama/Meta-Llama-3-70B
|
220
|
-
"""
|
221
346
|
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
class GroqLlama2_70B(Groq): # pylint: disable=invalid-name
|
226
|
-
"""Llama2-70B with 4K context window.
|
227
|
-
|
228
|
-
See: https://huggingface.co/meta-llama/Llama-2-70b
|
229
|
-
"""
|
230
|
-
|
231
|
-
model = 'llama2-70b-4096'
|
347
|
+
class GroqDeepSeekR1_DistillLlama_70B_SpecDec(Groq): # pylint: disable=invalid-name
|
348
|
+
"""DeepSeek R1 distilled from Llama 70B (SpecDec)."""
|
349
|
+
model = 'deepseek-r1-distill-llama-70b-specdec'
|
232
350
|
|
233
351
|
|
234
352
|
class GroqMistral_8x7B(Groq): # pylint: disable=invalid-name
|
235
|
-
"""Mixtral 8x7B
|
236
|
-
|
237
|
-
See: https://huggingface.co/meta-llama/Llama-2-70b
|
238
|
-
"""
|
239
|
-
|
353
|
+
"""Mixtral 8x7B."""
|
240
354
|
model = 'mixtral-8x7b-32768'
|
241
355
|
|
242
356
|
|
243
357
|
class GroqGemma2_9B_IT(Groq): # pylint: disable=invalid-name
|
244
|
-
"""Gemma2 9B
|
245
|
-
|
246
|
-
See: https://huggingface.co/google/gemma-2-9b-it
|
247
|
-
"""
|
248
|
-
|
358
|
+
"""Gemma2 9B."""
|
249
359
|
model = 'gemma2-9b-it'
|
250
360
|
|
251
361
|
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
See: https://huggingface.co/google/gemma-1.1-7b-it
|
256
|
-
"""
|
257
|
-
|
258
|
-
model = 'gemma-7b-it'
|
259
|
-
|
260
|
-
|
261
|
-
class GroqWhisper_Large_v3(Groq): # pylint: disable=invalid-name
|
262
|
-
"""Whisper Large V3 with 8K context window.
|
263
|
-
|
264
|
-
See: https://huggingface.co/openai/whisper-large-v3
|
265
|
-
"""
|
362
|
+
#
|
363
|
+
# Register Groq models so they can be retrieved with LanguageModel.get().
|
364
|
+
#
|
266
365
|
|
267
|
-
model = 'whisper-large-v3'
|
268
366
|
|
367
|
+
def _groq_model(model: str, *args, **kwargs):
|
368
|
+
model = model.removeprefix('groq://')
|
369
|
+
return Groq(model, *args, **kwargs)
|
269
370
|
|
270
|
-
class GroqWhisper_Large_v3Turbo(Groq): # pylint: disable=invalid-name
|
271
|
-
"""Whisper Large V3 Turbo with 8K context window.
|
272
371
|
|
273
|
-
|
274
|
-
"""
|
372
|
+
def _register_groq_models():
|
373
|
+
"""Registers Groq models."""
|
374
|
+
for m in SUPPORTED_MODELS:
|
375
|
+
lf.LanguageModel.register('groq://' + m.model_id, _groq_model)
|
275
376
|
|
276
|
-
|
377
|
+
_register_groq_models()
|
langfun/core/llms/groq_test.py
CHANGED
@@ -17,12 +17,13 @@ import langfun.core as lf
|
|
17
17
|
from langfun.core.llms import groq
|
18
18
|
|
19
19
|
|
20
|
-
class
|
20
|
+
class GroqTest(unittest.TestCase):
|
21
21
|
|
22
22
|
def test_basics(self):
|
23
23
|
self.assertEqual(groq.GroqMistral_8x7B().model_id, 'mixtral-8x7b-32768')
|
24
|
-
self.assertEqual(
|
25
|
-
|
24
|
+
self.assertEqual(
|
25
|
+
groq.GroqMistral_8x7B().resource_id, 'groq://mixtral-8x7b-32768'
|
26
|
+
)
|
26
27
|
|
27
28
|
def test_request_args(self):
|
28
29
|
args = groq.GroqMistral_8x7B()._request_args(
|
@@ -59,6 +60,11 @@ class AuthropicTest(unittest.TestCase):
|
|
59
60
|
)
|
60
61
|
del os.environ['GROQ_API_KEY']
|
61
62
|
|
63
|
+
def test_lm_get(self):
|
64
|
+
self.assertIsInstance(
|
65
|
+
lf.LanguageModel.get('groq://gemma2-9b-it'),
|
66
|
+
groq.Groq,
|
67
|
+
)
|
62
68
|
|
63
69
|
if __name__ == '__main__':
|
64
70
|
unittest.main()
|