langfun 0.1.2.dev202411090804__py3-none-any.whl → 0.1.2.dev202411140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (36) hide show
  1. langfun/core/console.py +10 -2
  2. langfun/core/console_test.py +17 -0
  3. langfun/core/eval/__init__.py +2 -0
  4. langfun/core/eval/v2/__init__.py +38 -0
  5. langfun/core/eval/v2/checkpointing.py +135 -0
  6. langfun/core/eval/v2/checkpointing_test.py +89 -0
  7. langfun/core/eval/v2/evaluation.py +627 -0
  8. langfun/core/eval/v2/evaluation_test.py +156 -0
  9. langfun/core/eval/v2/example.py +295 -0
  10. langfun/core/eval/v2/example_test.py +114 -0
  11. langfun/core/eval/v2/experiment.py +949 -0
  12. langfun/core/eval/v2/experiment_test.py +304 -0
  13. langfun/core/eval/v2/metric_values.py +156 -0
  14. langfun/core/eval/v2/metric_values_test.py +80 -0
  15. langfun/core/eval/v2/metrics.py +357 -0
  16. langfun/core/eval/v2/metrics_test.py +203 -0
  17. langfun/core/eval/v2/progress.py +348 -0
  18. langfun/core/eval/v2/progress_test.py +82 -0
  19. langfun/core/eval/v2/progress_tracking.py +209 -0
  20. langfun/core/eval/v2/progress_tracking_test.py +56 -0
  21. langfun/core/eval/v2/reporting.py +144 -0
  22. langfun/core/eval/v2/reporting_test.py +41 -0
  23. langfun/core/eval/v2/runners.py +417 -0
  24. langfun/core/eval/v2/runners_test.py +311 -0
  25. langfun/core/eval/v2/test_helper.py +80 -0
  26. langfun/core/language_model.py +122 -11
  27. langfun/core/language_model_test.py +97 -4
  28. langfun/core/llms/__init__.py +3 -0
  29. langfun/core/llms/compositional.py +101 -0
  30. langfun/core/llms/compositional_test.py +73 -0
  31. langfun/core/llms/vertexai.py +4 -4
  32. {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/METADATA +1 -1
  33. {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/RECORD +36 -12
  34. {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/WHEEL +1 -1
  35. {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/LICENSE +0 -0
  36. {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/top_level.txt +0 -0
langfun/core/console.py CHANGED
@@ -59,12 +59,20 @@ def under_notebook() -> bool:
59
59
  return bool(_notebook)
60
60
 
61
61
 
62
- def display(value: Any, clear: bool = False) -> None: # pylint: disable=redefined-outer-name
62
+ def display(value: Any, clear: bool = False) -> Any: # pylint: disable=redefined-outer-name
63
63
  """Displays object in current notebook cell."""
64
64
  if _notebook is not None:
65
65
  if clear:
66
66
  _notebook.clear_output()
67
- _notebook.display(value)
67
+ return _notebook.display(value)
68
+ return None
69
+
70
+
71
+ def run_script(javascript: str) -> Any:
72
+ """Runs JavaScript in current notebook cell."""
73
+ if _notebook is not None:
74
+ return _notebook.display(_notebook.Javascript(javascript))
75
+ return
68
76
 
69
77
 
70
78
  def clear() -> None:
@@ -18,6 +18,7 @@ import io
18
18
  import unittest
19
19
 
20
20
  from langfun.core import console
21
+ import pyglove as pg
21
22
 
22
23
 
23
24
  class ConsoleTest(unittest.TestCase):
@@ -32,6 +33,22 @@ class ConsoleTest(unittest.TestCase):
32
33
 
33
34
  def test_under_notebook(self):
34
35
  self.assertFalse(console.under_notebook())
36
+ console._notebook = True
37
+ self.assertTrue(console.under_notebook())
38
+ console._notebook = None
39
+
40
+ def test_notebook_interaction(self):
41
+ console._notebook = pg.Dict(
42
+ display=lambda x: x, Javascript=lambda x: x, clear_output=lambda: None)
43
+ self.assertEqual(console.display('hi', clear=True), 'hi')
44
+ self.assertEqual(
45
+ console.run_script('console.log("hi")'),
46
+ 'console.log("hi")'
47
+ )
48
+ console.clear()
49
+ console._notebook = None
50
+ self.assertIsNone(console.display('hi'))
51
+ self.assertIsNone(console.run_script('console.log("hi")'))
35
52
 
36
53
 
37
54
  if __name__ == '__main__':
@@ -16,6 +16,8 @@
16
16
  # pylint: disable=g-importing-member
17
17
  # pylint: disable=g-bad-import-order
18
18
 
19
+ from langfun.core.eval import v2
20
+
19
21
  from langfun.core.eval.base import register
20
22
  from langfun.core.eval.base import registered_names
21
23
  from langfun.core.eval.base import get_evaluations
@@ -0,0 +1,38 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """langfun eval framework v2."""
15
+
16
+ # pylint: disable=g-importing-member
17
+ # pylint: disable=g-bad-import-order
18
+ from langfun.core.eval.v2.experiment import Experiment
19
+ from langfun.core.eval.v2.experiment import Suite
20
+ from langfun.core.eval.v2.evaluation import Evaluation
21
+
22
+ from langfun.core.eval.v2.example import Example
23
+ from langfun.core.eval.v2.progress import Progress
24
+
25
+ from langfun.core.eval.v2.metric_values import MetricValue
26
+ from langfun.core.eval.v2.metric_values import Rate
27
+ from langfun.core.eval.v2.metric_values import Average
28
+ from langfun.core.eval.v2.metrics import Metric
29
+ from langfun.core.eval.v2 import metrics
30
+
31
+ from langfun.core.eval.v2.experiment import Plugin
32
+
33
+ from langfun.core.eval.v2.experiment import Runner
34
+ from langfun.core.eval.v2 import runners
35
+
36
+
37
+ # pylint: enable=g-bad-import-order
38
+ # pylint: enable=g-importing-member
@@ -0,0 +1,135 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Checkpointing evaluation runs."""
15
+ import threading
16
+
17
+ from langfun.core.eval.v2 import example as example_lib
18
+ from langfun.core.eval.v2 import experiment as experiment_lib
19
+ import pyglove as pg
20
+
21
+ Example = example_lib.Example
22
+ Experiment = experiment_lib.Experiment
23
+ Runner = experiment_lib.Runner
24
+
25
+
26
+ class Checkpointer(experiment_lib.Plugin):
27
+ """Plugin for checkpointing evaluation runs."""
28
+
29
+ checkpoint_filename: str = 'checkpoint.bagz'
30
+
31
+ def _on_bound(self):
32
+ super()._on_bound()
33
+ self._lock = threading.Lock()
34
+ self._state_writer = None
35
+
36
+ def on_run_start(
37
+ self,
38
+ runner: Runner,
39
+ root: Experiment,
40
+ ) -> None:
41
+ self._state_writer = {}
42
+
43
+ def on_run_abort(
44
+ self,
45
+ runner: Runner,
46
+ root: Experiment,
47
+ error: BaseException
48
+ ) -> None:
49
+ with self._lock:
50
+ if self._state_writer is not None:
51
+ self._state_writer.clear()
52
+
53
+ def on_run_complete(
54
+ self,
55
+ runner: Runner,
56
+ root: Experiment,
57
+ ) -> None:
58
+ with self._lock:
59
+ assert self._state_writer is not None and not self._state_writer
60
+
61
+ def on_experiment_start(
62
+ self,
63
+ runner: Runner,
64
+ experiment: Experiment,
65
+ ) -> None:
66
+ """Creates the checkpoint file."""
67
+ if not experiment.is_leaf:
68
+ return
69
+ # For refresh runs, we don't want to load the previous state.
70
+ if not runner.current_run.refresh:
71
+ experiment.load_state(
72
+ runner.current_run.input_path_for(
73
+ experiment, self.checkpoint_filename
74
+ ),
75
+ raise_if_not_exist=False
76
+ )
77
+ state_writer = StateWriter(
78
+ runner.current_run.output_path_for(
79
+ experiment, self.checkpoint_filename
80
+ )
81
+ )
82
+ with self._lock:
83
+ if self._state_writer is not None:
84
+ self._state_writer[experiment.id] = state_writer
85
+
86
+ def on_experiment_complete(
87
+ self,
88
+ runner: Runner,
89
+ experiment: Experiment,
90
+ ) -> None:
91
+ """Closes the checkpoint file."""
92
+ if not experiment.is_leaf:
93
+ return
94
+ assert experiment.id in self._state_writer
95
+ with self._lock:
96
+ if self._state_writer is not None:
97
+ del self._state_writer[experiment.id]
98
+
99
+ def on_example_complete(
100
+ self,
101
+ runner: Runner,
102
+ experiment: Experiment,
103
+ example: Example,
104
+ ) -> None:
105
+ """Saves the example to the checkpoint file."""
106
+ assert experiment.id in self._state_writer
107
+ if not example.has_error:
108
+ runner.background_run(self._state_writer[experiment.id].add, example)
109
+
110
+
111
+ class StateWriter:
112
+ """Thread safe state writer."""
113
+
114
+ def __init__(self, path: str):
115
+ self._lock = threading.Lock()
116
+ self._sequence_writer = pg.io.open_sequence(path, 'w')
117
+
118
+ def add(self, example: Example):
119
+ example_blob = pg.to_json_str(
120
+ example,
121
+ hide_default_values=True,
122
+ save_ref_value=True,
123
+ exclude_input=True
124
+ )
125
+ with self._lock:
126
+ if self._sequence_writer is None:
127
+ return
128
+ self._sequence_writer.add(example_blob)
129
+
130
+ def __del__(self):
131
+ # Make sure there is no write in progress.
132
+ with self._lock:
133
+ assert self._sequence_writer is not None
134
+ self._sequence_writer.close()
135
+ self._sequence_writer = None
@@ -0,0 +1,89 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import tempfile
16
+ import unittest
17
+
18
+ from langfun.core.eval.v2 import checkpointing
19
+ from langfun.core.eval.v2 import example as example_lib
20
+ from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
21
+ from langfun.core.eval.v2 import test_helper
22
+ import pyglove as pg
23
+
24
+ Example = example_lib.Example
25
+
26
+
27
+ class StateWriterTest(unittest.TestCase):
28
+
29
+ def test_basic(self):
30
+ file = os.path.join(tempfile.gettempdir(), 'test.jsonl')
31
+ writer = checkpointing.StateWriter(file)
32
+ example = Example(id=1, input=pg.Dict(x=1), output=2)
33
+ writer.add(example)
34
+ del writer
35
+ self.assertTrue(pg.io.path_exists(file))
36
+
37
+ def test_error_handling(self):
38
+ file = os.path.join(tempfile.gettempdir(), 'test_error_handling.jsonl')
39
+ writer = checkpointing.StateWriter(file)
40
+ writer.add(Example(id=1, input=pg.Dict(x=1), output=2))
41
+
42
+ def f():
43
+ raise ValueError('Intentional error')
44
+
45
+ try:
46
+ writer.add(f())
47
+ except ValueError:
48
+ del writer
49
+
50
+ self.assertTrue(pg.io.path_exists(file))
51
+ with pg.io.open_sequence(file, 'r') as f:
52
+ self.assertEqual(len(list(iter(f))), 1)
53
+
54
+
55
+ class CheckpointingTest(unittest.TestCase):
56
+
57
+ def test_checkpointing(self):
58
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_checkpointing')
59
+ experiment = test_helper.test_experiment()
60
+ checkpoint_filename = 'checkpoint.jsonl'
61
+ checkpointer = checkpointing.Checkpointer(checkpoint_filename)
62
+ run = experiment.run(
63
+ root_dir, 'new', runner='sequential', plugins=[checkpointer]
64
+ )
65
+ self.assertEqual(len(checkpointer._state_writer), 0)
66
+ num_processed = {}
67
+ for leaf in experiment.leaf_nodes:
68
+ ckpt = run.output_path_for(leaf, checkpoint_filename)
69
+ self.assertTrue(pg.io.path_exists(ckpt))
70
+ with pg.io.open_sequence(ckpt) as f:
71
+ self.assertEqual(
72
+ len(list(iter(f))),
73
+ leaf.progress.num_completed - leaf.progress.num_failed
74
+ )
75
+ if leaf.id not in num_processed:
76
+ self.assertEqual(leaf.progress.num_skipped, 0)
77
+ num_processed[leaf.id] = leaf.progress.num_processed
78
+
79
+ # Run again, should skip existing.
80
+ _ = experiment.run(
81
+ root_dir, 'latest', runner='sequential', plugins=[checkpointer]
82
+ )
83
+ self.assertEqual(len(checkpointer._state_writer), 0)
84
+ for leaf in experiment.leaf_nodes:
85
+ self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])
86
+
87
+
88
+ if __name__ == '__main__':
89
+ unittest.main()