langfun 0.1.2.dev202501050804__py3-none-any.whl → 0.1.2.dev202501090804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. langfun/core/__init__.py +0 -5
  2. langfun/core/coding/python/correction.py +4 -3
  3. langfun/core/coding/python/errors.py +10 -9
  4. langfun/core/coding/python/execution.py +23 -12
  5. langfun/core/coding/python/execution_test.py +21 -2
  6. langfun/core/coding/python/generation.py +18 -9
  7. langfun/core/concurrent.py +2 -3
  8. langfun/core/console.py +8 -3
  9. langfun/core/eval/base.py +2 -3
  10. langfun/core/eval/v2/reporting.py +15 -6
  11. langfun/core/language_model.py +7 -4
  12. langfun/core/language_model_test.py +15 -0
  13. langfun/core/llms/__init__.py +25 -26
  14. langfun/core/llms/cache/in_memory.py +6 -0
  15. langfun/core/llms/cache/in_memory_test.py +5 -0
  16. langfun/core/llms/deepseek.py +261 -0
  17. langfun/core/llms/deepseek_test.py +438 -0
  18. langfun/core/llms/gemini.py +507 -0
  19. langfun/core/llms/gemini_test.py +195 -0
  20. langfun/core/llms/google_genai.py +46 -320
  21. langfun/core/llms/google_genai_test.py +9 -204
  22. langfun/core/llms/openai.py +5 -0
  23. langfun/core/llms/vertexai.py +31 -359
  24. langfun/core/llms/vertexai_test.py +6 -166
  25. langfun/core/structured/mapping.py +13 -13
  26. langfun/core/structured/mapping_test.py +2 -2
  27. langfun/core/structured/schema.py +16 -8
  28. {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/METADATA +19 -14
  29. {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/RECORD +32 -30
  30. {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/WHEEL +1 -1
  31. langfun/core/text_formatting.py +0 -168
  32. langfun/core/text_formatting_test.py +0 -65
  33. {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/LICENSE +0 -0
  34. {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/top_level.txt +0 -0
langfun/core/__init__.py CHANGED
@@ -77,11 +77,6 @@ from langfun.core.concurrent import concurrent_map
77
77
  from langfun.core.concurrent import with_context_access
78
78
  from langfun.core.concurrent import with_retry
79
79
 
80
- # Utility libraries for text formatting.
81
- from langfun.core.text_formatting import colored
82
- from langfun.core.text_formatting import colored_print as print # pylint: disable=redefined-builtin
83
- from langfun.core.text_formatting import colored_template
84
-
85
80
  # Interface for natural language formattable.
86
81
  from langfun.core.natural_language import NaturalLanguageFormattable
87
82
 
@@ -40,6 +40,7 @@ def run_with_correction(
40
40
  sandbox: bool | None = None,
41
41
  timeout: int | None = 5,
42
42
  returns_code: bool = False,
43
+ returns_stdout: bool = False,
43
44
  outputs_intermediate: bool = False,
44
45
  ) -> Any | tuple[Any, str]:
45
46
  """Correct code with a language model via self-play.
@@ -62,6 +63,7 @@ def run_with_correction(
62
63
  timeout. Applicable only when sandbox is set to True.
63
64
  returns_code: If True, the return value is a tuple of (result, final code).
64
65
  Otherwise the return value is the result only.
66
+ returns_stdout: If True, the stdout (a str) will be returned.
65
67
  outputs_intermediate: If True, intermediate output will be outputted as a
66
68
  dict, with the last line's value accessible by key '__result__'. Otherwise
67
69
  the value of the last line will be returned.
@@ -87,6 +89,7 @@ def run_with_correction(
87
89
  global_vars=global_vars,
88
90
  sandbox=sandbox,
89
91
  timeout=timeout,
92
+ returns_stdout=returns_stdout,
90
93
  outputs_intermediate=outputs_intermediate,
91
94
  )
92
95
  )
@@ -189,9 +192,7 @@ def correct(
189
192
  def _error_feedback_str(error: Exception) -> str:
190
193
  """Returns the error str for feedback."""
191
194
  if isinstance(error, errors.CodeError):
192
- return lf.text_formatting.decolored(
193
- error.format(include_complete_code=False)
194
- )
195
+ return pg.decolor(error.format(include_complete_code=False))
195
196
  else:
196
197
  return f"Encountered {error.__class__.__name__}: {error}"
197
198
 
@@ -17,7 +17,8 @@ import io
17
17
  import sys
18
18
  import textwrap
19
19
  import traceback
20
- import langfun.core as lf
20
+
21
+ import pyglove as pg
21
22
 
22
23
 
23
24
  class CodeError(RuntimeError):
@@ -62,13 +63,13 @@ class CodeError(RuntimeError):
62
63
  if 'line' not in error_message and self.lineno is not None:
63
64
  error_message += f' (<unknown>, line {self.lineno})'
64
65
  r.write(
65
- lf.colored(
66
+ pg.colored(
66
67
  f'{self.cause.__class__.__name__}: {error_message}', 'magenta'))
67
68
 
68
69
  if self.lineno is not None:
69
70
  r.write('\n\n')
70
71
  r.write(textwrap.indent(
71
- lf.colored(
72
+ pg.colored(
72
73
  self.code_lines(self.lineno - 1, self.end_lineno), 'magenta'),
73
74
  ' ' * 2
74
75
  ))
@@ -76,14 +77,14 @@ class CodeError(RuntimeError):
76
77
 
77
78
  if include_complete_code:
78
79
  r.write('\n')
79
- r.write(lf.colored('[Generated Code]', 'green', styles=['bold']))
80
+ r.write(pg.colored('[Generated Code]', 'green', styles=['bold']))
80
81
  r.write('\n\n')
81
- r.write(lf.colored(' ```python\n', 'green'))
82
+ r.write(pg.colored(' ```python\n', 'green'))
82
83
  r.write(textwrap.indent(
83
- lf.colored(self.code, 'green'),
84
+ pg.colored(self.code, 'green'),
84
85
  ' ' * 2
85
86
  ))
86
- r.write(lf.colored('\n ```\n', 'green'))
87
+ r.write(pg.colored('\n ```\n', 'green'))
87
88
  return r.getvalue()
88
89
 
89
90
 
@@ -98,10 +99,10 @@ class SerializationError(RuntimeError):
98
99
  r = io.StringIO()
99
100
  cause_message = str(self.cause).rstrip()
100
101
  if self.message:
101
- r.write(lf.colored(self.message, 'magenta'))
102
+ r.write(pg.colored(self.message, 'magenta'))
102
103
  r.write('\n\n')
103
104
  r.write(
104
- lf.colored(
105
+ pg.colored(
105
106
  f'{self.cause.__class__.__name__}: {cause_message}', 'magenta'
106
107
  )
107
108
  )
@@ -57,6 +57,7 @@ def evaluate(
57
57
  *,
58
58
  global_vars: dict[str, Any] | None = None,
59
59
  permission: permissions.CodePermission | None = None,
60
+ returns_stdout: bool = False,
60
61
  outputs_intermediate: bool = False,
61
62
  ) -> Any | dict[str, Any]:
62
63
  """Executes Python code.
@@ -71,14 +72,17 @@ def evaluate(
71
72
  global_vars: An optional dict as the globals that could be referenced by the
72
73
  code.
73
74
  permission: Permission for the Python code to run.
74
- outputs_intermediate: If True, intermediate output will be outputted as a
75
- dict, with the last line's value accessible by key '__result__'. Otherwise
76
- the value of the last line will be returned.
75
+ returns_stdout: If True, the stdout (a str) will be returned.
76
+ outputs_intermediate: Applicable when returns_stdout is False. If True,
77
+ intermediate output will be outputted as a dict, with the last line's
78
+ value accessible by key '__result__' and the std output accessible by
79
+ key '__stdout__'. Otherwise the value of the last line will be returned.
77
80
 
78
81
  Returns:
79
- The value of the last line of the code. Or a dict of variable name to
80
- their values if `outputs_intermediate` is set to True, with the final result
81
- accessible by key '__result__'.
82
+ The value of the last line of the code block. Or a dict of variable
83
+ names of all locals to their evaluated values as the output of the code to
84
+ run. The value for the last line can be accessed by key '__result__'. Or the
85
+ stdout as a str.
82
86
  """
83
87
  # Set up the permission and context.
84
88
  permission = permission or permissions.get_permission()
@@ -136,6 +140,8 @@ def evaluate(
136
140
  raise errors.CodeError(code, e) from e
137
141
  global_vars[RESULT_KEY] = list(global_vars.values())[-1]
138
142
 
143
+ if returns_stdout:
144
+ return stdout.getvalue()
139
145
  if outputs_intermediate:
140
146
  outputs = {}
141
147
  for k, v in global_vars.items():
@@ -258,6 +264,7 @@ def run(
258
264
  *,
259
265
  global_vars: dict[str, Any] | None = None,
260
266
  permission: permissions.CodePermission | None = None,
267
+ returns_stdout: bool = False,
261
268
  outputs_intermediate: bool = False,
262
269
  sandbox: bool | None = None,
263
270
  timeout: float | None = None,
@@ -273,9 +280,11 @@ def run(
273
280
  code: Python code to run.
274
281
  global_vars: An optional dict of
275
282
  permission: Permission for the Python code to run.
276
- outputs_intermediate: If True, all variables created as locals will be
277
- returned, with the final result accessible by key '__result__'. Otherwise
278
- only the final result will be returned.
283
+ returns_stdout: If True, the stdout (a str) will be returned.
284
+ outputs_intermediate: Applicable when returns_stdout is False. If True,
285
+ intermediate output will be outputted as a dict, with the last line's
286
+ value accessible by key '__result__' and the std output accessible by
287
+ key '__stdout__'. Otherwise the value of the last line will be returned.
279
288
  sandbox: If True, run code in sandbox; If False, run code in current
280
289
  process. If None, run in sandbox first, if the output could not be
281
290
  serialized and pass to current process, run the code again in current
@@ -285,7 +294,8 @@ def run(
285
294
  Returns:
286
295
  The value of the last line of the code block. Or a dict of variable
287
296
  names of all locals to their evaluated values as the output of the code to
288
- run. The value for the last line can be accessed by key '__result__'.
297
+ run. The value for the last line can be accessed by key '__result__'. Or the
298
+ stdout as a str.
289
299
 
290
300
  Raises:
291
301
  TimeoutError: If the execution time exceeds the timeout.
@@ -293,5 +303,6 @@ def run(
293
303
  """
294
304
  return call(
295
305
  evaluate, code=code, global_vars=global_vars, permission=permission,
296
- outputs_intermediate=outputs_intermediate,
297
- sandbox=sandbox, timeout=timeout)
306
+ returns_stdout=returns_stdout, outputs_intermediate=outputs_intermediate,
307
+ sandbox=sandbox, timeout=timeout
308
+ )
@@ -63,6 +63,15 @@ class EvaluateTest(unittest.TestCase):
63
63
  ),
64
64
  3,
65
65
  )
66
+ with self.assertRaisesRegex(errors.CodeError, 'ValueError'):
67
+ execution.evaluate(
68
+ """
69
+ def foo():
70
+ raise ValueError("intentional error")
71
+ foo()
72
+ """,
73
+ permission=permissions.CodePermission.ALL
74
+ )
66
75
 
67
76
  def test_class_def(self):
68
77
  ret = execution.evaluate(
@@ -102,16 +111,20 @@ class EvaluateTest(unittest.TestCase):
102
111
  self.assertIs(ret['__result__'], ret['bar'])
103
112
 
104
113
  def test_function_def_and_call(self):
105
- ret = execution.evaluate(
114
+ code = (
106
115
  """
107
116
  def foo(x, y):
108
117
  return x + y
109
118
 
110
119
  def bar(z):
120
+ print(f'z is {z}')
111
121
  return z + foo(z, z)
112
122
 
113
123
  bar(1)
114
- """,
124
+ """
125
+ )
126
+ ret = execution.evaluate(
127
+ code,
115
128
  permission=permissions.CodePermission.ALL,
116
129
  outputs_intermediate=True,
117
130
  )
@@ -119,6 +132,12 @@ class EvaluateTest(unittest.TestCase):
119
132
  list(ret.keys()), ['foo', 'bar', '__result__', '__stdout__']
120
133
  )
121
134
  self.assertEqual(ret['__result__'], 3)
135
+ ret = execution.evaluate(
136
+ code,
137
+ permission=permissions.CodePermission.ALL,
138
+ returns_stdout=True,
139
+ )
140
+ self.assertEqual(ret, 'z is 1\n')
122
141
 
123
142
  def test_complex(self):
124
143
  ret = execution.evaluate(
@@ -88,6 +88,8 @@ class PythonCode(pg.Object):
88
88
  sandbox: bool | None = None,
89
89
  timeout: int | None = 5,
90
90
  global_vars: dict[str, Any] | None = None,
91
+ returns_stdout: bool = False,
92
+ outputs_intermediate: bool = False,
91
93
  autofix: int = 3,
92
94
  autofix_lm: lf.LanguageModel | None = None,
93
95
  ) -> Any:
@@ -101,13 +103,22 @@ class PythonCode(pg.Object):
101
103
  timeout: Timeout in seconds. If None, there is no timeout. Applicable when
102
104
  sandbox is set to True.
103
105
  global_vars: Global variables that could be accessed from the source code.
106
+ returns_stdout: If True, the stdout (a str) will be returned.
107
+ outputs_intermediate: Applicable when returns_stdout is False. If True,
108
+ intermediate output will be outputted as a dict, with the last line's
109
+ value accessible by key '__result__' and the std output accessible by
110
+ key '__stdout__'. Otherwise the value of the last line will be returned.
104
111
  autofix: Number of attempts to auto fix the generated code. If 0, autofix
105
112
  is disabled.
106
113
  autofix_lm: Language model to be used. If not specified, it will try to
107
114
  use the `lm` under `lf.context`.
108
115
 
109
116
  Returns:
110
- The value of the last expression in the source code.
117
+ The value of the last expression in the source code. Or a dict of local
118
+ variable names defined in the source code to their values if
119
+ `outputs_intermediate` is set to True. The value for the last line can be
120
+ accessed by key '__result__'. Or the stdout as a str if `returns_stdout`
121
+ is set to True.
111
122
 
112
123
  Raises:
113
124
  TimeoutError: If `sandbox` is True and timeout has reached.
@@ -121,6 +132,8 @@ class PythonCode(pg.Object):
121
132
  max_attempts=autofix,
122
133
  lm=autofix_lm,
123
134
  returns_code=True,
135
+ returns_stdout=returns_stdout,
136
+ outputs_intermediate=outputs_intermediate,
124
137
  )
125
138
  self.rebind(source=updated_code)
126
139
  return result
@@ -158,18 +171,14 @@ class PythonCode(pg.Object):
158
171
  TimeoutError: If `sandbox` is True and timeout has reached.
159
172
  Exception: Any errors that the source code has raised.
160
173
  """
161
- result, updated_code = correction.run_with_correction(
162
- self.source,
163
- global_vars=global_vars,
174
+ return self(
164
175
  sandbox=sandbox,
165
176
  timeout=timeout,
177
+ global_vars=global_vars,
178
+ autofix=autofix,
179
+ autofix_lm=autofix_lm,
166
180
  outputs_intermediate=True,
167
- max_attempts=autofix,
168
- lm=autofix_lm,
169
- returns_code=True,
170
181
  )
171
- self.rebind(source=updated_code)
172
- return result
173
182
 
174
183
 
175
184
  class PythonFunction(pg.Object):
@@ -25,7 +25,6 @@ import time
25
25
  from typing import Any, Callable, Iterable, Iterator, Literal, Sequence, Tuple, Type, Union
26
26
 
27
27
  from langfun.core import component
28
- from langfun.core import text_formatting
29
28
  import pyglove as pg
30
29
 
31
30
 
@@ -844,10 +843,10 @@ class _ConsoleProgressControl(_ProgressControl):
844
843
  def refresh(self):
845
844
  s = io.StringIO()
846
845
  if self.label is not None:
847
- s.write(text_formatting.colored(self.label, 'red', styles=['bold']))
846
+ s.write(pg.colored(self.label, 'red', styles=['bold']))
848
847
  s.write(': ')
849
848
  s.write(
850
- text_formatting.colored(
849
+ pg.colored(
851
850
  '%d%% (%d/%d)' %
852
851
  (
853
852
  self._progress * 100 // self.total,
langfun/core/console.py CHANGED
@@ -15,7 +15,7 @@
15
15
 
16
16
  import sys
17
17
  from typing import Any
18
- from langfun.core.text_formatting import colored
18
+ import pyglove as pg
19
19
 
20
20
 
21
21
  def write(
@@ -42,10 +42,15 @@ def write(
42
42
  """
43
43
  # Print title if present.
44
44
  if title is not None:
45
- print(colored(title, styles=['bold']))
45
+ print(pg.colored(title, styles=['bold']))
46
46
 
47
47
  # Print body.
48
- print(colored(str(value), color=color, background=background, styles=styles))
48
+ print(dir(pg.utils))
49
+ print(
50
+ pg.colored(
51
+ str(value), color=color, background=background, styles=styles
52
+ )
53
+ )
49
54
 
50
55
 
51
56
  try:
langfun/core/eval/base.py CHANGED
@@ -1298,7 +1298,7 @@ class Evaluation(Evaluable):
1298
1298
  id=self.id,
1299
1299
  dir=self.dir,
1300
1300
  model=self.lm.model_id,
1301
- prompt_template=lf.text_formatting.decolored(str(self.prompt)),
1301
+ prompt_template=pg.decolor(str(self.prompt)),
1302
1302
  method=self.method,
1303
1303
  schema_fn=str(self.schema_fn),
1304
1304
  ),
@@ -2110,8 +2110,7 @@ class Summary(pg.Object):
2110
2110
 
2111
2111
  def _format_error(error: Exception):
2112
2112
  """Formats an error into a string."""
2113
- return (f'({error.__class__.__name__}) '
2114
- + lf.text_formatting.decolored(str(error)))
2113
+ return (f'({error.__class__.__name__}) ' + pg.decolor(str(error)))
2115
2114
 
2116
2115
 
2117
2116
  def _error_key(error: Exception) -> str:
@@ -51,6 +51,8 @@ class HtmlReporter(experiment_lib.Plugin):
51
51
  self._update_thread = None
52
52
  self._stop_update = False
53
53
  self._stop_update_experiment_ids = set()
54
+ self._summary_lock = None
55
+ self._experiment_index_lock = None
54
56
 
55
57
  def on_run_start(
56
58
  self,
@@ -61,6 +63,10 @@ class HtmlReporter(experiment_lib.Plugin):
61
63
  self._last_experiment_report_time = {leaf.id: 0 for leaf in root.leaf_nodes}
62
64
  self._stop_update = False
63
65
  self._stop_update_experiment_ids = set()
66
+ self._summary_lock = threading.Lock()
67
+ self._experiment_index_lock = {
68
+ leaf.id: threading.Lock() for leaf in root.leaf_nodes
69
+ }
64
70
  self._update_thread = threading.Thread(
65
71
  target=self._update_thread_func, args=(runner,)
66
72
  )
@@ -133,21 +139,23 @@ class HtmlReporter(experiment_lib.Plugin):
133
139
  """Maybe update the summary of current run."""
134
140
  run = runner.current_run
135
141
  def _summary():
136
- run.experiment.to_html(
142
+ html = run.experiment.to_html(
137
143
  collapse_level=None,
138
144
  extra_flags=dict(
139
145
  current_run=run, interactive=False, card_view=True,
140
146
  )
141
- ).save(
142
- run.output_path_for(run.experiment, _SUMMARY_FILE)
143
147
  )
148
+ with self._summary_lock:
149
+ html.save(
150
+ run.output_path_for(run.experiment, _SUMMARY_FILE)
151
+ )
144
152
 
145
153
  if force or (time.time() - self._last_summary_time > self.summary_interval):
154
+ self._last_summary_time = time.time()
146
155
  if background:
147
156
  runner.background_run(_summary)
148
157
  else:
149
158
  _summary()
150
- self._last_summary_time = time.time()
151
159
 
152
160
  def _maybe_update_experiment_html(
153
161
  self,
@@ -170,7 +178,8 @@ class HtmlReporter(experiment_lib.Plugin):
170
178
  card_view=False,
171
179
  ),
172
180
  )
173
- html.save(index_html_path)
181
+ with self._experiment_index_lock[experiment.id]:
182
+ html.save(index_html_path)
174
183
  experiment.info(
175
184
  f'Updated {index_html_path!r} in {t.elapse:.2f} seconds.',
176
185
  )
@@ -185,11 +194,11 @@ class HtmlReporter(experiment_lib.Plugin):
185
194
  time.time() - self._last_experiment_report_time[experiment.id]
186
195
  > self.experiment_report_interval
187
196
  ):
197
+ self._last_experiment_report_time[experiment.id] = time.time()
188
198
  if background:
189
199
  runner.background_run(_save)
190
200
  else:
191
201
  _save()
192
- self._last_experiment_report_time[experiment.id] = time.time()
193
202
 
194
203
  def _save_example_html(
195
204
  self, runner: Runner, experiment: Experiment, example: Example
@@ -434,7 +434,10 @@ class LanguageModel(component.Component):
434
434
  def __init__(self, *args, **kwargs) -> None:
435
435
  """Overrides __init__ to pass through **kwargs to sampling options."""
436
436
 
437
- sampling_options = kwargs.pop('sampling_options', LMSamplingOptions())
437
+ sampling_options = kwargs.pop(
438
+ 'sampling_options',
439
+ pg.clone(self.__schema__.fields['sampling_options'].default_value)
440
+ )
438
441
  sampling_options_delta = {}
439
442
 
440
443
  for k, v in kwargs.items():
@@ -650,7 +653,7 @@ class LanguageModel(component.Component):
650
653
  """Outputs debugging information about the model."""
651
654
  title_suffix = ''
652
655
  if usage.total_tokens != 0:
653
- title_suffix = console.colored(
656
+ title_suffix = pg.colored(
654
657
  f' (total {usage.total_tokens} tokens)', 'red'
655
658
  )
656
659
 
@@ -669,7 +672,7 @@ class LanguageModel(component.Component):
669
672
  """Outputs debugging information about the prompt."""
670
673
  title_suffix = ''
671
674
  if usage.prompt_tokens != 0:
672
- title_suffix = console.colored(f' ({usage.prompt_tokens} tokens)', 'red')
675
+ title_suffix = pg.colored(f' ({usage.prompt_tokens} tokens)', 'red')
673
676
 
674
677
  console.write(
675
678
  # We use metadata 'formatted_text' for scenarios where the prompt text
@@ -700,7 +703,7 @@ class LanguageModel(component.Component):
700
703
  if usage.completion_tokens != 0:
701
704
  title_suffix += f'{usage.completion_tokens} tokens '
702
705
  title_suffix += f'in {elapse:.2f} seconds)'
703
- title_suffix = console.colored(title_suffix, 'red')
706
+ title_suffix = pg.colored(title_suffix, 'red')
704
707
 
705
708
  console.write(
706
709
  str(response) + '\n',
@@ -117,6 +117,21 @@ class LanguageModelTest(unittest.TestCase):
117
117
  self.assertEqual(lm.sampling_options.top_k, 2)
118
118
  self.assertEqual(lm.max_attempts, 2)
119
119
 
120
+ def test_subclassing(self):
121
+
122
+ class ChildModel(lm_lib.LanguageModel):
123
+
124
+ sampling_options = lm_lib.LMSamplingOptions(
125
+ temperature=0.5, top_k=20
126
+ )
127
+
128
+ def _sample(self, *args, **kwargs):
129
+ pass
130
+
131
+ lm = ChildModel(top_k=10)
132
+ self.assertEqual(lm.sampling_options.temperature, 0.5)
133
+ self.assertEqual(lm.sampling_options.top_k, 10)
134
+
120
135
  def test_sample(self):
121
136
  lm = MockModel(top_k=1)
122
137
  self.assertEqual(
@@ -32,16 +32,30 @@ from langfun.core.llms.rest import REST
32
32
 
33
33
  # Gemini models.
34
34
  from langfun.core.llms.google_genai import GenAI
35
- from langfun.core.llms.google_genai import GeminiFlash2_0ThinkingExp
35
+ from langfun.core.llms.google_genai import GeminiFlash2_0ThinkingExp_20241219
36
36
  from langfun.core.llms.google_genai import GeminiFlash2_0Exp
37
- from langfun.core.llms.google_genai import GeminiExp_20241114
38
37
  from langfun.core.llms.google_genai import GeminiExp_20241206
39
- from langfun.core.llms.google_genai import GeminiFlash1_5
38
+ from langfun.core.llms.google_genai import GeminiExp_20241114
40
39
  from langfun.core.llms.google_genai import GeminiPro1_5
41
- from langfun.core.llms.google_genai import GeminiPro
42
- from langfun.core.llms.google_genai import GeminiProVision
43
- from langfun.core.llms.google_genai import Palm2
44
- from langfun.core.llms.google_genai import Palm2_IT
40
+ from langfun.core.llms.google_genai import GeminiPro1_5_002
41
+ from langfun.core.llms.google_genai import GeminiPro1_5_001
42
+ from langfun.core.llms.google_genai import GeminiFlash1_5
43
+ from langfun.core.llms.google_genai import GeminiFlash1_5_002
44
+ from langfun.core.llms.google_genai import GeminiFlash1_5_001
45
+ from langfun.core.llms.google_genai import GeminiPro1
46
+
47
+ from langfun.core.llms.vertexai import VertexAI
48
+ from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0ThinkingExp_20241219
49
+ from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0Exp
50
+ from langfun.core.llms.vertexai import VertexAIGeminiExp_20241206
51
+ from langfun.core.llms.vertexai import VertexAIGeminiExp_20241114
52
+ from langfun.core.llms.vertexai import VertexAIGeminiPro1_5
53
+ from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_002
54
+ from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_001
55
+ from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5
56
+ from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_002
57
+ from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_001
58
+ from langfun.core.llms.vertexai import VertexAIGeminiPro1
45
59
 
46
60
  # OpenAI models.
47
61
  from langfun.core.llms.openai import OpenAI
@@ -124,28 +138,13 @@ from langfun.core.llms.groq import GroqGemma_7B_IT
124
138
  from langfun.core.llms.groq import GroqWhisper_Large_v3
125
139
  from langfun.core.llms.groq import GroqWhisper_Large_v3Turbo
126
140
 
127
- from langfun.core.llms.vertexai import VertexAI
128
- from langfun.core.llms.vertexai import VertexAIGemini2_0
129
- from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0Exp
130
- from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0ThinkingExp
131
- from langfun.core.llms.vertexai import VertexAIGemini1_5
132
- from langfun.core.llms.vertexai import VertexAIGeminiPro1_5
133
- from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_001
134
- from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_002
135
- from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_0514
136
- from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_0409
137
- from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5
138
- from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_001
139
- from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_002
140
- from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_0514
141
- from langfun.core.llms.vertexai import VertexAIGeminiPro1
142
- from langfun.core.llms.vertexai import VertexAIGeminiPro1Vision
143
- from langfun.core.llms.vertexai import VertexAIEndpoint
144
-
145
-
146
141
  # LLaMA C++ models.
147
142
  from langfun.core.llms.llama_cpp import LlamaCppRemote
148
143
 
144
+ # DeepSeek models.
145
+ from langfun.core.llms.deepseek import DeepSeek
146
+ from langfun.core.llms.deepseek import DeepSeekChat
147
+
149
148
  # Placeholder for Google-internal imports.
150
149
 
151
150
  # Include cache as sub-module.
@@ -15,6 +15,7 @@
15
15
 
16
16
  import collections
17
17
  import contextlib
18
+ import json
18
19
  from typing import Annotated, Any, Iterator
19
20
  import langfun.core as lf
20
21
  from langfun.core.llms.cache import base
@@ -49,6 +50,11 @@ class InMemory(base.LMCacheBase):
49
50
  "Creating a new cache as cache file '%s' does not exist.",
50
51
  self.filename,
51
52
  )
53
+ except json.JSONDecodeError:
54
+ pg.logging.warning(
55
+ "Creating a new cache as cache file '%s' is corrupted.",
56
+ self.filename,
57
+ )
52
58
 
53
59
  def model_ids(self) -> list[str]:
54
60
  """Returns the model ids of cached queires."""
@@ -295,6 +295,11 @@ class InMemoryLMCacheTest(unittest.TestCase):
295
295
  self.assertEqual(cache2.stats.num_updates, 2)
296
296
  cache2.save()
297
297
 
298
+ # Corrupted file.
299
+ pg.io.writefile(path, 'bad_content')
300
+ cache3 = in_memory.InMemory(path)
301
+ self.assertEqual(len(cache3), 0)
302
+
298
303
 
299
304
  class LmCacheTest(unittest.TestCase):
300
305