langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501150804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. langfun/__init__.py +20 -2
  2. langfun/core/__init__.py +16 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -21
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +63 -2
  18. langfun/core/component_test.py +53 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +17 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +16 -1
  24. langfun/core/eval/base.py +622 -174
  25. langfun/core/eval/base_test.py +200 -54
  26. langfun/core/eval/matching.py +63 -76
  27. langfun/core/eval/matching_test.py +17 -8
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +26 -26
  31. langfun/core/eval/scoring_test.py +19 -2
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +4 -17
  55. langfun/core/langfunc_test.py +22 -6
  56. langfun/core/language_model.py +577 -39
  57. langfun/core/language_model_test.py +470 -56
  58. langfun/core/llms/__init__.py +87 -16
  59. langfun/core/llms/anthropic.py +312 -87
  60. langfun/core/llms/anthropic_test.py +71 -3
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +53 -2
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +11 -7
  69. langfun/core/llms/fake_test.py +14 -0
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -202
  74. langfun/core/llms/groq.py +160 -144
  75. langfun/core/llms/groq_test.py +31 -137
  76. langfun/core/llms/llama_cpp.py +15 -42
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +395 -203
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +30 -395
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -26
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +12 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +7 -6
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +60 -27
  112. langfun/core/structured/function_generation_test.py +72 -2
  113. langfun/core/structured/mapping.py +97 -47
  114. langfun/core/structured/mapping_test.py +90 -2
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +53 -9
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
  119. langfun/core/structured/schema.py +204 -97
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_test.py +130 -29
  122. langfun/core/structured/scoring.py +125 -19
  123. langfun/core/structured/scoring_test.py +30 -0
  124. langfun/core/structured/tokenization.py +64 -0
  125. langfun/core/structured/tokenization_test.py +48 -0
  126. langfun/core/template.py +115 -1
  127. langfun/core/template_test.py +71 -1
  128. langfun/core/templates/conversation.py +9 -0
  129. langfun/core/templates/conversation_test.py +4 -3
  130. langfun/core/templates/selfplay_test.py +10 -2
  131. langfun-0.1.2.dev202501150804.dist-info/METADATA +225 -0
  132. langfun-0.1.2.dev202501150804.dist-info/RECORD +153 -0
  133. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/WHEEL +1 -1
  134. langfun/core/coding/python/errors.py +0 -108
  135. langfun/core/coding/python/errors_test.py +0 -99
  136. langfun/core/coding/python/permissions.py +0 -90
  137. langfun/core/coding/python/permissions_test.py +0 -86
  138. langfun/core/structured/prompting.py +0 -238
  139. langfun/core/text_formatting.py +0 -162
  140. langfun/core/text_formatting_test.py +0 -47
  141. langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
  142. langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
  143. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/LICENSE +0 -0
  144. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,10 @@ import os
18
18
  from typing import Any
19
19
  import unittest
20
20
  from unittest import mock
21
+
22
+ from google.auth import exceptions
23
+ from langfun.core import language_model
24
+ from langfun.core import message as lf_message
21
25
  from langfun.core import modalities as lf_modalities
22
26
  from langfun.core.llms import anthropic
23
27
  import pyglove as pg
@@ -59,18 +63,30 @@ image_content = (
59
63
  b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
60
64
  )
61
65
 
66
+ pdf_content = (
67
+ b'%PDF-1.4\n1 0 obj\n<< /Type /Catalog /Pages 2 0 R >>\nendobj\n2 0 obj\n<<'
68
+ b' /Type /Pages /Count 1 /Kids [3 0 R] >>\nendobj\n3 0 obj\n<< /Type /Page'
69
+ b' /Parent 2 0 R /MediaBox [0 0 612 792] /Contents 4 0 R >>\nendobj\n4 0'
70
+ b' obj\n<< /Length 44 >>\nstream\nBT /F1 24 Tf 100 700 Td (Hello, PDF'
71
+ b' content!) Tj ET\nendstream\nendobj\n5 0 obj\n<< /Type /Font /Subtype'
72
+ b' /Type1 /BaseFont /Helvetica >>\nendobj\nxref\n0 6\n0000000000 65535 f'
73
+ b' \n0000000010 00000 n \n0000000079 00000 n \n0000000178 00000 n'
74
+ b' \n0000000278 00000 n \n0000000407 00000 n \ntrailer\n<< /Size 6 /Root 1'
75
+ b' 0 R >>\nstartxref\n517\n%%EOF'
76
+ )
77
+
62
78
 
63
79
  def mock_mm_requests_post(url: str, json: dict[str, Any], **kwargs):
64
80
  del url, kwargs
65
81
  v = json['messages'][0]['content'][0]
66
- image = lf_modalities.Image.from_bytes(base64.b64decode(v['source']['data']))
82
+ content = lf_modalities.Mime.from_bytes(base64.b64decode(v['source']['data']))
67
83
 
68
84
  response = requests.Response()
69
85
  response.status_code = 200
70
86
  response._content = pg.to_json_str({
71
87
  'content': [{
72
88
  'type': 'text',
73
- 'text': f'{v["type"]}: {image.mime_type}',
89
+ 'text': f'{v["type"]}: {content.mime_type}',
74
90
  }],
75
91
  'usage': {
76
92
  'input_tokens': 2,
@@ -146,6 +162,13 @@ class AnthropicTest(unittest.TestCase):
146
162
  response = lm(lf_modalities.Image.from_bytes(image_content), lm=lm)
147
163
  self.assertEqual(response.text, 'image: image/png')
148
164
 
165
+ def test_pdf_call(self):
166
+ with mock.patch('requests.Session.post') as mock_mm_request:
167
+ mock_mm_request.side_effect = mock_mm_requests_post
168
+ lm = anthropic.Claude3Haiku(api_key='fake_key')
169
+ response = lm(lf_modalities.PDF.from_bytes(pdf_content), lm=lm)
170
+ self.assertEqual(response.text, 'document: application/pdf')
171
+
149
172
  def test_call_errors(self):
150
173
  for status_code, error_type, error_message in [
151
174
  (429, 'rate_limit', 'Rate limit exceeded.'),
@@ -160,7 +183,52 @@ class AnthropicTest(unittest.TestCase):
160
183
  with self.assertRaisesRegex(
161
184
  Exception, f'.*{status_code}: .*{error_message}'
162
185
  ):
163
- lm('hello', lm=lm, max_attempts=1)
186
+ lm('hello', max_attempts=1)
187
+
188
+
189
+ class VertexAIAnthropicTest(unittest.TestCase):
190
+ """Tests for VertexAI Anthropic models."""
191
+
192
+ def test_basics(self):
193
+ with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
194
+ lm = anthropic.VertexAIClaude3_5_Sonnet_20241022()
195
+ lm('hi')
196
+
197
+ model = anthropic.VertexAIClaude3_5_Sonnet_20241022(project='langfun')
198
+
199
+ # NOTE(daiyip): For OSS users, default credentials are not available unless
200
+ # users have already set up their GCP project. Therefore we ignore the
201
+ # exception here.
202
+ try:
203
+ model._initialize()
204
+ except exceptions.DefaultCredentialsError:
205
+ pass
206
+
207
+ self.assertEqual(
208
+ model.api_endpoint,
209
+ (
210
+ 'https://us-east5-aiplatform.googleapis.com/v1/projects/'
211
+ 'langfun/locations/us-east5/publishers/anthropic/'
212
+ 'models/claude-3-5-sonnet-v2@20241022:streamRawPredict'
213
+ )
214
+ )
215
+ request = model.request(
216
+ lf_message.UserMessage('hi'),
217
+ language_model.LMSamplingOptions(temperature=0.0),
218
+ )
219
+ self.assertEqual(
220
+ request,
221
+ {
222
+ 'anthropic_version': 'vertex-2023-10-16',
223
+ 'max_tokens': 8192,
224
+ 'messages': [
225
+ {'content': [{'text': 'hi', 'type': 'text'}], 'role': 'user'}
226
+ ],
227
+ 'stream': False,
228
+ 'temperature': 0.0,
229
+ 'top_k': 40,
230
+ },
231
+ )
164
232
 
165
233
 
166
234
  if __name__ == '__main__':
@@ -60,13 +60,16 @@ class LMCacheBase(lf.LMCache):
60
60
  self, lm: lf.LanguageModel, prompt: lf.Message, seed: int
61
61
  ) -> lf.LMSamplingResult | None:
62
62
  """Gets the cached result of a prompt generated by a language model."""
63
- entry = self._get(lm.model_id, self._key(lm, prompt, seed))
63
+ key = self._key(lm, prompt, seed)
64
+ entry = self._get(lm.model_id, key)
64
65
  self._stats.num_queries += 1
65
66
  if entry is None:
66
67
  self._stats.num_misses += 1
67
68
  return None
68
69
  if entry.expire is not None and entry.expire < datetime.datetime.now():
69
70
  self._stats.num_hit_expires += 1
71
+ self._stats.num_deletes += 1
72
+ assert self._delete(lm.model_id, key)
70
73
  return None
71
74
  self._stats.num_hits += 1
72
75
  return entry.result
@@ -86,6 +89,18 @@ class LMCacheBase(lf.LMCache):
86
89
  self._put(lm.model_id, self._key(lm, prompt, seed), entry)
87
90
  self._stats.num_updates += 1
88
91
 
92
+ def delete(
93
+ self,
94
+ lm: lf.LanguageModel,
95
+ prompt: lf.Message,
96
+ seed: int,
97
+ ) -> bool:
98
+ """Deletes the result of a prompt generated by a language model in cache."""
99
+ deleted = self._delete(lm.model_id, self._key(lm, prompt, seed))
100
+ if deleted:
101
+ self._stats.num_deletes += 1
102
+ return deleted
103
+
89
104
  @abc.abstractmethod
90
105
  def _get(self, model_id: str, key: str) -> LMCacheEntry | None:
91
106
  """Returns a LM cache entry associated with the key."""
@@ -94,6 +109,10 @@ class LMCacheBase(lf.LMCache):
94
109
  def _put(self, model_id: str, key: str, entry: LMCacheEntry) -> None:
95
110
  """Puts a LM cache entry associated with the key."""
96
111
 
112
+ @abc.abstractmethod
113
+ def _delete(self, model_id: str, key: str) -> bool:
114
+ """Deletes a LM cache entry associated with the key."""
115
+
97
116
  def _sym_clone(self, deep: bool, memo: Any = None) -> 'LMCacheBase':
98
117
  v = super()._sym_clone(deep, memo)
99
118
  v._stats = self._stats # pylint: disable=protected-access
@@ -102,4 +121,4 @@ class LMCacheBase(lf.LMCache):
102
121
 
103
122
  def default_key(lm: lf.LanguageModel, prompt: lf.Message, seed: int) -> Any:
104
123
  """Default key for LM cache."""
105
- return (prompt.text, lm.sampling_options.cache_key(), seed)
124
+ return (prompt.text_with_modality_hash, lm.sampling_options.cache_key(), seed)
@@ -15,6 +15,7 @@
15
15
 
16
16
  import collections
17
17
  import contextlib
18
+ import json
18
19
  from typing import Annotated, Any, Iterator
19
20
  import langfun.core as lf
20
21
  from langfun.core.llms.cache import base
@@ -49,6 +50,11 @@ class InMemory(base.LMCacheBase):
49
50
  "Creating a new cache as cache file '%s' does not exist.",
50
51
  self.filename,
51
52
  )
53
+ except json.JSONDecodeError:
54
+ pg.logging.warning(
55
+ "Creating a new cache as cache file '%s' is corrupted.",
56
+ self.filename,
57
+ )
52
58
 
53
59
  def model_ids(self) -> list[str]:
54
60
  """Returns the model ids of cached queires."""
@@ -99,6 +105,13 @@ class InMemory(base.LMCacheBase):
99
105
  """Puts a LM cache entry associated with the key."""
100
106
  self._cache[model_id][key] = entry
101
107
 
108
+ def _delete(self, model_id: str, key: str) -> bool:
109
+ """Deletes a LM cache entry associated with the key."""
110
+ model_cache = self._cache.get(model_id, None)
111
+ if model_cache is None:
112
+ return False
113
+ return model_cache.pop(key, None) is not None
114
+
102
115
  def reset(self, model_id: str | None = None) -> None:
103
116
  """Resets the cache."""
104
117
  if model_id is not None:
@@ -66,14 +66,15 @@ class InMemoryLMCacheTest(unittest.TestCase):
66
66
  [
67
67
  lf.LMSample(
68
68
  lf.AIMessage(response_text, cache_seed=cache_seed),
69
- score=1.0
69
+ score=1.0,
70
70
  )
71
71
  ],
72
72
  usage=lf.LMSamplingUsage(
73
73
  1,
74
74
  len(response_text),
75
75
  len(response_text) + 1,
76
- )
76
+ ),
77
+ is_cached=True,
77
78
  )
78
79
  )
79
80
 
@@ -148,6 +149,50 @@ class InMemoryLMCacheTest(unittest.TestCase):
148
149
  self.assertIs(copy.deepcopy(cache)._cache, cache._cache)
149
150
  self.assertIs(copy.deepcopy(cache)._stats, cache._stats)
150
151
 
152
+ self.assertFalse(
153
+ cache.delete(fake.StaticResponse('hi'), lf.UserMessage('c'), seed=0)
154
+ )
155
+ self.assertFalse(cache.delete(lm, lf.UserMessage('c'), seed=1))
156
+ self.assertFalse(cache.delete(lm, lf.UserMessage('d'), seed=0))
157
+ self.assertTrue(cache.delete(lm, lf.UserMessage('c'), seed=0))
158
+ self.assertEqual(
159
+ list(cache.keys('StaticSequence')),
160
+ [
161
+ ('a', (None, None, 1, 40, None, None), 0),
162
+ ('a', (None, None, 1, 40, None, None), 1),
163
+ ('b', (None, None, 1, 40, None, None), 0),
164
+ ],
165
+ )
166
+ self.assertEqual(cache.stats.num_deletes, 1)
167
+
168
+ def test_cache_with_modalities(self):
169
+
170
+ class CustomModality(lf.Modality):
171
+ content: str
172
+
173
+ def to_bytes(self):
174
+ return self.content.encode()
175
+
176
+ cache = in_memory.InMemory()
177
+ lm = fake.StaticSequence(['1', '2', '3', '4', '5', '6'], cache=cache)
178
+ lm(lf.UserMessage('hi <<[[image]]>>', image=CustomModality('foo')))
179
+ lm(lf.UserMessage('hi <<[[image]]>>', image=CustomModality('bar')))
180
+ self.assertEqual(
181
+ list(cache.keys()),
182
+ [
183
+ (
184
+ 'hi <<[[image]]>><image>acbd18db</image>',
185
+ (None, None, 1, 40, None, None),
186
+ 0,
187
+ ),
188
+ (
189
+ 'hi <<[[image]]>><image>37b51d19</image>',
190
+ (None, None, 1, 40, None, None),
191
+ 0,
192
+ ),
193
+ ],
194
+ )
195
+
151
196
  def test_ttl(self):
152
197
  cache = in_memory.InMemory(ttl=1)
153
198
  lm = fake.StaticSequence(['1', '2', '3'], cache=cache)
@@ -160,6 +205,7 @@ class InMemoryLMCacheTest(unittest.TestCase):
160
205
  self.assertEqual(cache.stats.num_hits, 1)
161
206
  self.assertEqual(cache.stats.num_hit_expires, 1)
162
207
  self.assertEqual(cache.stats.num_misses, 1)
208
+ self.assertEqual(cache.stats.num_deletes, 1)
163
209
 
164
210
  def test_different_sampling_options(self):
165
211
  cache = in_memory.InMemory()
@@ -249,6 +295,11 @@ class InMemoryLMCacheTest(unittest.TestCase):
249
295
  self.assertEqual(cache2.stats.num_updates, 2)
250
296
  cache2.save()
251
297
 
298
+ # Corrupted file.
299
+ pg.io.writefile(path, 'bad_content')
300
+ cache3 = in_memory.InMemory(path)
301
+ self.assertEqual(len(cache3), 0)
302
+
252
303
 
253
304
  class LmCacheTest(unittest.TestCase):
254
305
 
@@ -0,0 +1,101 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Compositions of different LLM models."""
15
+ import random
16
+ from typing import Annotated
17
+
18
+ import langfun.core as lf
19
+ import pyglove as pg
20
+
21
+
22
+ @pg.use_init_args(['candidates', 'seed'])
23
+ class RandomChoice(lf.LanguageModel):
24
+ """Random choice of a list of LLM models."""
25
+
26
+ candidates: Annotated[
27
+ list[lf.LanguageModel],
28
+ (
29
+ 'A list of LLMs as candidates to choose from.'
30
+ )
31
+ ]
32
+
33
+ seed: Annotated[
34
+ int,
35
+ (
36
+ 'The random seed to use for the random choice.'
37
+ )
38
+ ] = 0
39
+
40
+ def _on_bound(self):
41
+ super()._on_bound()
42
+ self._rand = random.Random(self.seed)
43
+ # Applying sampling options to all candidates.
44
+ parent_non_default = self.sampling_options.sym_nondefault()
45
+ if parent_non_default:
46
+ for c in self.candidates:
47
+ c.sampling_options.rebind(
48
+ parent_non_default, notify_parents=False, raise_on_no_change=False
49
+ )
50
+
51
+ @property
52
+ def model_id(self) -> str:
53
+ model_ids = ', '.join(
54
+ sorted(c.model_id for c in self.candidates)
55
+ )
56
+ return f'RandomChoice({model_ids})'
57
+
58
+ @property
59
+ def resource_id(self) -> str:
60
+ resource_ids = ', '.join(
61
+ sorted(c.resource_id for c in self.candidates)
62
+ )
63
+ return f'RandomChoice({resource_ids})'
64
+
65
+ def _select_lm(self) -> lf.LanguageModel:
66
+ """Selects a random LLM from the candidates."""
67
+ return self._rand.choice(self.candidates)
68
+
69
+ def sample(
70
+ self,
71
+ prompts: list[str | lf.Message],
72
+ *,
73
+ cache_seed: int = 0,
74
+ **kwargs,
75
+ ) -> list[lf.LMSamplingResult]:
76
+ return self._select_lm().sample(
77
+ prompts, cache_seed=cache_seed, **kwargs
78
+ )
79
+
80
+ def __call__(
81
+ self, prompt: lf.Message, *, cache_seed: int = 0, **kwargs
82
+ ) -> lf.Message:
83
+ return self._select_lm()(prompt, cache_seed=cache_seed, **kwargs)
84
+
85
+ def score(
86
+ self,
87
+ prompt: str | lf.Message | list[lf.Message],
88
+ completions: list[str | lf.Message],
89
+ **kwargs,
90
+ ) -> list[lf.LMScoringResult]:
91
+ return self._select_lm().score(prompt, completions, **kwargs)
92
+
93
+ def tokenize(
94
+ self,
95
+ prompt: str | lf.Message,
96
+ **kwargs,
97
+ ) -> list[tuple[str | bytes, int]]:
98
+ return self._select_lm().tokenize(prompt, **kwargs)
99
+
100
+ def _sample(self, *arg, **kwargs):
101
+ assert False, 'Should never trigger.'
@@ -0,0 +1,73 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for compositional models."""
15
+ import unittest
16
+
17
+ import langfun.core as lf
18
+ from langfun.core.llms import compositional
19
+ from langfun.core.llms import fake
20
+
21
+
22
+ class RandomChoiceTest(unittest.TestCase):
23
+
24
+ def test_basic(self):
25
+ lm = compositional.RandomChoice([
26
+ fake.StaticResponse('hi'),
27
+ fake.StaticSequence(['hello', 'world'])
28
+ ])
29
+ self.assertEqual(
30
+ lm.model_id,
31
+ 'RandomChoice(StaticResponse, StaticSequence)'
32
+ )
33
+ self.assertEqual(
34
+ lm.resource_id,
35
+ 'RandomChoice(StaticResponse, StaticSequence)'
36
+ )
37
+ self.assertEqual(
38
+ [lm('a'), lm('b'), lm('c')],
39
+ ['hello', 'world', 'hi']
40
+ )
41
+ lm = lm.clone()
42
+ self.assertEqual(
43
+ [
44
+ x.samples[0].response for x in [
45
+ lm.sample(['a'])[0],
46
+ lm.sample(['b'])[0],
47
+ lm.sample(['c'])[0],
48
+ ]
49
+ ],
50
+ ['hello', 'world', 'hi']
51
+ )
52
+ self.assertEqual(
53
+ lm.score('hello', ['world']),
54
+ [lf.LMScoringResult(0.0)]
55
+ )
56
+ self.assertEqual(
57
+ lm.tokenize('hello'),
58
+ [('hello', 0)]
59
+ )
60
+
61
+ def test_sampling_options(self):
62
+ lm = compositional.RandomChoice([
63
+ fake.StaticResponse('hi'),
64
+ fake.StaticSequence(['hello', 'world'])
65
+ ], temperature=0.5)
66
+ self.assertEqual(
67
+ lm.candidates[0].sampling_options.temperature,
68
+ 0.5
69
+ )
70
+
71
+
72
+ if __name__ == '__main__':
73
+ unittest.main()
@@ -0,0 +1,117 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Language models from DeepSeek."""
15
+
16
+ import os
17
+ from typing import Annotated, Any
18
+
19
+ import langfun.core as lf
20
+ from langfun.core.llms import openai_compatible
21
+ import pyglove as pg
22
+
23
+ SUPPORTED_MODELS_AND_SETTINGS = {
24
+ # pylint: disable=g-line-too-long
25
+ # TODO(yifenglu): The RPM and TPM are arbitrary numbers. Update them once DeepSeek provides concrete guidelines.
26
+ # DeepSeek doesn't control the rate limit at the moment: https://api-docs.deepseek.com/quick_start/rate_limit
27
+ # The cost is based on: https://api-docs.deepseek.com/quick_start/pricing
28
+ 'deepseek-chat': pg.Dict(
29
+ in_service=True,
30
+ rpm=100,
31
+ tpm=1000000,
32
+ cost_per_1k_input_tokens=0.00014,
33
+ cost_per_1k_output_tokens=0.00028,
34
+ ),
35
+ }
36
+
37
+
38
+ # DeepSeek API uses an API format compatible with OpenAI.
39
+ # Reference: https://api-docs.deepseek.com/
40
+ @lf.use_init_args(['model'])
41
+ class DeepSeek(openai_compatible.OpenAICompatible):
42
+ """DeepSeek model."""
43
+
44
+ model: pg.typing.Annotated[
45
+ pg.typing.Enum(
46
+ pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
47
+ ),
48
+ 'The name of the model to use.',
49
+ ]
50
+
51
+ api_endpoint: str = 'https://api.deepseek.com/chat/completions'
52
+
53
+ api_key: Annotated[
54
+ str | None,
55
+ (
56
+ 'API key. If None, the key will be read from environment variable '
57
+ "'DEEPSEEK_API_KEY'."
58
+ ),
59
+ ] = None
60
+
61
+ @property
62
+ def headers(self) -> dict[str, Any]:
63
+ api_key = self.api_key or os.environ.get('DEEPSEEK_API_KEY', None)
64
+ if not api_key:
65
+ raise ValueError(
66
+ 'Please specify `api_key` during `__init__` or set environment '
67
+ 'variable `DEEPSEEK_API_KEY` with your DeepSeek API key.'
68
+ )
69
+ headers = super().headers
70
+ headers.update({
71
+ 'Authorization': f'Bearer {api_key}',
72
+ })
73
+ return headers
74
+
75
+ @property
76
+ def model_id(self) -> str:
77
+ """Returns a string to identify the model."""
78
+ return f'DeepSeek({self.model})'
79
+
80
+ @property
81
+ def max_concurrency(self) -> int:
82
+ rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
83
+ tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
84
+ return self.rate_to_max_concurrency(
85
+ requests_per_min=rpm, tokens_per_min=tpm
86
+ )
87
+
88
+ def estimate_cost(
89
+ self, num_input_tokens: int, num_output_tokens: int
90
+ ) -> float | None:
91
+ """Estimate the cost based on usage."""
92
+ cost_per_1k_input_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
93
+ 'cost_per_1k_input_tokens', None
94
+ )
95
+ cost_per_1k_output_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
96
+ 'cost_per_1k_output_tokens', None
97
+ )
98
+ if cost_per_1k_output_tokens is None or cost_per_1k_input_tokens is None:
99
+ return None
100
+ return (
101
+ cost_per_1k_input_tokens * num_input_tokens
102
+ + cost_per_1k_output_tokens * num_output_tokens
103
+ ) / 1000
104
+
105
+ @classmethod
106
+ def dir(cls):
107
+ return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
108
+
109
+
110
+ class DeepSeekChat(DeepSeek):
111
+ """DeepSeek Chat model.
112
+
113
+ Currently, it is powered by DeepSeek-V3 model, 64K input contenxt window and
114
+ 8k max output tokens.
115
+ """
116
+
117
+ model = 'deepseek-chat'
@@ -0,0 +1,61 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import unittest
15
+ from langfun.core.llms import deepseek
16
+
17
+
18
+ class DeepSeekTest(unittest.TestCase):
19
+ """Tests for DeepSeek language model."""
20
+
21
+ def test_dir(self):
22
+ self.assertIn('deepseek-chat', deepseek.DeepSeek.dir())
23
+
24
+ def test_key(self):
25
+ with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
26
+ _ = deepseek.DeepSeekChat().headers
27
+ self.assertEqual(
28
+ deepseek.DeepSeekChat(api_key='test_key').headers,
29
+ {
30
+ 'Content-Type': 'application/json',
31
+ 'Authorization': 'Bearer test_key',
32
+ }
33
+ )
34
+
35
+ def test_model_id(self):
36
+ self.assertEqual(
37
+ deepseek.DeepSeekChat(api_key='test_key').model_id,
38
+ 'DeepSeek(deepseek-chat)',
39
+ )
40
+
41
+ def test_resource_id(self):
42
+ self.assertEqual(
43
+ deepseek.DeepSeekChat(api_key='test_key').resource_id,
44
+ 'DeepSeek(deepseek-chat)',
45
+ )
46
+
47
+ def test_max_concurrency(self):
48
+ self.assertGreater(
49
+ deepseek.DeepSeekChat(api_key='test_key').max_concurrency, 0
50
+ )
51
+
52
+ def test_estimate_cost(self):
53
+ self.assertEqual(
54
+ deepseek.DeepSeekChat(api_key='test_key').estimate_cost(
55
+ num_input_tokens=100, num_output_tokens=100
56
+ ),
57
+ 4.2e-5
58
+ )
59
+
60
+ if __name__ == '__main__':
61
+ unittest.main()