langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501150804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +20 -2
- langfun/core/__init__.py +16 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -21
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +63 -2
- langfun/core/component_test.py +53 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +17 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +16 -1
- langfun/core/eval/base.py +622 -174
- langfun/core/eval/base_test.py +200 -54
- langfun/core/eval/matching.py +63 -76
- langfun/core/eval/matching_test.py +17 -8
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +26 -26
- langfun/core/eval/scoring_test.py +19 -2
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +4 -17
- langfun/core/langfunc_test.py +22 -6
- langfun/core/language_model.py +577 -39
- langfun/core/language_model_test.py +470 -56
- langfun/core/llms/__init__.py +87 -16
- langfun/core/llms/anthropic.py +312 -87
- langfun/core/llms/anthropic_test.py +71 -3
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +53 -2
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +11 -7
- langfun/core/llms/fake_test.py +14 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -202
- langfun/core/llms/groq.py +160 -144
- langfun/core/llms/groq_test.py +31 -137
- langfun/core/llms/llama_cpp.py +15 -42
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +395 -203
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +30 -395
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -26
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +12 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +7 -6
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +60 -27
- langfun/core/structured/function_generation_test.py +72 -2
- langfun/core/structured/mapping.py +97 -47
- langfun/core/structured/mapping_test.py +90 -2
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +53 -9
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
- langfun/core/structured/schema.py +204 -97
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_test.py +130 -29
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +115 -1
- langfun/core/template_test.py +71 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +10 -2
- langfun-0.1.2.dev202501150804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501150804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -238
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
- langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,10 @@ import os
|
|
18
18
|
from typing import Any
|
19
19
|
import unittest
|
20
20
|
from unittest import mock
|
21
|
+
|
22
|
+
from google.auth import exceptions
|
23
|
+
from langfun.core import language_model
|
24
|
+
from langfun.core import message as lf_message
|
21
25
|
from langfun.core import modalities as lf_modalities
|
22
26
|
from langfun.core.llms import anthropic
|
23
27
|
import pyglove as pg
|
@@ -59,18 +63,30 @@ image_content = (
|
|
59
63
|
b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
|
60
64
|
)
|
61
65
|
|
66
|
+
pdf_content = (
|
67
|
+
b'%PDF-1.4\n1 0 obj\n<< /Type /Catalog /Pages 2 0 R >>\nendobj\n2 0 obj\n<<'
|
68
|
+
b' /Type /Pages /Count 1 /Kids [3 0 R] >>\nendobj\n3 0 obj\n<< /Type /Page'
|
69
|
+
b' /Parent 2 0 R /MediaBox [0 0 612 792] /Contents 4 0 R >>\nendobj\n4 0'
|
70
|
+
b' obj\n<< /Length 44 >>\nstream\nBT /F1 24 Tf 100 700 Td (Hello, PDF'
|
71
|
+
b' content!) Tj ET\nendstream\nendobj\n5 0 obj\n<< /Type /Font /Subtype'
|
72
|
+
b' /Type1 /BaseFont /Helvetica >>\nendobj\nxref\n0 6\n0000000000 65535 f'
|
73
|
+
b' \n0000000010 00000 n \n0000000079 00000 n \n0000000178 00000 n'
|
74
|
+
b' \n0000000278 00000 n \n0000000407 00000 n \ntrailer\n<< /Size 6 /Root 1'
|
75
|
+
b' 0 R >>\nstartxref\n517\n%%EOF'
|
76
|
+
)
|
77
|
+
|
62
78
|
|
63
79
|
def mock_mm_requests_post(url: str, json: dict[str, Any], **kwargs):
|
64
80
|
del url, kwargs
|
65
81
|
v = json['messages'][0]['content'][0]
|
66
|
-
|
82
|
+
content = lf_modalities.Mime.from_bytes(base64.b64decode(v['source']['data']))
|
67
83
|
|
68
84
|
response = requests.Response()
|
69
85
|
response.status_code = 200
|
70
86
|
response._content = pg.to_json_str({
|
71
87
|
'content': [{
|
72
88
|
'type': 'text',
|
73
|
-
'text': f'{v["type"]}: {
|
89
|
+
'text': f'{v["type"]}: {content.mime_type}',
|
74
90
|
}],
|
75
91
|
'usage': {
|
76
92
|
'input_tokens': 2,
|
@@ -146,6 +162,13 @@ class AnthropicTest(unittest.TestCase):
|
|
146
162
|
response = lm(lf_modalities.Image.from_bytes(image_content), lm=lm)
|
147
163
|
self.assertEqual(response.text, 'image: image/png')
|
148
164
|
|
165
|
+
def test_pdf_call(self):
|
166
|
+
with mock.patch('requests.Session.post') as mock_mm_request:
|
167
|
+
mock_mm_request.side_effect = mock_mm_requests_post
|
168
|
+
lm = anthropic.Claude3Haiku(api_key='fake_key')
|
169
|
+
response = lm(lf_modalities.PDF.from_bytes(pdf_content), lm=lm)
|
170
|
+
self.assertEqual(response.text, 'document: application/pdf')
|
171
|
+
|
149
172
|
def test_call_errors(self):
|
150
173
|
for status_code, error_type, error_message in [
|
151
174
|
(429, 'rate_limit', 'Rate limit exceeded.'),
|
@@ -160,7 +183,52 @@ class AnthropicTest(unittest.TestCase):
|
|
160
183
|
with self.assertRaisesRegex(
|
161
184
|
Exception, f'.*{status_code}: .*{error_message}'
|
162
185
|
):
|
163
|
-
lm('hello',
|
186
|
+
lm('hello', max_attempts=1)
|
187
|
+
|
188
|
+
|
189
|
+
class VertexAIAnthropicTest(unittest.TestCase):
|
190
|
+
"""Tests for VertexAI Anthropic models."""
|
191
|
+
|
192
|
+
def test_basics(self):
|
193
|
+
with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
|
194
|
+
lm = anthropic.VertexAIClaude3_5_Sonnet_20241022()
|
195
|
+
lm('hi')
|
196
|
+
|
197
|
+
model = anthropic.VertexAIClaude3_5_Sonnet_20241022(project='langfun')
|
198
|
+
|
199
|
+
# NOTE(daiyip): For OSS users, default credentials are not available unless
|
200
|
+
# users have already set up their GCP project. Therefore we ignore the
|
201
|
+
# exception here.
|
202
|
+
try:
|
203
|
+
model._initialize()
|
204
|
+
except exceptions.DefaultCredentialsError:
|
205
|
+
pass
|
206
|
+
|
207
|
+
self.assertEqual(
|
208
|
+
model.api_endpoint,
|
209
|
+
(
|
210
|
+
'https://us-east5-aiplatform.googleapis.com/v1/projects/'
|
211
|
+
'langfun/locations/us-east5/publishers/anthropic/'
|
212
|
+
'models/claude-3-5-sonnet-v2@20241022:streamRawPredict'
|
213
|
+
)
|
214
|
+
)
|
215
|
+
request = model.request(
|
216
|
+
lf_message.UserMessage('hi'),
|
217
|
+
language_model.LMSamplingOptions(temperature=0.0),
|
218
|
+
)
|
219
|
+
self.assertEqual(
|
220
|
+
request,
|
221
|
+
{
|
222
|
+
'anthropic_version': 'vertex-2023-10-16',
|
223
|
+
'max_tokens': 8192,
|
224
|
+
'messages': [
|
225
|
+
{'content': [{'text': 'hi', 'type': 'text'}], 'role': 'user'}
|
226
|
+
],
|
227
|
+
'stream': False,
|
228
|
+
'temperature': 0.0,
|
229
|
+
'top_k': 40,
|
230
|
+
},
|
231
|
+
)
|
164
232
|
|
165
233
|
|
166
234
|
if __name__ == '__main__':
|
langfun/core/llms/cache/base.py
CHANGED
@@ -60,13 +60,16 @@ class LMCacheBase(lf.LMCache):
|
|
60
60
|
self, lm: lf.LanguageModel, prompt: lf.Message, seed: int
|
61
61
|
) -> lf.LMSamplingResult | None:
|
62
62
|
"""Gets the cached result of a prompt generated by a language model."""
|
63
|
-
|
63
|
+
key = self._key(lm, prompt, seed)
|
64
|
+
entry = self._get(lm.model_id, key)
|
64
65
|
self._stats.num_queries += 1
|
65
66
|
if entry is None:
|
66
67
|
self._stats.num_misses += 1
|
67
68
|
return None
|
68
69
|
if entry.expire is not None and entry.expire < datetime.datetime.now():
|
69
70
|
self._stats.num_hit_expires += 1
|
71
|
+
self._stats.num_deletes += 1
|
72
|
+
assert self._delete(lm.model_id, key)
|
70
73
|
return None
|
71
74
|
self._stats.num_hits += 1
|
72
75
|
return entry.result
|
@@ -86,6 +89,18 @@ class LMCacheBase(lf.LMCache):
|
|
86
89
|
self._put(lm.model_id, self._key(lm, prompt, seed), entry)
|
87
90
|
self._stats.num_updates += 1
|
88
91
|
|
92
|
+
def delete(
|
93
|
+
self,
|
94
|
+
lm: lf.LanguageModel,
|
95
|
+
prompt: lf.Message,
|
96
|
+
seed: int,
|
97
|
+
) -> bool:
|
98
|
+
"""Deletes the result of a prompt generated by a language model in cache."""
|
99
|
+
deleted = self._delete(lm.model_id, self._key(lm, prompt, seed))
|
100
|
+
if deleted:
|
101
|
+
self._stats.num_deletes += 1
|
102
|
+
return deleted
|
103
|
+
|
89
104
|
@abc.abstractmethod
|
90
105
|
def _get(self, model_id: str, key: str) -> LMCacheEntry | None:
|
91
106
|
"""Returns a LM cache entry associated with the key."""
|
@@ -94,6 +109,10 @@ class LMCacheBase(lf.LMCache):
|
|
94
109
|
def _put(self, model_id: str, key: str, entry: LMCacheEntry) -> None:
|
95
110
|
"""Puts a LM cache entry associated with the key."""
|
96
111
|
|
112
|
+
@abc.abstractmethod
|
113
|
+
def _delete(self, model_id: str, key: str) -> bool:
|
114
|
+
"""Deletes a LM cache entry associated with the key."""
|
115
|
+
|
97
116
|
def _sym_clone(self, deep: bool, memo: Any = None) -> 'LMCacheBase':
|
98
117
|
v = super()._sym_clone(deep, memo)
|
99
118
|
v._stats = self._stats # pylint: disable=protected-access
|
@@ -102,4 +121,4 @@ class LMCacheBase(lf.LMCache):
|
|
102
121
|
|
103
122
|
def default_key(lm: lf.LanguageModel, prompt: lf.Message, seed: int) -> Any:
|
104
123
|
"""Default key for LM cache."""
|
105
|
-
return (prompt.
|
124
|
+
return (prompt.text_with_modality_hash, lm.sampling_options.cache_key(), seed)
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
import collections
|
17
17
|
import contextlib
|
18
|
+
import json
|
18
19
|
from typing import Annotated, Any, Iterator
|
19
20
|
import langfun.core as lf
|
20
21
|
from langfun.core.llms.cache import base
|
@@ -49,6 +50,11 @@ class InMemory(base.LMCacheBase):
|
|
49
50
|
"Creating a new cache as cache file '%s' does not exist.",
|
50
51
|
self.filename,
|
51
52
|
)
|
53
|
+
except json.JSONDecodeError:
|
54
|
+
pg.logging.warning(
|
55
|
+
"Creating a new cache as cache file '%s' is corrupted.",
|
56
|
+
self.filename,
|
57
|
+
)
|
52
58
|
|
53
59
|
def model_ids(self) -> list[str]:
|
54
60
|
"""Returns the model ids of cached queires."""
|
@@ -99,6 +105,13 @@ class InMemory(base.LMCacheBase):
|
|
99
105
|
"""Puts a LM cache entry associated with the key."""
|
100
106
|
self._cache[model_id][key] = entry
|
101
107
|
|
108
|
+
def _delete(self, model_id: str, key: str) -> bool:
|
109
|
+
"""Deletes a LM cache entry associated with the key."""
|
110
|
+
model_cache = self._cache.get(model_id, None)
|
111
|
+
if model_cache is None:
|
112
|
+
return False
|
113
|
+
return model_cache.pop(key, None) is not None
|
114
|
+
|
102
115
|
def reset(self, model_id: str | None = None) -> None:
|
103
116
|
"""Resets the cache."""
|
104
117
|
if model_id is not None:
|
@@ -66,14 +66,15 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
66
66
|
[
|
67
67
|
lf.LMSample(
|
68
68
|
lf.AIMessage(response_text, cache_seed=cache_seed),
|
69
|
-
score=1.0
|
69
|
+
score=1.0,
|
70
70
|
)
|
71
71
|
],
|
72
72
|
usage=lf.LMSamplingUsage(
|
73
73
|
1,
|
74
74
|
len(response_text),
|
75
75
|
len(response_text) + 1,
|
76
|
-
)
|
76
|
+
),
|
77
|
+
is_cached=True,
|
77
78
|
)
|
78
79
|
)
|
79
80
|
|
@@ -148,6 +149,50 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
148
149
|
self.assertIs(copy.deepcopy(cache)._cache, cache._cache)
|
149
150
|
self.assertIs(copy.deepcopy(cache)._stats, cache._stats)
|
150
151
|
|
152
|
+
self.assertFalse(
|
153
|
+
cache.delete(fake.StaticResponse('hi'), lf.UserMessage('c'), seed=0)
|
154
|
+
)
|
155
|
+
self.assertFalse(cache.delete(lm, lf.UserMessage('c'), seed=1))
|
156
|
+
self.assertFalse(cache.delete(lm, lf.UserMessage('d'), seed=0))
|
157
|
+
self.assertTrue(cache.delete(lm, lf.UserMessage('c'), seed=0))
|
158
|
+
self.assertEqual(
|
159
|
+
list(cache.keys('StaticSequence')),
|
160
|
+
[
|
161
|
+
('a', (None, None, 1, 40, None, None), 0),
|
162
|
+
('a', (None, None, 1, 40, None, None), 1),
|
163
|
+
('b', (None, None, 1, 40, None, None), 0),
|
164
|
+
],
|
165
|
+
)
|
166
|
+
self.assertEqual(cache.stats.num_deletes, 1)
|
167
|
+
|
168
|
+
def test_cache_with_modalities(self):
|
169
|
+
|
170
|
+
class CustomModality(lf.Modality):
|
171
|
+
content: str
|
172
|
+
|
173
|
+
def to_bytes(self):
|
174
|
+
return self.content.encode()
|
175
|
+
|
176
|
+
cache = in_memory.InMemory()
|
177
|
+
lm = fake.StaticSequence(['1', '2', '3', '4', '5', '6'], cache=cache)
|
178
|
+
lm(lf.UserMessage('hi <<[[image]]>>', image=CustomModality('foo')))
|
179
|
+
lm(lf.UserMessage('hi <<[[image]]>>', image=CustomModality('bar')))
|
180
|
+
self.assertEqual(
|
181
|
+
list(cache.keys()),
|
182
|
+
[
|
183
|
+
(
|
184
|
+
'hi <<[[image]]>><image>acbd18db</image>',
|
185
|
+
(None, None, 1, 40, None, None),
|
186
|
+
0,
|
187
|
+
),
|
188
|
+
(
|
189
|
+
'hi <<[[image]]>><image>37b51d19</image>',
|
190
|
+
(None, None, 1, 40, None, None),
|
191
|
+
0,
|
192
|
+
),
|
193
|
+
],
|
194
|
+
)
|
195
|
+
|
151
196
|
def test_ttl(self):
|
152
197
|
cache = in_memory.InMemory(ttl=1)
|
153
198
|
lm = fake.StaticSequence(['1', '2', '3'], cache=cache)
|
@@ -160,6 +205,7 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
160
205
|
self.assertEqual(cache.stats.num_hits, 1)
|
161
206
|
self.assertEqual(cache.stats.num_hit_expires, 1)
|
162
207
|
self.assertEqual(cache.stats.num_misses, 1)
|
208
|
+
self.assertEqual(cache.stats.num_deletes, 1)
|
163
209
|
|
164
210
|
def test_different_sampling_options(self):
|
165
211
|
cache = in_memory.InMemory()
|
@@ -249,6 +295,11 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
249
295
|
self.assertEqual(cache2.stats.num_updates, 2)
|
250
296
|
cache2.save()
|
251
297
|
|
298
|
+
# Corrupted file.
|
299
|
+
pg.io.writefile(path, 'bad_content')
|
300
|
+
cache3 = in_memory.InMemory(path)
|
301
|
+
self.assertEqual(len(cache3), 0)
|
302
|
+
|
252
303
|
|
253
304
|
class LmCacheTest(unittest.TestCase):
|
254
305
|
|
@@ -0,0 +1,101 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Compositions of different LLM models."""
|
15
|
+
import random
|
16
|
+
from typing import Annotated
|
17
|
+
|
18
|
+
import langfun.core as lf
|
19
|
+
import pyglove as pg
|
20
|
+
|
21
|
+
|
22
|
+
@pg.use_init_args(['candidates', 'seed'])
|
23
|
+
class RandomChoice(lf.LanguageModel):
|
24
|
+
"""Random choice of a list of LLM models."""
|
25
|
+
|
26
|
+
candidates: Annotated[
|
27
|
+
list[lf.LanguageModel],
|
28
|
+
(
|
29
|
+
'A list of LLMs as candidates to choose from.'
|
30
|
+
)
|
31
|
+
]
|
32
|
+
|
33
|
+
seed: Annotated[
|
34
|
+
int,
|
35
|
+
(
|
36
|
+
'The random seed to use for the random choice.'
|
37
|
+
)
|
38
|
+
] = 0
|
39
|
+
|
40
|
+
def _on_bound(self):
|
41
|
+
super()._on_bound()
|
42
|
+
self._rand = random.Random(self.seed)
|
43
|
+
# Applying sampling options to all candidates.
|
44
|
+
parent_non_default = self.sampling_options.sym_nondefault()
|
45
|
+
if parent_non_default:
|
46
|
+
for c in self.candidates:
|
47
|
+
c.sampling_options.rebind(
|
48
|
+
parent_non_default, notify_parents=False, raise_on_no_change=False
|
49
|
+
)
|
50
|
+
|
51
|
+
@property
|
52
|
+
def model_id(self) -> str:
|
53
|
+
model_ids = ', '.join(
|
54
|
+
sorted(c.model_id for c in self.candidates)
|
55
|
+
)
|
56
|
+
return f'RandomChoice({model_ids})'
|
57
|
+
|
58
|
+
@property
|
59
|
+
def resource_id(self) -> str:
|
60
|
+
resource_ids = ', '.join(
|
61
|
+
sorted(c.resource_id for c in self.candidates)
|
62
|
+
)
|
63
|
+
return f'RandomChoice({resource_ids})'
|
64
|
+
|
65
|
+
def _select_lm(self) -> lf.LanguageModel:
|
66
|
+
"""Selects a random LLM from the candidates."""
|
67
|
+
return self._rand.choice(self.candidates)
|
68
|
+
|
69
|
+
def sample(
|
70
|
+
self,
|
71
|
+
prompts: list[str | lf.Message],
|
72
|
+
*,
|
73
|
+
cache_seed: int = 0,
|
74
|
+
**kwargs,
|
75
|
+
) -> list[lf.LMSamplingResult]:
|
76
|
+
return self._select_lm().sample(
|
77
|
+
prompts, cache_seed=cache_seed, **kwargs
|
78
|
+
)
|
79
|
+
|
80
|
+
def __call__(
|
81
|
+
self, prompt: lf.Message, *, cache_seed: int = 0, **kwargs
|
82
|
+
) -> lf.Message:
|
83
|
+
return self._select_lm()(prompt, cache_seed=cache_seed, **kwargs)
|
84
|
+
|
85
|
+
def score(
|
86
|
+
self,
|
87
|
+
prompt: str | lf.Message | list[lf.Message],
|
88
|
+
completions: list[str | lf.Message],
|
89
|
+
**kwargs,
|
90
|
+
) -> list[lf.LMScoringResult]:
|
91
|
+
return self._select_lm().score(prompt, completions, **kwargs)
|
92
|
+
|
93
|
+
def tokenize(
|
94
|
+
self,
|
95
|
+
prompt: str | lf.Message,
|
96
|
+
**kwargs,
|
97
|
+
) -> list[tuple[str | bytes, int]]:
|
98
|
+
return self._select_lm().tokenize(prompt, **kwargs)
|
99
|
+
|
100
|
+
def _sample(self, *arg, **kwargs):
|
101
|
+
assert False, 'Should never trigger.'
|
@@ -0,0 +1,73 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tests for compositional models."""
|
15
|
+
import unittest
|
16
|
+
|
17
|
+
import langfun.core as lf
|
18
|
+
from langfun.core.llms import compositional
|
19
|
+
from langfun.core.llms import fake
|
20
|
+
|
21
|
+
|
22
|
+
class RandomChoiceTest(unittest.TestCase):
|
23
|
+
|
24
|
+
def test_basic(self):
|
25
|
+
lm = compositional.RandomChoice([
|
26
|
+
fake.StaticResponse('hi'),
|
27
|
+
fake.StaticSequence(['hello', 'world'])
|
28
|
+
])
|
29
|
+
self.assertEqual(
|
30
|
+
lm.model_id,
|
31
|
+
'RandomChoice(StaticResponse, StaticSequence)'
|
32
|
+
)
|
33
|
+
self.assertEqual(
|
34
|
+
lm.resource_id,
|
35
|
+
'RandomChoice(StaticResponse, StaticSequence)'
|
36
|
+
)
|
37
|
+
self.assertEqual(
|
38
|
+
[lm('a'), lm('b'), lm('c')],
|
39
|
+
['hello', 'world', 'hi']
|
40
|
+
)
|
41
|
+
lm = lm.clone()
|
42
|
+
self.assertEqual(
|
43
|
+
[
|
44
|
+
x.samples[0].response for x in [
|
45
|
+
lm.sample(['a'])[0],
|
46
|
+
lm.sample(['b'])[0],
|
47
|
+
lm.sample(['c'])[0],
|
48
|
+
]
|
49
|
+
],
|
50
|
+
['hello', 'world', 'hi']
|
51
|
+
)
|
52
|
+
self.assertEqual(
|
53
|
+
lm.score('hello', ['world']),
|
54
|
+
[lf.LMScoringResult(0.0)]
|
55
|
+
)
|
56
|
+
self.assertEqual(
|
57
|
+
lm.tokenize('hello'),
|
58
|
+
[('hello', 0)]
|
59
|
+
)
|
60
|
+
|
61
|
+
def test_sampling_options(self):
|
62
|
+
lm = compositional.RandomChoice([
|
63
|
+
fake.StaticResponse('hi'),
|
64
|
+
fake.StaticSequence(['hello', 'world'])
|
65
|
+
], temperature=0.5)
|
66
|
+
self.assertEqual(
|
67
|
+
lm.candidates[0].sampling_options.temperature,
|
68
|
+
0.5
|
69
|
+
)
|
70
|
+
|
71
|
+
|
72
|
+
if __name__ == '__main__':
|
73
|
+
unittest.main()
|
@@ -0,0 +1,117 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Language models from DeepSeek."""
|
15
|
+
|
16
|
+
import os
|
17
|
+
from typing import Annotated, Any
|
18
|
+
|
19
|
+
import langfun.core as lf
|
20
|
+
from langfun.core.llms import openai_compatible
|
21
|
+
import pyglove as pg
|
22
|
+
|
23
|
+
SUPPORTED_MODELS_AND_SETTINGS = {
|
24
|
+
# pylint: disable=g-line-too-long
|
25
|
+
# TODO(yifenglu): The RPM and TPM are arbitrary numbers. Update them once DeepSeek provides concrete guidelines.
|
26
|
+
# DeepSeek doesn't control the rate limit at the moment: https://api-docs.deepseek.com/quick_start/rate_limit
|
27
|
+
# The cost is based on: https://api-docs.deepseek.com/quick_start/pricing
|
28
|
+
'deepseek-chat': pg.Dict(
|
29
|
+
in_service=True,
|
30
|
+
rpm=100,
|
31
|
+
tpm=1000000,
|
32
|
+
cost_per_1k_input_tokens=0.00014,
|
33
|
+
cost_per_1k_output_tokens=0.00028,
|
34
|
+
),
|
35
|
+
}
|
36
|
+
|
37
|
+
|
38
|
+
# DeepSeek API uses an API format compatible with OpenAI.
|
39
|
+
# Reference: https://api-docs.deepseek.com/
|
40
|
+
@lf.use_init_args(['model'])
|
41
|
+
class DeepSeek(openai_compatible.OpenAICompatible):
|
42
|
+
"""DeepSeek model."""
|
43
|
+
|
44
|
+
model: pg.typing.Annotated[
|
45
|
+
pg.typing.Enum(
|
46
|
+
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
|
47
|
+
),
|
48
|
+
'The name of the model to use.',
|
49
|
+
]
|
50
|
+
|
51
|
+
api_endpoint: str = 'https://api.deepseek.com/chat/completions'
|
52
|
+
|
53
|
+
api_key: Annotated[
|
54
|
+
str | None,
|
55
|
+
(
|
56
|
+
'API key. If None, the key will be read from environment variable '
|
57
|
+
"'DEEPSEEK_API_KEY'."
|
58
|
+
),
|
59
|
+
] = None
|
60
|
+
|
61
|
+
@property
|
62
|
+
def headers(self) -> dict[str, Any]:
|
63
|
+
api_key = self.api_key or os.environ.get('DEEPSEEK_API_KEY', None)
|
64
|
+
if not api_key:
|
65
|
+
raise ValueError(
|
66
|
+
'Please specify `api_key` during `__init__` or set environment '
|
67
|
+
'variable `DEEPSEEK_API_KEY` with your DeepSeek API key.'
|
68
|
+
)
|
69
|
+
headers = super().headers
|
70
|
+
headers.update({
|
71
|
+
'Authorization': f'Bearer {api_key}',
|
72
|
+
})
|
73
|
+
return headers
|
74
|
+
|
75
|
+
@property
|
76
|
+
def model_id(self) -> str:
|
77
|
+
"""Returns a string to identify the model."""
|
78
|
+
return f'DeepSeek({self.model})'
|
79
|
+
|
80
|
+
@property
|
81
|
+
def max_concurrency(self) -> int:
|
82
|
+
rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
|
83
|
+
tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
|
84
|
+
return self.rate_to_max_concurrency(
|
85
|
+
requests_per_min=rpm, tokens_per_min=tpm
|
86
|
+
)
|
87
|
+
|
88
|
+
def estimate_cost(
|
89
|
+
self, num_input_tokens: int, num_output_tokens: int
|
90
|
+
) -> float | None:
|
91
|
+
"""Estimate the cost based on usage."""
|
92
|
+
cost_per_1k_input_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
93
|
+
'cost_per_1k_input_tokens', None
|
94
|
+
)
|
95
|
+
cost_per_1k_output_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
96
|
+
'cost_per_1k_output_tokens', None
|
97
|
+
)
|
98
|
+
if cost_per_1k_output_tokens is None or cost_per_1k_input_tokens is None:
|
99
|
+
return None
|
100
|
+
return (
|
101
|
+
cost_per_1k_input_tokens * num_input_tokens
|
102
|
+
+ cost_per_1k_output_tokens * num_output_tokens
|
103
|
+
) / 1000
|
104
|
+
|
105
|
+
@classmethod
|
106
|
+
def dir(cls):
|
107
|
+
return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
|
108
|
+
|
109
|
+
|
110
|
+
class DeepSeekChat(DeepSeek):
|
111
|
+
"""DeepSeek Chat model.
|
112
|
+
|
113
|
+
Currently, it is powered by DeepSeek-V3 model, 64K input contenxt window and
|
114
|
+
8k max output tokens.
|
115
|
+
"""
|
116
|
+
|
117
|
+
model = 'deepseek-chat'
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# Copyright 2023 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import unittest
|
15
|
+
from langfun.core.llms import deepseek
|
16
|
+
|
17
|
+
|
18
|
+
class DeepSeekTest(unittest.TestCase):
|
19
|
+
"""Tests for DeepSeek language model."""
|
20
|
+
|
21
|
+
def test_dir(self):
|
22
|
+
self.assertIn('deepseek-chat', deepseek.DeepSeek.dir())
|
23
|
+
|
24
|
+
def test_key(self):
|
25
|
+
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
26
|
+
_ = deepseek.DeepSeekChat().headers
|
27
|
+
self.assertEqual(
|
28
|
+
deepseek.DeepSeekChat(api_key='test_key').headers,
|
29
|
+
{
|
30
|
+
'Content-Type': 'application/json',
|
31
|
+
'Authorization': 'Bearer test_key',
|
32
|
+
}
|
33
|
+
)
|
34
|
+
|
35
|
+
def test_model_id(self):
|
36
|
+
self.assertEqual(
|
37
|
+
deepseek.DeepSeekChat(api_key='test_key').model_id,
|
38
|
+
'DeepSeek(deepseek-chat)',
|
39
|
+
)
|
40
|
+
|
41
|
+
def test_resource_id(self):
|
42
|
+
self.assertEqual(
|
43
|
+
deepseek.DeepSeekChat(api_key='test_key').resource_id,
|
44
|
+
'DeepSeek(deepseek-chat)',
|
45
|
+
)
|
46
|
+
|
47
|
+
def test_max_concurrency(self):
|
48
|
+
self.assertGreater(
|
49
|
+
deepseek.DeepSeekChat(api_key='test_key').max_concurrency, 0
|
50
|
+
)
|
51
|
+
|
52
|
+
def test_estimate_cost(self):
|
53
|
+
self.assertEqual(
|
54
|
+
deepseek.DeepSeekChat(api_key='test_key').estimate_cost(
|
55
|
+
num_input_tokens=100, num_output_tokens=100
|
56
|
+
),
|
57
|
+
4.2e-5
|
58
|
+
)
|
59
|
+
|
60
|
+
if __name__ == '__main__':
|
61
|
+
unittest.main()
|