langfun 0.1.2.dev202411090804__py3-none-any.whl → 0.1.2.dev202411140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/core/console.py +10 -2
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +2 -0
- langfun/core/eval/v2/__init__.py +38 -0
- langfun/core/eval/v2/checkpointing.py +135 -0
- langfun/core/eval/v2/checkpointing_test.py +89 -0
- langfun/core/eval/v2/evaluation.py +627 -0
- langfun/core/eval/v2/evaluation_test.py +156 -0
- langfun/core/eval/v2/example.py +295 -0
- langfun/core/eval/v2/example_test.py +114 -0
- langfun/core/eval/v2/experiment.py +949 -0
- langfun/core/eval/v2/experiment_test.py +304 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +209 -0
- langfun/core/eval/v2/progress_tracking_test.py +56 -0
- langfun/core/eval/v2/reporting.py +144 -0
- langfun/core/eval/v2/reporting_test.py +41 -0
- langfun/core/eval/v2/runners.py +417 -0
- langfun/core/eval/v2/runners_test.py +311 -0
- langfun/core/eval/v2/test_helper.py +80 -0
- langfun/core/language_model.py +122 -11
- langfun/core/language_model_test.py +97 -4
- langfun/core/llms/__init__.py +3 -0
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/vertexai.py +4 -4
- {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/RECORD +36 -12
- {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/WHEEL +1 -1
- {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,311 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import os
|
15
|
+
import tempfile
|
16
|
+
import threading
|
17
|
+
import time
|
18
|
+
from typing import Any
|
19
|
+
import unittest
|
20
|
+
|
21
|
+
from langfun.core.eval.v2 import example as example_lib
|
22
|
+
from langfun.core.eval.v2 import experiment as experiment_lib
|
23
|
+
from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
|
24
|
+
from langfun.core.eval.v2 import test_helper
|
25
|
+
import pyglove as pg
|
26
|
+
|
27
|
+
|
28
|
+
Runner = experiment_lib.Runner
|
29
|
+
Example = example_lib.Example
|
30
|
+
Experiment = experiment_lib.Experiment
|
31
|
+
Suite = experiment_lib.Suite
|
32
|
+
Plugin = experiment_lib.Plugin
|
33
|
+
|
34
|
+
|
35
|
+
class TestPlugin(Plugin):
|
36
|
+
started_experiments: list[Experiment] = []
|
37
|
+
completed_experiments: list[Experiment] = []
|
38
|
+
skipped_experiments: list[Experiment] = []
|
39
|
+
started_example_ids: list[int] = []
|
40
|
+
completed_example_ids: list[int] = []
|
41
|
+
skipped_example_ids: list[int] = []
|
42
|
+
start_time: float | None = None
|
43
|
+
complete_time: float | None = None
|
44
|
+
|
45
|
+
def _on_bound(self):
|
46
|
+
super()._on_bound()
|
47
|
+
self._lock = threading.Lock()
|
48
|
+
|
49
|
+
def on_run_start(self, runner: Runner, root: Experiment):
|
50
|
+
del root
|
51
|
+
with pg.notify_on_change(False), pg.allow_writable_accessors(True):
|
52
|
+
self.start_time = time.time()
|
53
|
+
|
54
|
+
def on_run_complete(self, runner: Runner, root: Experiment):
|
55
|
+
del root
|
56
|
+
with pg.notify_on_change(False), pg.allow_writable_accessors(True):
|
57
|
+
self.complete_time = time.time()
|
58
|
+
|
59
|
+
def on_experiment_start(self, runner: Runner, experiment: Experiment):
|
60
|
+
del runner
|
61
|
+
with pg.notify_on_change(False), self._lock:
|
62
|
+
self.started_experiments.append(pg.Ref(experiment))
|
63
|
+
|
64
|
+
def on_experiment_skipped(self, runner: Runner, experiment: Experiment):
|
65
|
+
del runner
|
66
|
+
with pg.notify_on_change(False), self._lock:
|
67
|
+
self.skipped_experiments.append(pg.Ref(experiment))
|
68
|
+
|
69
|
+
def on_experiment_complete(self, runner: Runner, experiment: Experiment):
|
70
|
+
del runner
|
71
|
+
with pg.notify_on_change(False), self._lock:
|
72
|
+
self.completed_experiments.append(pg.Ref(experiment))
|
73
|
+
|
74
|
+
def on_example_start(
|
75
|
+
self, runner: Runner, experiment: Experiment, example: Example):
|
76
|
+
del runner, experiment
|
77
|
+
with pg.notify_on_change(False), self._lock:
|
78
|
+
self.started_example_ids.append(example.id)
|
79
|
+
|
80
|
+
def on_example_skipped(
|
81
|
+
self, runner: Runner, experiment: Experiment, example: Example):
|
82
|
+
del runner, experiment
|
83
|
+
with pg.notify_on_change(False), self._lock:
|
84
|
+
self.skipped_example_ids.append(example.id)
|
85
|
+
|
86
|
+
def on_example_complete(
|
87
|
+
self, runner: Runner, experiment: Experiment, example: Example):
|
88
|
+
del runner, experiment
|
89
|
+
with pg.notify_on_change(False), self._lock:
|
90
|
+
self.completed_example_ids.append(example.id)
|
91
|
+
|
92
|
+
|
93
|
+
class RunnerTest(unittest.TestCase):
|
94
|
+
|
95
|
+
def assert_same_list(self, actual: list[Any], expected: list[Any]):
|
96
|
+
self.assertEqual(len(actual), len(expected))
|
97
|
+
for i, (x, y) in enumerate(zip(actual, expected)):
|
98
|
+
if x is not y:
|
99
|
+
print(i, pg.diff(x, y))
|
100
|
+
self.assertIs(x, y)
|
101
|
+
|
102
|
+
def test_basic(self):
|
103
|
+
plugin = TestPlugin()
|
104
|
+
exp = test_helper.test_experiment()
|
105
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_sequential_runner')
|
106
|
+
run = exp.run(root_dir, runner='sequential', plugins=[plugin])
|
107
|
+
|
108
|
+
self.assertIsNotNone(plugin.start_time)
|
109
|
+
self.assertIsNotNone(plugin.complete_time)
|
110
|
+
self.assertGreater(plugin.complete_time, plugin.start_time)
|
111
|
+
|
112
|
+
self.assert_same_list(
|
113
|
+
plugin.started_experiments,
|
114
|
+
exp.nonleaf_nodes + exp.leaf_nodes
|
115
|
+
)
|
116
|
+
self.assert_same_list(
|
117
|
+
plugin.completed_experiments,
|
118
|
+
exp.leaf_nodes + list(reversed(exp.nonleaf_nodes))
|
119
|
+
)
|
120
|
+
self.assert_same_list(
|
121
|
+
plugin.started_example_ids, list(range(1, 11)) * 6
|
122
|
+
)
|
123
|
+
self.assert_same_list(
|
124
|
+
plugin.completed_example_ids, list(range(1, 11)) * 6
|
125
|
+
)
|
126
|
+
self.assert_same_list(plugin.skipped_experiments, [])
|
127
|
+
self.assert_same_list(plugin.skipped_example_ids, [])
|
128
|
+
self.assertTrue(
|
129
|
+
pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
|
130
|
+
)
|
131
|
+
|
132
|
+
for node in exp.nodes:
|
133
|
+
self.assertTrue(node.progress.is_started)
|
134
|
+
self.assertTrue(node.progress.is_completed)
|
135
|
+
if node.is_leaf:
|
136
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
137
|
+
self.assertEqual(node.progress.num_completed, 10)
|
138
|
+
self.assertEqual(node.progress.num_failed, 1)
|
139
|
+
else:
|
140
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
141
|
+
self.assertEqual(node.progress.num_failed, 0)
|
142
|
+
self.assertEqual(node.progress.num_processed, node.progress.num_total)
|
143
|
+
|
144
|
+
def test_raise_if_has_error(self):
|
145
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_raise_if_has_error')
|
146
|
+
exp = test_helper.TestEvaluation()
|
147
|
+
with self.assertRaisesRegex(ValueError, 'x should not be 5'):
|
148
|
+
exp.run(
|
149
|
+
root_dir, runner='sequential', plugins=[], raise_if_has_error=True
|
150
|
+
)
|
151
|
+
|
152
|
+
with self.assertRaisesRegex(ValueError, 'x should not be 5'):
|
153
|
+
exp.run(root_dir, runner='parallel', plugins=[], raise_if_has_error=True)
|
154
|
+
|
155
|
+
def test_example_ids(self):
|
156
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_example_ids')
|
157
|
+
exp = test_helper.test_experiment()
|
158
|
+
plugin = TestPlugin()
|
159
|
+
_ = exp.run(
|
160
|
+
root_dir, runner='sequential', plugins=[plugin], example_ids=[5, 7, 9]
|
161
|
+
)
|
162
|
+
self.assertEqual(plugin.started_example_ids, [5, 7, 9] * 6)
|
163
|
+
self.assertEqual(plugin.completed_example_ids, [5, 7, 9] * 6)
|
164
|
+
|
165
|
+
def test_filter(self):
|
166
|
+
plugin = TestPlugin()
|
167
|
+
exp = test_helper.test_experiment()
|
168
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_filter')
|
169
|
+
|
170
|
+
_ = exp.run(
|
171
|
+
root_dir, runner='sequential', plugins=[plugin],
|
172
|
+
filter=lambda e: e.lm.offset != 0
|
173
|
+
)
|
174
|
+
self.assert_same_list(
|
175
|
+
plugin.started_experiments,
|
176
|
+
exp.nonleaf_nodes + exp.leaf_nodes[2:]
|
177
|
+
)
|
178
|
+
self.assert_same_list(
|
179
|
+
plugin.skipped_experiments, exp.leaf_nodes[:2]
|
180
|
+
)
|
181
|
+
self.assert_same_list(
|
182
|
+
plugin.completed_experiments,
|
183
|
+
exp.leaf_nodes[2:] + [exp.children[1], exp]
|
184
|
+
)
|
185
|
+
|
186
|
+
def test_use_cache(self):
|
187
|
+
@pg.functor()
|
188
|
+
def test_inputs(num_examples: int = 10):
|
189
|
+
return [
|
190
|
+
pg.Dict(
|
191
|
+
x=i // 2, y=(i // 2) ** 2,
|
192
|
+
groundtruth=(i // 2 + (i // 2) ** 2)
|
193
|
+
) for i in range(num_examples)
|
194
|
+
]
|
195
|
+
|
196
|
+
exp = test_helper.TestEvaluation(
|
197
|
+
inputs=test_inputs(num_examples=pg.oneof([2, 4]))
|
198
|
+
)
|
199
|
+
# Global cache.
|
200
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'global_cache')
|
201
|
+
run = exp.run(root_dir, runner='sequential', use_cache='global', plugins=[])
|
202
|
+
self.assertTrue(pg.io.path_exists(run.output_path_for(exp, 'cache.json')))
|
203
|
+
self.assertEqual(exp.usage_summary.cached.total.num_requests, 4)
|
204
|
+
self.assertEqual(exp.usage_summary.uncached.total.num_requests, 2)
|
205
|
+
|
206
|
+
# Per-dataset cache.
|
207
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'per_dataset')
|
208
|
+
run = exp.run(
|
209
|
+
root_dir, runner='sequential', use_cache='per_dataset', plugins=[]
|
210
|
+
)
|
211
|
+
for leaf in exp.leaf_nodes:
|
212
|
+
self.assertTrue(
|
213
|
+
pg.io.path_exists(run.output_path_for(leaf, 'cache.json'))
|
214
|
+
)
|
215
|
+
self.assertEqual(exp.usage_summary.cached.total.num_requests, 3)
|
216
|
+
self.assertEqual(exp.usage_summary.uncached.total.num_requests, 3)
|
217
|
+
|
218
|
+
# No cache.
|
219
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'no')
|
220
|
+
run = exp.run(root_dir, runner='sequential', use_cache='no', plugins=[])
|
221
|
+
self.assertFalse(pg.io.path_exists(run.output_path_for(exp, 'cache.json')))
|
222
|
+
for leaf in exp.leaf_nodes:
|
223
|
+
self.assertFalse(
|
224
|
+
pg.io.path_exists(run.output_path_for(leaf, 'cache.json'))
|
225
|
+
)
|
226
|
+
self.assertEqual(exp.usage_summary.cached.total.num_requests, 0)
|
227
|
+
self.assertEqual(exp.usage_summary.uncached.total.num_requests, 6)
|
228
|
+
|
229
|
+
def test_parallel_runner(self):
|
230
|
+
plugin = TestPlugin()
|
231
|
+
exp = test_helper.test_experiment()
|
232
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_parallel_runner')
|
233
|
+
run = exp.run(root_dir, runner='parallel', plugins=[plugin])
|
234
|
+
|
235
|
+
self.assertIsNotNone(plugin.start_time)
|
236
|
+
self.assertIsNotNone(plugin.complete_time)
|
237
|
+
self.assertGreater(plugin.complete_time, plugin.start_time)
|
238
|
+
|
239
|
+
self.assertEqual(
|
240
|
+
len(plugin.started_experiments), len(exp.nodes)
|
241
|
+
)
|
242
|
+
self.assertEqual(
|
243
|
+
len(plugin.completed_experiments), len(exp.nodes)
|
244
|
+
)
|
245
|
+
self.assertEqual(
|
246
|
+
len(plugin.started_example_ids), 6 * 10
|
247
|
+
)
|
248
|
+
self.assertEqual(
|
249
|
+
len(plugin.completed_example_ids), 6 * 10
|
250
|
+
)
|
251
|
+
self.assert_same_list(plugin.skipped_experiments, [])
|
252
|
+
self.assert_same_list(plugin.skipped_example_ids, [])
|
253
|
+
self.assertTrue(
|
254
|
+
pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
|
255
|
+
)
|
256
|
+
|
257
|
+
for node in exp.nodes:
|
258
|
+
self.assertTrue(node.progress.is_started)
|
259
|
+
self.assertTrue(node.progress.is_completed)
|
260
|
+
if node.is_leaf:
|
261
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
262
|
+
self.assertEqual(node.progress.num_completed, 10)
|
263
|
+
self.assertEqual(node.progress.num_failed, 1)
|
264
|
+
else:
|
265
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
266
|
+
self.assertEqual(node.progress.num_failed, 0)
|
267
|
+
self.assertEqual(node.progress.num_processed, node.progress.num_total)
|
268
|
+
|
269
|
+
def test_debug_runner(self):
|
270
|
+
plugin = TestPlugin()
|
271
|
+
exp = test_helper.test_experiment()
|
272
|
+
root_dir = os.path.join(tempfile.gettempdir(), 'test_debug_runner')
|
273
|
+
run = exp.run(root_dir, runner='debug', plugins=[plugin])
|
274
|
+
|
275
|
+
self.assertIsNotNone(plugin.start_time)
|
276
|
+
self.assertIsNotNone(plugin.complete_time)
|
277
|
+
self.assertGreater(plugin.complete_time, plugin.start_time)
|
278
|
+
|
279
|
+
self.assertEqual(
|
280
|
+
len(plugin.started_experiments), len(exp.nodes)
|
281
|
+
)
|
282
|
+
self.assertEqual(
|
283
|
+
len(plugin.completed_experiments), len(exp.nodes)
|
284
|
+
)
|
285
|
+
self.assertEqual(
|
286
|
+
len(plugin.started_example_ids), 6 * 1
|
287
|
+
)
|
288
|
+
self.assertEqual(
|
289
|
+
len(plugin.completed_example_ids), 6 * 1
|
290
|
+
)
|
291
|
+
self.assert_same_list(plugin.skipped_experiments, [])
|
292
|
+
self.assert_same_list(plugin.skipped_example_ids, [])
|
293
|
+
self.assertFalse(
|
294
|
+
pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
|
295
|
+
)
|
296
|
+
|
297
|
+
for node in exp.nodes:
|
298
|
+
self.assertTrue(node.progress.is_started)
|
299
|
+
self.assertTrue(node.progress.is_completed)
|
300
|
+
if node.is_leaf:
|
301
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
302
|
+
self.assertEqual(node.progress.num_completed, 1)
|
303
|
+
self.assertEqual(node.progress.num_failed, 0)
|
304
|
+
else:
|
305
|
+
self.assertEqual(node.progress.num_skipped, 0)
|
306
|
+
self.assertEqual(node.progress.num_failed, 0)
|
307
|
+
self.assertEqual(node.progress.num_processed, node.progress.num_total)
|
308
|
+
|
309
|
+
|
310
|
+
if __name__ == '__main__':
|
311
|
+
unittest.main()
|
@@ -0,0 +1,80 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Helper classes and functions for evaluation tests."""
|
15
|
+
|
16
|
+
from langfun.core import language_model
|
17
|
+
from langfun.core import llms
|
18
|
+
from langfun.core import message as message_lib
|
19
|
+
from langfun.core import structured
|
20
|
+
|
21
|
+
from langfun.core.eval.v2 import evaluation as evaluation_lib
|
22
|
+
from langfun.core.eval.v2 import example as example_lib
|
23
|
+
from langfun.core.eval.v2 import experiment as experiment_lib
|
24
|
+
from langfun.core.eval.v2 import metrics as metrics_lib
|
25
|
+
|
26
|
+
import pyglove as pg
|
27
|
+
|
28
|
+
Example = example_lib.Example
|
29
|
+
Suite = experiment_lib.Suite
|
30
|
+
Evaluation = evaluation_lib.Evaluation
|
31
|
+
RunId = experiment_lib.RunId
|
32
|
+
Run = experiment_lib.Run
|
33
|
+
|
34
|
+
|
35
|
+
@pg.functor()
|
36
|
+
def test_inputs(num_examples: int | None = 10):
|
37
|
+
if num_examples is None:
|
38
|
+
num_examples = 20
|
39
|
+
return [
|
40
|
+
pg.Dict(x=i, y=i ** 2, groundtruth=i + i ** 2)
|
41
|
+
for i in range(num_examples)
|
42
|
+
]
|
43
|
+
|
44
|
+
|
45
|
+
class TestLLM(llms.Fake):
|
46
|
+
"""Test language model."""
|
47
|
+
|
48
|
+
offset: int = 0
|
49
|
+
|
50
|
+
def _response_from(self, prompt: message_lib.Message) -> message_lib.Message:
|
51
|
+
return message_lib.AIMessage(
|
52
|
+
str(prompt.metadata.x + prompt.metadata.y + self.offset)
|
53
|
+
)
|
54
|
+
|
55
|
+
@property
|
56
|
+
def resource_id(self) -> str:
|
57
|
+
return f'test_llm:{self.offset}'
|
58
|
+
|
59
|
+
|
60
|
+
class TestEvaluation(Evaluation):
|
61
|
+
"""Test evaluation class."""
|
62
|
+
inputs = test_inputs()
|
63
|
+
metrics = [metrics_lib.Match()]
|
64
|
+
lm: language_model.LanguageModel = TestLLM()
|
65
|
+
|
66
|
+
def process(self, v):
|
67
|
+
if v.x == 5:
|
68
|
+
raise ValueError('x should not be 5')
|
69
|
+
return structured.query(
|
70
|
+
'{{x}} + {{y}} = ?', int, lm=self.lm, x=v.x, y=v.y,
|
71
|
+
metadata_x=v.x, metadata_y=v.y
|
72
|
+
)
|
73
|
+
|
74
|
+
|
75
|
+
def test_experiment():
|
76
|
+
"""Returns a test experiment."""
|
77
|
+
return Suite([
|
78
|
+
TestEvaluation(lm=TestLLM(offset=0)),
|
79
|
+
TestEvaluation(lm=TestLLM(offset=pg.oneof(range(5)))),
|
80
|
+
])
|
langfun/core/language_model.py
CHANGED
@@ -17,6 +17,8 @@ import abc
|
|
17
17
|
import contextlib
|
18
18
|
import dataclasses
|
19
19
|
import enum
|
20
|
+
import functools
|
21
|
+
import math
|
20
22
|
import threading
|
21
23
|
import time
|
22
24
|
from typing import Annotated, Any, Callable, Iterator, Optional, Sequence, Tuple, Type, Union
|
@@ -875,7 +877,7 @@ class LanguageModel(component.Component):
|
|
875
877
|
return DEFAULT_MAX_CONCURRENCY # Default of 1
|
876
878
|
|
877
879
|
|
878
|
-
class UsageSummary(pg.Object):
|
880
|
+
class UsageSummary(pg.Object, pg.views.HtmlTreeView.Extension):
|
879
881
|
"""Usage sumary."""
|
880
882
|
|
881
883
|
class AggregatedUsage(pg.Object):
|
@@ -897,20 +899,131 @@ class UsageSummary(pg.Object):
|
|
897
899
|
aggregated = self.breakdown.get(model_id, None)
|
898
900
|
with pg.notify_on_change(False):
|
899
901
|
self.breakdown[model_id] = usage + aggregated
|
900
|
-
self.rebind(
|
902
|
+
self.rebind(
|
903
|
+
total=self.total + usage,
|
904
|
+
raise_on_no_change=False
|
905
|
+
)
|
906
|
+
|
907
|
+
def merge(self, other: 'UsageSummary.AggregatedUsage') -> None:
|
908
|
+
"""Merges the usage summary."""
|
909
|
+
with pg.notify_on_change(False):
|
910
|
+
for model_id, usage in other.breakdown.items():
|
911
|
+
self.add(model_id, usage)
|
912
|
+
|
913
|
+
def _on_bound(self):
|
914
|
+
super()._on_bound()
|
915
|
+
self._usage_badge = None
|
916
|
+
self._lock = threading.Lock()
|
901
917
|
|
902
918
|
@property
|
903
919
|
def total(self) -> LMSamplingUsage:
|
904
920
|
return self.cached.total + self.uncached.total
|
905
921
|
|
906
|
-
def
|
922
|
+
def add(self, model_id: str, usage: LMSamplingUsage, is_cached: bool):
|
907
923
|
"""Updates the usage summary."""
|
908
|
-
|
909
|
-
|
910
|
-
|
911
|
-
|
912
|
-
|
924
|
+
with self._lock:
|
925
|
+
if is_cached:
|
926
|
+
usage.rebind(estimated_cost=0.0, skip_notification=True)
|
927
|
+
self.cached.add(model_id, usage)
|
928
|
+
else:
|
929
|
+
self.uncached.add(model_id, usage)
|
930
|
+
self._update_view()
|
931
|
+
|
932
|
+
def merge(self, other: 'UsageSummary', as_cached: bool = False) -> None:
|
933
|
+
"""Aggregates the usage summary.
|
934
|
+
|
935
|
+
Args:
|
936
|
+
other: The usage summary to merge.
|
937
|
+
as_cached: Whether to merge the usage summary as cached.
|
938
|
+
"""
|
939
|
+
with self._lock:
|
940
|
+
self.cached.merge(other.cached)
|
941
|
+
if as_cached:
|
942
|
+
self.cached.merge(other.uncached)
|
943
|
+
else:
|
944
|
+
self.uncached.merge(other.uncached)
|
945
|
+
self._update_view()
|
946
|
+
|
947
|
+
def _sym_nondefault(self) -> dict[str, Any]:
|
948
|
+
"""Overrides nondefault values so volatile values are not included."""
|
949
|
+
return dict()
|
950
|
+
|
951
|
+
#
|
952
|
+
# Html views for the usage summary.
|
953
|
+
#
|
954
|
+
|
955
|
+
def _update_view(self):
|
956
|
+
if self._usage_badge is not None:
|
957
|
+
self._usage_badge.update(
|
958
|
+
self._badge_text(),
|
959
|
+
tooltip=pg.format(self.total, verbose=False),
|
960
|
+
styles=dict(color=self._badge_color()),
|
961
|
+
)
|
913
962
|
|
963
|
+
def _badge_text(self) -> str:
|
964
|
+
if self.total.estimated_cost is not None:
|
965
|
+
return f'{self.total.estimated_cost:.3f}'
|
966
|
+
return '0.000'
|
967
|
+
|
968
|
+
def _badge_color(self) -> str | None:
|
969
|
+
if self.total.estimated_cost is None or self.total.estimated_cost < 1.0:
|
970
|
+
return None
|
971
|
+
|
972
|
+
# Step 1: The normal cost range is around 1e-3 to 1e5.
|
973
|
+
# Therefore we normalize the log10 value from [-3, 5] to [0, 1].
|
974
|
+
normalized_value = (math.log10(self.total.estimated_cost) + 3) / (5 + 3)
|
975
|
+
|
976
|
+
# Step 2: Interpolate between green and red
|
977
|
+
red = int(255 * normalized_value)
|
978
|
+
green = int(255 * (1 - normalized_value))
|
979
|
+
return f'rgb({red}, {green}, 0)'
|
980
|
+
|
981
|
+
def _html_tree_view(
|
982
|
+
self,
|
983
|
+
*,
|
984
|
+
view: pg.views.HtmlTreeView,
|
985
|
+
extra_flags: dict[str, Any] | None = None,
|
986
|
+
**kwargs
|
987
|
+
) -> pg.Html:
|
988
|
+
extra_flags = extra_flags or {}
|
989
|
+
as_badge = extra_flags.pop('as_badge', False)
|
990
|
+
interactive = extra_flags.get('interactive', True)
|
991
|
+
if as_badge:
|
992
|
+
usage_badge = self._usage_badge
|
993
|
+
if usage_badge is None:
|
994
|
+
usage_badge = pg.views.html.controls.Badge(
|
995
|
+
self._badge_text(),
|
996
|
+
tooltip=pg.format(self.total, verbose=False),
|
997
|
+
css_classes=['usage-summary'],
|
998
|
+
styles=dict(color=self._badge_color()),
|
999
|
+
interactive=True,
|
1000
|
+
)
|
1001
|
+
if interactive:
|
1002
|
+
self._usage_badge = usage_badge
|
1003
|
+
return usage_badge.to_html()
|
1004
|
+
return super()._html_tree_view(
|
1005
|
+
view=view,
|
1006
|
+
extra_flags=extra_flags,
|
1007
|
+
**kwargs
|
1008
|
+
)
|
1009
|
+
|
1010
|
+
@classmethod
|
1011
|
+
@functools.cache
|
1012
|
+
def _html_tree_view_css_styles(cls) -> list[str]:
|
1013
|
+
return super()._html_tree_view_css_styles() + [
|
1014
|
+
"""
|
1015
|
+
.usage-summary.label {
|
1016
|
+
display: inline-flex;
|
1017
|
+
border-radius: 5px;
|
1018
|
+
padding: 5px;
|
1019
|
+
background-color: #f1f1f1;
|
1020
|
+
color: #CCC;
|
1021
|
+
}
|
1022
|
+
.usage-summary.label::before {
|
1023
|
+
content: '$';
|
1024
|
+
}
|
1025
|
+
"""
|
1026
|
+
]
|
914
1027
|
|
915
1028
|
pg.members(
|
916
1029
|
dict(
|
@@ -938,12 +1051,10 @@ class _UsageTracker:
|
|
938
1051
|
def __init__(self, model_ids: set[str] | None):
|
939
1052
|
self.model_ids = model_ids
|
940
1053
|
self.usage_summary = UsageSummary()
|
941
|
-
self._lock = threading.Lock()
|
942
1054
|
|
943
1055
|
def track(self, model_id: str, usage: LMSamplingUsage, is_cached: bool):
|
944
1056
|
if self.model_ids is None or model_id in self.model_ids:
|
945
|
-
|
946
|
-
self.usage_summary.update(model_id, usage, is_cached)
|
1057
|
+
self.usage_summary.add(model_id, usage, is_cached)
|
947
1058
|
|
948
1059
|
|
949
1060
|
@contextlib.contextmanager
|
@@ -685,7 +685,6 @@ class LanguageModelTest(unittest.TestCase):
|
|
685
685
|
lm2('hi')
|
686
686
|
list(concurrent.concurrent_map(call_lm, ['hi', 'hello']))
|
687
687
|
|
688
|
-
print(usages2)
|
689
688
|
self.assertEqual(usages2.uncached.breakdown, {
|
690
689
|
'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
|
691
690
|
})
|
@@ -777,7 +776,7 @@ class UsageSummaryTest(unittest.TestCase):
|
|
777
776
|
self.assertFalse(usage_summary.uncached)
|
778
777
|
|
779
778
|
# Add uncached.
|
780
|
-
usage_summary.
|
779
|
+
usage_summary.add(
|
781
780
|
'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
782
781
|
)
|
783
782
|
self.assertEqual(
|
@@ -788,7 +787,7 @@ class UsageSummaryTest(unittest.TestCase):
|
|
788
787
|
)
|
789
788
|
# Add cached.
|
790
789
|
self.assertFalse(usage_summary.cached)
|
791
|
-
usage_summary.
|
790
|
+
usage_summary.add(
|
792
791
|
'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), True
|
793
792
|
)
|
794
793
|
self.assertEqual(
|
@@ -798,7 +797,7 @@ class UsageSummaryTest(unittest.TestCase):
|
|
798
797
|
usage_summary.cached.total, lm_lib.LMSamplingUsage(1, 2, 3, 1, 0.0)
|
799
798
|
)
|
800
799
|
# Add UsageNotAvailable.
|
801
|
-
usage_summary.
|
800
|
+
usage_summary.add(
|
802
801
|
'model1', lm_lib.UsageNotAvailable(num_requests=1), False
|
803
802
|
)
|
804
803
|
self.assertEqual(
|
@@ -808,6 +807,100 @@ class UsageSummaryTest(unittest.TestCase):
|
|
808
807
|
usage_summary.uncached.total, lm_lib.UsageNotAvailable(num_requests=2)
|
809
808
|
)
|
810
809
|
|
810
|
+
def test_merge(self):
|
811
|
+
usage_summary = lm_lib.UsageSummary()
|
812
|
+
usage_summary.add(
|
813
|
+
'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
814
|
+
)
|
815
|
+
usage_summary.add(
|
816
|
+
'model2', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
817
|
+
)
|
818
|
+
usage_summary.add(
|
819
|
+
'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
820
|
+
)
|
821
|
+
usage_summary2 = lm_lib.UsageSummary()
|
822
|
+
usage_summary2.add(
|
823
|
+
'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
824
|
+
)
|
825
|
+
usage_summary2.add(
|
826
|
+
'model3', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
827
|
+
)
|
828
|
+
usage_summary2.merge(usage_summary)
|
829
|
+
self.assertEqual(
|
830
|
+
usage_summary2,
|
831
|
+
lm_lib.UsageSummary(
|
832
|
+
cached=lm_lib.UsageSummary.AggregatedUsage(
|
833
|
+
total=lm_lib.LMSamplingUsage(
|
834
|
+
prompt_tokens=0,
|
835
|
+
completion_tokens=0,
|
836
|
+
total_tokens=0,
|
837
|
+
num_requests=0,
|
838
|
+
estimated_cost=0.0,
|
839
|
+
),
|
840
|
+
breakdown={}
|
841
|
+
),
|
842
|
+
uncached=lm_lib.UsageSummary.AggregatedUsage(
|
843
|
+
total=lm_lib.LMSamplingUsage(
|
844
|
+
prompt_tokens=5,
|
845
|
+
completion_tokens=10,
|
846
|
+
total_tokens=15,
|
847
|
+
num_requests=5,
|
848
|
+
estimated_cost=25.0
|
849
|
+
),
|
850
|
+
breakdown=dict(
|
851
|
+
model1=lm_lib.LMSamplingUsage(
|
852
|
+
prompt_tokens=3,
|
853
|
+
completion_tokens=6,
|
854
|
+
total_tokens=9,
|
855
|
+
num_requests=3,
|
856
|
+
estimated_cost=15.0
|
857
|
+
),
|
858
|
+
model3=lm_lib.LMSamplingUsage(
|
859
|
+
prompt_tokens=1,
|
860
|
+
completion_tokens=2,
|
861
|
+
total_tokens=3,
|
862
|
+
num_requests=1,
|
863
|
+
estimated_cost=5.0
|
864
|
+
),
|
865
|
+
model2=lm_lib.LMSamplingUsage(
|
866
|
+
prompt_tokens=1,
|
867
|
+
completion_tokens=2,
|
868
|
+
total_tokens=3,
|
869
|
+
num_requests=1,
|
870
|
+
estimated_cost=5.0
|
871
|
+
)
|
872
|
+
)
|
873
|
+
)
|
874
|
+
)
|
875
|
+
)
|
876
|
+
|
877
|
+
def test_html_view(self):
|
878
|
+
usage_summary = lm_lib.UsageSummary()
|
879
|
+
usage_summary.add(
|
880
|
+
'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
881
|
+
)
|
882
|
+
self.assertIn(
|
883
|
+
'5.000',
|
884
|
+
usage_summary.to_html(extra_flags=dict(as_badge=True)).content
|
885
|
+
)
|
886
|
+
usage_summary.add(
|
887
|
+
'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
888
|
+
)
|
889
|
+
self.assertIn(
|
890
|
+
'10.000',
|
891
|
+
usage_summary.to_html(
|
892
|
+
extra_flags=dict(as_badge=True, interactive=True)
|
893
|
+
).content
|
894
|
+
)
|
895
|
+
self.assertTrue(
|
896
|
+
usage_summary.to_html().content.startswith('<details open')
|
897
|
+
)
|
898
|
+
with pg.views.html.controls.HtmlControl.track_scripts() as scripts:
|
899
|
+
usage_summary.add(
|
900
|
+
'model2', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
|
901
|
+
)
|
902
|
+
self.assertEqual(len(scripts), 4)
|
903
|
+
|
811
904
|
|
812
905
|
if __name__ == '__main__':
|
813
906
|
unittest.main()
|
langfun/core/llms/__init__.py
CHANGED
@@ -24,6 +24,9 @@ from langfun.core.llms.fake import StaticMapping
|
|
24
24
|
from langfun.core.llms.fake import StaticResponse
|
25
25
|
from langfun.core.llms.fake import StaticSequence
|
26
26
|
|
27
|
+
# Compositional models.
|
28
|
+
from langfun.core.llms.compositional import RandomChoice
|
29
|
+
|
27
30
|
# REST-based models.
|
28
31
|
from langfun.core.llms.rest import REST
|
29
32
|
|