langfun 0.1.2.dev202502110804__py3-none-any.whl → 0.1.2.dev202502120804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langfun/core/llms/groq.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Langfun Authors
1
+ # Copyright 2025 The Langfun Authors
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -13,84 +13,249 @@
13
13
  # limitations under the License.
14
14
  """Language models from Groq."""
15
15
 
16
+ import datetime
17
+ import functools
16
18
  import os
17
- from typing import Annotated, Any
19
+ from typing import Annotated, Any, Final
18
20
 
19
21
  import langfun.core as lf
20
22
  from langfun.core.llms import openai_compatible
21
23
  import pyglove as pg
22
24
 
23
25
 
24
- SUPPORTED_MODELS_AND_SETTINGS = {
25
- # Refer https://console.groq.com/docs/models
26
- # Price in US dollars at https://groq.com/pricing/ as of 2024-10-10.
27
- 'llama-3.2-3b-preview': pg.Dict(
28
- max_tokens=8192,
29
- max_concurrency=64,
30
- cost_per_1k_input_tokens=0.00006,
31
- cost_per_1k_output_tokens=0.00006,
26
+ class GroqModelInfo(lf.ModelInfo):
27
+ """Groq model info."""
28
+
29
+ LINKS = dict(
30
+ models='https://console.groq.com/docs/models',
31
+ pricing='https://groq.com/pricing/',
32
+ rate_limits='https://console.groq.com/docs/rate-limits',
33
+ error_codes='https://console.groq.com/docs/errors',
34
+ )
35
+
36
+ provider: Final[str] = 'Groq' # pylint: disable=invalid-name
37
+
38
+
39
+ SUPPORTED_MODELS = [
40
+ #
41
+ # Llama models.
42
+ #
43
+ GroqModelInfo(
44
+ model_id='llama-3.3-70b-versatile',
45
+ in_service=True,
46
+ model_type='instruction-tuned',
47
+ description='Llama 3.3 70B model on Groq (Production)',
48
+ url='https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct',
49
+ release_date=datetime.datetime(2024, 12, 6),
50
+ context_length=lf.ModelInfo.ContextLength(
51
+ max_input_tokens=128_000,
52
+ max_output_tokens=8_192,
53
+ ),
54
+ pricing=lf.ModelInfo.Pricing(
55
+ cost_per_1m_input_tokens=0.59,
56
+ cost_per_1m_output_tokens=0.79,
57
+ ),
58
+ rate_limits=lf.ModelInfo.RateLimits(
59
+ # Developer tier.
60
+ max_requests_per_minute=1_000,
61
+ max_tokens_per_minute=120_000,
62
+ ),
32
63
  ),
33
- 'llama-3.2-1b-preview': pg.Dict(
34
- max_tokens=8192,
35
- max_concurrency=64,
36
- cost_per_1k_input_tokens=0.00004,
37
- cost_per_1k_output_tokens=0.00004,
64
+ GroqModelInfo(
65
+ model_id='llama-3.3-70b-specdec',
66
+ in_service=True,
67
+ model_type='instruction-tuned',
68
+ description='Llama 3.3 70B model on Groq (Production)',
69
+ url='https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct',
70
+ release_date=datetime.datetime(2024, 12, 6),
71
+ context_length=lf.ModelInfo.ContextLength(
72
+ max_input_tokens=8_192,
73
+ max_output_tokens=None,
74
+ ),
75
+ pricing=lf.ModelInfo.Pricing(
76
+ cost_per_1m_input_tokens=0.59,
77
+ cost_per_1m_output_tokens=0.99,
78
+ ),
79
+ rate_limits=lf.ModelInfo.RateLimits(
80
+ # Developer tier.
81
+ max_requests_per_minute=100,
82
+ max_tokens_per_minute=30_000,
83
+ ),
38
84
  ),
39
- 'llama-3.1-70b-versatile': pg.Dict(
40
- max_tokens=8192,
41
- max_concurrency=16,
42
- cost_per_1k_input_tokens=0.00059,
43
- cost_per_1k_output_tokens=0.00079,
85
+ GroqModelInfo(
86
+ model_id='llama-3.2-1b-preview',
87
+ in_service=True,
88
+ model_type='instruction-tuned',
89
+ description='Llama 3.2 1B model on Groq (Preview)',
90
+ url='https://huggingface.co/meta-llama/Llama-3.2-1B',
91
+ release_date=datetime.datetime(2024, 12, 6),
92
+ context_length=lf.ModelInfo.ContextLength(
93
+ max_input_tokens=128_000,
94
+ max_output_tokens=None,
95
+ ),
96
+ pricing=lf.ModelInfo.Pricing(
97
+ cost_per_1m_input_tokens=0.04,
98
+ cost_per_1m_output_tokens=0.04,
99
+ ),
100
+ rate_limits=lf.ModelInfo.RateLimits(
101
+ # Developer tier.
102
+ max_requests_per_minute=100,
103
+ max_tokens_per_minute=30_000,
104
+ ),
44
105
  ),
45
- 'llama-3.1-8b-instant': pg.Dict(
46
- max_tokens=8192,
47
- max_concurrency=32,
48
- cost_per_1k_input_tokens=0.00005,
49
- cost_per_1k_output_tokens=0.00008,
106
+ GroqModelInfo(
107
+ model_id='llama-3.2-3b-preview',
108
+ in_service=True,
109
+ model_type='instruction-tuned',
110
+ description='Llama 3.2 3B model on Groq (Preview)',
111
+ url='https://huggingface.co/meta-llama/Llama-3.2-3B',
112
+ release_date=datetime.datetime(2024, 12, 6),
113
+ context_length=lf.ModelInfo.ContextLength(
114
+ max_input_tokens=128_000,
115
+ max_output_tokens=None,
116
+ ),
117
+ pricing=lf.ModelInfo.Pricing(
118
+ cost_per_1m_input_tokens=0.06,
119
+ cost_per_1m_output_tokens=0.06,
120
+ ),
121
+ rate_limits=lf.ModelInfo.RateLimits(
122
+ # Developer tier.
123
+ max_requests_per_minute=100,
124
+ max_tokens_per_minute=30_000,
125
+ ),
50
126
  ),
51
- 'llama3-70b-8192': pg.Dict(
52
- max_tokens=8192,
53
- max_concurrency=16,
54
- cost_per_1k_input_tokens=0.00059,
55
- cost_per_1k_output_tokens=0.00079,
127
+ GroqModelInfo(
128
+ model_id='llama-3.2-11b-vision-preview',
129
+ in_service=True,
130
+ model_type='instruction-tuned',
131
+ description='Llama 3.2 11B vision model on Groq (Preview)',
132
+ url='https://huggingface.co/meta-llama/Llama-3.2-11B-Vision',
133
+ release_date=datetime.datetime(2024, 12, 6),
134
+ context_length=lf.ModelInfo.ContextLength(
135
+ max_input_tokens=128_000,
136
+ max_output_tokens=None,
137
+ ),
138
+ pricing=lf.ModelInfo.Pricing(
139
+ cost_per_1m_input_tokens=0.18,
140
+ cost_per_1m_output_tokens=0.18,
141
+ ),
142
+ rate_limits=lf.ModelInfo.RateLimits(
143
+ # Developer tier.
144
+ max_requests_per_minute=100,
145
+ max_tokens_per_minute=30_000,
146
+ ),
56
147
  ),
57
- 'llama3-8b-8192': pg.Dict(
58
- max_tokens=8192,
59
- max_concurrency=32,
60
- cost_per_1k_input_tokens=0.00005,
61
- cost_per_1k_output_tokens=0.00008,
148
+ GroqModelInfo(
149
+ model_id='llama-3.2-90b-vision-preview',
150
+ in_service=True,
151
+ model_type='instruction-tuned',
152
+ description='Llama 3.2 90B vision model on Groq (Preview)',
153
+ url='https://huggingface.co/meta-llama/Llama-3.2-90B-Vision',
154
+ release_date=datetime.datetime(2024, 12, 6),
155
+ context_length=lf.ModelInfo.ContextLength(
156
+ max_input_tokens=128_000,
157
+ max_output_tokens=None,
158
+ ),
159
+ pricing=lf.ModelInfo.Pricing(
160
+ cost_per_1m_input_tokens=0.9,
161
+ cost_per_1m_output_tokens=0.9,
162
+ ),
163
+ rate_limits=lf.ModelInfo.RateLimits(
164
+ # Developer tier.
165
+ max_requests_per_minute=100,
166
+ max_tokens_per_minute=30_000,
167
+ ),
62
168
  ),
63
- 'llama2-70b-4096': pg.Dict(
64
- max_tokens=4096,
65
- max_concurrency=16,
169
+ #
170
+ # DeepSeek models
171
+ #
172
+ GroqModelInfo(
173
+ model_id='deepseek-r1-distill-llama-70b',
174
+ in_service=True,
175
+ model_type='thinking',
176
+ description='DeepSeek R1 distilled from Llama 70B (Preview)',
177
+ url='https://console.groq.com/docs/model/deepseek-r1-distill-llama-70b',
178
+ context_length=lf.ModelInfo.ContextLength(
179
+ max_input_tokens=128_000,
180
+ max_output_tokens=16_384,
181
+ ),
182
+ # TODO(daiyip): Pricing needs to be computed based on the number of
183
+ # input/output tokens.
184
+ pricing=None,
185
+ rate_limits=lf.ModelInfo.RateLimits(
186
+ # Developer tier.
187
+ max_requests_per_minute=1_000,
188
+ max_tokens_per_minute=120_000,
189
+ ),
66
190
  ),
67
- 'mixtral-8x7b-32768': pg.Dict(
68
- max_tokens=32768,
69
- max_concurrency=16,
70
- cost_per_1k_input_tokens=0.00024,
71
- cost_per_1k_output_tokens=0.00024,
191
+ GroqModelInfo(
192
+ model_id='deepseek-r1-distill-llama-70b-specdec',
193
+ in_service=True,
194
+ model_type='thinking',
195
+ description='DeepSeek R1 distilled from Llama 70B (Preview)',
196
+ url='https://console.groq.com/docs/model/deepseek-r1-distill-llama-70b',
197
+ context_length=lf.ModelInfo.ContextLength(
198
+ max_input_tokens=128_000,
199
+ max_output_tokens=16_384,
200
+ ),
201
+ # TODO(daiyip): Pricing needs to be computed based on the number of
202
+ # input/output tokens.
203
+ pricing=None,
204
+ rate_limits=lf.ModelInfo.RateLimits(
205
+ # Developer tier.
206
+ max_requests_per_minute=100,
207
+ max_tokens_per_minute=60_000,
208
+ ),
72
209
  ),
73
- 'gemma2-9b-it': pg.Dict(
74
- max_tokens=8192,
75
- max_concurrency=32,
76
- cost_per_1k_input_tokens=0.0002,
77
- cost_per_1k_output_tokens=0.0002,
210
+ #
211
+ # Gemma models.
212
+ #
213
+ GroqModelInfo(
214
+ model_id='gemma2-9b-it',
215
+ in_service=True,
216
+ model_type='instruction-tuned',
217
+ description='Google Gemma 2 9B model on Groq.',
218
+ url='https://huggingface.co/google/gemma-2-9b-it',
219
+ context_length=lf.ModelInfo.ContextLength(
220
+ max_input_tokens=8_192,
221
+ max_output_tokens=None,
222
+ ),
223
+ pricing=lf.ModelInfo.Pricing(
224
+ cost_per_1m_input_tokens=0.2,
225
+ cost_per_1m_output_tokens=0.2,
226
+ ),
227
+ rate_limits=lf.ModelInfo.RateLimits(
228
+ # Developer tier.
229
+ max_requests_per_minute=200,
230
+ max_tokens_per_minute=30_000,
231
+ ),
78
232
  ),
79
- 'gemma-7b-it': pg.Dict(
80
- max_tokens=8192,
81
- max_concurrency=32,
82
- cost_per_1k_input_tokens=0.00007,
83
- cost_per_1k_output_tokens=0.00007,
233
+ #
234
+ # Mixtral models.
235
+ #
236
+ GroqModelInfo(
237
+ model_id='mixtral-8x7b-32768',
238
+ in_service=True,
239
+ model_type='instruction-tuned',
240
+ description='Mixtral 8x7B model on Groq (Production)',
241
+ url='https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1',
242
+ context_length=lf.ModelInfo.ContextLength(
243
+ max_input_tokens=32_768,
244
+ max_output_tokens=None,
245
+ ),
246
+ pricing=lf.ModelInfo.Pricing(
247
+ cost_per_1m_input_tokens=0.24,
248
+ cost_per_1m_output_tokens=0.24,
249
+ ),
250
+ rate_limits=lf.ModelInfo.RateLimits(
251
+ # Developer tier.
252
+ max_requests_per_minute=100,
253
+ max_tokens_per_minute=25_000,
254
+ ),
84
255
  ),
85
- 'whisper-large-v3': pg.Dict(
86
- max_tokens=8192,
87
- max_concurrency=16,
88
- ),
89
- 'whisper-large-v3-turbo': pg.Dict(
90
- max_tokens=8192,
91
- max_concurrency=16,
92
- )
93
- }
256
+ ]
257
+
258
+ _SUPPORTED_MODELS_BY_ID = {m.model_id: m for m in SUPPORTED_MODELS}
94
259
 
95
260
 
96
261
  @lf.use_init_args(['model'])
@@ -102,7 +267,7 @@ class Groq(openai_compatible.OpenAICompatible):
102
267
 
103
268
  model: pg.typing.Annotated[
104
269
  pg.typing.Enum(
105
- pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
270
+ pg.MISSING_VALUE, [m.model_id for m in SUPPORTED_MODELS]
106
271
  ),
107
272
  'The name of the model to use.',
108
273
  ]
@@ -117,6 +282,10 @@ class Groq(openai_compatible.OpenAICompatible):
117
282
 
118
283
  api_endpoint: str = 'https://api.groq.com/openai/v1/chat/completions'
119
284
 
285
+ @functools.cached_property
286
+ def model_info(self) -> lf.ModelInfo:
287
+ return _SUPPORTED_MODELS_BY_ID[self.model]
288
+
120
289
  @property
121
290
  def headers(self) -> dict[str, Any]:
122
291
  api_key = self.api_key or os.environ.get('GROQ_API_KEY', None)
@@ -131,34 +300,6 @@ class Groq(openai_compatible.OpenAICompatible):
131
300
  })
132
301
  return headers
133
302
 
134
- @property
135
- def model_id(self) -> str:
136
- """Returns a string to identify the model."""
137
- return self.model
138
-
139
- @property
140
- def max_concurrency(self) -> int:
141
- return SUPPORTED_MODELS_AND_SETTINGS[self.model].max_concurrency
142
-
143
- def estimate_cost(
144
- self,
145
- num_input_tokens: int,
146
- num_output_tokens: int
147
- ) -> float | None:
148
- """Estimate the cost based on usage."""
149
- cost_per_1k_input_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
150
- 'cost_per_1k_input_tokens', None
151
- )
152
- cost_per_1k_output_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
153
- 'cost_per_1k_output_tokens', None
154
- )
155
- if cost_per_1k_input_tokens is None or cost_per_1k_output_tokens is None:
156
- return None
157
- return (
158
- cost_per_1k_input_tokens * num_input_tokens
159
- + cost_per_1k_output_tokens * num_output_tokens
160
- ) / 1000
161
-
162
303
  def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
163
304
  """Returns a dict as request arguments."""
164
305
  # `logprobs` and `top_logprobs` flags are not supported on Groq yet.
@@ -168,109 +309,69 @@ class Groq(openai_compatible.OpenAICompatible):
168
309
  return args
169
310
 
170
311
 
171
- class GroqLlama3_2_3B(Groq): # pylint: disable=invalid-name
172
- """Llama3.2-3B with 8K context window.
312
+ class GroqLlama33_70B_Versatile(Groq): # pylint: disable=invalid-name
313
+ """Llama3.2-3B with 128K context window."""
314
+ model = 'llama-3.3-70b-versatile'
173
315
 
174
- See: https://huggingface.co/meta-llama/Llama-3.2-3B
175
- """
176
316
 
177
- model = 'llama-3.2-3b-preview'
317
+ class GroqLlama33_70B_SpecDec(Groq): # pylint: disable=invalid-name
318
+ """Llama3.3-70B with 8K context window."""
319
+ model = 'llama-3.3-70b-specdec'
178
320
 
179
321
 
180
- class GroqLlama3_2_1B(Groq): # pylint: disable=invalid-name
181
- """Llama3.2-1B with 8K context window.
322
+ class GroqLlama32_1B(Groq): # pylint: disable=invalid-name
323
+ """Llama3.2-1B."""
324
+ model = 'llama-3.2-1b-preview'
182
325
 
183
- See: https://huggingface.co/meta-llama/Llama-3.2-1B
184
- """
185
326
 
327
+ class GroqLlama32_3B(Groq): # pylint: disable=invalid-name
328
+ """Llama3.2-3B."""
186
329
  model = 'llama-3.2-3b-preview'
187
330
 
188
331
 
189
- class GroqLlama3_8B(Groq): # pylint: disable=invalid-name
190
- """Llama3-8B with 8K context window.
191
-
192
- See: https://huggingface.co/meta-llama/Meta-Llama-3-8B
193
- """
194
-
195
- model = 'llama3-8b-8192'
196
-
332
+ class GroqLlama32_11B_Vision(Groq): # pylint: disable=invalid-name
333
+ """Llama3.2-11B vision."""
334
+ model = 'llama-3.2-11b-vision-preview'
197
335
 
198
- class GroqLlama3_1_70B(Groq): # pylint: disable=invalid-name
199
- """Llama3.1-70B with 8K context window.
200
336
 
201
- See: https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/MODEL_CARD.md # pylint: disable=line-too-long
202
- """
203
-
204
- model = 'llama-3.1-70b-versatile'
337
+ class GroqLlama32_90B_Vision(Groq): # pylint: disable=invalid-name
338
+ """Llama3.2-90B vision."""
339
+ model = 'llama-3.2-90b-vision-preview'
205
340
 
206
341
 
207
- class GroqLlama3_1_8B(Groq): # pylint: disable=invalid-name
208
- """Llama3.1-8B with 8K context window.
342
+ class GroqDeepSeekR1_DistillLlama_70B(Groq): # pylint: disable=invalid-name
343
+ """DeepSeek R1 distilled from Llama 70B."""
344
+ model = 'deepseek-r1-distill-llama-70b'
209
345
 
210
- See: https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/MODEL_CARD.md # pylint: disable=line-too-long
211
- """
212
-
213
- model = 'llama-3.1-8b-instant'
214
-
215
-
216
- class GroqLlama3_70B(Groq): # pylint: disable=invalid-name
217
- """Llama3-70B with 8K context window.
218
-
219
- See: https://huggingface.co/meta-llama/Meta-Llama-3-70B
220
- """
221
346
 
222
- model = 'llama3-70b-8192'
223
-
224
-
225
- class GroqLlama2_70B(Groq): # pylint: disable=invalid-name
226
- """Llama2-70B with 4K context window.
227
-
228
- See: https://huggingface.co/meta-llama/Llama-2-70b
229
- """
230
-
231
- model = 'llama2-70b-4096'
347
+ class GroqDeepSeekR1_DistillLlama_70B_SpecDec(Groq): # pylint: disable=invalid-name
348
+ """DeepSeek R1 distilled from Llama 70B (SpecDec)."""
349
+ model = 'deepseek-r1-distill-llama-70b-specdec'
232
350
 
233
351
 
234
352
  class GroqMistral_8x7B(Groq): # pylint: disable=invalid-name
235
- """Mixtral 8x7B with 32K context window.
236
-
237
- See: https://huggingface.co/meta-llama/Llama-2-70b
238
- """
239
-
353
+ """Mixtral 8x7B."""
240
354
  model = 'mixtral-8x7b-32768'
241
355
 
242
356
 
243
357
  class GroqGemma2_9B_IT(Groq): # pylint: disable=invalid-name
244
- """Gemma2 9B with 8K context window.
245
-
246
- See: https://huggingface.co/google/gemma-2-9b-it
247
- """
248
-
358
+ """Gemma2 9B."""
249
359
  model = 'gemma2-9b-it'
250
360
 
251
361
 
252
- class GroqGemma_7B_IT(Groq): # pylint: disable=invalid-name
253
- """Gemma 7B with 8K context window.
254
-
255
- See: https://huggingface.co/google/gemma-1.1-7b-it
256
- """
257
-
258
- model = 'gemma-7b-it'
259
-
260
-
261
- class GroqWhisper_Large_v3(Groq): # pylint: disable=invalid-name
262
- """Whisper Large V3 with 8K context window.
263
-
264
- See: https://huggingface.co/openai/whisper-large-v3
265
- """
362
+ #
363
+ # Register Groq models so they can be retrieved with LanguageModel.get().
364
+ #
266
365
 
267
- model = 'whisper-large-v3'
268
366
 
367
+ def _groq_model(model: str, *args, **kwargs):
368
+ model = model.removeprefix('groq://')
369
+ return Groq(model, *args, **kwargs)
269
370
 
270
- class GroqWhisper_Large_v3Turbo(Groq): # pylint: disable=invalid-name
271
- """Whisper Large V3 Turbo with 8K context window.
272
371
 
273
- See: https://huggingface.co/openai/whisper-large-v3-turbo
274
- """
372
+ def _register_groq_models():
373
+ """Registers Groq models."""
374
+ for m in SUPPORTED_MODELS:
375
+ lf.LanguageModel.register('groq://' + m.model_id, _groq_model)
275
376
 
276
- model = 'whisper-large-v3-turbo'
377
+ _register_groq_models()
@@ -17,12 +17,13 @@ import langfun.core as lf
17
17
  from langfun.core.llms import groq
18
18
 
19
19
 
20
- class AuthropicTest(unittest.TestCase):
20
+ class GroqTest(unittest.TestCase):
21
21
 
22
22
  def test_basics(self):
23
23
  self.assertEqual(groq.GroqMistral_8x7B().model_id, 'mixtral-8x7b-32768')
24
- self.assertEqual(groq.GroqMistral_8x7B().max_concurrency, 16)
25
- self.assertEqual(groq.GroqMistral_8x7B().estimate_cost(100, 100), 4.8e-5)
24
+ self.assertEqual(
25
+ groq.GroqMistral_8x7B().resource_id, 'groq://mixtral-8x7b-32768'
26
+ )
26
27
 
27
28
  def test_request_args(self):
28
29
  args = groq.GroqMistral_8x7B()._request_args(
@@ -59,6 +60,11 @@ class AuthropicTest(unittest.TestCase):
59
60
  )
60
61
  del os.environ['GROQ_API_KEY']
61
62
 
63
+ def test_lm_get(self):
64
+ self.assertIsInstance(
65
+ lf.LanguageModel.get('groq://gemma2-9b-it'),
66
+ groq.Groq,
67
+ )
62
68
 
63
69
  if __name__ == '__main__':
64
70
  unittest.main()