langfun 0.1.2.dev202501060804__py3-none-any.whl → 0.1.2.dev202501100804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. langfun/core/__init__.py +0 -5
  2. langfun/core/coding/python/correction.py +4 -3
  3. langfun/core/coding/python/errors.py +10 -9
  4. langfun/core/coding/python/execution.py +23 -12
  5. langfun/core/coding/python/execution_test.py +21 -2
  6. langfun/core/coding/python/generation.py +18 -9
  7. langfun/core/concurrent.py +2 -3
  8. langfun/core/console.py +8 -3
  9. langfun/core/eval/base.py +2 -3
  10. langfun/core/eval/v2/reporting.py +8 -4
  11. langfun/core/language_model.py +7 -4
  12. langfun/core/language_model_test.py +15 -0
  13. langfun/core/llms/__init__.py +7 -0
  14. langfun/core/llms/deepseek.py +117 -0
  15. langfun/core/llms/deepseek_test.py +61 -0
  16. langfun/core/llms/google_genai.py +1 -0
  17. langfun/core/llms/groq.py +12 -99
  18. langfun/core/llms/groq_test.py +31 -137
  19. langfun/core/llms/llama_cpp.py +17 -54
  20. langfun/core/llms/llama_cpp_test.py +2 -34
  21. langfun/core/llms/openai.py +14 -147
  22. langfun/core/llms/openai_compatible.py +179 -0
  23. langfun/core/llms/openai_compatible_test.py +480 -0
  24. langfun/core/llms/openai_test.py +13 -423
  25. langfun/core/llms/vertexai.py +6 -2
  26. langfun/core/llms/vertexai_test.py +1 -1
  27. langfun/core/modalities/mime.py +8 -0
  28. langfun/core/modalities/mime_test.py +19 -4
  29. langfun/core/modality_test.py +0 -1
  30. langfun/core/structured/mapping.py +13 -13
  31. langfun/core/structured/mapping_test.py +2 -2
  32. langfun/core/structured/schema.py +16 -8
  33. {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/METADATA +13 -2
  34. {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/RECORD +37 -35
  35. {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/WHEEL +1 -1
  36. langfun/core/text_formatting.py +0 -168
  37. langfun/core/text_formatting_test.py +0 -65
  38. {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/LICENSE +0 -0
  39. {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/top_level.txt +0 -0
langfun/core/__init__.py CHANGED
@@ -77,11 +77,6 @@ from langfun.core.concurrent import concurrent_map
77
77
  from langfun.core.concurrent import with_context_access
78
78
  from langfun.core.concurrent import with_retry
79
79
 
80
- # Utility libraries for text formatting.
81
- from langfun.core.text_formatting import colored
82
- from langfun.core.text_formatting import colored_print as print # pylint: disable=redefined-builtin
83
- from langfun.core.text_formatting import colored_template
84
-
85
80
  # Interface for natural language formattable.
86
81
  from langfun.core.natural_language import NaturalLanguageFormattable
87
82
 
@@ -40,6 +40,7 @@ def run_with_correction(
40
40
  sandbox: bool | None = None,
41
41
  timeout: int | None = 5,
42
42
  returns_code: bool = False,
43
+ returns_stdout: bool = False,
43
44
  outputs_intermediate: bool = False,
44
45
  ) -> Any | tuple[Any, str]:
45
46
  """Correct code with a language model via self-play.
@@ -62,6 +63,7 @@ def run_with_correction(
62
63
  timeout. Applicable only when sandbox is set to True.
63
64
  returns_code: If True, the return value is a tuple of (result, final code).
64
65
  Otherwise the return value is the result only.
66
+ returns_stdout: If True, the stdout (a str) will be returned.
65
67
  outputs_intermediate: If True, intermediate output will be outputted as a
66
68
  dict, with the last line's value accessible by key '__result__'. Otherwise
67
69
  the value of the last line will be returned.
@@ -87,6 +89,7 @@ def run_with_correction(
87
89
  global_vars=global_vars,
88
90
  sandbox=sandbox,
89
91
  timeout=timeout,
92
+ returns_stdout=returns_stdout,
90
93
  outputs_intermediate=outputs_intermediate,
91
94
  )
92
95
  )
@@ -189,9 +192,7 @@ def correct(
189
192
  def _error_feedback_str(error: Exception) -> str:
190
193
  """Returns the error str for feedback."""
191
194
  if isinstance(error, errors.CodeError):
192
- return lf.text_formatting.decolored(
193
- error.format(include_complete_code=False)
194
- )
195
+ return pg.decolor(error.format(include_complete_code=False))
195
196
  else:
196
197
  return f"Encountered {error.__class__.__name__}: {error}"
197
198
 
@@ -17,7 +17,8 @@ import io
17
17
  import sys
18
18
  import textwrap
19
19
  import traceback
20
- import langfun.core as lf
20
+
21
+ import pyglove as pg
21
22
 
22
23
 
23
24
  class CodeError(RuntimeError):
@@ -62,13 +63,13 @@ class CodeError(RuntimeError):
62
63
  if 'line' not in error_message and self.lineno is not None:
63
64
  error_message += f' (<unknown>, line {self.lineno})'
64
65
  r.write(
65
- lf.colored(
66
+ pg.colored(
66
67
  f'{self.cause.__class__.__name__}: {error_message}', 'magenta'))
67
68
 
68
69
  if self.lineno is not None:
69
70
  r.write('\n\n')
70
71
  r.write(textwrap.indent(
71
- lf.colored(
72
+ pg.colored(
72
73
  self.code_lines(self.lineno - 1, self.end_lineno), 'magenta'),
73
74
  ' ' * 2
74
75
  ))
@@ -76,14 +77,14 @@ class CodeError(RuntimeError):
76
77
 
77
78
  if include_complete_code:
78
79
  r.write('\n')
79
- r.write(lf.colored('[Generated Code]', 'green', styles=['bold']))
80
+ r.write(pg.colored('[Generated Code]', 'green', styles=['bold']))
80
81
  r.write('\n\n')
81
- r.write(lf.colored(' ```python\n', 'green'))
82
+ r.write(pg.colored(' ```python\n', 'green'))
82
83
  r.write(textwrap.indent(
83
- lf.colored(self.code, 'green'),
84
+ pg.colored(self.code, 'green'),
84
85
  ' ' * 2
85
86
  ))
86
- r.write(lf.colored('\n ```\n', 'green'))
87
+ r.write(pg.colored('\n ```\n', 'green'))
87
88
  return r.getvalue()
88
89
 
89
90
 
@@ -98,10 +99,10 @@ class SerializationError(RuntimeError):
98
99
  r = io.StringIO()
99
100
  cause_message = str(self.cause).rstrip()
100
101
  if self.message:
101
- r.write(lf.colored(self.message, 'magenta'))
102
+ r.write(pg.colored(self.message, 'magenta'))
102
103
  r.write('\n\n')
103
104
  r.write(
104
- lf.colored(
105
+ pg.colored(
105
106
  f'{self.cause.__class__.__name__}: {cause_message}', 'magenta'
106
107
  )
107
108
  )
@@ -57,6 +57,7 @@ def evaluate(
57
57
  *,
58
58
  global_vars: dict[str, Any] | None = None,
59
59
  permission: permissions.CodePermission | None = None,
60
+ returns_stdout: bool = False,
60
61
  outputs_intermediate: bool = False,
61
62
  ) -> Any | dict[str, Any]:
62
63
  """Executes Python code.
@@ -71,14 +72,17 @@ def evaluate(
71
72
  global_vars: An optional dict as the globals that could be referenced by the
72
73
  code.
73
74
  permission: Permission for the Python code to run.
74
- outputs_intermediate: If True, intermediate output will be outputted as a
75
- dict, with the last line's value accessible by key '__result__'. Otherwise
76
- the value of the last line will be returned.
75
+ returns_stdout: If True, the stdout (a str) will be returned.
76
+ outputs_intermediate: Applicable when returns_stdout is False. If True,
77
+ intermediate output will be outputted as a dict, with the last line's
78
+ value accessible by key '__result__' and the std output accessible by
79
+ key '__stdout__'. Otherwise the value of the last line will be returned.
77
80
 
78
81
  Returns:
79
- The value of the last line of the code. Or a dict of variable name to
80
- their values if `outputs_intermediate` is set to True, with the final result
81
- accessible by key '__result__'.
82
+ The value of the last line of the code block. Or a dict of variable
83
+ names of all locals to their evaluated values as the output of the code to
84
+ run. The value for the last line can be accessed by key '__result__'. Or the
85
+ stdout as a str.
82
86
  """
83
87
  # Set up the permission and context.
84
88
  permission = permission or permissions.get_permission()
@@ -136,6 +140,8 @@ def evaluate(
136
140
  raise errors.CodeError(code, e) from e
137
141
  global_vars[RESULT_KEY] = list(global_vars.values())[-1]
138
142
 
143
+ if returns_stdout:
144
+ return stdout.getvalue()
139
145
  if outputs_intermediate:
140
146
  outputs = {}
141
147
  for k, v in global_vars.items():
@@ -258,6 +264,7 @@ def run(
258
264
  *,
259
265
  global_vars: dict[str, Any] | None = None,
260
266
  permission: permissions.CodePermission | None = None,
267
+ returns_stdout: bool = False,
261
268
  outputs_intermediate: bool = False,
262
269
  sandbox: bool | None = None,
263
270
  timeout: float | None = None,
@@ -273,9 +280,11 @@ def run(
273
280
  code: Python code to run.
274
281
  global_vars: An optional dict of
275
282
  permission: Permission for the Python code to run.
276
- outputs_intermediate: If True, all variables created as locals will be
277
- returned, with the final result accessible by key '__result__'. Otherwise
278
- only the final result will be returned.
283
+ returns_stdout: If True, the stdout (a str) will be returned.
284
+ outputs_intermediate: Applicable when returns_stdout is False. If True,
285
+ intermediate output will be outputted as a dict, with the last line's
286
+ value accessible by key '__result__' and the std output accessible by
287
+ key '__stdout__'. Otherwise the value of the last line will be returned.
279
288
  sandbox: If True, run code in sandbox; If False, run code in current
280
289
  process. If None, run in sandbox first, if the output could not be
281
290
  serialized and pass to current process, run the code again in current
@@ -285,7 +294,8 @@ def run(
285
294
  Returns:
286
295
  The value of the last line of the code block. Or a dict of variable
287
296
  names of all locals to their evaluated values as the output of the code to
288
- run. The value for the last line can be accessed by key '__result__'.
297
+ run. The value for the last line can be accessed by key '__result__'. Or the
298
+ stdout as a str.
289
299
 
290
300
  Raises:
291
301
  TimeoutError: If the execution time exceeds the timeout.
@@ -293,5 +303,6 @@ def run(
293
303
  """
294
304
  return call(
295
305
  evaluate, code=code, global_vars=global_vars, permission=permission,
296
- outputs_intermediate=outputs_intermediate,
297
- sandbox=sandbox, timeout=timeout)
306
+ returns_stdout=returns_stdout, outputs_intermediate=outputs_intermediate,
307
+ sandbox=sandbox, timeout=timeout
308
+ )
@@ -63,6 +63,15 @@ class EvaluateTest(unittest.TestCase):
63
63
  ),
64
64
  3,
65
65
  )
66
+ with self.assertRaisesRegex(errors.CodeError, 'ValueError'):
67
+ execution.evaluate(
68
+ """
69
+ def foo():
70
+ raise ValueError("intentional error")
71
+ foo()
72
+ """,
73
+ permission=permissions.CodePermission.ALL
74
+ )
66
75
 
67
76
  def test_class_def(self):
68
77
  ret = execution.evaluate(
@@ -102,16 +111,20 @@ class EvaluateTest(unittest.TestCase):
102
111
  self.assertIs(ret['__result__'], ret['bar'])
103
112
 
104
113
  def test_function_def_and_call(self):
105
- ret = execution.evaluate(
114
+ code = (
106
115
  """
107
116
  def foo(x, y):
108
117
  return x + y
109
118
 
110
119
  def bar(z):
120
+ print(f'z is {z}')
111
121
  return z + foo(z, z)
112
122
 
113
123
  bar(1)
114
- """,
124
+ """
125
+ )
126
+ ret = execution.evaluate(
127
+ code,
115
128
  permission=permissions.CodePermission.ALL,
116
129
  outputs_intermediate=True,
117
130
  )
@@ -119,6 +132,12 @@ class EvaluateTest(unittest.TestCase):
119
132
  list(ret.keys()), ['foo', 'bar', '__result__', '__stdout__']
120
133
  )
121
134
  self.assertEqual(ret['__result__'], 3)
135
+ ret = execution.evaluate(
136
+ code,
137
+ permission=permissions.CodePermission.ALL,
138
+ returns_stdout=True,
139
+ )
140
+ self.assertEqual(ret, 'z is 1\n')
122
141
 
123
142
  def test_complex(self):
124
143
  ret = execution.evaluate(
@@ -88,6 +88,8 @@ class PythonCode(pg.Object):
88
88
  sandbox: bool | None = None,
89
89
  timeout: int | None = 5,
90
90
  global_vars: dict[str, Any] | None = None,
91
+ returns_stdout: bool = False,
92
+ outputs_intermediate: bool = False,
91
93
  autofix: int = 3,
92
94
  autofix_lm: lf.LanguageModel | None = None,
93
95
  ) -> Any:
@@ -101,13 +103,22 @@ class PythonCode(pg.Object):
101
103
  timeout: Timeout in seconds. If None, there is no timeout. Applicable when
102
104
  sandbox is set to True.
103
105
  global_vars: Global variables that could be accessed from the source code.
106
+ returns_stdout: If True, the stdout (a str) will be returned.
107
+ outputs_intermediate: Applicable when returns_stdout is False. If True,
108
+ intermediate output will be outputted as a dict, with the last line's
109
+ value accessible by key '__result__' and the std output accessible by
110
+ key '__stdout__'. Otherwise the value of the last line will be returned.
104
111
  autofix: Number of attempts to auto fix the generated code. If 0, autofix
105
112
  is disabled.
106
113
  autofix_lm: Language model to be used. If not specified, it will try to
107
114
  use the `lm` under `lf.context`.
108
115
 
109
116
  Returns:
110
- The value of the last expression in the source code.
117
+ The value of the last expression in the source code. Or a dict of local
118
+ variable names defined in the source code to their values if
119
+ `outputs_intermediate` is set to True. The value for the last line can be
120
+ accessed by key '__result__'. Or the stdout as a str if `returns_stdout`
121
+ is set to True.
111
122
 
112
123
  Raises:
113
124
  TimeoutError: If `sandbox` is True and timeout has reached.
@@ -121,6 +132,8 @@ class PythonCode(pg.Object):
121
132
  max_attempts=autofix,
122
133
  lm=autofix_lm,
123
134
  returns_code=True,
135
+ returns_stdout=returns_stdout,
136
+ outputs_intermediate=outputs_intermediate,
124
137
  )
125
138
  self.rebind(source=updated_code)
126
139
  return result
@@ -158,18 +171,14 @@ class PythonCode(pg.Object):
158
171
  TimeoutError: If `sandbox` is True and timeout has reached.
159
172
  Exception: Any errors that the source code has raised.
160
173
  """
161
- result, updated_code = correction.run_with_correction(
162
- self.source,
163
- global_vars=global_vars,
174
+ return self(
164
175
  sandbox=sandbox,
165
176
  timeout=timeout,
177
+ global_vars=global_vars,
178
+ autofix=autofix,
179
+ autofix_lm=autofix_lm,
166
180
  outputs_intermediate=True,
167
- max_attempts=autofix,
168
- lm=autofix_lm,
169
- returns_code=True,
170
181
  )
171
- self.rebind(source=updated_code)
172
- return result
173
182
 
174
183
 
175
184
  class PythonFunction(pg.Object):
@@ -25,7 +25,6 @@ import time
25
25
  from typing import Any, Callable, Iterable, Iterator, Literal, Sequence, Tuple, Type, Union
26
26
 
27
27
  from langfun.core import component
28
- from langfun.core import text_formatting
29
28
  import pyglove as pg
30
29
 
31
30
 
@@ -844,10 +843,10 @@ class _ConsoleProgressControl(_ProgressControl):
844
843
  def refresh(self):
845
844
  s = io.StringIO()
846
845
  if self.label is not None:
847
- s.write(text_formatting.colored(self.label, 'red', styles=['bold']))
846
+ s.write(pg.colored(self.label, 'red', styles=['bold']))
848
847
  s.write(': ')
849
848
  s.write(
850
- text_formatting.colored(
849
+ pg.colored(
851
850
  '%d%% (%d/%d)' %
852
851
  (
853
852
  self._progress * 100 // self.total,
langfun/core/console.py CHANGED
@@ -15,7 +15,7 @@
15
15
 
16
16
  import sys
17
17
  from typing import Any
18
- from langfun.core.text_formatting import colored
18
+ import pyglove as pg
19
19
 
20
20
 
21
21
  def write(
@@ -42,10 +42,15 @@ def write(
42
42
  """
43
43
  # Print title if present.
44
44
  if title is not None:
45
- print(colored(title, styles=['bold']))
45
+ print(pg.colored(title, styles=['bold']))
46
46
 
47
47
  # Print body.
48
- print(colored(str(value), color=color, background=background, styles=styles))
48
+ print(dir(pg.utils))
49
+ print(
50
+ pg.colored(
51
+ str(value), color=color, background=background, styles=styles
52
+ )
53
+ )
49
54
 
50
55
 
51
56
  try:
langfun/core/eval/base.py CHANGED
@@ -1298,7 +1298,7 @@ class Evaluation(Evaluable):
1298
1298
  id=self.id,
1299
1299
  dir=self.dir,
1300
1300
  model=self.lm.model_id,
1301
- prompt_template=lf.text_formatting.decolored(str(self.prompt)),
1301
+ prompt_template=pg.decolor(str(self.prompt)),
1302
1302
  method=self.method,
1303
1303
  schema_fn=str(self.schema_fn),
1304
1304
  ),
@@ -2110,8 +2110,7 @@ class Summary(pg.Object):
2110
2110
 
2111
2111
  def _format_error(error: Exception):
2112
2112
  """Formats an error into a string."""
2113
- return (f'({error.__class__.__name__}) '
2114
- + lf.text_formatting.decolored(str(error)))
2113
+ return (f'({error.__class__.__name__}) ' + pg.decolor(str(error)))
2115
2114
 
2116
2115
 
2117
2116
  def _error_key(error: Exception) -> str:
@@ -51,6 +51,7 @@ class HtmlReporter(experiment_lib.Plugin):
51
51
  self._update_thread = None
52
52
  self._stop_update = False
53
53
  self._stop_update_experiment_ids = set()
54
+ self._summary_lock = None
54
55
  self._experiment_index_lock = None
55
56
 
56
57
  def on_run_start(
@@ -62,6 +63,7 @@ class HtmlReporter(experiment_lib.Plugin):
62
63
  self._last_experiment_report_time = {leaf.id: 0 for leaf in root.leaf_nodes}
63
64
  self._stop_update = False
64
65
  self._stop_update_experiment_ids = set()
66
+ self._summary_lock = threading.Lock()
65
67
  self._experiment_index_lock = {
66
68
  leaf.id: threading.Lock() for leaf in root.leaf_nodes
67
69
  }
@@ -137,21 +139,23 @@ class HtmlReporter(experiment_lib.Plugin):
137
139
  """Maybe update the summary of current run."""
138
140
  run = runner.current_run
139
141
  def _summary():
140
- run.experiment.to_html(
142
+ html = run.experiment.to_html(
141
143
  collapse_level=None,
142
144
  extra_flags=dict(
143
145
  current_run=run, interactive=False, card_view=True,
144
146
  )
145
- ).save(
146
- run.output_path_for(run.experiment, _SUMMARY_FILE)
147
147
  )
148
+ with self._summary_lock:
149
+ html.save(
150
+ run.output_path_for(run.experiment, _SUMMARY_FILE)
151
+ )
148
152
 
149
153
  if force or (time.time() - self._last_summary_time > self.summary_interval):
154
+ self._last_summary_time = time.time()
150
155
  if background:
151
156
  runner.background_run(_summary)
152
157
  else:
153
158
  _summary()
154
- self._last_summary_time = time.time()
155
159
 
156
160
  def _maybe_update_experiment_html(
157
161
  self,
@@ -434,7 +434,10 @@ class LanguageModel(component.Component):
434
434
  def __init__(self, *args, **kwargs) -> None:
435
435
  """Overrides __init__ to pass through **kwargs to sampling options."""
436
436
 
437
- sampling_options = kwargs.pop('sampling_options', LMSamplingOptions())
437
+ sampling_options = kwargs.pop(
438
+ 'sampling_options',
439
+ pg.clone(self.__schema__.fields['sampling_options'].default_value)
440
+ )
438
441
  sampling_options_delta = {}
439
442
 
440
443
  for k, v in kwargs.items():
@@ -650,7 +653,7 @@ class LanguageModel(component.Component):
650
653
  """Outputs debugging information about the model."""
651
654
  title_suffix = ''
652
655
  if usage.total_tokens != 0:
653
- title_suffix = console.colored(
656
+ title_suffix = pg.colored(
654
657
  f' (total {usage.total_tokens} tokens)', 'red'
655
658
  )
656
659
 
@@ -669,7 +672,7 @@ class LanguageModel(component.Component):
669
672
  """Outputs debugging information about the prompt."""
670
673
  title_suffix = ''
671
674
  if usage.prompt_tokens != 0:
672
- title_suffix = console.colored(f' ({usage.prompt_tokens} tokens)', 'red')
675
+ title_suffix = pg.colored(f' ({usage.prompt_tokens} tokens)', 'red')
673
676
 
674
677
  console.write(
675
678
  # We use metadata 'formatted_text' for scenarios where the prompt text
@@ -700,7 +703,7 @@ class LanguageModel(component.Component):
700
703
  if usage.completion_tokens != 0:
701
704
  title_suffix += f'{usage.completion_tokens} tokens '
702
705
  title_suffix += f'in {elapse:.2f} seconds)'
703
- title_suffix = console.colored(title_suffix, 'red')
706
+ title_suffix = pg.colored(title_suffix, 'red')
704
707
 
705
708
  console.write(
706
709
  str(response) + '\n',
@@ -117,6 +117,21 @@ class LanguageModelTest(unittest.TestCase):
117
117
  self.assertEqual(lm.sampling_options.top_k, 2)
118
118
  self.assertEqual(lm.max_attempts, 2)
119
119
 
120
+ def test_subclassing(self):
121
+
122
+ class ChildModel(lm_lib.LanguageModel):
123
+
124
+ sampling_options = lm_lib.LMSamplingOptions(
125
+ temperature=0.5, top_k=20
126
+ )
127
+
128
+ def _sample(self, *args, **kwargs):
129
+ pass
130
+
131
+ lm = ChildModel(top_k=10)
132
+ self.assertEqual(lm.sampling_options.temperature, 0.5)
133
+ self.assertEqual(lm.sampling_options.top_k, 10)
134
+
120
135
  def test_sample(self):
121
136
  lm = MockModel(top_k=1)
122
137
  self.assertEqual(
@@ -57,6 +57,9 @@ from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_002
57
57
  from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_001
58
58
  from langfun.core.llms.vertexai import VertexAIGeminiPro1
59
59
 
60
+ # Base for OpenAI-compatible models.
61
+ from langfun.core.llms.openai_compatible import OpenAICompatible
62
+
60
63
  # OpenAI models.
61
64
  from langfun.core.llms.openai import OpenAI
62
65
 
@@ -141,6 +144,10 @@ from langfun.core.llms.groq import GroqWhisper_Large_v3Turbo
141
144
  # LLaMA C++ models.
142
145
  from langfun.core.llms.llama_cpp import LlamaCppRemote
143
146
 
147
+ # DeepSeek models.
148
+ from langfun.core.llms.deepseek import DeepSeek
149
+ from langfun.core.llms.deepseek import DeepSeekChat
150
+
144
151
  # Placeholder for Google-internal imports.
145
152
 
146
153
  # Include cache as sub-module.
@@ -0,0 +1,117 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Language models from DeepSeek."""
15
+
16
+ import os
17
+ from typing import Annotated, Any
18
+
19
+ import langfun.core as lf
20
+ from langfun.core.llms import openai_compatible
21
+ import pyglove as pg
22
+
23
+ SUPPORTED_MODELS_AND_SETTINGS = {
24
+ # pylint: disable=g-line-too-long
25
+ # TODO(yifenglu): The RPM and TPM are arbitrary numbers. Update them once DeepSeek provides concrete guidelines.
26
+ # DeepSeek doesn't control the rate limit at the moment: https://api-docs.deepseek.com/quick_start/rate_limit
27
+ # The cost is based on: https://api-docs.deepseek.com/quick_start/pricing
28
+ 'deepseek-chat': pg.Dict(
29
+ in_service=True,
30
+ rpm=100,
31
+ tpm=1000000,
32
+ cost_per_1k_input_tokens=0.00014,
33
+ cost_per_1k_output_tokens=0.00028,
34
+ ),
35
+ }
36
+
37
+
38
+ # DeepSeek API uses an API format compatible with OpenAI.
39
+ # Reference: https://api-docs.deepseek.com/
40
+ @lf.use_init_args(['model'])
41
+ class DeepSeek(openai_compatible.OpenAICompatible):
42
+ """DeepSeek model."""
43
+
44
+ model: pg.typing.Annotated[
45
+ pg.typing.Enum(
46
+ pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
47
+ ),
48
+ 'The name of the model to use.',
49
+ ]
50
+
51
+ api_endpoint: str = 'https://api.deepseek.com/chat/completions'
52
+
53
+ api_key: Annotated[
54
+ str | None,
55
+ (
56
+ 'API key. If None, the key will be read from environment variable '
57
+ "'DEEPSEEK_API_KEY'."
58
+ ),
59
+ ] = None
60
+
61
+ @property
62
+ def headers(self) -> dict[str, Any]:
63
+ api_key = self.api_key or os.environ.get('DEEPSEEK_API_KEY', None)
64
+ if not api_key:
65
+ raise ValueError(
66
+ 'Please specify `api_key` during `__init__` or set environment '
67
+ 'variable `DEEPSEEK_API_KEY` with your DeepSeek API key.'
68
+ )
69
+ headers = super().headers
70
+ headers.update({
71
+ 'Authorization': f'Bearer {api_key}',
72
+ })
73
+ return headers
74
+
75
+ @property
76
+ def model_id(self) -> str:
77
+ """Returns a string to identify the model."""
78
+ return f'DeepSeek({self.model})'
79
+
80
+ @property
81
+ def max_concurrency(self) -> int:
82
+ rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
83
+ tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
84
+ return self.rate_to_max_concurrency(
85
+ requests_per_min=rpm, tokens_per_min=tpm
86
+ )
87
+
88
+ def estimate_cost(
89
+ self, num_input_tokens: int, num_output_tokens: int
90
+ ) -> float | None:
91
+ """Estimate the cost based on usage."""
92
+ cost_per_1k_input_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
93
+ 'cost_per_1k_input_tokens', None
94
+ )
95
+ cost_per_1k_output_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
96
+ 'cost_per_1k_output_tokens', None
97
+ )
98
+ if cost_per_1k_output_tokens is None or cost_per_1k_input_tokens is None:
99
+ return None
100
+ return (
101
+ cost_per_1k_input_tokens * num_input_tokens
102
+ + cost_per_1k_output_tokens * num_output_tokens
103
+ ) / 1000
104
+
105
+ @classmethod
106
+ def dir(cls):
107
+ return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
108
+
109
+
110
+ class DeepSeekChat(DeepSeek):
111
+ """DeepSeek Chat model.
112
+
113
+ Currently, it is powered by DeepSeek-V3 model, 64K input contenxt window and
114
+ 8k max output tokens.
115
+ """
116
+
117
+ model = 'deepseek-chat'
@@ -0,0 +1,61 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import unittest
15
+ from langfun.core.llms import deepseek
16
+
17
+
18
+ class DeepSeekTest(unittest.TestCase):
19
+ """Tests for DeepSeek language model."""
20
+
21
+ def test_dir(self):
22
+ self.assertIn('deepseek-chat', deepseek.DeepSeek.dir())
23
+
24
+ def test_key(self):
25
+ with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
26
+ _ = deepseek.DeepSeekChat().headers
27
+ self.assertEqual(
28
+ deepseek.DeepSeekChat(api_key='test_key').headers,
29
+ {
30
+ 'Content-Type': 'application/json',
31
+ 'Authorization': 'Bearer test_key',
32
+ }
33
+ )
34
+
35
+ def test_model_id(self):
36
+ self.assertEqual(
37
+ deepseek.DeepSeekChat(api_key='test_key').model_id,
38
+ 'DeepSeek(deepseek-chat)',
39
+ )
40
+
41
+ def test_resource_id(self):
42
+ self.assertEqual(
43
+ deepseek.DeepSeekChat(api_key='test_key').resource_id,
44
+ 'DeepSeek(deepseek-chat)',
45
+ )
46
+
47
+ def test_max_concurrency(self):
48
+ self.assertGreater(
49
+ deepseek.DeepSeekChat(api_key='test_key').max_concurrency, 0
50
+ )
51
+
52
+ def test_estimate_cost(self):
53
+ self.assertEqual(
54
+ deepseek.DeepSeekChat(api_key='test_key').estimate_cost(
55
+ num_input_tokens=100, num_output_tokens=100
56
+ ),
57
+ 4.2e-5
58
+ )
59
+
60
+ if __name__ == '__main__':
61
+ unittest.main()