langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. langfun/__init__.py +20 -2
  2. langfun/core/__init__.py +16 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -21
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +63 -2
  18. langfun/core/component_test.py +53 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +16 -1
  24. langfun/core/eval/base.py +622 -174
  25. langfun/core/eval/base_test.py +200 -54
  26. langfun/core/eval/matching.py +63 -76
  27. langfun/core/eval/matching_test.py +17 -8
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +26 -26
  31. langfun/core/eval/scoring_test.py +19 -2
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +4 -17
  55. langfun/core/langfunc_test.py +22 -6
  56. langfun/core/language_model.py +577 -39
  57. langfun/core/language_model_test.py +470 -56
  58. langfun/core/llms/__init__.py +87 -16
  59. langfun/core/llms/anthropic.py +312 -87
  60. langfun/core/llms/anthropic_test.py +71 -3
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +53 -2
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +11 -7
  69. langfun/core/llms/fake_test.py +14 -0
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -202
  74. langfun/core/llms/groq.py +160 -144
  75. langfun/core/llms/groq_test.py +31 -137
  76. langfun/core/llms/llama_cpp.py +15 -42
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +395 -203
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +30 -395
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -26
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +12 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +7 -6
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +60 -27
  112. langfun/core/structured/function_generation_test.py +72 -2
  113. langfun/core/structured/mapping.py +97 -47
  114. langfun/core/structured/mapping_test.py +90 -2
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +53 -9
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
  119. langfun/core/structured/schema.py +204 -97
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_test.py +130 -29
  122. langfun/core/structured/scoring.py +125 -19
  123. langfun/core/structured/scoring_test.py +30 -0
  124. langfun/core/structured/tokenization.py +64 -0
  125. langfun/core/structured/tokenization_test.py +48 -0
  126. langfun/core/template.py +115 -1
  127. langfun/core/template_test.py +71 -1
  128. langfun/core/templates/conversation.py +9 -0
  129. langfun/core/templates/conversation_test.py +4 -3
  130. langfun/core/templates/selfplay_test.py +10 -2
  131. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  132. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  133. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  134. langfun/core/coding/python/errors.py +0 -108
  135. langfun/core/coding/python/errors_test.py +0 -99
  136. langfun/core/coding/python/permissions.py +0 -90
  137. langfun/core/coding/python/permissions_test.py +0 -86
  138. langfun/core/structured/prompting.py +0 -238
  139. langfun/core/text_formatting.py +0 -162
  140. langfun/core/text_formatting_test.py +0 -47
  141. langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
  142. langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
  143. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  144. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -11,221 +11,28 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Tests for Gemini models."""
14
+ """Tests for Google GenAI models."""
15
15
 
16
16
  import os
17
17
  import unittest
18
- from unittest import mock
19
-
20
- from google import generativeai as genai
21
- import langfun.core as lf
22
- from langfun.core import modalities as lf_modalities
23
18
  from langfun.core.llms import google_genai
24
- import pyglove as pg
25
-
26
-
27
- example_image = (
28
- b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
29
- b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
30
- b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
31
- b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
32
- b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
33
- b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
34
- b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
35
- b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
36
- )
37
-
38
-
39
- def mock_get_model(model_name, *args, **kwargs):
40
- del args, kwargs
41
- if 'gemini' in model_name:
42
- method = 'generateContent'
43
- elif 'chat' in model_name:
44
- method = 'generateMessage'
45
- else:
46
- method = 'generateText'
47
- return pg.Dict(supported_generation_methods=[method])
48
-
49
-
50
- def mock_generate_text(*, model, prompt, **kwargs):
51
- return pg.Dict(
52
- candidates=[pg.Dict(output=f'{prompt} to {model} with {kwargs}')]
53
- )
54
-
55
-
56
- def mock_chat(*, model, messages, **kwargs):
57
- return pg.Dict(
58
- candidates=[pg.Dict(content=f'{messages} to {model} with {kwargs}')]
59
- )
60
-
61
-
62
- def mock_generate_content(content, generation_config, **kwargs):
63
- del kwargs
64
- c = generation_config
65
- return genai.types.GenerateContentResponse(
66
- done=True,
67
- iterator=None,
68
- chunks=[],
69
- result=pg.Dict(
70
- prompt_feedback=pg.Dict(block_reason=None),
71
- candidates=[
72
- pg.Dict(
73
- content=pg.Dict(
74
- parts=[
75
- pg.Dict(
76
- text=(
77
- f'This is a response to {content[0]} with '
78
- f'n={c.candidate_count}, '
79
- f'temperature={c.temperature}, '
80
- f'top_p={c.top_p}, '
81
- f'top_k={c.top_k}, '
82
- f'max_tokens={c.max_output_tokens}, '
83
- f'stop={c.stop_sequences}.'
84
- )
85
- )
86
- ]
87
- ),
88
- ),
89
- ],
90
- ),
91
- )
92
19
 
93
20
 
94
21
  class GenAITest(unittest.TestCase):
95
- """Tests for Google GenAI model."""
96
-
97
- def test_content_from_message_text_only(self):
98
- text = 'This is a beautiful day'
99
- model = google_genai.GeminiPro()
100
- chunks = model._content_from_message(lf.UserMessage(text))
101
- self.assertEqual(chunks, [text])
102
-
103
- def test_content_from_message_mm(self):
104
- message = lf.UserMessage(
105
- 'This is an {{image}}, what is it?',
106
- image=lf_modalities.Image.from_bytes(example_image),
107
- )
22
+ """Tests for GenAI model."""
108
23
 
109
- # Non-multimodal model.
110
- with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
111
- google_genai.GeminiPro()._content_from_message(message)
112
-
113
- model = google_genai.GeminiProVision()
114
- chunks = model._content_from_message(message)
115
- self.maxDiff = None
116
- self.assertEqual(
117
- chunks,
118
- [
119
- 'This is an',
120
- genai.types.BlobDict(mime_type='image/png', data=example_image),
121
- ', what is it?',
122
- ],
123
- )
124
-
125
- def test_response_to_result_text_only(self):
126
- response = genai.types.GenerateContentResponse(
127
- done=True,
128
- iterator=None,
129
- chunks=[],
130
- result=pg.Dict(
131
- prompt_feedback=pg.Dict(block_reason=None),
132
- candidates=[
133
- pg.Dict(
134
- content=pg.Dict(
135
- parts=[pg.Dict(text='This is response 1.')]
136
- ),
137
- ),
138
- pg.Dict(
139
- content=pg.Dict(parts=[pg.Dict(text='This is response 2.')])
140
- ),
141
- ],
142
- ),
143
- )
144
- model = google_genai.GeminiProVision()
145
- result = model._response_to_result(response)
146
- self.assertEqual(
147
- result,
148
- lf.LMSamplingResult([
149
- lf.LMSample(lf.AIMessage('This is response 1.'), score=0.0),
150
- lf.LMSample(lf.AIMessage('This is response 2.'), score=0.0),
151
- ]),
152
- )
153
-
154
- def test_model_hub(self):
155
- orig_get_model = genai.get_model
156
- genai.get_model = mock_get_model
157
-
158
- model = google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro')
159
- self.assertIsNotNone(model)
160
- self.assertIs(google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model)
161
-
162
- genai.get_model = orig_get_model
163
-
164
- def test_api_key_check(self):
24
+ def test_basics(self):
165
25
  with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
166
- _ = google_genai.GeminiPro()._api_initialized
26
+ _ = google_genai.GeminiPro1_5().api_endpoint
27
+
28
+ self.assertIsNotNone(google_genai.GeminiPro1_5(api_key='abc').api_endpoint)
167
29
 
168
- self.assertTrue(google_genai.GeminiPro(api_key='abc')._api_initialized)
169
30
  os.environ['GOOGLE_API_KEY'] = 'abc'
170
- self.assertTrue(google_genai.GeminiPro()._api_initialized)
31
+ lm = google_genai.GeminiPro1_5()
32
+ self.assertIsNotNone(lm.api_endpoint)
33
+ self.assertTrue(lm.model_id.startswith('GenAI('))
171
34
  del os.environ['GOOGLE_API_KEY']
172
35
 
173
- def test_call(self):
174
- with mock.patch(
175
- 'google.generativeai.GenerativeModel.generate_content',
176
- ) as mock_generate:
177
- orig_get_model = genai.get_model
178
- genai.get_model = mock_get_model
179
- mock_generate.side_effect = mock_generate_content
180
-
181
- lm = google_genai.GeminiPro(api_key='test_key')
182
- self.maxDiff = None
183
- self.assertEqual(
184
- lm('hello', temperature=2.0, top_k=20, max_tokens=1024).text,
185
- (
186
- 'This is a response to hello with n=1, temperature=2.0, '
187
- 'top_p=None, top_k=20, max_tokens=1024, stop=None.'
188
- ),
189
- )
190
- genai.get_model = orig_get_model
191
-
192
- def test_call_with_legacy_completion_model(self):
193
- orig_get_model = genai.get_model
194
- genai.get_model = mock_get_model
195
- orig_generate_text = genai.generate_text
196
- genai.generate_text = mock_generate_text
197
-
198
- lm = google_genai.Palm2(api_key='test_key')
199
- self.maxDiff = None
200
- self.assertEqual(
201
- lm('hello', temperature=2.0, top_k=20).text,
202
- (
203
- "hello to models/text-bison-001 with {'temperature': 2.0, "
204
- "'top_k': 20, 'top_p': None, 'candidate_count': 1, "
205
- "'max_output_tokens': None, 'stop_sequences': None}"
206
- ),
207
- )
208
- genai.get_model = orig_get_model
209
- genai.generate_text = orig_generate_text
210
-
211
- def test_call_with_legacy_chat_model(self):
212
- orig_get_model = genai.get_model
213
- genai.get_model = mock_get_model
214
- orig_chat = genai.chat
215
- genai.chat = mock_chat
216
-
217
- lm = google_genai.Palm2_IT(api_key='test_key')
218
- self.maxDiff = None
219
- self.assertEqual(
220
- lm('hello', temperature=2.0, top_k=20).text,
221
- (
222
- "hello to models/chat-bison-001 with {'temperature': 2.0, "
223
- "'top_k': 20, 'top_p': None, 'candidate_count': 1}"
224
- ),
225
- )
226
- genai.get_model = orig_get_model
227
- genai.chat = orig_chat
228
-
229
36
 
230
37
  if __name__ == '__main__':
231
38
  unittest.main()
langfun/core/llms/groq.py CHANGED
@@ -13,43 +13,88 @@
13
13
  # limitations under the License.
14
14
  """Language models from Groq."""
15
15
 
16
- import functools
17
16
  import os
18
17
  from typing import Annotated, Any
19
18
 
20
19
  import langfun.core as lf
21
- from langfun.core import modalities as lf_modalities
20
+ from langfun.core.llms import openai_compatible
22
21
  import pyglove as pg
23
- import requests
24
22
 
25
23
 
26
24
  SUPPORTED_MODELS_AND_SETTINGS = {
27
25
  # Refer https://console.groq.com/docs/models
28
- 'llama3-8b-8192': pg.Dict(max_tokens=8192, max_concurrency=16),
29
- 'llama3-70b-8192': pg.Dict(max_tokens=8192, max_concurrency=16),
30
- 'llama2-70b-4096': pg.Dict(max_tokens=4096, max_concurrency=16),
31
- 'mixtral-8x7b-32768': pg.Dict(max_tokens=32768, max_concurrency=16),
32
- 'gemma-7b-it': pg.Dict(max_tokens=8192, max_concurrency=16),
26
+ # Price in US dollars at https://groq.com/pricing/ as of 2024-10-10.
27
+ 'llama-3.2-3b-preview': pg.Dict(
28
+ max_tokens=8192,
29
+ max_concurrency=64,
30
+ cost_per_1k_input_tokens=0.00006,
31
+ cost_per_1k_output_tokens=0.00006,
32
+ ),
33
+ 'llama-3.2-1b-preview': pg.Dict(
34
+ max_tokens=8192,
35
+ max_concurrency=64,
36
+ cost_per_1k_input_tokens=0.00004,
37
+ cost_per_1k_output_tokens=0.00004,
38
+ ),
39
+ 'llama-3.1-70b-versatile': pg.Dict(
40
+ max_tokens=8192,
41
+ max_concurrency=16,
42
+ cost_per_1k_input_tokens=0.00059,
43
+ cost_per_1k_output_tokens=0.00079,
44
+ ),
45
+ 'llama-3.1-8b-instant': pg.Dict(
46
+ max_tokens=8192,
47
+ max_concurrency=32,
48
+ cost_per_1k_input_tokens=0.00005,
49
+ cost_per_1k_output_tokens=0.00008,
50
+ ),
51
+ 'llama3-70b-8192': pg.Dict(
52
+ max_tokens=8192,
53
+ max_concurrency=16,
54
+ cost_per_1k_input_tokens=0.00059,
55
+ cost_per_1k_output_tokens=0.00079,
56
+ ),
57
+ 'llama3-8b-8192': pg.Dict(
58
+ max_tokens=8192,
59
+ max_concurrency=32,
60
+ cost_per_1k_input_tokens=0.00005,
61
+ cost_per_1k_output_tokens=0.00008,
62
+ ),
63
+ 'llama2-70b-4096': pg.Dict(
64
+ max_tokens=4096,
65
+ max_concurrency=16,
66
+ ),
67
+ 'mixtral-8x7b-32768': pg.Dict(
68
+ max_tokens=32768,
69
+ max_concurrency=16,
70
+ cost_per_1k_input_tokens=0.00024,
71
+ cost_per_1k_output_tokens=0.00024,
72
+ ),
73
+ 'gemma2-9b-it': pg.Dict(
74
+ max_tokens=8192,
75
+ max_concurrency=32,
76
+ cost_per_1k_input_tokens=0.0002,
77
+ cost_per_1k_output_tokens=0.0002,
78
+ ),
79
+ 'gemma-7b-it': pg.Dict(
80
+ max_tokens=8192,
81
+ max_concurrency=32,
82
+ cost_per_1k_input_tokens=0.00007,
83
+ cost_per_1k_output_tokens=0.00007,
84
+ ),
85
+ 'whisper-large-v3': pg.Dict(
86
+ max_tokens=8192,
87
+ max_concurrency=16,
88
+ ),
89
+ 'whisper-large-v3-turbo': pg.Dict(
90
+ max_tokens=8192,
91
+ max_concurrency=16,
92
+ )
33
93
  }
34
94
 
35
95
 
36
- class GroqError(Exception): # pylint: disable=g-bad-exception-name
37
- """Base class for Groq errors."""
38
-
39
-
40
- class RateLimitError(GroqError):
41
- """Error for rate limit reached."""
42
-
43
-
44
- class OverloadedError(GroqError):
45
- """Groq's server is temporarily overloaded."""
46
-
47
-
48
- _CHAT_COMPLETE_API_ENDPOINT = 'https://api.groq.com/openai/v1/chat/completions'
49
-
50
-
51
96
  @lf.use_init_args(['model'])
52
- class Groq(lf.LanguageModel):
97
+ class Groq(openai_compatible.OpenAICompatible):
53
98
  """Groq LLMs through REST APIs (OpenAI compatible).
54
99
 
55
100
  See https://platform.openai.com/docs/api-reference/chat
@@ -62,10 +107,6 @@ class Groq(lf.LanguageModel):
62
107
  'The name of the model to use.',
63
108
  ]
64
109
 
65
- multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
66
- False
67
- )
68
-
69
110
  api_key: Annotated[
70
111
  str | None,
71
112
  (
@@ -74,32 +115,21 @@ class Groq(lf.LanguageModel):
74
115
  ),
75
116
  ] = None
76
117
 
77
- def _on_bound(self):
78
- super()._on_bound()
79
- self._api_key = None
80
- self.__dict__.pop('_api_initialized', None)
81
- self.__dict__.pop('_session', None)
118
+ api_endpoint: str = 'https://api.groq.com/openai/v1/chat/completions'
82
119
 
83
- @functools.cached_property
84
- def _api_initialized(self):
120
+ @property
121
+ def headers(self) -> dict[str, Any]:
85
122
  api_key = self.api_key or os.environ.get('GROQ_API_KEY', None)
86
123
  if not api_key:
87
124
  raise ValueError(
88
125
  'Please specify `api_key` during `__init__` or set environment '
89
126
  'variable `GROQ_API_KEY` with your Groq API key.'
90
127
  )
91
- self._api_key = api_key
92
- return True
93
-
94
- @functools.cached_property
95
- def _session(self) -> requests.Session:
96
- assert self._api_initialized
97
- s = requests.Session()
98
- s.headers.update({
99
- 'Authorization': f'Bearer {self._api_key}',
100
- 'Content-Type': 'application/json',
128
+ headers = super().headers
129
+ headers.update({
130
+ 'Authorization': f'Bearer {api_key}',
101
131
  })
102
- return s
132
+ return headers
103
133
 
104
134
  @property
105
135
  def model_id(self) -> str:
@@ -110,109 +140,50 @@ class Groq(lf.LanguageModel):
110
140
  def max_concurrency(self) -> int:
111
141
  return SUPPORTED_MODELS_AND_SETTINGS[self.model].max_concurrency
112
142
 
113
- def _get_request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
143
+ def estimate_cost(
144
+ self,
145
+ num_input_tokens: int,
146
+ num_output_tokens: int
147
+ ) -> float | None:
148
+ """Estimate the cost based on usage."""
149
+ cost_per_1k_input_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
150
+ 'cost_per_1k_input_tokens', None
151
+ )
152
+ cost_per_1k_output_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
153
+ 'cost_per_1k_output_tokens', None
154
+ )
155
+ if cost_per_1k_input_tokens is None or cost_per_1k_output_tokens is None:
156
+ return None
157
+ return (
158
+ cost_per_1k_input_tokens * num_input_tokens
159
+ + cost_per_1k_output_tokens * num_output_tokens
160
+ ) / 1000
161
+
162
+ def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
114
163
  """Returns a dict as request arguments."""
115
164
  # `logprobs` and `top_logprobs` flags are not supported on Groq yet.
116
- args = dict(
117
- model=self.model,
118
- n=options.n,
119
- stream=False,
120
- )
121
-
122
- if options.temperature is not None:
123
- args['temperature'] = options.temperature
124
- if options.max_tokens is not None:
125
- args['max_tokens'] = options.max_tokens
126
- if options.top_p is not None:
127
- args['top_p'] = options.top_p
128
- if options.stop:
129
- args['stop'] = options.stop
165
+ args = super()._request_args(options)
166
+ args.pop('logprobs', None)
167
+ args.pop('top_logprobs', None)
130
168
  return args
131
169
 
132
- def _content_from_message(self, prompt: lf.Message) -> list[dict[str, Any]]:
133
- """Converts an message to Groq's content protocol (list of dicts)."""
134
- # Refer: https://platform.openai.com/docs/api-reference/chat/create
135
- content = []
136
- for chunk in prompt.chunk():
137
- if isinstance(chunk, str):
138
- item = dict(type='text', text=chunk)
139
- elif (
140
- self.multimodal
141
- and isinstance(chunk, lf_modalities.Image)
142
- and chunk.uri
143
- ):
144
- # NOTE(daiyip): Groq only support image URL.
145
- item = dict(type='image_url', image_url=chunk.uri)
146
- else:
147
- raise ValueError(f'Unsupported modality object: {chunk!r}.')
148
- content.append(item)
149
- return content
150
-
151
- def _message_from_choice(self, choice: dict[str, Any]) -> lf.Message:
152
- """Converts Groq's content protocol to message."""
153
- # Refer: https://platform.openai.com/docs/api-reference/chat/create
154
- content = choice['message']['content']
155
- if isinstance(content, str):
156
- return lf.AIMessage(content)
157
- return lf.AIMessage.from_chunks(
158
- [x['text'] for x in content if x['type'] == 'text']
159
- )
160
170
 
161
- def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
162
- """Parses Groq's response."""
163
- # Refer: https://platform.openai.com/docs/api-reference/chat/object
164
- if response.status_code == 200:
165
- output = response.json()
166
- samples = [
167
- lf.LMSample(self._message_from_choice(choice), score=0.0)
168
- for choice in output['choices']
169
- ]
170
- usage = output['usage']
171
- return lf.LMSamplingResult(
172
- samples,
173
- usage=lf.LMSamplingUsage(
174
- prompt_tokens=usage['prompt_tokens'],
175
- completion_tokens=usage['completion_tokens'],
176
- total_tokens=usage['total_tokens'],
177
- ),
178
- )
179
- else:
180
- # https://platform.openai.com/docs/guides/error-codes/api-errors
181
- if response.status_code == 429:
182
- error_cls = RateLimitError
183
- elif response.status_code in (500, 502, 503):
184
- error_cls = OverloadedError
185
- else:
186
- error_cls = GroqError
187
- raise error_cls(f'{response.status_code}: {response.content}')
188
-
189
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
190
- assert self._api_initialized
191
- return self._parallel_execute_with_currency_control(
192
- self._sample_single,
193
- prompts,
194
- retry_on_errors=(RateLimitError, OverloadedError),
195
- )
171
+ class GroqLlama3_2_3B(Groq): # pylint: disable=invalid-name
172
+ """Llama3.2-3B with 8K context window.
196
173
 
197
- def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
198
- request = dict()
199
- request.update(self._get_request_args(self.sampling_options))
200
- request.update(
201
- dict(
202
- messages=[
203
- dict(role='user', content=self._content_from_message(prompt))
204
- ]
205
- )
206
- )
207
- try:
208
- response = self._session.post(
209
- _CHAT_COMPLETE_API_ENDPOINT,
210
- json=request,
211
- timeout=self.timeout,
212
- )
213
- return self._parse_response(response)
214
- except ConnectionError as e:
215
- raise OverloadedError(str(e)) from e
174
+ See: https://huggingface.co/meta-llama/Llama-3.2-3B
175
+ """
176
+
177
+ model = 'llama-3.2-3b-preview'
178
+
179
+
180
+ class GroqLlama3_2_1B(Groq): # pylint: disable=invalid-name
181
+ """Llama3.2-1B with 8K context window.
182
+
183
+ See: https://huggingface.co/meta-llama/Llama-3.2-1B
184
+ """
185
+
186
+ model = 'llama-3.2-3b-preview'
216
187
 
217
188
 
218
189
  class GroqLlama3_8B(Groq): # pylint: disable=invalid-name
@@ -224,6 +195,24 @@ class GroqLlama3_8B(Groq): # pylint: disable=invalid-name
224
195
  model = 'llama3-8b-8192'
225
196
 
226
197
 
198
+ class GroqLlama3_1_70B(Groq): # pylint: disable=invalid-name
199
+ """Llama3.1-70B with 8K context window.
200
+
201
+ See: https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/MODEL_CARD.md # pylint: disable=line-too-long
202
+ """
203
+
204
+ model = 'llama-3.1-70b-versatile'
205
+
206
+
207
+ class GroqLlama3_1_8B(Groq): # pylint: disable=invalid-name
208
+ """Llama3.1-8B with 8K context window.
209
+
210
+ See: https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/MODEL_CARD.md # pylint: disable=line-too-long
211
+ """
212
+
213
+ model = 'llama-3.1-8b-instant'
214
+
215
+
227
216
  class GroqLlama3_70B(Groq): # pylint: disable=invalid-name
228
217
  """Llama3-70B with 8K context window.
229
218
 
@@ -251,10 +240,37 @@ class GroqMistral_8x7B(Groq): # pylint: disable=invalid-name
251
240
  model = 'mixtral-8x7b-32768'
252
241
 
253
242
 
254
- class GroqGemma7B_IT(Groq): # pylint: disable=invalid-name
243
+ class GroqGemma2_9B_IT(Groq): # pylint: disable=invalid-name
244
+ """Gemma2 9B with 8K context window.
245
+
246
+ See: https://huggingface.co/google/gemma-2-9b-it
247
+ """
248
+
249
+ model = 'gemma2-9b-it'
250
+
251
+
252
+ class GroqGemma_7B_IT(Groq): # pylint: disable=invalid-name
255
253
  """Gemma 7B with 8K context window.
256
254
 
257
255
  See: https://huggingface.co/google/gemma-1.1-7b-it
258
256
  """
259
257
 
260
258
  model = 'gemma-7b-it'
259
+
260
+
261
+ class GroqWhisper_Large_v3(Groq): # pylint: disable=invalid-name
262
+ """Whisper Large V3 with 8K context window.
263
+
264
+ See: https://huggingface.co/openai/whisper-large-v3
265
+ """
266
+
267
+ model = 'whisper-large-v3'
268
+
269
+
270
+ class GroqWhisper_Large_v3Turbo(Groq): # pylint: disable=invalid-name
271
+ """Whisper Large V3 Turbo with 8K context window.
272
+
273
+ See: https://huggingface.co/openai/whisper-large-v3-turbo
274
+ """
275
+
276
+ model = 'whisper-large-v3-turbo'