langfun 0.0.2.dev20240330__py3-none-any.whl → 0.0.2.dev20240511__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (59) hide show
  1. langfun/__init__.py +7 -0
  2. langfun/core/__init__.py +1 -0
  3. langfun/core/coding/python/correction.py +0 -7
  4. langfun/core/component.py +6 -0
  5. langfun/core/component_test.py +1 -0
  6. langfun/core/eval/__init__.py +15 -0
  7. langfun/core/eval/base.py +665 -95
  8. langfun/core/eval/base_test.py +224 -53
  9. langfun/core/eval/matching.py +48 -30
  10. langfun/core/eval/matching_test.py +25 -3
  11. langfun/core/eval/patching.py +130 -0
  12. langfun/core/eval/patching_test.py +170 -0
  13. langfun/core/eval/scoring.py +19 -10
  14. langfun/core/eval/scoring_test.py +21 -3
  15. langfun/core/langfunc.py +1 -22
  16. langfun/core/langfunc_test.py +10 -4
  17. langfun/core/language_model.py +130 -24
  18. langfun/core/language_model_test.py +249 -26
  19. langfun/core/llms/__init__.py +27 -2
  20. langfun/core/llms/anthropic.py +263 -0
  21. langfun/core/llms/anthropic_test.py +167 -0
  22. langfun/core/llms/cache/in_memory_test.py +37 -28
  23. langfun/core/llms/fake.py +34 -25
  24. langfun/core/llms/fake_test.py +122 -11
  25. langfun/core/llms/google_genai.py +8 -0
  26. langfun/core/llms/google_genai_test.py +8 -3
  27. langfun/core/llms/groq.py +260 -0
  28. langfun/core/llms/groq_test.py +170 -0
  29. langfun/core/llms/llama_cpp.py +3 -1
  30. langfun/core/llms/openai.py +100 -81
  31. langfun/core/llms/openai_test.py +287 -60
  32. langfun/core/llms/vertexai.py +291 -0
  33. langfun/core/llms/vertexai_test.py +233 -0
  34. langfun/core/modalities/image.py +1 -3
  35. langfun/core/modalities/mime.py +6 -0
  36. langfun/core/modalities/video.py +6 -5
  37. langfun/core/structured/__init__.py +5 -0
  38. langfun/core/structured/completion_test.py +2 -2
  39. langfun/core/structured/function_generation.py +245 -0
  40. langfun/core/structured/function_generation_test.py +329 -0
  41. langfun/core/structured/mapping.py +61 -3
  42. langfun/core/structured/mapping_test.py +17 -0
  43. langfun/core/structured/parsing_test.py +18 -13
  44. langfun/core/structured/prompting.py +61 -12
  45. langfun/core/structured/prompting_test.py +122 -12
  46. langfun/core/structured/schema.py +38 -6
  47. langfun/core/structured/schema_generation_test.py +2 -2
  48. langfun/core/structured/schema_test.py +36 -7
  49. langfun/core/structured/scoring.py +4 -1
  50. langfun/core/structured/scoring_test.py +6 -0
  51. langfun/core/template.py +147 -11
  52. langfun/core/template_test.py +75 -0
  53. langfun/core/templates/selfplay_test.py +6 -2
  54. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/METADATA +3 -2
  55. langfun-0.0.2.dev20240511.dist-info/RECORD +112 -0
  56. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  57. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/LICENSE +0 -0
  58. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/WHEEL +0 -0
  59. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,263 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Language models from Anthropic."""
15
+
16
+ import base64
17
+ import functools
18
+ import os
19
+ from typing import Annotated, Any
20
+
21
+ import langfun.core as lf
22
+ from langfun.core import modalities as lf_modalities
23
+ import pyglove as pg
24
+ import requests
25
+
26
+
27
+ SUPPORTED_MODELS_AND_SETTINGS = {
28
+ # See https://docs.anthropic.com/claude/docs/models-overview
29
+ # Rate limits from https://docs.anthropic.com/claude/reference/rate-limits
30
+ # RPM/TPM for Claude-2.1, Claude-2.0, and Claude-Instant-1.2 estimated
31
+ # as RPM/TPM of the largest-available model (Claude-3-Opus).
32
+ 'claude-3-opus-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
33
+ 'claude-3-sonnet-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
34
+ 'claude-3-haiku-20240307': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
35
+ 'claude-2.1': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
36
+ 'claude-2.0': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
37
+ 'claude-instant-1.2': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
38
+ }
39
+
40
+
41
+ class AnthropicError(Exception): # pylint: disable=g-bad-exception-name
42
+ """Base class for Anthropic errors."""
43
+
44
+
45
+ class RateLimitError(AnthropicError):
46
+ """Error for rate limit reached."""
47
+
48
+
49
+ class OverloadedError(AnthropicError):
50
+ """Anthropic's server is temporarily overloaded."""
51
+
52
+
53
+ _ANTHROPIC_MESSAGE_API_ENDPOINT = 'https://api.anthropic.com/v1/messages'
54
+ _ANTHROPIC_API_VERSION = '2023-06-01'
55
+
56
+
57
+ @lf.use_init_args(['model'])
58
+ class Anthropic(lf.LanguageModel):
59
+ """Anthropic LLMs (Claude) through REST APIs.
60
+
61
+ See https://docs.anthropic.com/claude/reference/messages_post
62
+ """
63
+
64
+ model: pg.typing.Annotated[
65
+ pg.typing.Enum(
66
+ pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
67
+ ),
68
+ 'The name of the model to use.',
69
+ ]
70
+
71
+ multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
72
+ True
73
+ )
74
+
75
+ api_key: Annotated[
76
+ str | None,
77
+ (
78
+ 'API key. If None, the key will be read from environment variable '
79
+ "'ANTHROPIC_API_KEY'."
80
+ ),
81
+ ] = None
82
+
83
+ def _on_bound(self):
84
+ super()._on_bound()
85
+ self._api_key = None
86
+ self.__dict__.pop('_api_initialized', None)
87
+ self.__dict__.pop('_session', None)
88
+
89
+ @functools.cached_property
90
+ def _api_initialized(self):
91
+ api_key = self.api_key or os.environ.get('ANTHROPIC_API_KEY', None)
92
+ if not api_key:
93
+ raise ValueError(
94
+ 'Please specify `api_key` during `__init__` or set environment '
95
+ 'variable `ANTHROPIC_API_KEY` with your Anthropic API key.'
96
+ )
97
+ self._api_key = api_key
98
+ return True
99
+
100
+ @functools.cached_property
101
+ def _session(self) -> requests.Session:
102
+ assert self._api_initialized
103
+ s = requests.Session()
104
+ s.headers.update({
105
+ 'x-api-key': self._api_key,
106
+ 'anthropic-version': _ANTHROPIC_API_VERSION,
107
+ 'content-type': 'application/json',
108
+ })
109
+ return s
110
+
111
+ @property
112
+ def model_id(self) -> str:
113
+ """Returns a string to identify the model."""
114
+ return self.model
115
+
116
+ @property
117
+ def max_concurrency(self) -> int:
118
+ rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
119
+ tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
120
+ return self.rate_to_max_concurrency(
121
+ requests_per_min=rpm, tokens_per_min=tpm
122
+ )
123
+
124
+ def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
125
+ assert self._api_initialized
126
+ return self._parallel_execute_with_currency_control(
127
+ self._sample_single, prompts, retry_on_errors=(RateLimitError)
128
+ )
129
+
130
+ def _get_request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
131
+ """Returns a dict as request arguments."""
132
+ # Authropic requires `max_tokens` to be specified.
133
+ max_tokens = (
134
+ options.max_tokens
135
+ or SUPPORTED_MODELS_AND_SETTINGS[self.model].max_tokens
136
+ )
137
+ args = dict(
138
+ model=self.model,
139
+ max_tokens=max_tokens,
140
+ stream=False,
141
+ )
142
+ if options.stop:
143
+ args['stop_sequences'] = options.stop
144
+ if options.temperature is not None:
145
+ args['temperature'] = options.temperature
146
+ if options.top_k is not None:
147
+ args['top_k'] = options.top_k
148
+ if options.top_p is not None:
149
+ args['top_p'] = options.top_p
150
+ return args
151
+
152
+ def _content_from_message(self, prompt: lf.Message) -> list[dict[str, Any]]:
153
+ """Converts an message to Anthropic's content protocol (list of dicts)."""
154
+ # Refer: https://docs.anthropic.com/claude/reference/messages-examples
155
+ if self.multimodal:
156
+ content = []
157
+ for chunk in prompt.chunk():
158
+ if isinstance(chunk, str):
159
+ item = dict(type='text', text=chunk)
160
+ elif isinstance(chunk, lf_modalities.Image):
161
+ # NOTE(daiyip): Anthropic only support image content instead of URL.
162
+ item = dict(
163
+ type='image',
164
+ source=dict(
165
+ type='base64',
166
+ media_type=chunk.mime_type,
167
+ data=base64.b64encode(chunk.to_bytes()).decode(),
168
+ ),
169
+ )
170
+ else:
171
+ raise ValueError(f'Unsupported modality object: {chunk!r}.')
172
+ content.append(item)
173
+ return content
174
+ else:
175
+ return [dict(type='text', text=prompt.text)]
176
+
177
+ def _message_from_content(self, content: list[dict[str, Any]]) -> lf.Message:
178
+ """Converts Anthropic's content protocol to message."""
179
+ # Refer: https://docs.anthropic.com/claude/reference/messages-examples
180
+ return lf.AIMessage.from_chunks(
181
+ [x['text'] for x in content if x['type'] == 'text']
182
+ )
183
+
184
+ def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
185
+ """Parses Anthropic's response."""
186
+ # NOTE(daiyip): Refer https://docs.anthropic.com/claude/reference/errors
187
+ if response.status_code == 200:
188
+ output = response.json()
189
+ message = self._message_from_content(output['content'])
190
+ input_tokens = output['usage']['input_tokens']
191
+ output_tokens = output['usage']['output_tokens']
192
+ return lf.LMSamplingResult(
193
+ [lf.LMSample(message)],
194
+ usage=lf.LMSamplingUsage(
195
+ prompt_tokens=input_tokens,
196
+ completion_tokens=output_tokens,
197
+ total_tokens=input_tokens + output_tokens,
198
+ ),
199
+ )
200
+ else:
201
+ if response.status_code == 429:
202
+ error_cls = RateLimitError
203
+ elif response.status_code in (502, 529):
204
+ error_cls = OverloadedError
205
+ else:
206
+ error_cls = AnthropicError
207
+ raise error_cls(f'{response.status_code}: {response.content}')
208
+
209
+ def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
210
+ request = dict()
211
+ request.update(self._get_request_args(self.sampling_options))
212
+ request.update(
213
+ dict(
214
+ messages=[
215
+ dict(role='user', content=self._content_from_message(prompt))
216
+ ]
217
+ )
218
+ )
219
+ try:
220
+ response = self._session.post(
221
+ _ANTHROPIC_MESSAGE_API_ENDPOINT, json=request, timeout=self.timeout,
222
+ )
223
+ return self._parse_response(response)
224
+ except ConnectionError as e:
225
+ raise OverloadedError(str(e)) from e
226
+
227
+
228
+ class Claude3(Anthropic):
229
+ """Base class for Claude 3 models. 200K input tokens and 4K output tokens."""
230
+ multimodal = True
231
+
232
+
233
+ class Claude3Opus(Claude3):
234
+ """Anthropic's most powerful model."""
235
+
236
+ model = 'claude-3-opus-20240229'
237
+
238
+
239
+ class Claude3Sonnet(Claude3):
240
+ """A balance between between Opus and Haiku."""
241
+
242
+ model = 'claude-3-sonnet-20240229'
243
+
244
+
245
+ class Claude3Haiku(Claude3):
246
+ """Anthropic's most compact model."""
247
+
248
+ model = 'claude-3-haiku-20240307'
249
+
250
+
251
+ class Claude2(Anthropic):
252
+ """Predecessor to Claude 3 with 100K context window.."""
253
+ model = 'claude-2.0'
254
+
255
+
256
+ class Claude21(Anthropic):
257
+ """Updated Claude 2 model with improved accuracy and 200K context window."""
258
+ model = 'claude-2.1'
259
+
260
+
261
+ class ClaudeInstant(Anthropic):
262
+ """Cheapest small and fast model, 100K context window."""
263
+ model = 'claude-instant-1.2'
@@ -0,0 +1,167 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for Anthropic models."""
15
+
16
+ import base64
17
+ import os
18
+ from typing import Any
19
+ import unittest
20
+ from unittest import mock
21
+ from langfun.core import modalities as lf_modalities
22
+ from langfun.core.llms import anthropic
23
+ import pyglove as pg
24
+ import requests
25
+
26
+
27
+ def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
28
+ del url, kwargs
29
+
30
+ response = requests.Response()
31
+ response.status_code = 200
32
+ response._content = pg.to_json_str({
33
+ 'content': [{
34
+ 'type': 'text',
35
+ 'text': (
36
+ f'hello with temperature={json.get("temperature")}, '
37
+ f'top_k={json.get("top_k")}, '
38
+ f'top_p={json.get("top_p")}, '
39
+ f'max_tokens={json.get("max_tokens")}, '
40
+ f'stop={json.get("stop_sequences")}.'
41
+ ),
42
+ }],
43
+ 'usage': {
44
+ 'input_tokens': 2,
45
+ 'output_tokens': 1,
46
+ },
47
+ }).encode()
48
+ return response
49
+
50
+
51
+ image_content = (
52
+ b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
53
+ b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
54
+ b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
55
+ b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
56
+ b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
57
+ b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
58
+ b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
59
+ b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
60
+ )
61
+
62
+
63
+ def mock_mm_requests_post(url: str, json: dict[str, Any], **kwargs):
64
+ del url, kwargs
65
+ v = json['messages'][0]['content'][0]
66
+ image = lf_modalities.Image.from_bytes(base64.b64decode(v['source']['data']))
67
+
68
+ response = requests.Response()
69
+ response.status_code = 200
70
+ response._content = pg.to_json_str({
71
+ 'content': [{
72
+ 'type': 'text',
73
+ 'text': f'{v["type"]}: {image.mime_type}',
74
+ }],
75
+ 'usage': {
76
+ 'input_tokens': 2,
77
+ 'output_tokens': 1,
78
+ },
79
+ }).encode()
80
+ return response
81
+
82
+
83
+ def mock_requests_post_error(status_code, error_type, error_message):
84
+ def _mock_requests(url: str, json: dict[str, Any], **kwargs):
85
+ del url, json, kwargs
86
+ response = requests.Response()
87
+ response.status_code = status_code
88
+ response._content = pg.to_json_str(
89
+ {
90
+ 'error': {
91
+ 'type': error_type,
92
+ 'message': error_message,
93
+ }
94
+ }
95
+ ).encode()
96
+ return response
97
+
98
+ return _mock_requests
99
+
100
+
101
+ class AnthropicTest(unittest.TestCase):
102
+
103
+ def test_basics(self):
104
+ self.assertEqual(
105
+ anthropic.Claude3Haiku().model_id, 'claude-3-haiku-20240307'
106
+ )
107
+ self.assertGreater(anthropic.Claude3Haiku().max_concurrency, 0)
108
+
109
+ def test_api_key(self):
110
+ lm = anthropic.Claude3Haiku()
111
+ with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
112
+ lm('hi')
113
+
114
+ with mock.patch('requests.Session.post') as mock_request:
115
+ mock_request.side_effect = mock_requests_post
116
+
117
+ lm = anthropic.Claude3Haiku(api_key='fake key')
118
+ self.assertRegex(lm('hi').text, 'hello.*')
119
+
120
+ os.environ['ANTHROPIC_API_KEY'] = 'abc'
121
+ lm = anthropic.Claude3Haiku()
122
+ self.assertRegex(lm('hi').text, 'hello.*')
123
+ del os.environ['ANTHROPIC_API_KEY']
124
+
125
+ def test_call(self):
126
+ with mock.patch('requests.Session.post') as mock_request:
127
+ mock_request.side_effect = mock_requests_post
128
+ lm = anthropic.Claude3Haiku(api_key='fake_key')
129
+ response = lm('hello', temperature=0.0, top_k=0.1, top_p=0.2, stop=['\n'])
130
+ self.assertEqual(
131
+ response.text,
132
+ (
133
+ 'hello with temperature=0.0, top_k=0.1, top_p=0.2, '
134
+ "max_tokens=4096, stop=['\\n']."
135
+ ),
136
+ )
137
+ self.assertIsNotNone(response.usage)
138
+ self.assertIsNotNone(response.usage.prompt_tokens, 2)
139
+ self.assertIsNotNone(response.usage.completion_tokens, 1)
140
+ self.assertIsNotNone(response.usage.total_tokens, 3)
141
+
142
+ def test_mm_call(self):
143
+ with mock.patch('requests.Session.post') as mock_mm_request:
144
+ mock_mm_request.side_effect = mock_mm_requests_post
145
+ lm = anthropic.Claude3Haiku(api_key='fake_key')
146
+ response = lm(lf_modalities.Image.from_bytes(image_content), lm=lm)
147
+ self.assertEqual(response.text, 'image: image/png')
148
+
149
+ def test_call_errors(self):
150
+ for status_code, error_type, error_message in [
151
+ (429, 'rate_limit', 'Rate limit exceeded.'),
152
+ (529, 'service_unavailable', 'Service unavailable.'),
153
+ (500, 'bad_request', 'Bad request.'),
154
+ ]:
155
+ with mock.patch('requests.Session.post') as mock_mm_request:
156
+ mock_mm_request.side_effect = mock_requests_post_error(
157
+ status_code, error_type, error_message
158
+ )
159
+ lm = anthropic.Claude3Haiku(api_key='fake_key')
160
+ with self.assertRaisesRegex(
161
+ Exception, f'.*{status_code}: .*{error_message}'
162
+ ):
163
+ lm('hello', lm=lm, max_attempts=1)
164
+
165
+
166
+ if __name__ == '__main__':
167
+ unittest.main()
@@ -44,28 +44,37 @@ class InMemoryLMCacheTest(unittest.TestCase):
44
44
  self.assertEqual(
45
45
  list(cache.keys()),
46
46
  [
47
- ('a', (0.0, 1024, 1, 40, None, None), 0),
48
- ('a', (0.0, 1024, 1, 40, None, None), 1),
49
- ('b', (0.0, 1024, 1, 40, None, None), 0),
50
- ('c', (0.0, 1024, 1, 40, None, None), 0),
47
+ ('a', (None, None, 1, 40, None, None), 0),
48
+ ('a', (None, None, 1, 40, None, None), 1),
49
+ ('b', (None, None, 1, 40, None, None), 0),
50
+ ('c', (None, None, 1, 40, None, None), 0),
51
51
  ],
52
52
  )
53
53
  self.assertEqual(
54
54
  list(cache.keys('StaticSequence')),
55
55
  [
56
- ('a', (0.0, 1024, 1, 40, None, None), 0),
57
- ('a', (0.0, 1024, 1, 40, None, None), 1),
58
- ('b', (0.0, 1024, 1, 40, None, None), 0),
59
- ('c', (0.0, 1024, 1, 40, None, None), 0),
56
+ ('a', (None, None, 1, 40, None, None), 0),
57
+ ('a', (None, None, 1, 40, None, None), 1),
58
+ ('b', (None, None, 1, 40, None, None), 0),
59
+ ('c', (None, None, 1, 40, None, None), 0),
60
60
  ],
61
61
  )
62
62
 
63
63
  def cache_entry(response_text, cache_seed=0):
64
64
  return base.LMCacheEntry(
65
- lf.LMSamplingResult([
66
- lf.LMSample(
67
- lf.AIMessage(response_text, cache_seed=cache_seed), score=1.0)
68
- ])
65
+ lf.LMSamplingResult(
66
+ [
67
+ lf.LMSample(
68
+ lf.AIMessage(response_text, cache_seed=cache_seed),
69
+ score=1.0
70
+ )
71
+ ],
72
+ usage=lf.LMSamplingUsage(
73
+ 1,
74
+ len(response_text),
75
+ len(response_text) + 1,
76
+ )
77
+ )
69
78
  )
70
79
 
71
80
  self.assertEqual(
@@ -90,19 +99,19 @@ class InMemoryLMCacheTest(unittest.TestCase):
90
99
  list(cache.items()),
91
100
  [
92
101
  (
93
- ('a', (0.0, 1024, 1, 40, None, None), 0),
102
+ ('a', (None, None, 1, 40, None, None), 0),
94
103
  cache_entry('1'),
95
104
  ),
96
105
  (
97
- ('a', (0.0, 1024, 1, 40, None, None), 1),
106
+ ('a', (None, None, 1, 40, None, None), 1),
98
107
  cache_entry('2', 1),
99
108
  ),
100
109
  (
101
- ('b', (0.0, 1024, 1, 40, None, None), 0),
110
+ ('b', (None, None, 1, 40, None, None), 0),
102
111
  cache_entry('3'),
103
112
  ),
104
113
  (
105
- ('c', (0.0, 1024, 1, 40, None, None), 0),
114
+ ('c', (None, None, 1, 40, None, None), 0),
106
115
  cache_entry('4'),
107
116
  ),
108
117
  ],
@@ -111,19 +120,19 @@ class InMemoryLMCacheTest(unittest.TestCase):
111
120
  list(cache.items('StaticSequence')),
112
121
  [
113
122
  (
114
- ('a', (0.0, 1024, 1, 40, None, None), 0),
123
+ ('a', (None, None, 1, 40, None, None), 0),
115
124
  cache_entry('1'),
116
125
  ),
117
126
  (
118
- ('a', (0.0, 1024, 1, 40, None, None), 1),
127
+ ('a', (None, None, 1, 40, None, None), 1),
119
128
  cache_entry('2', 1),
120
129
  ),
121
130
  (
122
- ('b', (0.0, 1024, 1, 40, None, None), 0),
131
+ ('b', (None, None, 1, 40, None, None), 0),
123
132
  cache_entry('3'),
124
133
  ),
125
134
  (
126
- ('c', (0.0, 1024, 1, 40, None, None), 0),
135
+ ('c', (None, None, 1, 40, None, None), 0),
127
136
  cache_entry('4'),
128
137
  ),
129
138
  ],
@@ -161,15 +170,15 @@ class InMemoryLMCacheTest(unittest.TestCase):
161
170
  self.assertEqual(
162
171
  list(cache.keys()),
163
172
  [
164
- ('a', (0.0, 1024, 1, 40, None, None), 0),
165
- ('a', (1.0, 1024, 1, 40, None, None), 0),
173
+ ('a', (None, None, 1, 40, None, None), 0),
174
+ ('a', (1.0, None, 1, 40, None, None), 0),
166
175
  ],
167
176
  )
168
177
 
169
178
  def test_different_model(self):
170
179
  cache = in_memory.InMemory()
171
- lm1 = fake.StaticSequence(['1', '2', '3'], cache=cache)
172
- lm2 = fake.Echo(cache=cache)
180
+ lm1 = fake.StaticSequence(['1', '2', '3'], cache=cache, temperature=0.0)
181
+ lm2 = fake.Echo(cache=cache, temperature=0.0)
173
182
 
174
183
  self.assertEqual(lm1('a'), '1')
175
184
  self.assertEqual(lm2('a'), 'a')
@@ -180,15 +189,15 @@ class InMemoryLMCacheTest(unittest.TestCase):
180
189
  self.assertEqual(
181
190
  list(cache.keys('StaticSequence')),
182
191
  [
183
- ('a', (0.0, 1024, 1, 40, None, None), 0),
184
- ('b', (0.0, 1024, 1, 40, None, None), 0),
192
+ ('a', (0.0, None, 1, 40, None, None), 0),
193
+ ('b', (0.0, None, 1, 40, None, None), 0),
185
194
  ],
186
195
  )
187
196
  self.assertEqual(
188
197
  list(cache.keys('Echo')),
189
198
  [
190
- ('a', (0.0, 1024, 1, 40, None, None), 0),
191
- ('b', (0.0, 1024, 1, 40, None, None), 0),
199
+ ('a', (0.0, None, 1, 40, None, None), 0),
200
+ ('b', (0.0, None, 1, 40, None, None), 0),
192
201
  ],
193
202
  )
194
203
  self.assertEqual(len(cache), 4)
langfun/core/llms/fake.py CHANGED
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  """Fake LMs for testing."""
15
15
 
16
+ import abc
16
17
  from typing import Annotated
17
18
  import langfun.core as lf
18
19
 
@@ -23,15 +24,32 @@ class Fake(lf.LanguageModel):
23
24
  def _score(self, prompt: lf.Message, completions: list[lf.Message]):
24
25
  return [lf.LMScoringResult(score=-i * 1.0) for i in range(len(completions))]
25
26
 
27
+ def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
28
+ results = []
29
+ for prompt in prompts:
30
+ response = self._response_from(prompt)
31
+ results.append(
32
+ lf.LMSamplingResult(
33
+ [lf.LMSample(response, 1.0)],
34
+ usage=lf.LMSamplingUsage(
35
+ prompt_tokens=len(prompt.text),
36
+ completion_tokens=len(response.text),
37
+ total_tokens=len(prompt.text) + len(response.text),
38
+ )
39
+ )
40
+ )
41
+ return results
42
+
43
+ @abc.abstractmethod
44
+ def _response_from(self, prompt: lf.Message) -> lf.Message:
45
+ """Returns the response for the given prompt."""
46
+
26
47
 
27
48
  class Echo(Fake):
28
49
  """A simple echo language model for testing."""
29
50
 
30
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
31
- return [
32
- lf.LMSamplingResult([lf.LMSample(prompt.text, 1.0)])
33
- for prompt in prompts
34
- ]
51
+ def _response_from(self, prompt: lf.Message) -> lf.Message:
52
+ return lf.AIMessage(prompt.text)
35
53
 
36
54
 
37
55
  @lf.use_init_args(['response'])
@@ -39,15 +57,12 @@ class StaticResponse(Fake):
39
57
  """Language model that always gives the same canned response."""
40
58
 
41
59
  response: Annotated[
42
- str,
60
+ str | lf.Message,
43
61
  'A canned response that will be returned regardless of the prompt.'
44
62
  ]
45
63
 
46
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
47
- return [
48
- lf.LMSamplingResult([lf.LMSample(self.response, 1.0)])
49
- for _ in prompts
50
- ]
64
+ def _response_from(self, prompt: lf.Message) -> lf.Message:
65
+ return lf.AIMessage.from_value(self.response)
51
66
 
52
67
 
53
68
  @lf.use_init_args(['mapping'])
@@ -55,15 +70,12 @@ class StaticMapping(Fake):
55
70
  """A static mapping from prompt to response."""
56
71
 
57
72
  mapping: Annotated[
58
- dict[str, str],
73
+ dict[str, str | lf.Message],
59
74
  'A mapping from prompt to response.'
60
75
  ]
61
76
 
62
- def _sample(self, prompts: list[str]) -> list[lf.LMSamplingResult]:
63
- return [
64
- lf.LMSamplingResult([lf.LMSample(self.mapping[prompt], 1.0)])
65
- for prompt in prompts
66
- ]
77
+ def _response_from(self, prompt: lf.Message) -> lf.Message:
78
+ return lf.AIMessage.from_value(self.mapping[prompt])
67
79
 
68
80
 
69
81
  @lf.use_init_args(['sequence'])
@@ -71,7 +83,7 @@ class StaticSequence(Fake):
71
83
  """A static sequence of responses to use."""
72
84
 
73
85
  sequence: Annotated[
74
- list[str],
86
+ list[str | lf.Message],
75
87
  'A sequence of strings as the response.'
76
88
  ]
77
89
 
@@ -79,10 +91,7 @@ class StaticSequence(Fake):
79
91
  super()._on_bound()
80
92
  self._pos = 0
81
93
 
82
- def _sample(self, prompts: list[str]) -> list[lf.LMSamplingResult]:
83
- results = []
84
- for _ in prompts:
85
- results.append(lf.LMSamplingResult(
86
- [lf.LMSample(self.sequence[self._pos], 1.0)]))
87
- self._pos += 1
88
- return results
94
+ def _response_from(self, prompt: lf.Message) -> lf.Message:
95
+ r = lf.AIMessage.from_value(self.sequence[self._pos])
96
+ self._pos += 1
97
+ return r