langfun 0.1.2.dev202501060804__py3-none-any.whl → 0.1.2.dev202501100804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. langfun/core/__init__.py +0 -5
  2. langfun/core/coding/python/correction.py +4 -3
  3. langfun/core/coding/python/errors.py +10 -9
  4. langfun/core/coding/python/execution.py +23 -12
  5. langfun/core/coding/python/execution_test.py +21 -2
  6. langfun/core/coding/python/generation.py +18 -9
  7. langfun/core/concurrent.py +2 -3
  8. langfun/core/console.py +8 -3
  9. langfun/core/eval/base.py +2 -3
  10. langfun/core/eval/v2/reporting.py +8 -4
  11. langfun/core/language_model.py +7 -4
  12. langfun/core/language_model_test.py +15 -0
  13. langfun/core/llms/__init__.py +7 -0
  14. langfun/core/llms/deepseek.py +117 -0
  15. langfun/core/llms/deepseek_test.py +61 -0
  16. langfun/core/llms/google_genai.py +1 -0
  17. langfun/core/llms/groq.py +12 -99
  18. langfun/core/llms/groq_test.py +31 -137
  19. langfun/core/llms/llama_cpp.py +17 -54
  20. langfun/core/llms/llama_cpp_test.py +2 -34
  21. langfun/core/llms/openai.py +14 -147
  22. langfun/core/llms/openai_compatible.py +179 -0
  23. langfun/core/llms/openai_compatible_test.py +480 -0
  24. langfun/core/llms/openai_test.py +13 -423
  25. langfun/core/llms/vertexai.py +6 -2
  26. langfun/core/llms/vertexai_test.py +1 -1
  27. langfun/core/modalities/mime.py +8 -0
  28. langfun/core/modalities/mime_test.py +19 -4
  29. langfun/core/modality_test.py +0 -1
  30. langfun/core/structured/mapping.py +13 -13
  31. langfun/core/structured/mapping_test.py +2 -2
  32. langfun/core/structured/schema.py +16 -8
  33. {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/METADATA +13 -2
  34. {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/RECORD +37 -35
  35. {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/WHEEL +1 -1
  36. langfun/core/text_formatting.py +0 -168
  37. langfun/core/text_formatting_test.py +0 -65
  38. {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/LICENSE +0 -0
  39. {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/top_level.txt +0 -0
@@ -68,6 +68,7 @@ class GeminiFlash2_0ThinkingExp_20241219(GenAI): # pylint: disable=invalid-name
68
68
 
69
69
  api_version = 'v1alpha'
70
70
  model = 'gemini-2.0-flash-thinking-exp-1219'
71
+ timeout = None
71
72
 
72
73
 
73
74
  class GeminiFlash2_0Exp(GenAI): # pylint: disable=invalid-name
langfun/core/llms/groq.py CHANGED
@@ -17,8 +17,7 @@ import os
17
17
  from typing import Annotated, Any
18
18
 
19
19
  import langfun.core as lf
20
- from langfun.core import modalities as lf_modalities
21
- from langfun.core.llms import rest
20
+ from langfun.core.llms import openai_compatible
22
21
  import pyglove as pg
23
22
 
24
23
 
@@ -95,7 +94,7 @@ SUPPORTED_MODELS_AND_SETTINGS = {
95
94
 
96
95
 
97
96
  @lf.use_init_args(['model'])
98
- class Groq(rest.REST):
97
+ class Groq(openai_compatible.OpenAICompatible):
99
98
  """Groq LLMs through REST APIs (OpenAI compatible).
100
99
 
101
100
  See https://platform.openai.com/docs/api-reference/chat
@@ -108,10 +107,6 @@ class Groq(rest.REST):
108
107
  'The name of the model to use.',
109
108
  ]
110
109
 
111
- multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
112
- False
113
- )
114
-
115
110
  api_key: Annotated[
116
111
  str | None,
117
112
  (
@@ -122,25 +117,19 @@ class Groq(rest.REST):
122
117
 
123
118
  api_endpoint: str = 'https://api.groq.com/openai/v1/chat/completions'
124
119
 
125
- def _on_bound(self):
126
- super()._on_bound()
127
- self._api_key = None
128
-
129
- def _initialize(self):
120
+ @property
121
+ def headers(self) -> dict[str, Any]:
130
122
  api_key = self.api_key or os.environ.get('GROQ_API_KEY', None)
131
123
  if not api_key:
132
124
  raise ValueError(
133
125
  'Please specify `api_key` during `__init__` or set environment '
134
126
  'variable `GROQ_API_KEY` with your Groq API key.'
135
127
  )
136
- self._api_key = api_key
137
-
138
- @property
139
- def headers(self) -> dict[str, Any]:
140
- return {
141
- 'Authorization': f'Bearer {self._api_key}',
142
- 'Content-Type': 'application/json',
143
- }
128
+ headers = super().headers
129
+ headers.update({
130
+ 'Authorization': f'Bearer {api_key}',
131
+ })
132
+ return headers
144
133
 
145
134
  @property
146
135
  def model_id(self) -> str:
@@ -170,90 +159,14 @@ class Groq(rest.REST):
170
159
  + cost_per_1k_output_tokens * num_output_tokens
171
160
  ) / 1000
172
161
 
173
- def request(
174
- self,
175
- prompt: lf.Message,
176
- sampling_options: lf.LMSamplingOptions
177
- ) -> dict[str, Any]:
178
- """Returns the JSON input for a message."""
179
- request = dict()
180
- request.update(self._request_args(sampling_options))
181
- request.update(
182
- dict(
183
- messages=[
184
- dict(role='user', content=self._content_from_message(prompt))
185
- ]
186
- )
187
- )
188
- return request
189
-
190
162
  def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
191
163
  """Returns a dict as request arguments."""
192
164
  # `logprobs` and `top_logprobs` flags are not supported on Groq yet.
193
- args = dict(
194
- model=self.model,
195
- n=options.n,
196
- stream=False,
197
- )
198
-
199
- if options.temperature is not None:
200
- args['temperature'] = options.temperature
201
- if options.max_tokens is not None:
202
- args['max_tokens'] = options.max_tokens
203
- if options.top_p is not None:
204
- args['top_p'] = options.top_p
205
- if options.stop:
206
- args['stop'] = options.stop
165
+ args = super()._request_args(options)
166
+ args.pop('logprobs', None)
167
+ args.pop('top_logprobs', None)
207
168
  return args
208
169
 
209
- def _content_from_message(self, prompt: lf.Message) -> list[dict[str, Any]]:
210
- """Converts an message to Groq's content protocol (list of dicts)."""
211
- # Refer: https://platform.openai.com/docs/api-reference/chat/create
212
- content = []
213
- for chunk in prompt.chunk():
214
- if isinstance(chunk, str):
215
- item = dict(type='text', text=chunk)
216
- elif (
217
- self.multimodal
218
- and isinstance(chunk, lf_modalities.Image)
219
- and chunk.uri
220
- ):
221
- # NOTE(daiyip): Groq only support image URL.
222
- item = dict(type='image_url', image_url=chunk.uri)
223
- else:
224
- raise ValueError(f'Unsupported modality object: {chunk!r}.')
225
- content.append(item)
226
- return content
227
-
228
- def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
229
- samples = [
230
- lf.LMSample(self._message_from_choice(choice), score=0.0)
231
- for choice in json['choices']
232
- ]
233
- usage = json['usage']
234
- return lf.LMSamplingResult(
235
- samples,
236
- usage=lf.LMSamplingUsage(
237
- prompt_tokens=usage['prompt_tokens'],
238
- completion_tokens=usage['completion_tokens'],
239
- total_tokens=usage['total_tokens'],
240
- estimated_cost=self.estimate_cost(
241
- num_input_tokens=usage['prompt_tokens'],
242
- num_output_tokens=usage['completion_tokens'],
243
- ),
244
- ),
245
- )
246
-
247
- def _message_from_choice(self, choice: dict[str, Any]) -> lf.Message:
248
- """Converts Groq's content protocol to message."""
249
- # Refer: https://platform.openai.com/docs/api-reference/chat/create
250
- content = choice['message']['content']
251
- if isinstance(content, str):
252
- return lf.AIMessage(content)
253
- return lf.AIMessage.from_chunks(
254
- [x['text'] for x in content if x['type'] == 'text']
255
- )
256
-
257
170
 
258
171
  class GroqLlama3_2_3B(Groq): # pylint: disable=invalid-name
259
172
  """Llama3.2-3B with 8K context window.
@@ -11,89 +11,10 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Tests for Groq models."""
15
-
16
14
  import os
17
- from typing import Any
18
15
  import unittest
19
- from unittest import mock
20
- from langfun.core import modalities as lf_modalities
16
+ import langfun.core as lf
21
17
  from langfun.core.llms import groq
22
- import pyglove as pg
23
- import requests
24
-
25
-
26
- def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
27
- del url, kwargs
28
-
29
- response = requests.Response()
30
- response.status_code = 200
31
- response._content = pg.to_json_str({
32
- 'choices': [{
33
- 'message': {
34
- 'content': [{
35
- 'type': 'text',
36
- 'text': (
37
- f'hello with temperature={json.get("temperature")}, '
38
- f'top_p={json.get("top_p")}, '
39
- f'max_tokens={json.get("max_tokens")}, '
40
- f'stop={json.get("stop")}.'
41
- ),
42
- }],
43
- }
44
- }],
45
- 'usage': {
46
- 'prompt_tokens': 2,
47
- 'completion_tokens': 1,
48
- 'total_tokens': 3,
49
- },
50
- }).encode()
51
- return response
52
-
53
-
54
- def mock_mm_requests_post(url: str, json: dict[str, Any], **kwargs):
55
- del url, kwargs
56
- v = json['messages'][0]['content'][0]
57
- image = lf_modalities.Image.from_uri(v['image_url'])
58
-
59
- response = requests.Response()
60
- response.status_code = 200
61
- response._content = pg.to_json_str({
62
- 'choices': [
63
- {
64
- 'message': {
65
- 'content': [{
66
- 'type': 'text',
67
- 'text': image.uri,
68
- }],
69
- }
70
- }
71
- ],
72
- 'usage': {
73
- 'prompt_tokens': 2,
74
- 'completion_tokens': 1,
75
- 'total_tokens': 3,
76
- },
77
- }).encode()
78
- return response
79
-
80
-
81
- def mock_requests_post_error(status_code, error_type, error_message):
82
- def _mock_requests(url: str, json: dict[str, Any], **kwargs):
83
- del url, json, kwargs
84
- response = requests.Response()
85
- response.status_code = status_code
86
- response._content = pg.to_json_str(
87
- {
88
- 'error': {
89
- 'type': error_type,
90
- 'message': error_message,
91
- }
92
- }
93
- ).encode()
94
- return response
95
-
96
- return _mock_requests
97
18
 
98
19
 
99
20
  class AuthropicTest(unittest.TestCase):
@@ -101,69 +22,42 @@ class AuthropicTest(unittest.TestCase):
101
22
  def test_basics(self):
102
23
  self.assertEqual(groq.GroqMistral_8x7B().model_id, 'mixtral-8x7b-32768')
103
24
  self.assertEqual(groq.GroqMistral_8x7B().max_concurrency, 16)
25
+ self.assertEqual(groq.GroqMistral_8x7B().estimate_cost(100, 100), 4.8e-5)
26
+
27
+ def test_request_args(self):
28
+ args = groq.GroqMistral_8x7B()._request_args(
29
+ lf.LMSamplingOptions(
30
+ temperature=1.0, stop=['\n'], n=1, random_seed=123,
31
+ logprobs=True, top_logprobs=True
32
+ )
33
+ )
34
+ self.assertNotIn('logprobs', args)
35
+ self.assertNotIn('top_logprobs', args)
104
36
 
105
37
  def test_api_key(self):
106
38
  lm = groq.GroqMistral_8x7B()
107
39
  with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
108
- lm('hi')
109
-
110
- with mock.patch('requests.Session.post') as mock_request:
111
- mock_request.side_effect = mock_requests_post
112
-
113
- lm = groq.GroqMistral_8x7B(api_key='fake key')
114
- self.assertRegex(lm('hi').text, 'hello.*')
115
-
116
- os.environ['GROQ_API_KEY'] = 'abc'
117
- lm = groq.GroqMistral_8x7B()
118
- self.assertRegex(lm('hi').text, 'hello.*')
119
- del os.environ['GROQ_API_KEY']
120
-
121
- def test_call(self):
122
- with mock.patch('requests.Session.post') as mock_request:
123
- mock_request.side_effect = mock_requests_post
124
- lm = groq.GroqLlama3_70B(api_key='fake_key')
125
- response = lm(
126
- 'hello',
127
- temperature=0.0,
128
- max_tokens=1024,
129
- top_k=0.1,
130
- top_p=0.2,
131
- stop=['\n'],
132
- )
133
- self.assertEqual(
134
- response.text,
135
- (
136
- 'hello with temperature=0.0, top_p=0.2, '
137
- "max_tokens=1024, stop=['\\n']."
138
- ),
139
- )
140
- self.assertIsNotNone(response.usage)
141
- self.assertIsNotNone(response.usage.prompt_tokens, 2)
142
- self.assertIsNotNone(response.usage.completion_tokens, 1)
143
- self.assertIsNotNone(response.usage.total_tokens, 3)
40
+ _ = lm.headers
144
41
 
145
- def test_mm_call(self):
146
- with mock.patch('requests.Session.post') as mock_mm_request:
147
- mock_mm_request.side_effect = mock_mm_requests_post
148
- lm = groq.GroqLlama3_70B(multimodal=True, api_key='fake_key')
149
- response = lm(lf_modalities.Image.from_uri('https://fake/image.jpg'))
150
- self.assertEqual(response.text, 'https://fake/image.jpg')
42
+ lm = groq.GroqMistral_8x7B(api_key='fake key')
43
+ self.assertEqual(
44
+ lm.headers,
45
+ {
46
+ 'Content-Type': 'application/json',
47
+ 'Authorization': 'Bearer fake key',
48
+ }
49
+ )
151
50
 
152
- def test_call_errors(self):
153
- for status_code, error_type, error_message in [
154
- (429, 'rate_limit', 'Rate limit exceeded.'),
155
- (503, 'service_unavailable', 'Service unavailable.'),
156
- (500, 'bad_request', 'Bad request.'),
157
- ]:
158
- with mock.patch('requests.Session.post') as mock_mm_request:
159
- mock_mm_request.side_effect = mock_requests_post_error(
160
- status_code, error_type, error_message
161
- )
162
- lm = groq.GroqLlama3_70B(api_key='fake_key')
163
- with self.assertRaisesRegex(
164
- Exception, f'{status_code}:.*{error_type}'
165
- ):
166
- lm('hello', max_attempts=1)
51
+ os.environ['GROQ_API_KEY'] = 'abc'
52
+ lm = groq.GroqMistral_8x7B()
53
+ self.assertEqual(
54
+ lm.headers,
55
+ {
56
+ 'Content-Type': 'application/json',
57
+ 'Authorization': 'Bearer abc',
58
+ }
59
+ )
60
+ del os.environ['GROQ_API_KEY']
167
61
 
168
62
 
169
63
  if __name__ == '__main__':
@@ -13,72 +13,35 @@
13
13
  # limitations under the License.
14
14
  """Language models from llama.cpp."""
15
15
 
16
- from typing import Any
17
-
18
- import langfun.core as lf
19
- from langfun.core.llms import rest
16
+ from typing import Annotated
17
+ from langfun.core.llms import openai_compatible
20
18
  import pyglove as pg
21
19
 
22
20
 
23
- class LlamaCppRemote(rest.REST):
21
+ @pg.use_init_args(['url', 'model'])
22
+ @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
23
+ class LlamaCppRemote(openai_compatible.OpenAICompatible):
24
24
  """The remote LLaMA C++ model.
25
25
 
26
26
  The Remote LLaMA C++ models can be launched via
27
27
  https://github.com/ggerganov/llama.cpp/tree/master/examples/server
28
28
  """
29
+ url: Annotated[
30
+ str,
31
+ 'The URL of the LLaMA C++ server.',
32
+ ]
33
+
34
+ model: Annotated[
35
+ str,
36
+ 'The name of the model to use.',
37
+ ] = ''
29
38
 
30
- @pg.explicit_method_override
31
- def __init__(self, url: str, model: str | None = None, **kwargs):
32
- super().__init__(api_endpoint=f'{url}/completion', model=model, **kwargs)
39
+ @property
40
+ def api_endpoint(self) -> str:
41
+ return self.url + '/completion'
33
42
 
34
43
  @property
35
44
  def model_id(self) -> str:
36
45
  """Returns a string to identify the model."""
37
46
  return f'LLaMAC++({self.model or ""})'
38
47
 
39
- def request(
40
- self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
41
- ) -> dict[str, Any]:
42
- """Returns the JSON input for a message."""
43
- request = dict()
44
- request.update(self._request_args(sampling_options))
45
- # NOTE(daiyip): multi-modal is current not supported.
46
- request['prompt'] = prompt.text
47
- return request
48
-
49
- def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
50
- """Returns a dict as request arguments."""
51
- args = dict(
52
- n_predict=options.max_tokens or 1024,
53
- top_k=options.top_k or 50,
54
- top_p=options.top_p or 0.95,
55
- )
56
- if options.temperature is not None:
57
- args['temperature'] = options.temperature
58
- return args
59
-
60
- def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
61
- return lf.LMSamplingResult(
62
- [lf.LMSample(item['content'], score=0.0) for item in json['items']]
63
- )
64
-
65
- def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
66
- request = self.request(prompt, self.sampling_options)
67
-
68
- def _sample_one_example(request):
69
- response = self._session.post(
70
- self.api_endpoint,
71
- json=request,
72
- timeout=self.timeout,
73
- )
74
- if response.status_code == 200:
75
- return response.json()
76
- else:
77
- error_cls = self._error_cls_from_status(response.status_code)
78
- raise error_cls(f'{response.status_code}: {response.content}')
79
-
80
- items = self._parallel_execute_with_currency_control(
81
- _sample_one_example,
82
- [request] * (self.sampling_options.n or 1),
83
- )
84
- return self.result(dict(items=items))
@@ -11,48 +11,16 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Tests for llama cpp models."""
15
-
16
- import typing
17
14
  import unittest
18
- from unittest import mock
19
-
20
15
  from langfun.core.llms import llama_cpp
21
16
 
22
17
 
23
- def mock_requests_post(url: str, json: typing.Dict[str, typing.Any], **kwargs):
24
- del kwargs
25
-
26
- class TEMP:
27
- @property
28
- def status_code(self):
29
- return 200
30
-
31
- def json(self):
32
- return {"content": json["prompt"] + "\n" + url}
33
-
34
- return TEMP()
35
-
36
-
37
18
  class LlamaCppRemoteTest(unittest.TestCase):
38
19
  """Tests for the LlamaCppRemote model."""
39
20
 
40
- def test_call_completion(self):
41
- with mock.patch("requests.Session.post") as mock_request:
42
- mock_request.side_effect = mock_requests_post
43
- lm = llama_cpp.LlamaCppRemote("http://127.0.0.1:8080")
44
- [result] = lm.sample(["hello"], n=2)
45
- self.assertEqual(
46
- len(result.samples),
47
- 2
48
- )
49
- self.assertEqual(
50
- str(result.samples[0].response),
51
- "hello\nhttp://127.0.0.1:8080/completion",
52
- )
53
-
54
- def test_model_id(self):
21
+ def test_basics(self):
55
22
  lm = llama_cpp.LlamaCppRemote("http://127.0.0.1:8080")
23
+ self.assertEqual(lm.api_endpoint, "http://127.0.0.1:8080/completion")
56
24
  self.assertEqual(lm.model_id, "LLaMAC++()")
57
25
  lm = llama_cpp.LlamaCppRemote("xxx", model="x")
58
26
  self.assertEqual(lm.model_id, "LLaMAC++(x)")