langfun 0.1.2.dev202411090804__py3-none-any.whl → 0.1.2.dev202411140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/core/console.py +10 -2
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +2 -0
- langfun/core/eval/v2/__init__.py +38 -0
- langfun/core/eval/v2/checkpointing.py +135 -0
- langfun/core/eval/v2/checkpointing_test.py +89 -0
- langfun/core/eval/v2/evaluation.py +627 -0
- langfun/core/eval/v2/evaluation_test.py +156 -0
- langfun/core/eval/v2/example.py +295 -0
- langfun/core/eval/v2/example_test.py +114 -0
- langfun/core/eval/v2/experiment.py +949 -0
- langfun/core/eval/v2/experiment_test.py +304 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +209 -0
- langfun/core/eval/v2/progress_tracking_test.py +56 -0
- langfun/core/eval/v2/reporting.py +144 -0
- langfun/core/eval/v2/reporting_test.py +41 -0
- langfun/core/eval/v2/runners.py +417 -0
- langfun/core/eval/v2/runners_test.py +311 -0
- langfun/core/eval/v2/test_helper.py +80 -0
- langfun/core/language_model.py +122 -11
- langfun/core/language_model_test.py +97 -4
- langfun/core/llms/__init__.py +3 -0
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/vertexai.py +4 -4
- {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/RECORD +36 -12
- {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/WHEEL +1 -1
- {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/top_level.txt +0 -0
langfun/core/console.py
CHANGED
@@ -59,12 +59,20 @@ def under_notebook() -> bool:
|
|
59
59
|
return bool(_notebook)
|
60
60
|
|
61
61
|
|
62
|
-
def display(value: Any, clear: bool = False) ->
|
62
|
+
def display(value: Any, clear: bool = False) -> Any: # pylint: disable=redefined-outer-name
|
63
63
|
"""Displays object in current notebook cell."""
|
64
64
|
if _notebook is not None:
|
65
65
|
if clear:
|
66
66
|
_notebook.clear_output()
|
67
|
-
_notebook.display(value)
|
67
|
+
return _notebook.display(value)
|
68
|
+
return None
|
69
|
+
|
70
|
+
|
71
|
+
def run_script(javascript: str) -> Any:
|
72
|
+
"""Runs JavaScript in current notebook cell."""
|
73
|
+
if _notebook is not None:
|
74
|
+
return _notebook.display(_notebook.Javascript(javascript))
|
75
|
+
return
|
68
76
|
|
69
77
|
|
70
78
|
def clear() -> None:
|
langfun/core/console_test.py
CHANGED
@@ -18,6 +18,7 @@ import io
|
|
18
18
|
import unittest
|
19
19
|
|
20
20
|
from langfun.core import console
|
21
|
+
import pyglove as pg
|
21
22
|
|
22
23
|
|
23
24
|
class ConsoleTest(unittest.TestCase):
|
@@ -32,6 +33,22 @@ class ConsoleTest(unittest.TestCase):
|
|
32
33
|
|
33
34
|
def test_under_notebook(self):
|
34
35
|
self.assertFalse(console.under_notebook())
|
36
|
+
console._notebook = True
|
37
|
+
self.assertTrue(console.under_notebook())
|
38
|
+
console._notebook = None
|
39
|
+
|
40
|
+
def test_notebook_interaction(self):
|
41
|
+
console._notebook = pg.Dict(
|
42
|
+
display=lambda x: x, Javascript=lambda x: x, clear_output=lambda: None)
|
43
|
+
self.assertEqual(console.display('hi', clear=True), 'hi')
|
44
|
+
self.assertEqual(
|
45
|
+
console.run_script('console.log("hi")'),
|
46
|
+
'console.log("hi")'
|
47
|
+
)
|
48
|
+
console.clear()
|
49
|
+
console._notebook = None
|
50
|
+
self.assertIsNone(console.display('hi'))
|
51
|
+
self.assertIsNone(console.run_script('console.log("hi")'))
|
35
52
|
|
36
53
|
|
37
54
|
if __name__ == '__main__':
|
langfun/core/eval/__init__.py
CHANGED
@@ -16,6 +16,8 @@
|
|
16
16
|
# pylint: disable=g-importing-member
|
17
17
|
# pylint: disable=g-bad-import-order
|
18
18
|
|
19
|
+
from langfun.core.eval import v2
|
20
|
+
|
19
21
|
from langfun.core.eval.base import register
|
20
22
|
from langfun.core.eval.base import registered_names
|
21
23
|
from langfun.core.eval.base import get_evaluations
|
@@ -0,0 +1,38 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""langfun eval framework v2."""
|
15
|
+
|
16
|
+
# pylint: disable=g-importing-member
|
17
|
+
# pylint: disable=g-bad-import-order
|
18
|
+
from langfun.core.eval.v2.experiment import Experiment
|
19
|
+
from langfun.core.eval.v2.experiment import Suite
|
20
|
+
from langfun.core.eval.v2.evaluation import Evaluation
|
21
|
+
|
22
|
+
from langfun.core.eval.v2.example import Example
|
23
|
+
from langfun.core.eval.v2.progress import Progress
|
24
|
+
|
25
|
+
from langfun.core.eval.v2.metric_values import MetricValue
|
26
|
+
from langfun.core.eval.v2.metric_values import Rate
|
27
|
+
from langfun.core.eval.v2.metric_values import Average
|
28
|
+
from langfun.core.eval.v2.metrics import Metric
|
29
|
+
from langfun.core.eval.v2 import metrics
|
30
|
+
|
31
|
+
from langfun.core.eval.v2.experiment import Plugin
|
32
|
+
|
33
|
+
from langfun.core.eval.v2.experiment import Runner
|
34
|
+
from langfun.core.eval.v2 import runners
|
35
|
+
|
36
|
+
|
37
|
+
# pylint: enable=g-bad-import-order
|
38
|
+
# pylint: enable=g-importing-member
|
@@ -0,0 +1,135 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Checkpointing evaluation runs."""
|
15
|
+
import threading
|
16
|
+
|
17
|
+
from langfun.core.eval.v2 import example as example_lib
|
18
|
+
from langfun.core.eval.v2 import experiment as experiment_lib
|
19
|
+
import pyglove as pg
|
20
|
+
|
21
|
+
Example = example_lib.Example
|
22
|
+
Experiment = experiment_lib.Experiment
|
23
|
+
Runner = experiment_lib.Runner
|
24
|
+
|
25
|
+
|
26
|
+
class Checkpointer(experiment_lib.Plugin):
|
27
|
+
"""Plugin for checkpointing evaluation runs."""
|
28
|
+
|
29
|
+
checkpoint_filename: str = 'checkpoint.bagz'
|
30
|
+
|
31
|
+
def _on_bound(self):
|
32
|
+
super()._on_bound()
|
33
|
+
self._lock = threading.Lock()
|
34
|
+
self._state_writer = None
|
35
|
+
|
36
|
+
def on_run_start(
|
37
|
+
self,
|
38
|
+
runner: Runner,
|
39
|
+
root: Experiment,
|
40
|
+
) -> None:
|
41
|
+
self._state_writer = {}
|
42
|
+
|
43
|
+
def on_run_abort(
|
44
|
+
self,
|
45
|
+
runner: Runner,
|
46
|
+
root: Experiment,
|
47
|
+
error: BaseException
|
48
|
+
) -> None:
|
49
|
+
with self._lock:
|
50
|
+
if self._state_writer is not None:
|
51
|
+
self._state_writer.clear()
|
52
|
+
|
53
|
+
def on_run_complete(
|
54
|
+
self,
|
55
|
+
runner: Runner,
|
56
|
+
root: Experiment,
|
57
|
+
) -> None:
|
58
|
+
with self._lock:
|
59
|
+
assert self._state_writer is not None and not self._state_writer
|
60
|
+
|
61
|
+
def on_experiment_start(
|
62
|
+
self,
|
63
|
+
runner: Runner,
|
64
|
+
experiment: Experiment,
|
65
|
+
) -> None:
|
66
|
+
"""Creates the checkpoint file."""
|
67
|
+
if not experiment.is_leaf:
|
68
|
+
return
|
69
|
+
# For refresh runs, we don't want to load the previous state.
|
70
|
+
if not runner.current_run.refresh:
|
71
|
+
experiment.load_state(
|
72
|
+
runner.current_run.input_path_for(
|
73
|
+
experiment, self.checkpoint_filename
|
74
|
+
),
|
75
|
+
raise_if_not_exist=False
|
76
|
+
)
|
77
|
+
state_writer = StateWriter(
|
78
|
+
runner.current_run.output_path_for(
|
79
|
+
experiment, self.checkpoint_filename
|
80
|
+
)
|
81
|
+
)
|
82
|
+
with self._lock:
|
83
|
+
if self._state_writer is not None:
|
84
|
+
self._state_writer[experiment.id] = state_writer
|
85
|
+
|
86
|
+
def on_experiment_complete(
|
87
|
+
self,
|
88
|
+
runner: Runner,
|
89
|
+
experiment: Experiment,
|
90
|
+
) -> None:
|
91
|
+
"""Closes the checkpoint file."""
|
92
|
+
if not experiment.is_leaf:
|
93
|
+
return
|
94
|
+
assert experiment.id in self._state_writer
|
95
|
+
with self._lock:
|
96
|
+
if self._state_writer is not None:
|
97
|
+
del self._state_writer[experiment.id]
|
98
|
+
|
99
|
+
def on_example_complete(
|
100
|
+
self,
|
101
|
+
runner: Runner,
|
102
|
+
experiment: Experiment,
|
103
|
+
example: Example,
|
104
|
+
) -> None:
|
105
|
+
"""Saves the example to the checkpoint file."""
|
106
|
+
assert experiment.id in self._state_writer
|
107
|
+
if not example.has_error:
|
108
|
+
runner.background_run(self._state_writer[experiment.id].add, example)
|
109
|
+
|
110
|
+
|
111
|
+
class StateWriter:
|
112
|
+
"""Thread safe state writer."""
|
113
|
+
|
114
|
+
def __init__(self, path: str):
|
115
|
+
self._lock = threading.Lock()
|
116
|
+
self._sequence_writer = pg.io.open_sequence(path, 'w')
|
117
|
+
|
118
|
+
def add(self, example: Example):
|
119
|
+
example_blob = pg.to_json_str(
|
120
|
+
example,
|
121
|
+
hide_default_values=True,
|
122
|
+
save_ref_value=True,
|
123
|
+
exclude_input=True
|
124
|
+
)
|
125
|
+
with self._lock:
|
126
|
+
if self._sequence_writer is None:
|
127
|
+
return
|
128
|
+
self._sequence_writer.add(example_blob)
|
129
|
+
|
130
|
+
def __del__(self):
|
131
|
+
# Make sure there is no write in progress.
|
132
|
+
with self._lock:
|
133
|
+
assert self._sequence_writer is not None
|
134
|
+
self._sequence_writer.close()
|
135
|
+
self._sequence_writer = None
|
@@ -0,0 +1,89 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import os
|
15
|
+
import tempfile
|
16
|
+
import unittest
|
17
|
+
|
18
|
+
from langfun.core.eval.v2 import checkpointing
|
19
|
+
from langfun.core.eval.v2 import example as example_lib
|
20
|
+
from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
|
21
|
+
from langfun.core.eval.v2 import test_helper
|
22
|
+
import pyglove as pg
|
23
|
+
|
24
|
+
Example = example_lib.Example
|
25
|
+
|
26
|
+
|
27
|
+
class StateWriterTest(unittest.TestCase):
|
28
|
+
|
29
|
+
def test_basic(self):
|
30
|
+
file = os.path.join(tempfile.gettempdir(), 'test.jsonl')
|
31
|
+
writer = checkpointing.StateWriter(file)
|
32
|
+
example = Example(id=1, input=pg.Dict(x=1), output=2)
|
33
|
+
writer.add(example)
|
34
|
+
del writer
|
35
|
+
self.assertTrue(pg.io.path_exists(file))
|
36
|
+
|
37
|
+
def test_error_handling(self):
|
38
|
+
file = os.path.join(tempfile.gettempdir(), 'test_error_handling.jsonl')
|
39
|
+
writer = checkpointing.StateWriter(file)
|
40
|
+
writer.add(Example(id=1, input=pg.Dict(x=1), output=2))
|
41
|
+
|
42
|
+
def f():
|
43
|
+
raise ValueError('Intentional error')
|
44
|
+
|
45
|
+
try:
|
46
|
+
writer.add(f())
|
47
|
+
except ValueError:
|
48
|
+
del writer
|
49
|
+
|
50
|
+
self.assertTrue(pg.io.path_exists(file))
|
51
|
+
with pg.io.open_sequence(file, 'r') as f:
|
52
|
+
self.assertEqual(len(list(iter(f))), 1)
|
53
|
+
|
54
|
+
|
55
|
+
class CheckpointingTest(unittest.TestCase):
|
56
|
+
|
57
|
+
def test_checkpointing(self):
|
58
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_checkpointing')
|
59
|
+
experiment = test_helper.test_experiment()
|
60
|
+
checkpoint_filename = 'checkpoint.jsonl'
|
61
|
+
checkpointer = checkpointing.Checkpointer(checkpoint_filename)
|
62
|
+
run = experiment.run(
|
63
|
+
root_dir, 'new', runner='sequential', plugins=[checkpointer]
|
64
|
+
)
|
65
|
+
self.assertEqual(len(checkpointer._state_writer), 0)
|
66
|
+
num_processed = {}
|
67
|
+
for leaf in experiment.leaf_nodes:
|
68
|
+
ckpt = run.output_path_for(leaf, checkpoint_filename)
|
69
|
+
self.assertTrue(pg.io.path_exists(ckpt))
|
70
|
+
with pg.io.open_sequence(ckpt) as f:
|
71
|
+
self.assertEqual(
|
72
|
+
len(list(iter(f))),
|
73
|
+
leaf.progress.num_completed - leaf.progress.num_failed
|
74
|
+
)
|
75
|
+
if leaf.id not in num_processed:
|
76
|
+
self.assertEqual(leaf.progress.num_skipped, 0)
|
77
|
+
num_processed[leaf.id] = leaf.progress.num_processed
|
78
|
+
|
79
|
+
# Run again, should skip existing.
|
80
|
+
_ = experiment.run(
|
81
|
+
root_dir, 'latest', runner='sequential', plugins=[checkpointer]
|
82
|
+
)
|
83
|
+
self.assertEqual(len(checkpointer._state_writer), 0)
|
84
|
+
for leaf in experiment.leaf_nodes:
|
85
|
+
self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])
|
86
|
+
|
87
|
+
|
88
|
+
if __name__ == '__main__':
|
89
|
+
unittest.main()
|