langfun 0.0.2.dev20240330__py3-none-any.whl → 0.0.2.dev20240511__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (59) hide show
  1. langfun/__init__.py +7 -0
  2. langfun/core/__init__.py +1 -0
  3. langfun/core/coding/python/correction.py +0 -7
  4. langfun/core/component.py +6 -0
  5. langfun/core/component_test.py +1 -0
  6. langfun/core/eval/__init__.py +15 -0
  7. langfun/core/eval/base.py +665 -95
  8. langfun/core/eval/base_test.py +224 -53
  9. langfun/core/eval/matching.py +48 -30
  10. langfun/core/eval/matching_test.py +25 -3
  11. langfun/core/eval/patching.py +130 -0
  12. langfun/core/eval/patching_test.py +170 -0
  13. langfun/core/eval/scoring.py +19 -10
  14. langfun/core/eval/scoring_test.py +21 -3
  15. langfun/core/langfunc.py +1 -22
  16. langfun/core/langfunc_test.py +10 -4
  17. langfun/core/language_model.py +130 -24
  18. langfun/core/language_model_test.py +249 -26
  19. langfun/core/llms/__init__.py +27 -2
  20. langfun/core/llms/anthropic.py +263 -0
  21. langfun/core/llms/anthropic_test.py +167 -0
  22. langfun/core/llms/cache/in_memory_test.py +37 -28
  23. langfun/core/llms/fake.py +34 -25
  24. langfun/core/llms/fake_test.py +122 -11
  25. langfun/core/llms/google_genai.py +8 -0
  26. langfun/core/llms/google_genai_test.py +8 -3
  27. langfun/core/llms/groq.py +260 -0
  28. langfun/core/llms/groq_test.py +170 -0
  29. langfun/core/llms/llama_cpp.py +3 -1
  30. langfun/core/llms/openai.py +100 -81
  31. langfun/core/llms/openai_test.py +287 -60
  32. langfun/core/llms/vertexai.py +291 -0
  33. langfun/core/llms/vertexai_test.py +233 -0
  34. langfun/core/modalities/image.py +1 -3
  35. langfun/core/modalities/mime.py +6 -0
  36. langfun/core/modalities/video.py +6 -5
  37. langfun/core/structured/__init__.py +5 -0
  38. langfun/core/structured/completion_test.py +2 -2
  39. langfun/core/structured/function_generation.py +245 -0
  40. langfun/core/structured/function_generation_test.py +329 -0
  41. langfun/core/structured/mapping.py +61 -3
  42. langfun/core/structured/mapping_test.py +17 -0
  43. langfun/core/structured/parsing_test.py +18 -13
  44. langfun/core/structured/prompting.py +61 -12
  45. langfun/core/structured/prompting_test.py +122 -12
  46. langfun/core/structured/schema.py +38 -6
  47. langfun/core/structured/schema_generation_test.py +2 -2
  48. langfun/core/structured/schema_test.py +36 -7
  49. langfun/core/structured/scoring.py +4 -1
  50. langfun/core/structured/scoring_test.py +6 -0
  51. langfun/core/template.py +147 -11
  52. langfun/core/template_test.py +75 -0
  53. langfun/core/templates/selfplay_test.py +6 -2
  54. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/METADATA +3 -2
  55. langfun-0.0.2.dev20240511.dist-info/RECORD +112 -0
  56. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  57. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/LICENSE +0 -0
  58. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/WHEEL +0 -0
  59. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,24 @@ class EchoTest(unittest.TestCase):
25
25
  def test_sample(self):
26
26
  lm = fakelm.Echo()
27
27
  self.assertEqual(
28
- lm.sample(['hi']), [lf.LMSamplingResult([lf.LMSample('hi', 1.0)])]
28
+ lm.sample(['hi']),
29
+ [
30
+ lf.LMSamplingResult(
31
+ [
32
+ lf.LMSample(
33
+ lf.AIMessage(
34
+ 'hi',
35
+ score=1.0,
36
+ logprobs=None,
37
+ usage=lf.LMSamplingUsage(2, 2, 4),
38
+ tags=[lf.Message.TAG_LM_RESPONSE],
39
+ ),
40
+ score=1.0,
41
+ logprobs=None,
42
+ )
43
+ ],
44
+ lf.LMSamplingUsage(2, 2, 4))
45
+ ]
29
46
  )
30
47
 
31
48
  def test_call(self):
@@ -34,8 +51,8 @@ class EchoTest(unittest.TestCase):
34
51
  with contextlib.redirect_stdout(string_io):
35
52
  self.assertEqual(lm('hi'), 'hi')
36
53
  debug_info = string_io.getvalue()
37
- self.assertIn('[0] LM INFO:', debug_info)
38
- self.assertIn('[0] PROMPT SENT TO LM:', debug_info)
54
+ self.assertIn('[0] LM INFO', debug_info)
55
+ self.assertIn('[0] PROMPT SENT TO LM', debug_info)
39
56
  self.assertIn('[0] LM RESPONSE', debug_info)
40
57
 
41
58
  def test_score(self):
@@ -53,11 +70,45 @@ class StaticResponseTest(unittest.TestCase):
53
70
  lm = fakelm.StaticResponse(canned_response)
54
71
  self.assertEqual(
55
72
  lm.sample(['hi']),
56
- [lf.LMSamplingResult([lf.LMSample(canned_response, 1.0)])],
73
+ [
74
+ lf.LMSamplingResult(
75
+ [
76
+ lf.LMSample(
77
+ lf.AIMessage(
78
+ canned_response,
79
+ score=1.0,
80
+ logprobs=None,
81
+ usage=lf.LMSamplingUsage(2, 38, 40),
82
+ tags=[lf.Message.TAG_LM_RESPONSE],
83
+ ),
84
+ score=1.0,
85
+ logprobs=None,
86
+ )
87
+ ],
88
+ usage=lf.LMSamplingUsage(2, 38, 40)
89
+ )
90
+ ],
57
91
  )
58
92
  self.assertEqual(
59
93
  lm.sample(['Tell me a joke.']),
60
- [lf.LMSamplingResult([lf.LMSample(canned_response, 1.0)])],
94
+ [
95
+ lf.LMSamplingResult(
96
+ [
97
+ lf.LMSample(
98
+ lf.AIMessage(
99
+ canned_response,
100
+ score=1.0,
101
+ logprobs=None,
102
+ usage=lf.LMSamplingUsage(15, 38, 53),
103
+ tags=[lf.Message.TAG_LM_RESPONSE],
104
+ ),
105
+ score=1.0,
106
+ logprobs=None,
107
+ )
108
+ ],
109
+ usage=lf.LMSamplingUsage(15, 38, 53)
110
+ )
111
+ ],
61
112
  )
62
113
 
63
114
  def test_call(self):
@@ -69,8 +120,8 @@ class StaticResponseTest(unittest.TestCase):
69
120
  self.assertEqual(lm('hi'), canned_response)
70
121
 
71
122
  debug_info = string_io.getvalue()
72
- self.assertIn('[0] LM INFO:', debug_info)
73
- self.assertIn('[0] PROMPT SENT TO LM:', debug_info)
123
+ self.assertIn('[0] LM INFO', debug_info)
124
+ self.assertIn('[0] PROMPT SENT TO LM', debug_info)
74
125
  self.assertIn('[0] LM RESPONSE', debug_info)
75
126
 
76
127
 
@@ -85,8 +136,38 @@ class StaticMappingTest(unittest.TestCase):
85
136
  self.assertEqual(
86
137
  lm.sample(['Hi', 'How are you?']),
87
138
  [
88
- lf.LMSamplingResult([lf.LMSample('Hello', 1.0)]),
89
- lf.LMSamplingResult([lf.LMSample('I am fine, how about you?', 1.0)])
139
+ lf.LMSamplingResult(
140
+ [
141
+ lf.LMSample(
142
+ lf.AIMessage(
143
+ 'Hello',
144
+ score=1.0,
145
+ logprobs=None,
146
+ usage=lf.LMSamplingUsage(2, 5, 7),
147
+ tags=[lf.Message.TAG_LM_RESPONSE],
148
+ ),
149
+ score=1.0,
150
+ logprobs=None,
151
+ )
152
+ ],
153
+ usage=lf.LMSamplingUsage(2, 5, 7)
154
+ ),
155
+ lf.LMSamplingResult(
156
+ [
157
+ lf.LMSample(
158
+ lf.AIMessage(
159
+ 'I am fine, how about you?',
160
+ score=1.0,
161
+ logprobs=None,
162
+ usage=lf.LMSamplingUsage(12, 25, 37),
163
+ tags=[lf.Message.TAG_LM_RESPONSE],
164
+ ),
165
+ score=1.0,
166
+ logprobs=None,
167
+ )
168
+ ],
169
+ usage=lf.LMSamplingUsage(12, 25, 37)
170
+ )
90
171
  ]
91
172
  )
92
173
  with self.assertRaises(KeyError):
@@ -104,8 +185,38 @@ class StaticSequenceTest(unittest.TestCase):
104
185
  self.assertEqual(
105
186
  lm.sample(['Hi', 'How are you?']),
106
187
  [
107
- lf.LMSamplingResult([lf.LMSample('Hello', 1.0)]),
108
- lf.LMSamplingResult([lf.LMSample('I am fine, how about you?', 1.0)])
188
+ lf.LMSamplingResult(
189
+ [
190
+ lf.LMSample(
191
+ lf.AIMessage(
192
+ 'Hello',
193
+ score=1.0,
194
+ logprobs=None,
195
+ usage=lf.LMSamplingUsage(2, 5, 7),
196
+ tags=[lf.Message.TAG_LM_RESPONSE],
197
+ ),
198
+ score=1.0,
199
+ logprobs=None,
200
+ )
201
+ ],
202
+ usage=lf.LMSamplingUsage(2, 5, 7)
203
+ ),
204
+ lf.LMSamplingResult(
205
+ [
206
+ lf.LMSample(
207
+ lf.AIMessage(
208
+ 'I am fine, how about you?',
209
+ score=1.0,
210
+ logprobs=None,
211
+ usage=lf.LMSamplingUsage(12, 25, 37),
212
+ tags=[lf.Message.TAG_LM_RESPONSE],
213
+ ),
214
+ score=1.0,
215
+ logprobs=None,
216
+ )
217
+ ],
218
+ usage=lf.LMSamplingUsage(12, 25, 37)
219
+ )
109
220
  ]
110
221
  )
111
222
  with self.assertRaises(IndexError):
@@ -34,6 +34,7 @@ class GenAI(lf.LanguageModel):
34
34
  'gemini-pro-vision',
35
35
  'text-bison-001',
36
36
  'chat-bison-001',
37
+ 'gemini-1.5-pro-latest',
37
38
  ],
38
39
  'Model name.',
39
40
  ]
@@ -262,6 +263,13 @@ _GOOGLE_GENAI_MODEL_HUB = _ModelHub()
262
263
  #
263
264
 
264
265
 
266
+ class GeminiPro1_5(GenAI): # pylint: disable=invalid-name
267
+ """Gemini Pro latest model."""
268
+
269
+ model = 'gemini-1.5-pro-latest'
270
+ multimodal = True
271
+
272
+
265
273
  class GeminiPro(GenAI):
266
274
  """Gemini Pro model."""
267
275
 
@@ -152,10 +152,15 @@ class GenAITest(unittest.TestCase):
152
152
  )
153
153
 
154
154
  def test_model_hub(self):
155
+ orig_get_model = genai.get_model
156
+ genai.get_model = mock_get_model
157
+
155
158
  model = google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro')
156
159
  self.assertIsNotNone(model)
157
160
  self.assertIs(google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model)
158
161
 
162
+ genai.get_model = orig_get_model
163
+
159
164
  def test_api_key_check(self):
160
165
  with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
161
166
  _ = google_genai.GeminiPro()._api_initialized
@@ -167,7 +172,7 @@ class GenAITest(unittest.TestCase):
167
172
 
168
173
  def test_call(self):
169
174
  with mock.patch(
170
- 'google.generativeai.generative_models.GenerativeModel.generate_content'
175
+ 'google.generativeai.GenerativeModel.generate_content',
171
176
  ) as mock_generate:
172
177
  orig_get_model = genai.get_model
173
178
  genai.get_model = mock_get_model
@@ -176,7 +181,7 @@ class GenAITest(unittest.TestCase):
176
181
  lm = google_genai.GeminiPro(api_key='test_key')
177
182
  self.maxDiff = None
178
183
  self.assertEqual(
179
- lm('hello', temperature=2.0, top_k=20).text,
184
+ lm('hello', temperature=2.0, top_k=20, max_tokens=1024).text,
180
185
  (
181
186
  'This is a response to hello with n=1, temperature=2.0, '
182
187
  'top_p=None, top_k=20, max_tokens=1024, stop=None.'
@@ -197,7 +202,7 @@ class GenAITest(unittest.TestCase):
197
202
  (
198
203
  "hello to models/text-bison-001 with {'temperature': 2.0, "
199
204
  "'top_k': 20, 'top_p': None, 'candidate_count': 1, "
200
- "'max_output_tokens': 1024, 'stop_sequences': None}"
205
+ "'max_output_tokens': None, 'stop_sequences': None}"
201
206
  ),
202
207
  )
203
208
  genai.get_model = orig_get_model
@@ -0,0 +1,260 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Language models from Groq."""
15
+
16
+ import functools
17
+ import os
18
+ from typing import Annotated, Any
19
+
20
+ import langfun.core as lf
21
+ from langfun.core import modalities as lf_modalities
22
+ import pyglove as pg
23
+ import requests
24
+
25
+
26
+ SUPPORTED_MODELS_AND_SETTINGS = {
27
+ # Refer https://console.groq.com/docs/models
28
+ 'llama3-8b-8192': pg.Dict(max_tokens=8192, max_concurrency=16),
29
+ 'llama3-70b-8192': pg.Dict(max_tokens=8192, max_concurrency=16),
30
+ 'llama2-70b-4096': pg.Dict(max_tokens=4096, max_concurrency=16),
31
+ 'mixtral-8x7b-32768': pg.Dict(max_tokens=32768, max_concurrency=16),
32
+ 'gemma-7b-it': pg.Dict(max_tokens=8192, max_concurrency=16),
33
+ }
34
+
35
+
36
+ class GroqError(Exception): # pylint: disable=g-bad-exception-name
37
+ """Base class for Groq errors."""
38
+
39
+
40
+ class RateLimitError(GroqError):
41
+ """Error for rate limit reached."""
42
+
43
+
44
+ class OverloadedError(GroqError):
45
+ """Groq's server is temporarily overloaded."""
46
+
47
+
48
+ _CHAT_COMPLETE_API_ENDPOINT = 'https://api.groq.com/openai/v1/chat/completions'
49
+
50
+
51
+ @lf.use_init_args(['model'])
52
+ class Groq(lf.LanguageModel):
53
+ """Groq LLMs through REST APIs (OpenAI compatible).
54
+
55
+ See https://platform.openai.com/docs/api-reference/chat
56
+ """
57
+
58
+ model: pg.typing.Annotated[
59
+ pg.typing.Enum(
60
+ pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
61
+ ),
62
+ 'The name of the model to use.',
63
+ ]
64
+
65
+ multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
66
+ False
67
+ )
68
+
69
+ api_key: Annotated[
70
+ str | None,
71
+ (
72
+ 'API key. If None, the key will be read from environment variable '
73
+ "'GROQ_API_KEY'."
74
+ ),
75
+ ] = None
76
+
77
+ def _on_bound(self):
78
+ super()._on_bound()
79
+ self._api_key = None
80
+ self.__dict__.pop('_api_initialized', None)
81
+ self.__dict__.pop('_session', None)
82
+
83
+ @functools.cached_property
84
+ def _api_initialized(self):
85
+ api_key = self.api_key or os.environ.get('GROQ_API_KEY', None)
86
+ if not api_key:
87
+ raise ValueError(
88
+ 'Please specify `api_key` during `__init__` or set environment '
89
+ 'variable `GROQ_API_KEY` with your Groq API key.'
90
+ )
91
+ self._api_key = api_key
92
+ return True
93
+
94
+ @functools.cached_property
95
+ def _session(self) -> requests.Session:
96
+ assert self._api_initialized
97
+ s = requests.Session()
98
+ s.headers.update({
99
+ 'Authorization': f'Bearer {self._api_key}',
100
+ 'Content-Type': 'application/json',
101
+ })
102
+ return s
103
+
104
+ @property
105
+ def model_id(self) -> str:
106
+ """Returns a string to identify the model."""
107
+ return self.model
108
+
109
+ @property
110
+ def max_concurrency(self) -> int:
111
+ return SUPPORTED_MODELS_AND_SETTINGS[self.model].max_concurrency
112
+
113
+ def _get_request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
114
+ """Returns a dict as request arguments."""
115
+ # `logprobs` and `top_logprobs` flags are not supported on Groq yet.
116
+ args = dict(
117
+ model=self.model,
118
+ n=options.n,
119
+ stream=False,
120
+ )
121
+
122
+ if options.temperature is not None:
123
+ args['temperature'] = options.temperature
124
+ if options.max_tokens is not None:
125
+ args['max_tokens'] = options.max_tokens
126
+ if options.top_p is not None:
127
+ args['top_p'] = options.top_p
128
+ if options.stop:
129
+ args['stop'] = options.stop
130
+ return args
131
+
132
+ def _content_from_message(self, prompt: lf.Message) -> list[dict[str, Any]]:
133
+ """Converts an message to Groq's content protocol (list of dicts)."""
134
+ # Refer: https://platform.openai.com/docs/api-reference/chat/create
135
+ content = []
136
+ for chunk in prompt.chunk():
137
+ if isinstance(chunk, str):
138
+ item = dict(type='text', text=chunk)
139
+ elif (
140
+ self.multimodal
141
+ and isinstance(chunk, lf_modalities.Image)
142
+ and chunk.uri
143
+ ):
144
+ # NOTE(daiyip): Groq only support image URL.
145
+ item = dict(type='image_url', image_url=chunk.uri)
146
+ else:
147
+ raise ValueError(f'Unsupported modality object: {chunk!r}.')
148
+ content.append(item)
149
+ return content
150
+
151
+ def _message_from_choice(self, choice: dict[str, Any]) -> lf.Message:
152
+ """Converts Groq's content protocol to message."""
153
+ # Refer: https://platform.openai.com/docs/api-reference/chat/create
154
+ content = choice['message']['content']
155
+ if isinstance(content, str):
156
+ return lf.AIMessage(content)
157
+ return lf.AIMessage.from_chunks(
158
+ [x['text'] for x in content if x['type'] == 'text']
159
+ )
160
+
161
+ def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
162
+ """Parses Groq's response."""
163
+ # Refer: https://platform.openai.com/docs/api-reference/chat/object
164
+ if response.status_code == 200:
165
+ output = response.json()
166
+ samples = [
167
+ lf.LMSample(self._message_from_choice(choice), score=0.0)
168
+ for choice in output['choices']
169
+ ]
170
+ usage = output['usage']
171
+ return lf.LMSamplingResult(
172
+ samples,
173
+ usage=lf.LMSamplingUsage(
174
+ prompt_tokens=usage['prompt_tokens'],
175
+ completion_tokens=usage['completion_tokens'],
176
+ total_tokens=usage['total_tokens'],
177
+ ),
178
+ )
179
+ else:
180
+ # https://platform.openai.com/docs/guides/error-codes/api-errors
181
+ if response.status_code == 429:
182
+ error_cls = RateLimitError
183
+ elif response.status_code in (500, 502, 503):
184
+ error_cls = OverloadedError
185
+ else:
186
+ error_cls = GroqError
187
+ raise error_cls(f'{response.status_code}: {response.content}')
188
+
189
+ def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
190
+ assert self._api_initialized
191
+ return self._parallel_execute_with_currency_control(
192
+ self._sample_single,
193
+ prompts,
194
+ retry_on_errors=(RateLimitError, OverloadedError),
195
+ )
196
+
197
+ def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
198
+ request = dict()
199
+ request.update(self._get_request_args(self.sampling_options))
200
+ request.update(
201
+ dict(
202
+ messages=[
203
+ dict(role='user', content=self._content_from_message(prompt))
204
+ ]
205
+ )
206
+ )
207
+ try:
208
+ response = self._session.post(
209
+ _CHAT_COMPLETE_API_ENDPOINT,
210
+ json=request,
211
+ timeout=self.timeout,
212
+ )
213
+ return self._parse_response(response)
214
+ except ConnectionError as e:
215
+ raise OverloadedError(str(e)) from e
216
+
217
+
218
+ class GroqLlama3_8B(Groq): # pylint: disable=invalid-name
219
+ """Llama3-8B with 8K context window.
220
+
221
+ See: https://huggingface.co/meta-llama/Meta-Llama-3-8B
222
+ """
223
+
224
+ model = 'llama3-8b-8192'
225
+
226
+
227
+ class GroqLlama3_70B(Groq): # pylint: disable=invalid-name
228
+ """Llama3-70B with 8K context window.
229
+
230
+ See: https://huggingface.co/meta-llama/Meta-Llama-3-70B
231
+ """
232
+
233
+ model = 'llama3-70b-8192'
234
+
235
+
236
+ class GroqLlama2_70B(Groq): # pylint: disable=invalid-name
237
+ """Llama2-70B with 4K context window.
238
+
239
+ See: https://huggingface.co/meta-llama/Llama-2-70b
240
+ """
241
+
242
+ model = 'llama2-70b-4096'
243
+
244
+
245
+ class GroqMistral_8x7B(Groq): # pylint: disable=invalid-name
246
+ """Mixtral 8x7B with 32K context window.
247
+
248
+ See: https://huggingface.co/meta-llama/Llama-2-70b
249
+ """
250
+
251
+ model = 'mixtral-8x7b-32768'
252
+
253
+
254
+ class GroqGemma7B_IT(Groq): # pylint: disable=invalid-name
255
+ """Gemma 7B with 8K context window.
256
+
257
+ See: https://huggingface.co/google/gemma-1.1-7b-it
258
+ """
259
+
260
+ model = 'gemma-7b-it'
@@ -0,0 +1,170 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for Groq models."""
15
+
16
+ import os
17
+ from typing import Any
18
+ import unittest
19
+ from unittest import mock
20
+ from langfun.core import modalities as lf_modalities
21
+ from langfun.core.llms import groq
22
+ import pyglove as pg
23
+ import requests
24
+
25
+
26
+ def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
27
+ del url, kwargs
28
+
29
+ response = requests.Response()
30
+ response.status_code = 200
31
+ response._content = pg.to_json_str({
32
+ 'choices': [{
33
+ 'message': {
34
+ 'content': [{
35
+ 'type': 'text',
36
+ 'text': (
37
+ f'hello with temperature={json.get("temperature")}, '
38
+ f'top_p={json.get("top_p")}, '
39
+ f'max_tokens={json.get("max_tokens")}, '
40
+ f'stop={json.get("stop")}.'
41
+ ),
42
+ }],
43
+ }
44
+ }],
45
+ 'usage': {
46
+ 'prompt_tokens': 2,
47
+ 'completion_tokens': 1,
48
+ 'total_tokens': 3,
49
+ },
50
+ }).encode()
51
+ return response
52
+
53
+
54
+ def mock_mm_requests_post(url: str, json: dict[str, Any], **kwargs):
55
+ del url, kwargs
56
+ v = json['messages'][0]['content'][0]
57
+ image = lf_modalities.Image.from_uri(v['image_url'])
58
+
59
+ response = requests.Response()
60
+ response.status_code = 200
61
+ response._content = pg.to_json_str({
62
+ 'choices': [
63
+ {
64
+ 'message': {
65
+ 'content': [{
66
+ 'type': 'text',
67
+ 'text': image.uri,
68
+ }],
69
+ }
70
+ }
71
+ ],
72
+ 'usage': {
73
+ 'prompt_tokens': 2,
74
+ 'completion_tokens': 1,
75
+ 'total_tokens': 3,
76
+ },
77
+ }).encode()
78
+ return response
79
+
80
+
81
+ def mock_requests_post_error(status_code, error_type, error_message):
82
+ def _mock_requests(url: str, json: dict[str, Any], **kwargs):
83
+ del url, json, kwargs
84
+ response = requests.Response()
85
+ response.status_code = status_code
86
+ response._content = pg.to_json_str(
87
+ {
88
+ 'error': {
89
+ 'type': error_type,
90
+ 'message': error_message,
91
+ }
92
+ }
93
+ ).encode()
94
+ return response
95
+
96
+ return _mock_requests
97
+
98
+
99
+ class AuthropicTest(unittest.TestCase):
100
+
101
+ def test_basics(self):
102
+ self.assertEqual(groq.GroqMistral_8x7B().model_id, 'mixtral-8x7b-32768')
103
+ self.assertEqual(groq.GroqMistral_8x7B().max_concurrency, 16)
104
+
105
+ def test_api_key(self):
106
+ lm = groq.GroqMistral_8x7B()
107
+ with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
108
+ lm('hi')
109
+
110
+ with mock.patch('requests.Session.post') as mock_request:
111
+ mock_request.side_effect = mock_requests_post
112
+
113
+ lm = groq.GroqMistral_8x7B(api_key='fake key')
114
+ self.assertRegex(lm('hi').text, 'hello.*')
115
+
116
+ os.environ['GROQ_API_KEY'] = 'abc'
117
+ lm = groq.GroqMistral_8x7B()
118
+ self.assertRegex(lm('hi').text, 'hello.*')
119
+ del os.environ['GROQ_API_KEY']
120
+
121
+ def test_call(self):
122
+ with mock.patch('requests.Session.post') as mock_request:
123
+ mock_request.side_effect = mock_requests_post
124
+ lm = groq.GroqLlama3_70B(api_key='fake_key')
125
+ response = lm(
126
+ 'hello',
127
+ temperature=0.0,
128
+ max_tokens=1024,
129
+ top_k=0.1,
130
+ top_p=0.2,
131
+ stop=['\n'],
132
+ )
133
+ self.assertEqual(
134
+ response.text,
135
+ (
136
+ 'hello with temperature=0.0, top_p=0.2, '
137
+ "max_tokens=1024, stop=['\\n']."
138
+ ),
139
+ )
140
+ self.assertIsNotNone(response.usage)
141
+ self.assertIsNotNone(response.usage.prompt_tokens, 2)
142
+ self.assertIsNotNone(response.usage.completion_tokens, 1)
143
+ self.assertIsNotNone(response.usage.total_tokens, 3)
144
+
145
+ def test_mm_call(self):
146
+ with mock.patch('requests.Session.post') as mock_mm_request:
147
+ mock_mm_request.side_effect = mock_mm_requests_post
148
+ lm = groq.GroqLlama3_70B(multimodal=True, api_key='fake_key')
149
+ response = lm(lf_modalities.Image.from_uri('https://fake/image.jpg'))
150
+ self.assertEqual(response.text, 'https://fake/image.jpg')
151
+
152
+ def test_call_errors(self):
153
+ for status_code, error_type, error_message in [
154
+ (429, 'rate_limit', 'Rate limit exceeded.'),
155
+ (503, 'service_unavailable', 'Service unavailable.'),
156
+ (500, 'bad_request', 'Bad request.'),
157
+ ]:
158
+ with mock.patch('requests.Session.post') as mock_mm_request:
159
+ mock_mm_request.side_effect = mock_requests_post_error(
160
+ status_code, error_type, error_message
161
+ )
162
+ lm = groq.GroqLlama3_70B(api_key='fake_key')
163
+ with self.assertRaisesRegex(
164
+ Exception, f'{status_code}:.*{error_type}'
165
+ ):
166
+ lm('hello', lm=lm, max_attempts=1)
167
+
168
+
169
+ if __name__ == '__main__':
170
+ unittest.main()
@@ -51,10 +51,12 @@ class LlamaCppRemote(lf.LanguageModel):
51
51
  data = {
52
52
  "prompt": prompt.text,
53
53
  "n_predict": self.sampling_options.max_tokens,
54
- "temperature": self.sampling_options.temperature,
55
54
  "top_k": self.sampling_options.top_k or 50,
56
55
  "top_p": self.sampling_options.top_p or 0.95,
57
56
  }
57
+ if self.sampling_options.temperature is not None:
58
+ data["temperature"] = self.sampling_options.temperature
59
+
58
60
  response = requests.post(
59
61
  f"{self.url}/completion",
60
62
  json=data,