langfun 0.1.2.dev202501050804__py3-none-any.whl → 0.1.2.dev202501090804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/__init__.py +0 -5
- langfun/core/coding/python/correction.py +4 -3
- langfun/core/coding/python/errors.py +10 -9
- langfun/core/coding/python/execution.py +23 -12
- langfun/core/coding/python/execution_test.py +21 -2
- langfun/core/coding/python/generation.py +18 -9
- langfun/core/concurrent.py +2 -3
- langfun/core/console.py +8 -3
- langfun/core/eval/base.py +2 -3
- langfun/core/eval/v2/reporting.py +15 -6
- langfun/core/language_model.py +7 -4
- langfun/core/language_model_test.py +15 -0
- langfun/core/llms/__init__.py +25 -26
- langfun/core/llms/cache/in_memory.py +6 -0
- langfun/core/llms/cache/in_memory_test.py +5 -0
- langfun/core/llms/deepseek.py +261 -0
- langfun/core/llms/deepseek_test.py +438 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +46 -320
- langfun/core/llms/google_genai_test.py +9 -204
- langfun/core/llms/openai.py +5 -0
- langfun/core/llms/vertexai.py +31 -359
- langfun/core/llms/vertexai_test.py +6 -166
- langfun/core/structured/mapping.py +13 -13
- langfun/core/structured/mapping_test.py +2 -2
- langfun/core/structured/schema.py +16 -8
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/METADATA +19 -14
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/RECORD +32 -30
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/WHEEL +1 -1
- langfun/core/text_formatting.py +0 -168
- langfun/core/text_formatting_test.py +0 -65
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/top_level.txt +0 -0
langfun/core/__init__.py
CHANGED
@@ -77,11 +77,6 @@ from langfun.core.concurrent import concurrent_map
|
|
77
77
|
from langfun.core.concurrent import with_context_access
|
78
78
|
from langfun.core.concurrent import with_retry
|
79
79
|
|
80
|
-
# Utility libraries for text formatting.
|
81
|
-
from langfun.core.text_formatting import colored
|
82
|
-
from langfun.core.text_formatting import colored_print as print # pylint: disable=redefined-builtin
|
83
|
-
from langfun.core.text_formatting import colored_template
|
84
|
-
|
85
80
|
# Interface for natural language formattable.
|
86
81
|
from langfun.core.natural_language import NaturalLanguageFormattable
|
87
82
|
|
@@ -40,6 +40,7 @@ def run_with_correction(
|
|
40
40
|
sandbox: bool | None = None,
|
41
41
|
timeout: int | None = 5,
|
42
42
|
returns_code: bool = False,
|
43
|
+
returns_stdout: bool = False,
|
43
44
|
outputs_intermediate: bool = False,
|
44
45
|
) -> Any | tuple[Any, str]:
|
45
46
|
"""Correct code with a language model via self-play.
|
@@ -62,6 +63,7 @@ def run_with_correction(
|
|
62
63
|
timeout. Applicable only when sandbox is set to True.
|
63
64
|
returns_code: If True, the return value is a tuple of (result, final code).
|
64
65
|
Otherwise the return value is the result only.
|
66
|
+
returns_stdout: If True, the stdout (a str) will be returned.
|
65
67
|
outputs_intermediate: If True, intermediate output will be outputted as a
|
66
68
|
dict, with the last line's value accessible by key '__result__'. Otherwise
|
67
69
|
the value of the last line will be returned.
|
@@ -87,6 +89,7 @@ def run_with_correction(
|
|
87
89
|
global_vars=global_vars,
|
88
90
|
sandbox=sandbox,
|
89
91
|
timeout=timeout,
|
92
|
+
returns_stdout=returns_stdout,
|
90
93
|
outputs_intermediate=outputs_intermediate,
|
91
94
|
)
|
92
95
|
)
|
@@ -189,9 +192,7 @@ def correct(
|
|
189
192
|
def _error_feedback_str(error: Exception) -> str:
|
190
193
|
"""Returns the error str for feedback."""
|
191
194
|
if isinstance(error, errors.CodeError):
|
192
|
-
return
|
193
|
-
error.format(include_complete_code=False)
|
194
|
-
)
|
195
|
+
return pg.decolor(error.format(include_complete_code=False))
|
195
196
|
else:
|
196
197
|
return f"Encountered {error.__class__.__name__}: {error}"
|
197
198
|
|
@@ -17,7 +17,8 @@ import io
|
|
17
17
|
import sys
|
18
18
|
import textwrap
|
19
19
|
import traceback
|
20
|
-
|
20
|
+
|
21
|
+
import pyglove as pg
|
21
22
|
|
22
23
|
|
23
24
|
class CodeError(RuntimeError):
|
@@ -62,13 +63,13 @@ class CodeError(RuntimeError):
|
|
62
63
|
if 'line' not in error_message and self.lineno is not None:
|
63
64
|
error_message += f' (<unknown>, line {self.lineno})'
|
64
65
|
r.write(
|
65
|
-
|
66
|
+
pg.colored(
|
66
67
|
f'{self.cause.__class__.__name__}: {error_message}', 'magenta'))
|
67
68
|
|
68
69
|
if self.lineno is not None:
|
69
70
|
r.write('\n\n')
|
70
71
|
r.write(textwrap.indent(
|
71
|
-
|
72
|
+
pg.colored(
|
72
73
|
self.code_lines(self.lineno - 1, self.end_lineno), 'magenta'),
|
73
74
|
' ' * 2
|
74
75
|
))
|
@@ -76,14 +77,14 @@ class CodeError(RuntimeError):
|
|
76
77
|
|
77
78
|
if include_complete_code:
|
78
79
|
r.write('\n')
|
79
|
-
r.write(
|
80
|
+
r.write(pg.colored('[Generated Code]', 'green', styles=['bold']))
|
80
81
|
r.write('\n\n')
|
81
|
-
r.write(
|
82
|
+
r.write(pg.colored(' ```python\n', 'green'))
|
82
83
|
r.write(textwrap.indent(
|
83
|
-
|
84
|
+
pg.colored(self.code, 'green'),
|
84
85
|
' ' * 2
|
85
86
|
))
|
86
|
-
r.write(
|
87
|
+
r.write(pg.colored('\n ```\n', 'green'))
|
87
88
|
return r.getvalue()
|
88
89
|
|
89
90
|
|
@@ -98,10 +99,10 @@ class SerializationError(RuntimeError):
|
|
98
99
|
r = io.StringIO()
|
99
100
|
cause_message = str(self.cause).rstrip()
|
100
101
|
if self.message:
|
101
|
-
r.write(
|
102
|
+
r.write(pg.colored(self.message, 'magenta'))
|
102
103
|
r.write('\n\n')
|
103
104
|
r.write(
|
104
|
-
|
105
|
+
pg.colored(
|
105
106
|
f'{self.cause.__class__.__name__}: {cause_message}', 'magenta'
|
106
107
|
)
|
107
108
|
)
|
@@ -57,6 +57,7 @@ def evaluate(
|
|
57
57
|
*,
|
58
58
|
global_vars: dict[str, Any] | None = None,
|
59
59
|
permission: permissions.CodePermission | None = None,
|
60
|
+
returns_stdout: bool = False,
|
60
61
|
outputs_intermediate: bool = False,
|
61
62
|
) -> Any | dict[str, Any]:
|
62
63
|
"""Executes Python code.
|
@@ -71,14 +72,17 @@ def evaluate(
|
|
71
72
|
global_vars: An optional dict as the globals that could be referenced by the
|
72
73
|
code.
|
73
74
|
permission: Permission for the Python code to run.
|
74
|
-
|
75
|
-
|
76
|
-
|
75
|
+
returns_stdout: If True, the stdout (a str) will be returned.
|
76
|
+
outputs_intermediate: Applicable when returns_stdout is False. If True,
|
77
|
+
intermediate output will be outputted as a dict, with the last line's
|
78
|
+
value accessible by key '__result__' and the std output accessible by
|
79
|
+
key '__stdout__'. Otherwise the value of the last line will be returned.
|
77
80
|
|
78
81
|
Returns:
|
79
|
-
The value of the last line of the code. Or a dict of variable
|
80
|
-
|
81
|
-
|
82
|
+
The value of the last line of the code block. Or a dict of variable
|
83
|
+
names of all locals to their evaluated values as the output of the code to
|
84
|
+
run. The value for the last line can be accessed by key '__result__'. Or the
|
85
|
+
stdout as a str.
|
82
86
|
"""
|
83
87
|
# Set up the permission and context.
|
84
88
|
permission = permission or permissions.get_permission()
|
@@ -136,6 +140,8 @@ def evaluate(
|
|
136
140
|
raise errors.CodeError(code, e) from e
|
137
141
|
global_vars[RESULT_KEY] = list(global_vars.values())[-1]
|
138
142
|
|
143
|
+
if returns_stdout:
|
144
|
+
return stdout.getvalue()
|
139
145
|
if outputs_intermediate:
|
140
146
|
outputs = {}
|
141
147
|
for k, v in global_vars.items():
|
@@ -258,6 +264,7 @@ def run(
|
|
258
264
|
*,
|
259
265
|
global_vars: dict[str, Any] | None = None,
|
260
266
|
permission: permissions.CodePermission | None = None,
|
267
|
+
returns_stdout: bool = False,
|
261
268
|
outputs_intermediate: bool = False,
|
262
269
|
sandbox: bool | None = None,
|
263
270
|
timeout: float | None = None,
|
@@ -273,9 +280,11 @@ def run(
|
|
273
280
|
code: Python code to run.
|
274
281
|
global_vars: An optional dict of
|
275
282
|
permission: Permission for the Python code to run.
|
276
|
-
|
277
|
-
|
278
|
-
|
283
|
+
returns_stdout: If True, the stdout (a str) will be returned.
|
284
|
+
outputs_intermediate: Applicable when returns_stdout is False. If True,
|
285
|
+
intermediate output will be outputted as a dict, with the last line's
|
286
|
+
value accessible by key '__result__' and the std output accessible by
|
287
|
+
key '__stdout__'. Otherwise the value of the last line will be returned.
|
279
288
|
sandbox: If True, run code in sandbox; If False, run code in current
|
280
289
|
process. If None, run in sandbox first, if the output could not be
|
281
290
|
serialized and pass to current process, run the code again in current
|
@@ -285,7 +294,8 @@ def run(
|
|
285
294
|
Returns:
|
286
295
|
The value of the last line of the code block. Or a dict of variable
|
287
296
|
names of all locals to their evaluated values as the output of the code to
|
288
|
-
run. The value for the last line can be accessed by key '__result__'.
|
297
|
+
run. The value for the last line can be accessed by key '__result__'. Or the
|
298
|
+
stdout as a str.
|
289
299
|
|
290
300
|
Raises:
|
291
301
|
TimeoutError: If the execution time exceeds the timeout.
|
@@ -293,5 +303,6 @@ def run(
|
|
293
303
|
"""
|
294
304
|
return call(
|
295
305
|
evaluate, code=code, global_vars=global_vars, permission=permission,
|
296
|
-
outputs_intermediate=outputs_intermediate,
|
297
|
-
sandbox=sandbox, timeout=timeout
|
306
|
+
returns_stdout=returns_stdout, outputs_intermediate=outputs_intermediate,
|
307
|
+
sandbox=sandbox, timeout=timeout
|
308
|
+
)
|
@@ -63,6 +63,15 @@ class EvaluateTest(unittest.TestCase):
|
|
63
63
|
),
|
64
64
|
3,
|
65
65
|
)
|
66
|
+
with self.assertRaisesRegex(errors.CodeError, 'ValueError'):
|
67
|
+
execution.evaluate(
|
68
|
+
"""
|
69
|
+
def foo():
|
70
|
+
raise ValueError("intentional error")
|
71
|
+
foo()
|
72
|
+
""",
|
73
|
+
permission=permissions.CodePermission.ALL
|
74
|
+
)
|
66
75
|
|
67
76
|
def test_class_def(self):
|
68
77
|
ret = execution.evaluate(
|
@@ -102,16 +111,20 @@ class EvaluateTest(unittest.TestCase):
|
|
102
111
|
self.assertIs(ret['__result__'], ret['bar'])
|
103
112
|
|
104
113
|
def test_function_def_and_call(self):
|
105
|
-
|
114
|
+
code = (
|
106
115
|
"""
|
107
116
|
def foo(x, y):
|
108
117
|
return x + y
|
109
118
|
|
110
119
|
def bar(z):
|
120
|
+
print(f'z is {z}')
|
111
121
|
return z + foo(z, z)
|
112
122
|
|
113
123
|
bar(1)
|
114
|
-
"""
|
124
|
+
"""
|
125
|
+
)
|
126
|
+
ret = execution.evaluate(
|
127
|
+
code,
|
115
128
|
permission=permissions.CodePermission.ALL,
|
116
129
|
outputs_intermediate=True,
|
117
130
|
)
|
@@ -119,6 +132,12 @@ class EvaluateTest(unittest.TestCase):
|
|
119
132
|
list(ret.keys()), ['foo', 'bar', '__result__', '__stdout__']
|
120
133
|
)
|
121
134
|
self.assertEqual(ret['__result__'], 3)
|
135
|
+
ret = execution.evaluate(
|
136
|
+
code,
|
137
|
+
permission=permissions.CodePermission.ALL,
|
138
|
+
returns_stdout=True,
|
139
|
+
)
|
140
|
+
self.assertEqual(ret, 'z is 1\n')
|
122
141
|
|
123
142
|
def test_complex(self):
|
124
143
|
ret = execution.evaluate(
|
@@ -88,6 +88,8 @@ class PythonCode(pg.Object):
|
|
88
88
|
sandbox: bool | None = None,
|
89
89
|
timeout: int | None = 5,
|
90
90
|
global_vars: dict[str, Any] | None = None,
|
91
|
+
returns_stdout: bool = False,
|
92
|
+
outputs_intermediate: bool = False,
|
91
93
|
autofix: int = 3,
|
92
94
|
autofix_lm: lf.LanguageModel | None = None,
|
93
95
|
) -> Any:
|
@@ -101,13 +103,22 @@ class PythonCode(pg.Object):
|
|
101
103
|
timeout: Timeout in seconds. If None, there is no timeout. Applicable when
|
102
104
|
sandbox is set to True.
|
103
105
|
global_vars: Global variables that could be accessed from the source code.
|
106
|
+
returns_stdout: If True, the stdout (a str) will be returned.
|
107
|
+
outputs_intermediate: Applicable when returns_stdout is False. If True,
|
108
|
+
intermediate output will be outputted as a dict, with the last line's
|
109
|
+
value accessible by key '__result__' and the std output accessible by
|
110
|
+
key '__stdout__'. Otherwise the value of the last line will be returned.
|
104
111
|
autofix: Number of attempts to auto fix the generated code. If 0, autofix
|
105
112
|
is disabled.
|
106
113
|
autofix_lm: Language model to be used. If not specified, it will try to
|
107
114
|
use the `lm` under `lf.context`.
|
108
115
|
|
109
116
|
Returns:
|
110
|
-
The value of the last expression in the source code.
|
117
|
+
The value of the last expression in the source code. Or a dict of local
|
118
|
+
variable names defined in the source code to their values if
|
119
|
+
`outputs_intermediate` is set to True. The value for the last line can be
|
120
|
+
accessed by key '__result__'. Or the stdout as a str if `returns_stdout`
|
121
|
+
is set to True.
|
111
122
|
|
112
123
|
Raises:
|
113
124
|
TimeoutError: If `sandbox` is True and timeout has reached.
|
@@ -121,6 +132,8 @@ class PythonCode(pg.Object):
|
|
121
132
|
max_attempts=autofix,
|
122
133
|
lm=autofix_lm,
|
123
134
|
returns_code=True,
|
135
|
+
returns_stdout=returns_stdout,
|
136
|
+
outputs_intermediate=outputs_intermediate,
|
124
137
|
)
|
125
138
|
self.rebind(source=updated_code)
|
126
139
|
return result
|
@@ -158,18 +171,14 @@ class PythonCode(pg.Object):
|
|
158
171
|
TimeoutError: If `sandbox` is True and timeout has reached.
|
159
172
|
Exception: Any errors that the source code has raised.
|
160
173
|
"""
|
161
|
-
|
162
|
-
self.source,
|
163
|
-
global_vars=global_vars,
|
174
|
+
return self(
|
164
175
|
sandbox=sandbox,
|
165
176
|
timeout=timeout,
|
177
|
+
global_vars=global_vars,
|
178
|
+
autofix=autofix,
|
179
|
+
autofix_lm=autofix_lm,
|
166
180
|
outputs_intermediate=True,
|
167
|
-
max_attempts=autofix,
|
168
|
-
lm=autofix_lm,
|
169
|
-
returns_code=True,
|
170
181
|
)
|
171
|
-
self.rebind(source=updated_code)
|
172
|
-
return result
|
173
182
|
|
174
183
|
|
175
184
|
class PythonFunction(pg.Object):
|
langfun/core/concurrent.py
CHANGED
@@ -25,7 +25,6 @@ import time
|
|
25
25
|
from typing import Any, Callable, Iterable, Iterator, Literal, Sequence, Tuple, Type, Union
|
26
26
|
|
27
27
|
from langfun.core import component
|
28
|
-
from langfun.core import text_formatting
|
29
28
|
import pyglove as pg
|
30
29
|
|
31
30
|
|
@@ -844,10 +843,10 @@ class _ConsoleProgressControl(_ProgressControl):
|
|
844
843
|
def refresh(self):
|
845
844
|
s = io.StringIO()
|
846
845
|
if self.label is not None:
|
847
|
-
s.write(
|
846
|
+
s.write(pg.colored(self.label, 'red', styles=['bold']))
|
848
847
|
s.write(': ')
|
849
848
|
s.write(
|
850
|
-
|
849
|
+
pg.colored(
|
851
850
|
'%d%% (%d/%d)' %
|
852
851
|
(
|
853
852
|
self._progress * 100 // self.total,
|
langfun/core/console.py
CHANGED
@@ -15,7 +15,7 @@
|
|
15
15
|
|
16
16
|
import sys
|
17
17
|
from typing import Any
|
18
|
-
|
18
|
+
import pyglove as pg
|
19
19
|
|
20
20
|
|
21
21
|
def write(
|
@@ -42,10 +42,15 @@ def write(
|
|
42
42
|
"""
|
43
43
|
# Print title if present.
|
44
44
|
if title is not None:
|
45
|
-
print(colored(title, styles=['bold']))
|
45
|
+
print(pg.colored(title, styles=['bold']))
|
46
46
|
|
47
47
|
# Print body.
|
48
|
-
print(
|
48
|
+
print(dir(pg.utils))
|
49
|
+
print(
|
50
|
+
pg.colored(
|
51
|
+
str(value), color=color, background=background, styles=styles
|
52
|
+
)
|
53
|
+
)
|
49
54
|
|
50
55
|
|
51
56
|
try:
|
langfun/core/eval/base.py
CHANGED
@@ -1298,7 +1298,7 @@ class Evaluation(Evaluable):
|
|
1298
1298
|
id=self.id,
|
1299
1299
|
dir=self.dir,
|
1300
1300
|
model=self.lm.model_id,
|
1301
|
-
prompt_template=
|
1301
|
+
prompt_template=pg.decolor(str(self.prompt)),
|
1302
1302
|
method=self.method,
|
1303
1303
|
schema_fn=str(self.schema_fn),
|
1304
1304
|
),
|
@@ -2110,8 +2110,7 @@ class Summary(pg.Object):
|
|
2110
2110
|
|
2111
2111
|
def _format_error(error: Exception):
|
2112
2112
|
"""Formats an error into a string."""
|
2113
|
-
return (f'({error.__class__.__name__}) '
|
2114
|
-
+ lf.text_formatting.decolored(str(error)))
|
2113
|
+
return (f'({error.__class__.__name__}) ' + pg.decolor(str(error)))
|
2115
2114
|
|
2116
2115
|
|
2117
2116
|
def _error_key(error: Exception) -> str:
|
@@ -51,6 +51,8 @@ class HtmlReporter(experiment_lib.Plugin):
|
|
51
51
|
self._update_thread = None
|
52
52
|
self._stop_update = False
|
53
53
|
self._stop_update_experiment_ids = set()
|
54
|
+
self._summary_lock = None
|
55
|
+
self._experiment_index_lock = None
|
54
56
|
|
55
57
|
def on_run_start(
|
56
58
|
self,
|
@@ -61,6 +63,10 @@ class HtmlReporter(experiment_lib.Plugin):
|
|
61
63
|
self._last_experiment_report_time = {leaf.id: 0 for leaf in root.leaf_nodes}
|
62
64
|
self._stop_update = False
|
63
65
|
self._stop_update_experiment_ids = set()
|
66
|
+
self._summary_lock = threading.Lock()
|
67
|
+
self._experiment_index_lock = {
|
68
|
+
leaf.id: threading.Lock() for leaf in root.leaf_nodes
|
69
|
+
}
|
64
70
|
self._update_thread = threading.Thread(
|
65
71
|
target=self._update_thread_func, args=(runner,)
|
66
72
|
)
|
@@ -133,21 +139,23 @@ class HtmlReporter(experiment_lib.Plugin):
|
|
133
139
|
"""Maybe update the summary of current run."""
|
134
140
|
run = runner.current_run
|
135
141
|
def _summary():
|
136
|
-
run.experiment.to_html(
|
142
|
+
html = run.experiment.to_html(
|
137
143
|
collapse_level=None,
|
138
144
|
extra_flags=dict(
|
139
145
|
current_run=run, interactive=False, card_view=True,
|
140
146
|
)
|
141
|
-
).save(
|
142
|
-
run.output_path_for(run.experiment, _SUMMARY_FILE)
|
143
147
|
)
|
148
|
+
with self._summary_lock:
|
149
|
+
html.save(
|
150
|
+
run.output_path_for(run.experiment, _SUMMARY_FILE)
|
151
|
+
)
|
144
152
|
|
145
153
|
if force or (time.time() - self._last_summary_time > self.summary_interval):
|
154
|
+
self._last_summary_time = time.time()
|
146
155
|
if background:
|
147
156
|
runner.background_run(_summary)
|
148
157
|
else:
|
149
158
|
_summary()
|
150
|
-
self._last_summary_time = time.time()
|
151
159
|
|
152
160
|
def _maybe_update_experiment_html(
|
153
161
|
self,
|
@@ -170,7 +178,8 @@ class HtmlReporter(experiment_lib.Plugin):
|
|
170
178
|
card_view=False,
|
171
179
|
),
|
172
180
|
)
|
173
|
-
|
181
|
+
with self._experiment_index_lock[experiment.id]:
|
182
|
+
html.save(index_html_path)
|
174
183
|
experiment.info(
|
175
184
|
f'Updated {index_html_path!r} in {t.elapse:.2f} seconds.',
|
176
185
|
)
|
@@ -185,11 +194,11 @@ class HtmlReporter(experiment_lib.Plugin):
|
|
185
194
|
time.time() - self._last_experiment_report_time[experiment.id]
|
186
195
|
> self.experiment_report_interval
|
187
196
|
):
|
197
|
+
self._last_experiment_report_time[experiment.id] = time.time()
|
188
198
|
if background:
|
189
199
|
runner.background_run(_save)
|
190
200
|
else:
|
191
201
|
_save()
|
192
|
-
self._last_experiment_report_time[experiment.id] = time.time()
|
193
202
|
|
194
203
|
def _save_example_html(
|
195
204
|
self, runner: Runner, experiment: Experiment, example: Example
|
langfun/core/language_model.py
CHANGED
@@ -434,7 +434,10 @@ class LanguageModel(component.Component):
|
|
434
434
|
def __init__(self, *args, **kwargs) -> None:
|
435
435
|
"""Overrides __init__ to pass through **kwargs to sampling options."""
|
436
436
|
|
437
|
-
sampling_options = kwargs.pop(
|
437
|
+
sampling_options = kwargs.pop(
|
438
|
+
'sampling_options',
|
439
|
+
pg.clone(self.__schema__.fields['sampling_options'].default_value)
|
440
|
+
)
|
438
441
|
sampling_options_delta = {}
|
439
442
|
|
440
443
|
for k, v in kwargs.items():
|
@@ -650,7 +653,7 @@ class LanguageModel(component.Component):
|
|
650
653
|
"""Outputs debugging information about the model."""
|
651
654
|
title_suffix = ''
|
652
655
|
if usage.total_tokens != 0:
|
653
|
-
title_suffix =
|
656
|
+
title_suffix = pg.colored(
|
654
657
|
f' (total {usage.total_tokens} tokens)', 'red'
|
655
658
|
)
|
656
659
|
|
@@ -669,7 +672,7 @@ class LanguageModel(component.Component):
|
|
669
672
|
"""Outputs debugging information about the prompt."""
|
670
673
|
title_suffix = ''
|
671
674
|
if usage.prompt_tokens != 0:
|
672
|
-
title_suffix =
|
675
|
+
title_suffix = pg.colored(f' ({usage.prompt_tokens} tokens)', 'red')
|
673
676
|
|
674
677
|
console.write(
|
675
678
|
# We use metadata 'formatted_text' for scenarios where the prompt text
|
@@ -700,7 +703,7 @@ class LanguageModel(component.Component):
|
|
700
703
|
if usage.completion_tokens != 0:
|
701
704
|
title_suffix += f'{usage.completion_tokens} tokens '
|
702
705
|
title_suffix += f'in {elapse:.2f} seconds)'
|
703
|
-
title_suffix =
|
706
|
+
title_suffix = pg.colored(title_suffix, 'red')
|
704
707
|
|
705
708
|
console.write(
|
706
709
|
str(response) + '\n',
|
@@ -117,6 +117,21 @@ class LanguageModelTest(unittest.TestCase):
|
|
117
117
|
self.assertEqual(lm.sampling_options.top_k, 2)
|
118
118
|
self.assertEqual(lm.max_attempts, 2)
|
119
119
|
|
120
|
+
def test_subclassing(self):
|
121
|
+
|
122
|
+
class ChildModel(lm_lib.LanguageModel):
|
123
|
+
|
124
|
+
sampling_options = lm_lib.LMSamplingOptions(
|
125
|
+
temperature=0.5, top_k=20
|
126
|
+
)
|
127
|
+
|
128
|
+
def _sample(self, *args, **kwargs):
|
129
|
+
pass
|
130
|
+
|
131
|
+
lm = ChildModel(top_k=10)
|
132
|
+
self.assertEqual(lm.sampling_options.temperature, 0.5)
|
133
|
+
self.assertEqual(lm.sampling_options.top_k, 10)
|
134
|
+
|
120
135
|
def test_sample(self):
|
121
136
|
lm = MockModel(top_k=1)
|
122
137
|
self.assertEqual(
|
langfun/core/llms/__init__.py
CHANGED
@@ -32,16 +32,30 @@ from langfun.core.llms.rest import REST
|
|
32
32
|
|
33
33
|
# Gemini models.
|
34
34
|
from langfun.core.llms.google_genai import GenAI
|
35
|
-
from langfun.core.llms.google_genai import
|
35
|
+
from langfun.core.llms.google_genai import GeminiFlash2_0ThinkingExp_20241219
|
36
36
|
from langfun.core.llms.google_genai import GeminiFlash2_0Exp
|
37
|
-
from langfun.core.llms.google_genai import GeminiExp_20241114
|
38
37
|
from langfun.core.llms.google_genai import GeminiExp_20241206
|
39
|
-
from langfun.core.llms.google_genai import
|
38
|
+
from langfun.core.llms.google_genai import GeminiExp_20241114
|
40
39
|
from langfun.core.llms.google_genai import GeminiPro1_5
|
41
|
-
from langfun.core.llms.google_genai import
|
42
|
-
from langfun.core.llms.google_genai import
|
43
|
-
from langfun.core.llms.google_genai import
|
44
|
-
from langfun.core.llms.google_genai import
|
40
|
+
from langfun.core.llms.google_genai import GeminiPro1_5_002
|
41
|
+
from langfun.core.llms.google_genai import GeminiPro1_5_001
|
42
|
+
from langfun.core.llms.google_genai import GeminiFlash1_5
|
43
|
+
from langfun.core.llms.google_genai import GeminiFlash1_5_002
|
44
|
+
from langfun.core.llms.google_genai import GeminiFlash1_5_001
|
45
|
+
from langfun.core.llms.google_genai import GeminiPro1
|
46
|
+
|
47
|
+
from langfun.core.llms.vertexai import VertexAI
|
48
|
+
from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0ThinkingExp_20241219
|
49
|
+
from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0Exp
|
50
|
+
from langfun.core.llms.vertexai import VertexAIGeminiExp_20241206
|
51
|
+
from langfun.core.llms.vertexai import VertexAIGeminiExp_20241114
|
52
|
+
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5
|
53
|
+
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_002
|
54
|
+
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_001
|
55
|
+
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5
|
56
|
+
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_002
|
57
|
+
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_001
|
58
|
+
from langfun.core.llms.vertexai import VertexAIGeminiPro1
|
45
59
|
|
46
60
|
# OpenAI models.
|
47
61
|
from langfun.core.llms.openai import OpenAI
|
@@ -124,28 +138,13 @@ from langfun.core.llms.groq import GroqGemma_7B_IT
|
|
124
138
|
from langfun.core.llms.groq import GroqWhisper_Large_v3
|
125
139
|
from langfun.core.llms.groq import GroqWhisper_Large_v3Turbo
|
126
140
|
|
127
|
-
from langfun.core.llms.vertexai import VertexAI
|
128
|
-
from langfun.core.llms.vertexai import VertexAIGemini2_0
|
129
|
-
from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0Exp
|
130
|
-
from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0ThinkingExp
|
131
|
-
from langfun.core.llms.vertexai import VertexAIGemini1_5
|
132
|
-
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5
|
133
|
-
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_001
|
134
|
-
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_002
|
135
|
-
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_0514
|
136
|
-
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_0409
|
137
|
-
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5
|
138
|
-
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_001
|
139
|
-
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_002
|
140
|
-
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_0514
|
141
|
-
from langfun.core.llms.vertexai import VertexAIGeminiPro1
|
142
|
-
from langfun.core.llms.vertexai import VertexAIGeminiPro1Vision
|
143
|
-
from langfun.core.llms.vertexai import VertexAIEndpoint
|
144
|
-
|
145
|
-
|
146
141
|
# LLaMA C++ models.
|
147
142
|
from langfun.core.llms.llama_cpp import LlamaCppRemote
|
148
143
|
|
144
|
+
# DeepSeek models.
|
145
|
+
from langfun.core.llms.deepseek import DeepSeek
|
146
|
+
from langfun.core.llms.deepseek import DeepSeekChat
|
147
|
+
|
149
148
|
# Placeholder for Google-internal imports.
|
150
149
|
|
151
150
|
# Include cache as sub-module.
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
import collections
|
17
17
|
import contextlib
|
18
|
+
import json
|
18
19
|
from typing import Annotated, Any, Iterator
|
19
20
|
import langfun.core as lf
|
20
21
|
from langfun.core.llms.cache import base
|
@@ -49,6 +50,11 @@ class InMemory(base.LMCacheBase):
|
|
49
50
|
"Creating a new cache as cache file '%s' does not exist.",
|
50
51
|
self.filename,
|
51
52
|
)
|
53
|
+
except json.JSONDecodeError:
|
54
|
+
pg.logging.warning(
|
55
|
+
"Creating a new cache as cache file '%s' is corrupted.",
|
56
|
+
self.filename,
|
57
|
+
)
|
52
58
|
|
53
59
|
def model_ids(self) -> list[str]:
|
54
60
|
"""Returns the model ids of cached queires."""
|
@@ -295,6 +295,11 @@ class InMemoryLMCacheTest(unittest.TestCase):
|
|
295
295
|
self.assertEqual(cache2.stats.num_updates, 2)
|
296
296
|
cache2.save()
|
297
297
|
|
298
|
+
# Corrupted file.
|
299
|
+
pg.io.writefile(path, 'bad_content')
|
300
|
+
cache3 = in_memory.InMemory(path)
|
301
|
+
self.assertEqual(len(cache3), 0)
|
302
|
+
|
298
303
|
|
299
304
|
class LmCacheTest(unittest.TestCase):
|
300
305
|
|