langfun 0.1.2.dev202501060804__py3-none-any.whl → 0.1.2.dev202501100804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/__init__.py +0 -5
- langfun/core/coding/python/correction.py +4 -3
- langfun/core/coding/python/errors.py +10 -9
- langfun/core/coding/python/execution.py +23 -12
- langfun/core/coding/python/execution_test.py +21 -2
- langfun/core/coding/python/generation.py +18 -9
- langfun/core/concurrent.py +2 -3
- langfun/core/console.py +8 -3
- langfun/core/eval/base.py +2 -3
- langfun/core/eval/v2/reporting.py +8 -4
- langfun/core/language_model.py +7 -4
- langfun/core/language_model_test.py +15 -0
- langfun/core/llms/__init__.py +7 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/google_genai.py +1 -0
- langfun/core/llms/groq.py +12 -99
- langfun/core/llms/groq_test.py +31 -137
- langfun/core/llms/llama_cpp.py +17 -54
- langfun/core/llms/llama_cpp_test.py +2 -34
- langfun/core/llms/openai.py +14 -147
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +480 -0
- langfun/core/llms/openai_test.py +13 -423
- langfun/core/llms/vertexai.py +6 -2
- langfun/core/llms/vertexai_test.py +1 -1
- langfun/core/modalities/mime.py +8 -0
- langfun/core/modalities/mime_test.py +19 -4
- langfun/core/modality_test.py +0 -1
- langfun/core/structured/mapping.py +13 -13
- langfun/core/structured/mapping_test.py +2 -2
- langfun/core/structured/schema.py +16 -8
- {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/METADATA +13 -2
- {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/RECORD +37 -35
- {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/WHEEL +1 -1
- langfun/core/text_formatting.py +0 -168
- langfun/core/text_formatting_test.py +0 -65
- {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/top_level.txt +0 -0
langfun/core/__init__.py
CHANGED
@@ -77,11 +77,6 @@ from langfun.core.concurrent import concurrent_map
|
|
77
77
|
from langfun.core.concurrent import with_context_access
|
78
78
|
from langfun.core.concurrent import with_retry
|
79
79
|
|
80
|
-
# Utility libraries for text formatting.
|
81
|
-
from langfun.core.text_formatting import colored
|
82
|
-
from langfun.core.text_formatting import colored_print as print # pylint: disable=redefined-builtin
|
83
|
-
from langfun.core.text_formatting import colored_template
|
84
|
-
|
85
80
|
# Interface for natural language formattable.
|
86
81
|
from langfun.core.natural_language import NaturalLanguageFormattable
|
87
82
|
|
@@ -40,6 +40,7 @@ def run_with_correction(
|
|
40
40
|
sandbox: bool | None = None,
|
41
41
|
timeout: int | None = 5,
|
42
42
|
returns_code: bool = False,
|
43
|
+
returns_stdout: bool = False,
|
43
44
|
outputs_intermediate: bool = False,
|
44
45
|
) -> Any | tuple[Any, str]:
|
45
46
|
"""Correct code with a language model via self-play.
|
@@ -62,6 +63,7 @@ def run_with_correction(
|
|
62
63
|
timeout. Applicable only when sandbox is set to True.
|
63
64
|
returns_code: If True, the return value is a tuple of (result, final code).
|
64
65
|
Otherwise the return value is the result only.
|
66
|
+
returns_stdout: If True, the stdout (a str) will be returned.
|
65
67
|
outputs_intermediate: If True, intermediate output will be outputted as a
|
66
68
|
dict, with the last line's value accessible by key '__result__'. Otherwise
|
67
69
|
the value of the last line will be returned.
|
@@ -87,6 +89,7 @@ def run_with_correction(
|
|
87
89
|
global_vars=global_vars,
|
88
90
|
sandbox=sandbox,
|
89
91
|
timeout=timeout,
|
92
|
+
returns_stdout=returns_stdout,
|
90
93
|
outputs_intermediate=outputs_intermediate,
|
91
94
|
)
|
92
95
|
)
|
@@ -189,9 +192,7 @@ def correct(
|
|
189
192
|
def _error_feedback_str(error: Exception) -> str:
|
190
193
|
"""Returns the error str for feedback."""
|
191
194
|
if isinstance(error, errors.CodeError):
|
192
|
-
return
|
193
|
-
error.format(include_complete_code=False)
|
194
|
-
)
|
195
|
+
return pg.decolor(error.format(include_complete_code=False))
|
195
196
|
else:
|
196
197
|
return f"Encountered {error.__class__.__name__}: {error}"
|
197
198
|
|
@@ -17,7 +17,8 @@ import io
|
|
17
17
|
import sys
|
18
18
|
import textwrap
|
19
19
|
import traceback
|
20
|
-
|
20
|
+
|
21
|
+
import pyglove as pg
|
21
22
|
|
22
23
|
|
23
24
|
class CodeError(RuntimeError):
|
@@ -62,13 +63,13 @@ class CodeError(RuntimeError):
|
|
62
63
|
if 'line' not in error_message and self.lineno is not None:
|
63
64
|
error_message += f' (<unknown>, line {self.lineno})'
|
64
65
|
r.write(
|
65
|
-
|
66
|
+
pg.colored(
|
66
67
|
f'{self.cause.__class__.__name__}: {error_message}', 'magenta'))
|
67
68
|
|
68
69
|
if self.lineno is not None:
|
69
70
|
r.write('\n\n')
|
70
71
|
r.write(textwrap.indent(
|
71
|
-
|
72
|
+
pg.colored(
|
72
73
|
self.code_lines(self.lineno - 1, self.end_lineno), 'magenta'),
|
73
74
|
' ' * 2
|
74
75
|
))
|
@@ -76,14 +77,14 @@ class CodeError(RuntimeError):
|
|
76
77
|
|
77
78
|
if include_complete_code:
|
78
79
|
r.write('\n')
|
79
|
-
r.write(
|
80
|
+
r.write(pg.colored('[Generated Code]', 'green', styles=['bold']))
|
80
81
|
r.write('\n\n')
|
81
|
-
r.write(
|
82
|
+
r.write(pg.colored(' ```python\n', 'green'))
|
82
83
|
r.write(textwrap.indent(
|
83
|
-
|
84
|
+
pg.colored(self.code, 'green'),
|
84
85
|
' ' * 2
|
85
86
|
))
|
86
|
-
r.write(
|
87
|
+
r.write(pg.colored('\n ```\n', 'green'))
|
87
88
|
return r.getvalue()
|
88
89
|
|
89
90
|
|
@@ -98,10 +99,10 @@ class SerializationError(RuntimeError):
|
|
98
99
|
r = io.StringIO()
|
99
100
|
cause_message = str(self.cause).rstrip()
|
100
101
|
if self.message:
|
101
|
-
r.write(
|
102
|
+
r.write(pg.colored(self.message, 'magenta'))
|
102
103
|
r.write('\n\n')
|
103
104
|
r.write(
|
104
|
-
|
105
|
+
pg.colored(
|
105
106
|
f'{self.cause.__class__.__name__}: {cause_message}', 'magenta'
|
106
107
|
)
|
107
108
|
)
|
@@ -57,6 +57,7 @@ def evaluate(
|
|
57
57
|
*,
|
58
58
|
global_vars: dict[str, Any] | None = None,
|
59
59
|
permission: permissions.CodePermission | None = None,
|
60
|
+
returns_stdout: bool = False,
|
60
61
|
outputs_intermediate: bool = False,
|
61
62
|
) -> Any | dict[str, Any]:
|
62
63
|
"""Executes Python code.
|
@@ -71,14 +72,17 @@ def evaluate(
|
|
71
72
|
global_vars: An optional dict as the globals that could be referenced by the
|
72
73
|
code.
|
73
74
|
permission: Permission for the Python code to run.
|
74
|
-
|
75
|
-
|
76
|
-
|
75
|
+
returns_stdout: If True, the stdout (a str) will be returned.
|
76
|
+
outputs_intermediate: Applicable when returns_stdout is False. If True,
|
77
|
+
intermediate output will be outputted as a dict, with the last line's
|
78
|
+
value accessible by key '__result__' and the std output accessible by
|
79
|
+
key '__stdout__'. Otherwise the value of the last line will be returned.
|
77
80
|
|
78
81
|
Returns:
|
79
|
-
The value of the last line of the code. Or a dict of variable
|
80
|
-
|
81
|
-
|
82
|
+
The value of the last line of the code block. Or a dict of variable
|
83
|
+
names of all locals to their evaluated values as the output of the code to
|
84
|
+
run. The value for the last line can be accessed by key '__result__'. Or the
|
85
|
+
stdout as a str.
|
82
86
|
"""
|
83
87
|
# Set up the permission and context.
|
84
88
|
permission = permission or permissions.get_permission()
|
@@ -136,6 +140,8 @@ def evaluate(
|
|
136
140
|
raise errors.CodeError(code, e) from e
|
137
141
|
global_vars[RESULT_KEY] = list(global_vars.values())[-1]
|
138
142
|
|
143
|
+
if returns_stdout:
|
144
|
+
return stdout.getvalue()
|
139
145
|
if outputs_intermediate:
|
140
146
|
outputs = {}
|
141
147
|
for k, v in global_vars.items():
|
@@ -258,6 +264,7 @@ def run(
|
|
258
264
|
*,
|
259
265
|
global_vars: dict[str, Any] | None = None,
|
260
266
|
permission: permissions.CodePermission | None = None,
|
267
|
+
returns_stdout: bool = False,
|
261
268
|
outputs_intermediate: bool = False,
|
262
269
|
sandbox: bool | None = None,
|
263
270
|
timeout: float | None = None,
|
@@ -273,9 +280,11 @@ def run(
|
|
273
280
|
code: Python code to run.
|
274
281
|
global_vars: An optional dict of
|
275
282
|
permission: Permission for the Python code to run.
|
276
|
-
|
277
|
-
|
278
|
-
|
283
|
+
returns_stdout: If True, the stdout (a str) will be returned.
|
284
|
+
outputs_intermediate: Applicable when returns_stdout is False. If True,
|
285
|
+
intermediate output will be outputted as a dict, with the last line's
|
286
|
+
value accessible by key '__result__' and the std output accessible by
|
287
|
+
key '__stdout__'. Otherwise the value of the last line will be returned.
|
279
288
|
sandbox: If True, run code in sandbox; If False, run code in current
|
280
289
|
process. If None, run in sandbox first, if the output could not be
|
281
290
|
serialized and pass to current process, run the code again in current
|
@@ -285,7 +294,8 @@ def run(
|
|
285
294
|
Returns:
|
286
295
|
The value of the last line of the code block. Or a dict of variable
|
287
296
|
names of all locals to their evaluated values as the output of the code to
|
288
|
-
run. The value for the last line can be accessed by key '__result__'.
|
297
|
+
run. The value for the last line can be accessed by key '__result__'. Or the
|
298
|
+
stdout as a str.
|
289
299
|
|
290
300
|
Raises:
|
291
301
|
TimeoutError: If the execution time exceeds the timeout.
|
@@ -293,5 +303,6 @@ def run(
|
|
293
303
|
"""
|
294
304
|
return call(
|
295
305
|
evaluate, code=code, global_vars=global_vars, permission=permission,
|
296
|
-
outputs_intermediate=outputs_intermediate,
|
297
|
-
sandbox=sandbox, timeout=timeout
|
306
|
+
returns_stdout=returns_stdout, outputs_intermediate=outputs_intermediate,
|
307
|
+
sandbox=sandbox, timeout=timeout
|
308
|
+
)
|
@@ -63,6 +63,15 @@ class EvaluateTest(unittest.TestCase):
|
|
63
63
|
),
|
64
64
|
3,
|
65
65
|
)
|
66
|
+
with self.assertRaisesRegex(errors.CodeError, 'ValueError'):
|
67
|
+
execution.evaluate(
|
68
|
+
"""
|
69
|
+
def foo():
|
70
|
+
raise ValueError("intentional error")
|
71
|
+
foo()
|
72
|
+
""",
|
73
|
+
permission=permissions.CodePermission.ALL
|
74
|
+
)
|
66
75
|
|
67
76
|
def test_class_def(self):
|
68
77
|
ret = execution.evaluate(
|
@@ -102,16 +111,20 @@ class EvaluateTest(unittest.TestCase):
|
|
102
111
|
self.assertIs(ret['__result__'], ret['bar'])
|
103
112
|
|
104
113
|
def test_function_def_and_call(self):
|
105
|
-
|
114
|
+
code = (
|
106
115
|
"""
|
107
116
|
def foo(x, y):
|
108
117
|
return x + y
|
109
118
|
|
110
119
|
def bar(z):
|
120
|
+
print(f'z is {z}')
|
111
121
|
return z + foo(z, z)
|
112
122
|
|
113
123
|
bar(1)
|
114
|
-
"""
|
124
|
+
"""
|
125
|
+
)
|
126
|
+
ret = execution.evaluate(
|
127
|
+
code,
|
115
128
|
permission=permissions.CodePermission.ALL,
|
116
129
|
outputs_intermediate=True,
|
117
130
|
)
|
@@ -119,6 +132,12 @@ class EvaluateTest(unittest.TestCase):
|
|
119
132
|
list(ret.keys()), ['foo', 'bar', '__result__', '__stdout__']
|
120
133
|
)
|
121
134
|
self.assertEqual(ret['__result__'], 3)
|
135
|
+
ret = execution.evaluate(
|
136
|
+
code,
|
137
|
+
permission=permissions.CodePermission.ALL,
|
138
|
+
returns_stdout=True,
|
139
|
+
)
|
140
|
+
self.assertEqual(ret, 'z is 1\n')
|
122
141
|
|
123
142
|
def test_complex(self):
|
124
143
|
ret = execution.evaluate(
|
@@ -88,6 +88,8 @@ class PythonCode(pg.Object):
|
|
88
88
|
sandbox: bool | None = None,
|
89
89
|
timeout: int | None = 5,
|
90
90
|
global_vars: dict[str, Any] | None = None,
|
91
|
+
returns_stdout: bool = False,
|
92
|
+
outputs_intermediate: bool = False,
|
91
93
|
autofix: int = 3,
|
92
94
|
autofix_lm: lf.LanguageModel | None = None,
|
93
95
|
) -> Any:
|
@@ -101,13 +103,22 @@ class PythonCode(pg.Object):
|
|
101
103
|
timeout: Timeout in seconds. If None, there is no timeout. Applicable when
|
102
104
|
sandbox is set to True.
|
103
105
|
global_vars: Global variables that could be accessed from the source code.
|
106
|
+
returns_stdout: If True, the stdout (a str) will be returned.
|
107
|
+
outputs_intermediate: Applicable when returns_stdout is False. If True,
|
108
|
+
intermediate output will be outputted as a dict, with the last line's
|
109
|
+
value accessible by key '__result__' and the std output accessible by
|
110
|
+
key '__stdout__'. Otherwise the value of the last line will be returned.
|
104
111
|
autofix: Number of attempts to auto fix the generated code. If 0, autofix
|
105
112
|
is disabled.
|
106
113
|
autofix_lm: Language model to be used. If not specified, it will try to
|
107
114
|
use the `lm` under `lf.context`.
|
108
115
|
|
109
116
|
Returns:
|
110
|
-
The value of the last expression in the source code.
|
117
|
+
The value of the last expression in the source code. Or a dict of local
|
118
|
+
variable names defined in the source code to their values if
|
119
|
+
`outputs_intermediate` is set to True. The value for the last line can be
|
120
|
+
accessed by key '__result__'. Or the stdout as a str if `returns_stdout`
|
121
|
+
is set to True.
|
111
122
|
|
112
123
|
Raises:
|
113
124
|
TimeoutError: If `sandbox` is True and timeout has reached.
|
@@ -121,6 +132,8 @@ class PythonCode(pg.Object):
|
|
121
132
|
max_attempts=autofix,
|
122
133
|
lm=autofix_lm,
|
123
134
|
returns_code=True,
|
135
|
+
returns_stdout=returns_stdout,
|
136
|
+
outputs_intermediate=outputs_intermediate,
|
124
137
|
)
|
125
138
|
self.rebind(source=updated_code)
|
126
139
|
return result
|
@@ -158,18 +171,14 @@ class PythonCode(pg.Object):
|
|
158
171
|
TimeoutError: If `sandbox` is True and timeout has reached.
|
159
172
|
Exception: Any errors that the source code has raised.
|
160
173
|
"""
|
161
|
-
|
162
|
-
self.source,
|
163
|
-
global_vars=global_vars,
|
174
|
+
return self(
|
164
175
|
sandbox=sandbox,
|
165
176
|
timeout=timeout,
|
177
|
+
global_vars=global_vars,
|
178
|
+
autofix=autofix,
|
179
|
+
autofix_lm=autofix_lm,
|
166
180
|
outputs_intermediate=True,
|
167
|
-
max_attempts=autofix,
|
168
|
-
lm=autofix_lm,
|
169
|
-
returns_code=True,
|
170
181
|
)
|
171
|
-
self.rebind(source=updated_code)
|
172
|
-
return result
|
173
182
|
|
174
183
|
|
175
184
|
class PythonFunction(pg.Object):
|
langfun/core/concurrent.py
CHANGED
@@ -25,7 +25,6 @@ import time
|
|
25
25
|
from typing import Any, Callable, Iterable, Iterator, Literal, Sequence, Tuple, Type, Union
|
26
26
|
|
27
27
|
from langfun.core import component
|
28
|
-
from langfun.core import text_formatting
|
29
28
|
import pyglove as pg
|
30
29
|
|
31
30
|
|
@@ -844,10 +843,10 @@ class _ConsoleProgressControl(_ProgressControl):
|
|
844
843
|
def refresh(self):
|
845
844
|
s = io.StringIO()
|
846
845
|
if self.label is not None:
|
847
|
-
s.write(
|
846
|
+
s.write(pg.colored(self.label, 'red', styles=['bold']))
|
848
847
|
s.write(': ')
|
849
848
|
s.write(
|
850
|
-
|
849
|
+
pg.colored(
|
851
850
|
'%d%% (%d/%d)' %
|
852
851
|
(
|
853
852
|
self._progress * 100 // self.total,
|
langfun/core/console.py
CHANGED
@@ -15,7 +15,7 @@
|
|
15
15
|
|
16
16
|
import sys
|
17
17
|
from typing import Any
|
18
|
-
|
18
|
+
import pyglove as pg
|
19
19
|
|
20
20
|
|
21
21
|
def write(
|
@@ -42,10 +42,15 @@ def write(
|
|
42
42
|
"""
|
43
43
|
# Print title if present.
|
44
44
|
if title is not None:
|
45
|
-
print(colored(title, styles=['bold']))
|
45
|
+
print(pg.colored(title, styles=['bold']))
|
46
46
|
|
47
47
|
# Print body.
|
48
|
-
print(
|
48
|
+
print(dir(pg.utils))
|
49
|
+
print(
|
50
|
+
pg.colored(
|
51
|
+
str(value), color=color, background=background, styles=styles
|
52
|
+
)
|
53
|
+
)
|
49
54
|
|
50
55
|
|
51
56
|
try:
|
langfun/core/eval/base.py
CHANGED
@@ -1298,7 +1298,7 @@ class Evaluation(Evaluable):
|
|
1298
1298
|
id=self.id,
|
1299
1299
|
dir=self.dir,
|
1300
1300
|
model=self.lm.model_id,
|
1301
|
-
prompt_template=
|
1301
|
+
prompt_template=pg.decolor(str(self.prompt)),
|
1302
1302
|
method=self.method,
|
1303
1303
|
schema_fn=str(self.schema_fn),
|
1304
1304
|
),
|
@@ -2110,8 +2110,7 @@ class Summary(pg.Object):
|
|
2110
2110
|
|
2111
2111
|
def _format_error(error: Exception):
|
2112
2112
|
"""Formats an error into a string."""
|
2113
|
-
return (f'({error.__class__.__name__}) '
|
2114
|
-
+ lf.text_formatting.decolored(str(error)))
|
2113
|
+
return (f'({error.__class__.__name__}) ' + pg.decolor(str(error)))
|
2115
2114
|
|
2116
2115
|
|
2117
2116
|
def _error_key(error: Exception) -> str:
|
@@ -51,6 +51,7 @@ class HtmlReporter(experiment_lib.Plugin):
|
|
51
51
|
self._update_thread = None
|
52
52
|
self._stop_update = False
|
53
53
|
self._stop_update_experiment_ids = set()
|
54
|
+
self._summary_lock = None
|
54
55
|
self._experiment_index_lock = None
|
55
56
|
|
56
57
|
def on_run_start(
|
@@ -62,6 +63,7 @@ class HtmlReporter(experiment_lib.Plugin):
|
|
62
63
|
self._last_experiment_report_time = {leaf.id: 0 for leaf in root.leaf_nodes}
|
63
64
|
self._stop_update = False
|
64
65
|
self._stop_update_experiment_ids = set()
|
66
|
+
self._summary_lock = threading.Lock()
|
65
67
|
self._experiment_index_lock = {
|
66
68
|
leaf.id: threading.Lock() for leaf in root.leaf_nodes
|
67
69
|
}
|
@@ -137,21 +139,23 @@ class HtmlReporter(experiment_lib.Plugin):
|
|
137
139
|
"""Maybe update the summary of current run."""
|
138
140
|
run = runner.current_run
|
139
141
|
def _summary():
|
140
|
-
run.experiment.to_html(
|
142
|
+
html = run.experiment.to_html(
|
141
143
|
collapse_level=None,
|
142
144
|
extra_flags=dict(
|
143
145
|
current_run=run, interactive=False, card_view=True,
|
144
146
|
)
|
145
|
-
).save(
|
146
|
-
run.output_path_for(run.experiment, _SUMMARY_FILE)
|
147
147
|
)
|
148
|
+
with self._summary_lock:
|
149
|
+
html.save(
|
150
|
+
run.output_path_for(run.experiment, _SUMMARY_FILE)
|
151
|
+
)
|
148
152
|
|
149
153
|
if force or (time.time() - self._last_summary_time > self.summary_interval):
|
154
|
+
self._last_summary_time = time.time()
|
150
155
|
if background:
|
151
156
|
runner.background_run(_summary)
|
152
157
|
else:
|
153
158
|
_summary()
|
154
|
-
self._last_summary_time = time.time()
|
155
159
|
|
156
160
|
def _maybe_update_experiment_html(
|
157
161
|
self,
|
langfun/core/language_model.py
CHANGED
@@ -434,7 +434,10 @@ class LanguageModel(component.Component):
|
|
434
434
|
def __init__(self, *args, **kwargs) -> None:
|
435
435
|
"""Overrides __init__ to pass through **kwargs to sampling options."""
|
436
436
|
|
437
|
-
sampling_options = kwargs.pop(
|
437
|
+
sampling_options = kwargs.pop(
|
438
|
+
'sampling_options',
|
439
|
+
pg.clone(self.__schema__.fields['sampling_options'].default_value)
|
440
|
+
)
|
438
441
|
sampling_options_delta = {}
|
439
442
|
|
440
443
|
for k, v in kwargs.items():
|
@@ -650,7 +653,7 @@ class LanguageModel(component.Component):
|
|
650
653
|
"""Outputs debugging information about the model."""
|
651
654
|
title_suffix = ''
|
652
655
|
if usage.total_tokens != 0:
|
653
|
-
title_suffix =
|
656
|
+
title_suffix = pg.colored(
|
654
657
|
f' (total {usage.total_tokens} tokens)', 'red'
|
655
658
|
)
|
656
659
|
|
@@ -669,7 +672,7 @@ class LanguageModel(component.Component):
|
|
669
672
|
"""Outputs debugging information about the prompt."""
|
670
673
|
title_suffix = ''
|
671
674
|
if usage.prompt_tokens != 0:
|
672
|
-
title_suffix =
|
675
|
+
title_suffix = pg.colored(f' ({usage.prompt_tokens} tokens)', 'red')
|
673
676
|
|
674
677
|
console.write(
|
675
678
|
# We use metadata 'formatted_text' for scenarios where the prompt text
|
@@ -700,7 +703,7 @@ class LanguageModel(component.Component):
|
|
700
703
|
if usage.completion_tokens != 0:
|
701
704
|
title_suffix += f'{usage.completion_tokens} tokens '
|
702
705
|
title_suffix += f'in {elapse:.2f} seconds)'
|
703
|
-
title_suffix =
|
706
|
+
title_suffix = pg.colored(title_suffix, 'red')
|
704
707
|
|
705
708
|
console.write(
|
706
709
|
str(response) + '\n',
|
@@ -117,6 +117,21 @@ class LanguageModelTest(unittest.TestCase):
|
|
117
117
|
self.assertEqual(lm.sampling_options.top_k, 2)
|
118
118
|
self.assertEqual(lm.max_attempts, 2)
|
119
119
|
|
120
|
+
def test_subclassing(self):
|
121
|
+
|
122
|
+
class ChildModel(lm_lib.LanguageModel):
|
123
|
+
|
124
|
+
sampling_options = lm_lib.LMSamplingOptions(
|
125
|
+
temperature=0.5, top_k=20
|
126
|
+
)
|
127
|
+
|
128
|
+
def _sample(self, *args, **kwargs):
|
129
|
+
pass
|
130
|
+
|
131
|
+
lm = ChildModel(top_k=10)
|
132
|
+
self.assertEqual(lm.sampling_options.temperature, 0.5)
|
133
|
+
self.assertEqual(lm.sampling_options.top_k, 10)
|
134
|
+
|
120
135
|
def test_sample(self):
|
121
136
|
lm = MockModel(top_k=1)
|
122
137
|
self.assertEqual(
|
langfun/core/llms/__init__.py
CHANGED
@@ -57,6 +57,9 @@ from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_002
|
|
57
57
|
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_001
|
58
58
|
from langfun.core.llms.vertexai import VertexAIGeminiPro1
|
59
59
|
|
60
|
+
# Base for OpenAI-compatible models.
|
61
|
+
from langfun.core.llms.openai_compatible import OpenAICompatible
|
62
|
+
|
60
63
|
# OpenAI models.
|
61
64
|
from langfun.core.llms.openai import OpenAI
|
62
65
|
|
@@ -141,6 +144,10 @@ from langfun.core.llms.groq import GroqWhisper_Large_v3Turbo
|
|
141
144
|
# LLaMA C++ models.
|
142
145
|
from langfun.core.llms.llama_cpp import LlamaCppRemote
|
143
146
|
|
147
|
+
# DeepSeek models.
|
148
|
+
from langfun.core.llms.deepseek import DeepSeek
|
149
|
+
from langfun.core.llms.deepseek import DeepSeekChat
|
150
|
+
|
144
151
|
# Placeholder for Google-internal imports.
|
145
152
|
|
146
153
|
# Include cache as sub-module.
|
@@ -0,0 +1,117 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Language models from DeepSeek."""
|
15
|
+
|
16
|
+
import os
|
17
|
+
from typing import Annotated, Any
|
18
|
+
|
19
|
+
import langfun.core as lf
|
20
|
+
from langfun.core.llms import openai_compatible
|
21
|
+
import pyglove as pg
|
22
|
+
|
23
|
+
SUPPORTED_MODELS_AND_SETTINGS = {
|
24
|
+
# pylint: disable=g-line-too-long
|
25
|
+
# TODO(yifenglu): The RPM and TPM are arbitrary numbers. Update them once DeepSeek provides concrete guidelines.
|
26
|
+
# DeepSeek doesn't control the rate limit at the moment: https://api-docs.deepseek.com/quick_start/rate_limit
|
27
|
+
# The cost is based on: https://api-docs.deepseek.com/quick_start/pricing
|
28
|
+
'deepseek-chat': pg.Dict(
|
29
|
+
in_service=True,
|
30
|
+
rpm=100,
|
31
|
+
tpm=1000000,
|
32
|
+
cost_per_1k_input_tokens=0.00014,
|
33
|
+
cost_per_1k_output_tokens=0.00028,
|
34
|
+
),
|
35
|
+
}
|
36
|
+
|
37
|
+
|
38
|
+
# DeepSeek API uses an API format compatible with OpenAI.
|
39
|
+
# Reference: https://api-docs.deepseek.com/
|
40
|
+
@lf.use_init_args(['model'])
|
41
|
+
class DeepSeek(openai_compatible.OpenAICompatible):
|
42
|
+
"""DeepSeek model."""
|
43
|
+
|
44
|
+
model: pg.typing.Annotated[
|
45
|
+
pg.typing.Enum(
|
46
|
+
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
|
47
|
+
),
|
48
|
+
'The name of the model to use.',
|
49
|
+
]
|
50
|
+
|
51
|
+
api_endpoint: str = 'https://api.deepseek.com/chat/completions'
|
52
|
+
|
53
|
+
api_key: Annotated[
|
54
|
+
str | None,
|
55
|
+
(
|
56
|
+
'API key. If None, the key will be read from environment variable '
|
57
|
+
"'DEEPSEEK_API_KEY'."
|
58
|
+
),
|
59
|
+
] = None
|
60
|
+
|
61
|
+
@property
|
62
|
+
def headers(self) -> dict[str, Any]:
|
63
|
+
api_key = self.api_key or os.environ.get('DEEPSEEK_API_KEY', None)
|
64
|
+
if not api_key:
|
65
|
+
raise ValueError(
|
66
|
+
'Please specify `api_key` during `__init__` or set environment '
|
67
|
+
'variable `DEEPSEEK_API_KEY` with your DeepSeek API key.'
|
68
|
+
)
|
69
|
+
headers = super().headers
|
70
|
+
headers.update({
|
71
|
+
'Authorization': f'Bearer {api_key}',
|
72
|
+
})
|
73
|
+
return headers
|
74
|
+
|
75
|
+
@property
|
76
|
+
def model_id(self) -> str:
|
77
|
+
"""Returns a string to identify the model."""
|
78
|
+
return f'DeepSeek({self.model})'
|
79
|
+
|
80
|
+
@property
|
81
|
+
def max_concurrency(self) -> int:
|
82
|
+
rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
|
83
|
+
tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
|
84
|
+
return self.rate_to_max_concurrency(
|
85
|
+
requests_per_min=rpm, tokens_per_min=tpm
|
86
|
+
)
|
87
|
+
|
88
|
+
def estimate_cost(
|
89
|
+
self, num_input_tokens: int, num_output_tokens: int
|
90
|
+
) -> float | None:
|
91
|
+
"""Estimate the cost based on usage."""
|
92
|
+
cost_per_1k_input_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
93
|
+
'cost_per_1k_input_tokens', None
|
94
|
+
)
|
95
|
+
cost_per_1k_output_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
96
|
+
'cost_per_1k_output_tokens', None
|
97
|
+
)
|
98
|
+
if cost_per_1k_output_tokens is None or cost_per_1k_input_tokens is None:
|
99
|
+
return None
|
100
|
+
return (
|
101
|
+
cost_per_1k_input_tokens * num_input_tokens
|
102
|
+
+ cost_per_1k_output_tokens * num_output_tokens
|
103
|
+
) / 1000
|
104
|
+
|
105
|
+
@classmethod
|
106
|
+
def dir(cls):
|
107
|
+
return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
|
108
|
+
|
109
|
+
|
110
|
+
class DeepSeekChat(DeepSeek):
|
111
|
+
"""DeepSeek Chat model.
|
112
|
+
|
113
|
+
Currently, it is powered by DeepSeek-V3 model, 64K input contenxt window and
|
114
|
+
8k max output tokens.
|
115
|
+
"""
|
116
|
+
|
117
|
+
model = 'deepseek-chat'
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# Copyright 2023 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import unittest
|
15
|
+
from langfun.core.llms import deepseek
|
16
|
+
|
17
|
+
|
18
|
+
class DeepSeekTest(unittest.TestCase):
|
19
|
+
"""Tests for DeepSeek language model."""
|
20
|
+
|
21
|
+
def test_dir(self):
|
22
|
+
self.assertIn('deepseek-chat', deepseek.DeepSeek.dir())
|
23
|
+
|
24
|
+
def test_key(self):
|
25
|
+
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
26
|
+
_ = deepseek.DeepSeekChat().headers
|
27
|
+
self.assertEqual(
|
28
|
+
deepseek.DeepSeekChat(api_key='test_key').headers,
|
29
|
+
{
|
30
|
+
'Content-Type': 'application/json',
|
31
|
+
'Authorization': 'Bearer test_key',
|
32
|
+
}
|
33
|
+
)
|
34
|
+
|
35
|
+
def test_model_id(self):
|
36
|
+
self.assertEqual(
|
37
|
+
deepseek.DeepSeekChat(api_key='test_key').model_id,
|
38
|
+
'DeepSeek(deepseek-chat)',
|
39
|
+
)
|
40
|
+
|
41
|
+
def test_resource_id(self):
|
42
|
+
self.assertEqual(
|
43
|
+
deepseek.DeepSeekChat(api_key='test_key').resource_id,
|
44
|
+
'DeepSeek(deepseek-chat)',
|
45
|
+
)
|
46
|
+
|
47
|
+
def test_max_concurrency(self):
|
48
|
+
self.assertGreater(
|
49
|
+
deepseek.DeepSeekChat(api_key='test_key').max_concurrency, 0
|
50
|
+
)
|
51
|
+
|
52
|
+
def test_estimate_cost(self):
|
53
|
+
self.assertEqual(
|
54
|
+
deepseek.DeepSeekChat(api_key='test_key').estimate_cost(
|
55
|
+
num_input_tokens=100, num_output_tokens=100
|
56
|
+
),
|
57
|
+
4.2e-5
|
58
|
+
)
|
59
|
+
|
60
|
+
if __name__ == '__main__':
|
61
|
+
unittest.main()
|