langfun 0.1.2.dev202501080804__py3-none-any.whl → 0.1.2.dev202501240804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. langfun/core/__init__.py +1 -6
  2. langfun/core/coding/python/__init__.py +5 -11
  3. langfun/core/coding/python/correction.py +4 -7
  4. langfun/core/coding/python/correction_test.py +2 -3
  5. langfun/core/coding/python/execution.py +22 -211
  6. langfun/core/coding/python/execution_test.py +11 -90
  7. langfun/core/coding/python/generation.py +3 -2
  8. langfun/core/coding/python/generation_test.py +2 -2
  9. langfun/core/coding/python/parsing.py +108 -194
  10. langfun/core/coding/python/parsing_test.py +2 -105
  11. langfun/core/component.py +11 -273
  12. langfun/core/component_test.py +2 -29
  13. langfun/core/concurrent.py +187 -82
  14. langfun/core/concurrent_test.py +28 -19
  15. langfun/core/console.py +7 -3
  16. langfun/core/eval/base.py +2 -3
  17. langfun/core/eval/v2/evaluation.py +3 -1
  18. langfun/core/eval/v2/reporting.py +8 -4
  19. langfun/core/language_model.py +84 -8
  20. langfun/core/language_model_test.py +84 -29
  21. langfun/core/llms/__init__.py +46 -11
  22. langfun/core/llms/anthropic.py +1 -123
  23. langfun/core/llms/anthropic_test.py +0 -48
  24. langfun/core/llms/deepseek.py +117 -0
  25. langfun/core/llms/deepseek_test.py +61 -0
  26. langfun/core/llms/gemini.py +1 -1
  27. langfun/core/llms/groq.py +12 -99
  28. langfun/core/llms/groq_test.py +31 -137
  29. langfun/core/llms/llama_cpp.py +17 -54
  30. langfun/core/llms/llama_cpp_test.py +2 -34
  31. langfun/core/llms/openai.py +9 -147
  32. langfun/core/llms/openai_compatible.py +179 -0
  33. langfun/core/llms/openai_compatible_test.py +495 -0
  34. langfun/core/llms/openai_test.py +13 -423
  35. langfun/core/llms/rest_test.py +1 -1
  36. langfun/core/llms/vertexai.py +387 -18
  37. langfun/core/llms/vertexai_test.py +52 -0
  38. langfun/core/message_test.py +3 -3
  39. langfun/core/modalities/mime.py +8 -0
  40. langfun/core/modalities/mime_test.py +19 -4
  41. langfun/core/modality_test.py +0 -1
  42. langfun/core/structured/mapping.py +13 -13
  43. langfun/core/structured/mapping_test.py +2 -2
  44. langfun/core/structured/schema.py +16 -8
  45. langfun/core/structured/schema_generation.py +1 -1
  46. {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/METADATA +13 -2
  47. {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/RECORD +50 -52
  48. {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/WHEEL +1 -1
  49. langfun/core/coding/python/errors.py +0 -108
  50. langfun/core/coding/python/errors_test.py +0 -99
  51. langfun/core/coding/python/permissions.py +0 -90
  52. langfun/core/coding/python/permissions_test.py +0 -86
  53. langfun/core/text_formatting.py +0 -168
  54. langfun/core/text_formatting_test.py +0 -65
  55. {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/LICENSE +0 -0
  56. {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,117 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Language models from DeepSeek."""
15
+
16
+ import os
17
+ from typing import Annotated, Any
18
+
19
+ import langfun.core as lf
20
+ from langfun.core.llms import openai_compatible
21
+ import pyglove as pg
22
+
23
+ SUPPORTED_MODELS_AND_SETTINGS = {
24
+ # pylint: disable=g-line-too-long
25
+ # TODO(yifenglu): The RPM and TPM are arbitrary numbers. Update them once DeepSeek provides concrete guidelines.
26
+ # DeepSeek doesn't control the rate limit at the moment: https://api-docs.deepseek.com/quick_start/rate_limit
27
+ # The cost is based on: https://api-docs.deepseek.com/quick_start/pricing
28
+ 'deepseek-chat': pg.Dict(
29
+ in_service=True,
30
+ rpm=100,
31
+ tpm=1000000,
32
+ cost_per_1k_input_tokens=0.00014,
33
+ cost_per_1k_output_tokens=0.00028,
34
+ ),
35
+ }
36
+
37
+
38
+ # DeepSeek API uses an API format compatible with OpenAI.
39
+ # Reference: https://api-docs.deepseek.com/
40
+ @lf.use_init_args(['model'])
41
+ class DeepSeek(openai_compatible.OpenAICompatible):
42
+ """DeepSeek model."""
43
+
44
+ model: pg.typing.Annotated[
45
+ pg.typing.Enum(
46
+ pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
47
+ ),
48
+ 'The name of the model to use.',
49
+ ]
50
+
51
+ api_endpoint: str = 'https://api.deepseek.com/chat/completions'
52
+
53
+ api_key: Annotated[
54
+ str | None,
55
+ (
56
+ 'API key. If None, the key will be read from environment variable '
57
+ "'DEEPSEEK_API_KEY'."
58
+ ),
59
+ ] = None
60
+
61
+ @property
62
+ def headers(self) -> dict[str, Any]:
63
+ api_key = self.api_key or os.environ.get('DEEPSEEK_API_KEY', None)
64
+ if not api_key:
65
+ raise ValueError(
66
+ 'Please specify `api_key` during `__init__` or set environment '
67
+ 'variable `DEEPSEEK_API_KEY` with your DeepSeek API key.'
68
+ )
69
+ headers = super().headers
70
+ headers.update({
71
+ 'Authorization': f'Bearer {api_key}',
72
+ })
73
+ return headers
74
+
75
+ @property
76
+ def model_id(self) -> str:
77
+ """Returns a string to identify the model."""
78
+ return f'DeepSeek({self.model})'
79
+
80
+ @property
81
+ def max_concurrency(self) -> int:
82
+ rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
83
+ tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
84
+ return self.rate_to_max_concurrency(
85
+ requests_per_min=rpm, tokens_per_min=tpm
86
+ )
87
+
88
+ def estimate_cost(
89
+ self, num_input_tokens: int, num_output_tokens: int
90
+ ) -> float | None:
91
+ """Estimate the cost based on usage."""
92
+ cost_per_1k_input_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
93
+ 'cost_per_1k_input_tokens', None
94
+ )
95
+ cost_per_1k_output_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
96
+ 'cost_per_1k_output_tokens', None
97
+ )
98
+ if cost_per_1k_output_tokens is None or cost_per_1k_input_tokens is None:
99
+ return None
100
+ return (
101
+ cost_per_1k_input_tokens * num_input_tokens
102
+ + cost_per_1k_output_tokens * num_output_tokens
103
+ ) / 1000
104
+
105
+ @classmethod
106
+ def dir(cls):
107
+ return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
108
+
109
+
110
+ class DeepSeekChat(DeepSeek):
111
+ """DeepSeek Chat model.
112
+
113
+ Currently, it is powered by DeepSeek-V3 model, 64K input contenxt window and
114
+ 8k max output tokens.
115
+ """
116
+
117
+ model = 'deepseek-chat'
@@ -0,0 +1,61 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import unittest
15
+ from langfun.core.llms import deepseek
16
+
17
+
18
+ class DeepSeekTest(unittest.TestCase):
19
+ """Tests for DeepSeek language model."""
20
+
21
+ def test_dir(self):
22
+ self.assertIn('deepseek-chat', deepseek.DeepSeek.dir())
23
+
24
+ def test_key(self):
25
+ with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
26
+ _ = deepseek.DeepSeekChat().headers
27
+ self.assertEqual(
28
+ deepseek.DeepSeekChat(api_key='test_key').headers,
29
+ {
30
+ 'Content-Type': 'application/json',
31
+ 'Authorization': 'Bearer test_key',
32
+ }
33
+ )
34
+
35
+ def test_model_id(self):
36
+ self.assertEqual(
37
+ deepseek.DeepSeekChat(api_key='test_key').model_id,
38
+ 'DeepSeek(deepseek-chat)',
39
+ )
40
+
41
+ def test_resource_id(self):
42
+ self.assertEqual(
43
+ deepseek.DeepSeekChat(api_key='test_key').resource_id,
44
+ 'DeepSeek(deepseek-chat)',
45
+ )
46
+
47
+ def test_max_concurrency(self):
48
+ self.assertGreater(
49
+ deepseek.DeepSeekChat(api_key='test_key').max_concurrency, 0
50
+ )
51
+
52
+ def test_estimate_cost(self):
53
+ self.assertEqual(
54
+ deepseek.DeepSeekChat(api_key='test_key').estimate_cost(
55
+ num_input_tokens=100, num_output_tokens=100
56
+ ),
57
+ 4.2e-5
58
+ )
59
+
60
+ if __name__ == '__main__':
61
+ unittest.main()
@@ -380,7 +380,7 @@ class Gemini(rest.REST):
380
380
  return (
381
381
  cost_per_1m_input_tokens * num_input_tokens
382
382
  + cost_per_1m_output_tokens * num_output_tokens
383
- ) / 1000_1000
383
+ ) / 1000_000
384
384
 
385
385
  @property
386
386
  def model_id(self) -> str:
langfun/core/llms/groq.py CHANGED
@@ -17,8 +17,7 @@ import os
17
17
  from typing import Annotated, Any
18
18
 
19
19
  import langfun.core as lf
20
- from langfun.core import modalities as lf_modalities
21
- from langfun.core.llms import rest
20
+ from langfun.core.llms import openai_compatible
22
21
  import pyglove as pg
23
22
 
24
23
 
@@ -95,7 +94,7 @@ SUPPORTED_MODELS_AND_SETTINGS = {
95
94
 
96
95
 
97
96
  @lf.use_init_args(['model'])
98
- class Groq(rest.REST):
97
+ class Groq(openai_compatible.OpenAICompatible):
99
98
  """Groq LLMs through REST APIs (OpenAI compatible).
100
99
 
101
100
  See https://platform.openai.com/docs/api-reference/chat
@@ -108,10 +107,6 @@ class Groq(rest.REST):
108
107
  'The name of the model to use.',
109
108
  ]
110
109
 
111
- multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
112
- False
113
- )
114
-
115
110
  api_key: Annotated[
116
111
  str | None,
117
112
  (
@@ -122,25 +117,19 @@ class Groq(rest.REST):
122
117
 
123
118
  api_endpoint: str = 'https://api.groq.com/openai/v1/chat/completions'
124
119
 
125
- def _on_bound(self):
126
- super()._on_bound()
127
- self._api_key = None
128
-
129
- def _initialize(self):
120
+ @property
121
+ def headers(self) -> dict[str, Any]:
130
122
  api_key = self.api_key or os.environ.get('GROQ_API_KEY', None)
131
123
  if not api_key:
132
124
  raise ValueError(
133
125
  'Please specify `api_key` during `__init__` or set environment '
134
126
  'variable `GROQ_API_KEY` with your Groq API key.'
135
127
  )
136
- self._api_key = api_key
137
-
138
- @property
139
- def headers(self) -> dict[str, Any]:
140
- return {
141
- 'Authorization': f'Bearer {self._api_key}',
142
- 'Content-Type': 'application/json',
143
- }
128
+ headers = super().headers
129
+ headers.update({
130
+ 'Authorization': f'Bearer {api_key}',
131
+ })
132
+ return headers
144
133
 
145
134
  @property
146
135
  def model_id(self) -> str:
@@ -170,90 +159,14 @@ class Groq(rest.REST):
170
159
  + cost_per_1k_output_tokens * num_output_tokens
171
160
  ) / 1000
172
161
 
173
- def request(
174
- self,
175
- prompt: lf.Message,
176
- sampling_options: lf.LMSamplingOptions
177
- ) -> dict[str, Any]:
178
- """Returns the JSON input for a message."""
179
- request = dict()
180
- request.update(self._request_args(sampling_options))
181
- request.update(
182
- dict(
183
- messages=[
184
- dict(role='user', content=self._content_from_message(prompt))
185
- ]
186
- )
187
- )
188
- return request
189
-
190
162
  def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
191
163
  """Returns a dict as request arguments."""
192
164
  # `logprobs` and `top_logprobs` flags are not supported on Groq yet.
193
- args = dict(
194
- model=self.model,
195
- n=options.n,
196
- stream=False,
197
- )
198
-
199
- if options.temperature is not None:
200
- args['temperature'] = options.temperature
201
- if options.max_tokens is not None:
202
- args['max_tokens'] = options.max_tokens
203
- if options.top_p is not None:
204
- args['top_p'] = options.top_p
205
- if options.stop:
206
- args['stop'] = options.stop
165
+ args = super()._request_args(options)
166
+ args.pop('logprobs', None)
167
+ args.pop('top_logprobs', None)
207
168
  return args
208
169
 
209
- def _content_from_message(self, prompt: lf.Message) -> list[dict[str, Any]]:
210
- """Converts an message to Groq's content protocol (list of dicts)."""
211
- # Refer: https://platform.openai.com/docs/api-reference/chat/create
212
- content = []
213
- for chunk in prompt.chunk():
214
- if isinstance(chunk, str):
215
- item = dict(type='text', text=chunk)
216
- elif (
217
- self.multimodal
218
- and isinstance(chunk, lf_modalities.Image)
219
- and chunk.uri
220
- ):
221
- # NOTE(daiyip): Groq only support image URL.
222
- item = dict(type='image_url', image_url=chunk.uri)
223
- else:
224
- raise ValueError(f'Unsupported modality object: {chunk!r}.')
225
- content.append(item)
226
- return content
227
-
228
- def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
229
- samples = [
230
- lf.LMSample(self._message_from_choice(choice), score=0.0)
231
- for choice in json['choices']
232
- ]
233
- usage = json['usage']
234
- return lf.LMSamplingResult(
235
- samples,
236
- usage=lf.LMSamplingUsage(
237
- prompt_tokens=usage['prompt_tokens'],
238
- completion_tokens=usage['completion_tokens'],
239
- total_tokens=usage['total_tokens'],
240
- estimated_cost=self.estimate_cost(
241
- num_input_tokens=usage['prompt_tokens'],
242
- num_output_tokens=usage['completion_tokens'],
243
- ),
244
- ),
245
- )
246
-
247
- def _message_from_choice(self, choice: dict[str, Any]) -> lf.Message:
248
- """Converts Groq's content protocol to message."""
249
- # Refer: https://platform.openai.com/docs/api-reference/chat/create
250
- content = choice['message']['content']
251
- if isinstance(content, str):
252
- return lf.AIMessage(content)
253
- return lf.AIMessage.from_chunks(
254
- [x['text'] for x in content if x['type'] == 'text']
255
- )
256
-
257
170
 
258
171
  class GroqLlama3_2_3B(Groq): # pylint: disable=invalid-name
259
172
  """Llama3.2-3B with 8K context window.
@@ -11,89 +11,10 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Tests for Groq models."""
15
-
16
14
  import os
17
- from typing import Any
18
15
  import unittest
19
- from unittest import mock
20
- from langfun.core import modalities as lf_modalities
16
+ import langfun.core as lf
21
17
  from langfun.core.llms import groq
22
- import pyglove as pg
23
- import requests
24
-
25
-
26
- def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
27
- del url, kwargs
28
-
29
- response = requests.Response()
30
- response.status_code = 200
31
- response._content = pg.to_json_str({
32
- 'choices': [{
33
- 'message': {
34
- 'content': [{
35
- 'type': 'text',
36
- 'text': (
37
- f'hello with temperature={json.get("temperature")}, '
38
- f'top_p={json.get("top_p")}, '
39
- f'max_tokens={json.get("max_tokens")}, '
40
- f'stop={json.get("stop")}.'
41
- ),
42
- }],
43
- }
44
- }],
45
- 'usage': {
46
- 'prompt_tokens': 2,
47
- 'completion_tokens': 1,
48
- 'total_tokens': 3,
49
- },
50
- }).encode()
51
- return response
52
-
53
-
54
- def mock_mm_requests_post(url: str, json: dict[str, Any], **kwargs):
55
- del url, kwargs
56
- v = json['messages'][0]['content'][0]
57
- image = lf_modalities.Image.from_uri(v['image_url'])
58
-
59
- response = requests.Response()
60
- response.status_code = 200
61
- response._content = pg.to_json_str({
62
- 'choices': [
63
- {
64
- 'message': {
65
- 'content': [{
66
- 'type': 'text',
67
- 'text': image.uri,
68
- }],
69
- }
70
- }
71
- ],
72
- 'usage': {
73
- 'prompt_tokens': 2,
74
- 'completion_tokens': 1,
75
- 'total_tokens': 3,
76
- },
77
- }).encode()
78
- return response
79
-
80
-
81
- def mock_requests_post_error(status_code, error_type, error_message):
82
- def _mock_requests(url: str, json: dict[str, Any], **kwargs):
83
- del url, json, kwargs
84
- response = requests.Response()
85
- response.status_code = status_code
86
- response._content = pg.to_json_str(
87
- {
88
- 'error': {
89
- 'type': error_type,
90
- 'message': error_message,
91
- }
92
- }
93
- ).encode()
94
- return response
95
-
96
- return _mock_requests
97
18
 
98
19
 
99
20
  class AuthropicTest(unittest.TestCase):
@@ -101,69 +22,42 @@ class AuthropicTest(unittest.TestCase):
101
22
  def test_basics(self):
102
23
  self.assertEqual(groq.GroqMistral_8x7B().model_id, 'mixtral-8x7b-32768')
103
24
  self.assertEqual(groq.GroqMistral_8x7B().max_concurrency, 16)
25
+ self.assertEqual(groq.GroqMistral_8x7B().estimate_cost(100, 100), 4.8e-5)
26
+
27
+ def test_request_args(self):
28
+ args = groq.GroqMistral_8x7B()._request_args(
29
+ lf.LMSamplingOptions(
30
+ temperature=1.0, stop=['\n'], n=1, random_seed=123,
31
+ logprobs=True, top_logprobs=True
32
+ )
33
+ )
34
+ self.assertNotIn('logprobs', args)
35
+ self.assertNotIn('top_logprobs', args)
104
36
 
105
37
  def test_api_key(self):
106
38
  lm = groq.GroqMistral_8x7B()
107
39
  with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
108
- lm('hi')
109
-
110
- with mock.patch('requests.Session.post') as mock_request:
111
- mock_request.side_effect = mock_requests_post
112
-
113
- lm = groq.GroqMistral_8x7B(api_key='fake key')
114
- self.assertRegex(lm('hi').text, 'hello.*')
115
-
116
- os.environ['GROQ_API_KEY'] = 'abc'
117
- lm = groq.GroqMistral_8x7B()
118
- self.assertRegex(lm('hi').text, 'hello.*')
119
- del os.environ['GROQ_API_KEY']
120
-
121
- def test_call(self):
122
- with mock.patch('requests.Session.post') as mock_request:
123
- mock_request.side_effect = mock_requests_post
124
- lm = groq.GroqLlama3_70B(api_key='fake_key')
125
- response = lm(
126
- 'hello',
127
- temperature=0.0,
128
- max_tokens=1024,
129
- top_k=0.1,
130
- top_p=0.2,
131
- stop=['\n'],
132
- )
133
- self.assertEqual(
134
- response.text,
135
- (
136
- 'hello with temperature=0.0, top_p=0.2, '
137
- "max_tokens=1024, stop=['\\n']."
138
- ),
139
- )
140
- self.assertIsNotNone(response.usage)
141
- self.assertIsNotNone(response.usage.prompt_tokens, 2)
142
- self.assertIsNotNone(response.usage.completion_tokens, 1)
143
- self.assertIsNotNone(response.usage.total_tokens, 3)
40
+ _ = lm.headers
144
41
 
145
- def test_mm_call(self):
146
- with mock.patch('requests.Session.post') as mock_mm_request:
147
- mock_mm_request.side_effect = mock_mm_requests_post
148
- lm = groq.GroqLlama3_70B(multimodal=True, api_key='fake_key')
149
- response = lm(lf_modalities.Image.from_uri('https://fake/image.jpg'))
150
- self.assertEqual(response.text, 'https://fake/image.jpg')
42
+ lm = groq.GroqMistral_8x7B(api_key='fake key')
43
+ self.assertEqual(
44
+ lm.headers,
45
+ {
46
+ 'Content-Type': 'application/json',
47
+ 'Authorization': 'Bearer fake key',
48
+ }
49
+ )
151
50
 
152
- def test_call_errors(self):
153
- for status_code, error_type, error_message in [
154
- (429, 'rate_limit', 'Rate limit exceeded.'),
155
- (503, 'service_unavailable', 'Service unavailable.'),
156
- (500, 'bad_request', 'Bad request.'),
157
- ]:
158
- with mock.patch('requests.Session.post') as mock_mm_request:
159
- mock_mm_request.side_effect = mock_requests_post_error(
160
- status_code, error_type, error_message
161
- )
162
- lm = groq.GroqLlama3_70B(api_key='fake_key')
163
- with self.assertRaisesRegex(
164
- Exception, f'{status_code}:.*{error_type}'
165
- ):
166
- lm('hello', max_attempts=1)
51
+ os.environ['GROQ_API_KEY'] = 'abc'
52
+ lm = groq.GroqMistral_8x7B()
53
+ self.assertEqual(
54
+ lm.headers,
55
+ {
56
+ 'Content-Type': 'application/json',
57
+ 'Authorization': 'Bearer abc',
58
+ }
59
+ )
60
+ del os.environ['GROQ_API_KEY']
167
61
 
168
62
 
169
63
  if __name__ == '__main__':
@@ -13,72 +13,35 @@
13
13
  # limitations under the License.
14
14
  """Language models from llama.cpp."""
15
15
 
16
- from typing import Any
17
-
18
- import langfun.core as lf
19
- from langfun.core.llms import rest
16
+ from typing import Annotated
17
+ from langfun.core.llms import openai_compatible
20
18
  import pyglove as pg
21
19
 
22
20
 
23
- class LlamaCppRemote(rest.REST):
21
+ @pg.use_init_args(['url', 'model'])
22
+ @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
23
+ class LlamaCppRemote(openai_compatible.OpenAICompatible):
24
24
  """The remote LLaMA C++ model.
25
25
 
26
26
  The Remote LLaMA C++ models can be launched via
27
27
  https://github.com/ggerganov/llama.cpp/tree/master/examples/server
28
28
  """
29
+ url: Annotated[
30
+ str,
31
+ 'The URL of the LLaMA C++ server.',
32
+ ]
33
+
34
+ model: Annotated[
35
+ str,
36
+ 'The name of the model to use.',
37
+ ] = ''
29
38
 
30
- @pg.explicit_method_override
31
- def __init__(self, url: str, model: str | None = None, **kwargs):
32
- super().__init__(api_endpoint=f'{url}/completion', model=model, **kwargs)
39
+ @property
40
+ def api_endpoint(self) -> str:
41
+ return self.url + '/completion'
33
42
 
34
43
  @property
35
44
  def model_id(self) -> str:
36
45
  """Returns a string to identify the model."""
37
46
  return f'LLaMAC++({self.model or ""})'
38
47
 
39
- def request(
40
- self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
41
- ) -> dict[str, Any]:
42
- """Returns the JSON input for a message."""
43
- request = dict()
44
- request.update(self._request_args(sampling_options))
45
- # NOTE(daiyip): multi-modal is current not supported.
46
- request['prompt'] = prompt.text
47
- return request
48
-
49
- def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
50
- """Returns a dict as request arguments."""
51
- args = dict(
52
- n_predict=options.max_tokens or 1024,
53
- top_k=options.top_k or 50,
54
- top_p=options.top_p or 0.95,
55
- )
56
- if options.temperature is not None:
57
- args['temperature'] = options.temperature
58
- return args
59
-
60
- def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
61
- return lf.LMSamplingResult(
62
- [lf.LMSample(item['content'], score=0.0) for item in json['items']]
63
- )
64
-
65
- def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
66
- request = self.request(prompt, self.sampling_options)
67
-
68
- def _sample_one_example(request):
69
- response = self._session.post(
70
- self.api_endpoint,
71
- json=request,
72
- timeout=self.timeout,
73
- )
74
- if response.status_code == 200:
75
- return response.json()
76
- else:
77
- error_cls = self._error_cls_from_status(response.status_code)
78
- raise error_cls(f'{response.status_code}: {response.content}')
79
-
80
- items = self._parallel_execute_with_currency_control(
81
- _sample_one_example,
82
- [request] * (self.sampling_options.n or 1),
83
- )
84
- return self.result(dict(items=items))
@@ -11,48 +11,16 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Tests for llama cpp models."""
15
-
16
- import typing
17
14
  import unittest
18
- from unittest import mock
19
-
20
15
  from langfun.core.llms import llama_cpp
21
16
 
22
17
 
23
- def mock_requests_post(url: str, json: typing.Dict[str, typing.Any], **kwargs):
24
- del kwargs
25
-
26
- class TEMP:
27
- @property
28
- def status_code(self):
29
- return 200
30
-
31
- def json(self):
32
- return {"content": json["prompt"] + "\n" + url}
33
-
34
- return TEMP()
35
-
36
-
37
18
  class LlamaCppRemoteTest(unittest.TestCase):
38
19
  """Tests for the LlamaCppRemote model."""
39
20
 
40
- def test_call_completion(self):
41
- with mock.patch("requests.Session.post") as mock_request:
42
- mock_request.side_effect = mock_requests_post
43
- lm = llama_cpp.LlamaCppRemote("http://127.0.0.1:8080")
44
- [result] = lm.sample(["hello"], n=2)
45
- self.assertEqual(
46
- len(result.samples),
47
- 2
48
- )
49
- self.assertEqual(
50
- str(result.samples[0].response),
51
- "hello\nhttp://127.0.0.1:8080/completion",
52
- )
53
-
54
- def test_model_id(self):
21
+ def test_basics(self):
55
22
  lm = llama_cpp.LlamaCppRemote("http://127.0.0.1:8080")
23
+ self.assertEqual(lm.api_endpoint, "http://127.0.0.1:8080/completion")
56
24
  self.assertEqual(lm.model_id, "LLaMAC++()")
57
25
  lm = llama_cpp.LlamaCppRemote("xxx", model="x")
58
26
  self.assertEqual(lm.model_id, "LLaMAC++(x)")