langfun 0.1.2.dev202411090804__py3-none-any.whl → 0.1.2.dev202411140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (36) hide show
  1. langfun/core/console.py +10 -2
  2. langfun/core/console_test.py +17 -0
  3. langfun/core/eval/__init__.py +2 -0
  4. langfun/core/eval/v2/__init__.py +38 -0
  5. langfun/core/eval/v2/checkpointing.py +135 -0
  6. langfun/core/eval/v2/checkpointing_test.py +89 -0
  7. langfun/core/eval/v2/evaluation.py +627 -0
  8. langfun/core/eval/v2/evaluation_test.py +156 -0
  9. langfun/core/eval/v2/example.py +295 -0
  10. langfun/core/eval/v2/example_test.py +114 -0
  11. langfun/core/eval/v2/experiment.py +949 -0
  12. langfun/core/eval/v2/experiment_test.py +304 -0
  13. langfun/core/eval/v2/metric_values.py +156 -0
  14. langfun/core/eval/v2/metric_values_test.py +80 -0
  15. langfun/core/eval/v2/metrics.py +357 -0
  16. langfun/core/eval/v2/metrics_test.py +203 -0
  17. langfun/core/eval/v2/progress.py +348 -0
  18. langfun/core/eval/v2/progress_test.py +82 -0
  19. langfun/core/eval/v2/progress_tracking.py +209 -0
  20. langfun/core/eval/v2/progress_tracking_test.py +56 -0
  21. langfun/core/eval/v2/reporting.py +144 -0
  22. langfun/core/eval/v2/reporting_test.py +41 -0
  23. langfun/core/eval/v2/runners.py +417 -0
  24. langfun/core/eval/v2/runners_test.py +311 -0
  25. langfun/core/eval/v2/test_helper.py +80 -0
  26. langfun/core/language_model.py +122 -11
  27. langfun/core/language_model_test.py +97 -4
  28. langfun/core/llms/__init__.py +3 -0
  29. langfun/core/llms/compositional.py +101 -0
  30. langfun/core/llms/compositional_test.py +73 -0
  31. langfun/core/llms/vertexai.py +4 -4
  32. {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/METADATA +1 -1
  33. {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/RECORD +36 -12
  34. {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/WHEEL +1 -1
  35. {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/LICENSE +0 -0
  36. {langfun-0.1.2.dev202411090804.dist-info → langfun-0.1.2.dev202411140804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,311 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import tempfile
16
+ import threading
17
+ import time
18
+ from typing import Any
19
+ import unittest
20
+
21
+ from langfun.core.eval.v2 import example as example_lib
22
+ from langfun.core.eval.v2 import experiment as experiment_lib
23
+ from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
24
+ from langfun.core.eval.v2 import test_helper
25
+ import pyglove as pg
26
+
27
+
28
+ Runner = experiment_lib.Runner
29
+ Example = example_lib.Example
30
+ Experiment = experiment_lib.Experiment
31
+ Suite = experiment_lib.Suite
32
+ Plugin = experiment_lib.Plugin
33
+
34
+
35
+ class TestPlugin(Plugin):
36
+ started_experiments: list[Experiment] = []
37
+ completed_experiments: list[Experiment] = []
38
+ skipped_experiments: list[Experiment] = []
39
+ started_example_ids: list[int] = []
40
+ completed_example_ids: list[int] = []
41
+ skipped_example_ids: list[int] = []
42
+ start_time: float | None = None
43
+ complete_time: float | None = None
44
+
45
+ def _on_bound(self):
46
+ super()._on_bound()
47
+ self._lock = threading.Lock()
48
+
49
+ def on_run_start(self, runner: Runner, root: Experiment):
50
+ del root
51
+ with pg.notify_on_change(False), pg.allow_writable_accessors(True):
52
+ self.start_time = time.time()
53
+
54
+ def on_run_complete(self, runner: Runner, root: Experiment):
55
+ del root
56
+ with pg.notify_on_change(False), pg.allow_writable_accessors(True):
57
+ self.complete_time = time.time()
58
+
59
+ def on_experiment_start(self, runner: Runner, experiment: Experiment):
60
+ del runner
61
+ with pg.notify_on_change(False), self._lock:
62
+ self.started_experiments.append(pg.Ref(experiment))
63
+
64
+ def on_experiment_skipped(self, runner: Runner, experiment: Experiment):
65
+ del runner
66
+ with pg.notify_on_change(False), self._lock:
67
+ self.skipped_experiments.append(pg.Ref(experiment))
68
+
69
+ def on_experiment_complete(self, runner: Runner, experiment: Experiment):
70
+ del runner
71
+ with pg.notify_on_change(False), self._lock:
72
+ self.completed_experiments.append(pg.Ref(experiment))
73
+
74
+ def on_example_start(
75
+ self, runner: Runner, experiment: Experiment, example: Example):
76
+ del runner, experiment
77
+ with pg.notify_on_change(False), self._lock:
78
+ self.started_example_ids.append(example.id)
79
+
80
+ def on_example_skipped(
81
+ self, runner: Runner, experiment: Experiment, example: Example):
82
+ del runner, experiment
83
+ with pg.notify_on_change(False), self._lock:
84
+ self.skipped_example_ids.append(example.id)
85
+
86
+ def on_example_complete(
87
+ self, runner: Runner, experiment: Experiment, example: Example):
88
+ del runner, experiment
89
+ with pg.notify_on_change(False), self._lock:
90
+ self.completed_example_ids.append(example.id)
91
+
92
+
93
+ class RunnerTest(unittest.TestCase):
94
+
95
+ def assert_same_list(self, actual: list[Any], expected: list[Any]):
96
+ self.assertEqual(len(actual), len(expected))
97
+ for i, (x, y) in enumerate(zip(actual, expected)):
98
+ if x is not y:
99
+ print(i, pg.diff(x, y))
100
+ self.assertIs(x, y)
101
+
102
+ def test_basic(self):
103
+ plugin = TestPlugin()
104
+ exp = test_helper.test_experiment()
105
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_sequential_runner')
106
+ run = exp.run(root_dir, runner='sequential', plugins=[plugin])
107
+
108
+ self.assertIsNotNone(plugin.start_time)
109
+ self.assertIsNotNone(plugin.complete_time)
110
+ self.assertGreater(plugin.complete_time, plugin.start_time)
111
+
112
+ self.assert_same_list(
113
+ plugin.started_experiments,
114
+ exp.nonleaf_nodes + exp.leaf_nodes
115
+ )
116
+ self.assert_same_list(
117
+ plugin.completed_experiments,
118
+ exp.leaf_nodes + list(reversed(exp.nonleaf_nodes))
119
+ )
120
+ self.assert_same_list(
121
+ plugin.started_example_ids, list(range(1, 11)) * 6
122
+ )
123
+ self.assert_same_list(
124
+ plugin.completed_example_ids, list(range(1, 11)) * 6
125
+ )
126
+ self.assert_same_list(plugin.skipped_experiments, [])
127
+ self.assert_same_list(plugin.skipped_example_ids, [])
128
+ self.assertTrue(
129
+ pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
130
+ )
131
+
132
+ for node in exp.nodes:
133
+ self.assertTrue(node.progress.is_started)
134
+ self.assertTrue(node.progress.is_completed)
135
+ if node.is_leaf:
136
+ self.assertEqual(node.progress.num_skipped, 0)
137
+ self.assertEqual(node.progress.num_completed, 10)
138
+ self.assertEqual(node.progress.num_failed, 1)
139
+ else:
140
+ self.assertEqual(node.progress.num_skipped, 0)
141
+ self.assertEqual(node.progress.num_failed, 0)
142
+ self.assertEqual(node.progress.num_processed, node.progress.num_total)
143
+
144
+ def test_raise_if_has_error(self):
145
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_raise_if_has_error')
146
+ exp = test_helper.TestEvaluation()
147
+ with self.assertRaisesRegex(ValueError, 'x should not be 5'):
148
+ exp.run(
149
+ root_dir, runner='sequential', plugins=[], raise_if_has_error=True
150
+ )
151
+
152
+ with self.assertRaisesRegex(ValueError, 'x should not be 5'):
153
+ exp.run(root_dir, runner='parallel', plugins=[], raise_if_has_error=True)
154
+
155
+ def test_example_ids(self):
156
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_example_ids')
157
+ exp = test_helper.test_experiment()
158
+ plugin = TestPlugin()
159
+ _ = exp.run(
160
+ root_dir, runner='sequential', plugins=[plugin], example_ids=[5, 7, 9]
161
+ )
162
+ self.assertEqual(plugin.started_example_ids, [5, 7, 9] * 6)
163
+ self.assertEqual(plugin.completed_example_ids, [5, 7, 9] * 6)
164
+
165
+ def test_filter(self):
166
+ plugin = TestPlugin()
167
+ exp = test_helper.test_experiment()
168
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_filter')
169
+
170
+ _ = exp.run(
171
+ root_dir, runner='sequential', plugins=[plugin],
172
+ filter=lambda e: e.lm.offset != 0
173
+ )
174
+ self.assert_same_list(
175
+ plugin.started_experiments,
176
+ exp.nonleaf_nodes + exp.leaf_nodes[2:]
177
+ )
178
+ self.assert_same_list(
179
+ plugin.skipped_experiments, exp.leaf_nodes[:2]
180
+ )
181
+ self.assert_same_list(
182
+ plugin.completed_experiments,
183
+ exp.leaf_nodes[2:] + [exp.children[1], exp]
184
+ )
185
+
186
+ def test_use_cache(self):
187
+ @pg.functor()
188
+ def test_inputs(num_examples: int = 10):
189
+ return [
190
+ pg.Dict(
191
+ x=i // 2, y=(i // 2) ** 2,
192
+ groundtruth=(i // 2 + (i // 2) ** 2)
193
+ ) for i in range(num_examples)
194
+ ]
195
+
196
+ exp = test_helper.TestEvaluation(
197
+ inputs=test_inputs(num_examples=pg.oneof([2, 4]))
198
+ )
199
+ # Global cache.
200
+ root_dir = os.path.join(tempfile.gettempdir(), 'global_cache')
201
+ run = exp.run(root_dir, runner='sequential', use_cache='global', plugins=[])
202
+ self.assertTrue(pg.io.path_exists(run.output_path_for(exp, 'cache.json')))
203
+ self.assertEqual(exp.usage_summary.cached.total.num_requests, 4)
204
+ self.assertEqual(exp.usage_summary.uncached.total.num_requests, 2)
205
+
206
+ # Per-dataset cache.
207
+ root_dir = os.path.join(tempfile.gettempdir(), 'per_dataset')
208
+ run = exp.run(
209
+ root_dir, runner='sequential', use_cache='per_dataset', plugins=[]
210
+ )
211
+ for leaf in exp.leaf_nodes:
212
+ self.assertTrue(
213
+ pg.io.path_exists(run.output_path_for(leaf, 'cache.json'))
214
+ )
215
+ self.assertEqual(exp.usage_summary.cached.total.num_requests, 3)
216
+ self.assertEqual(exp.usage_summary.uncached.total.num_requests, 3)
217
+
218
+ # No cache.
219
+ root_dir = os.path.join(tempfile.gettempdir(), 'no')
220
+ run = exp.run(root_dir, runner='sequential', use_cache='no', plugins=[])
221
+ self.assertFalse(pg.io.path_exists(run.output_path_for(exp, 'cache.json')))
222
+ for leaf in exp.leaf_nodes:
223
+ self.assertFalse(
224
+ pg.io.path_exists(run.output_path_for(leaf, 'cache.json'))
225
+ )
226
+ self.assertEqual(exp.usage_summary.cached.total.num_requests, 0)
227
+ self.assertEqual(exp.usage_summary.uncached.total.num_requests, 6)
228
+
229
+ def test_parallel_runner(self):
230
+ plugin = TestPlugin()
231
+ exp = test_helper.test_experiment()
232
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_parallel_runner')
233
+ run = exp.run(root_dir, runner='parallel', plugins=[plugin])
234
+
235
+ self.assertIsNotNone(plugin.start_time)
236
+ self.assertIsNotNone(plugin.complete_time)
237
+ self.assertGreater(plugin.complete_time, plugin.start_time)
238
+
239
+ self.assertEqual(
240
+ len(plugin.started_experiments), len(exp.nodes)
241
+ )
242
+ self.assertEqual(
243
+ len(plugin.completed_experiments), len(exp.nodes)
244
+ )
245
+ self.assertEqual(
246
+ len(plugin.started_example_ids), 6 * 10
247
+ )
248
+ self.assertEqual(
249
+ len(plugin.completed_example_ids), 6 * 10
250
+ )
251
+ self.assert_same_list(plugin.skipped_experiments, [])
252
+ self.assert_same_list(plugin.skipped_example_ids, [])
253
+ self.assertTrue(
254
+ pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
255
+ )
256
+
257
+ for node in exp.nodes:
258
+ self.assertTrue(node.progress.is_started)
259
+ self.assertTrue(node.progress.is_completed)
260
+ if node.is_leaf:
261
+ self.assertEqual(node.progress.num_skipped, 0)
262
+ self.assertEqual(node.progress.num_completed, 10)
263
+ self.assertEqual(node.progress.num_failed, 1)
264
+ else:
265
+ self.assertEqual(node.progress.num_skipped, 0)
266
+ self.assertEqual(node.progress.num_failed, 0)
267
+ self.assertEqual(node.progress.num_processed, node.progress.num_total)
268
+
269
+ def test_debug_runner(self):
270
+ plugin = TestPlugin()
271
+ exp = test_helper.test_experiment()
272
+ root_dir = os.path.join(tempfile.gettempdir(), 'test_debug_runner')
273
+ run = exp.run(root_dir, runner='debug', plugins=[plugin])
274
+
275
+ self.assertIsNotNone(plugin.start_time)
276
+ self.assertIsNotNone(plugin.complete_time)
277
+ self.assertGreater(plugin.complete_time, plugin.start_time)
278
+
279
+ self.assertEqual(
280
+ len(plugin.started_experiments), len(exp.nodes)
281
+ )
282
+ self.assertEqual(
283
+ len(plugin.completed_experiments), len(exp.nodes)
284
+ )
285
+ self.assertEqual(
286
+ len(plugin.started_example_ids), 6 * 1
287
+ )
288
+ self.assertEqual(
289
+ len(plugin.completed_example_ids), 6 * 1
290
+ )
291
+ self.assert_same_list(plugin.skipped_experiments, [])
292
+ self.assert_same_list(plugin.skipped_example_ids, [])
293
+ self.assertFalse(
294
+ pg.io.path_exists(os.path.join(run.output_root, 'run.json'))
295
+ )
296
+
297
+ for node in exp.nodes:
298
+ self.assertTrue(node.progress.is_started)
299
+ self.assertTrue(node.progress.is_completed)
300
+ if node.is_leaf:
301
+ self.assertEqual(node.progress.num_skipped, 0)
302
+ self.assertEqual(node.progress.num_completed, 1)
303
+ self.assertEqual(node.progress.num_failed, 0)
304
+ else:
305
+ self.assertEqual(node.progress.num_skipped, 0)
306
+ self.assertEqual(node.progress.num_failed, 0)
307
+ self.assertEqual(node.progress.num_processed, node.progress.num_total)
308
+
309
+
310
+ if __name__ == '__main__':
311
+ unittest.main()
@@ -0,0 +1,80 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Helper classes and functions for evaluation tests."""
15
+
16
+ from langfun.core import language_model
17
+ from langfun.core import llms
18
+ from langfun.core import message as message_lib
19
+ from langfun.core import structured
20
+
21
+ from langfun.core.eval.v2 import evaluation as evaluation_lib
22
+ from langfun.core.eval.v2 import example as example_lib
23
+ from langfun.core.eval.v2 import experiment as experiment_lib
24
+ from langfun.core.eval.v2 import metrics as metrics_lib
25
+
26
+ import pyglove as pg
27
+
28
+ Example = example_lib.Example
29
+ Suite = experiment_lib.Suite
30
+ Evaluation = evaluation_lib.Evaluation
31
+ RunId = experiment_lib.RunId
32
+ Run = experiment_lib.Run
33
+
34
+
35
+ @pg.functor()
36
+ def test_inputs(num_examples: int | None = 10):
37
+ if num_examples is None:
38
+ num_examples = 20
39
+ return [
40
+ pg.Dict(x=i, y=i ** 2, groundtruth=i + i ** 2)
41
+ for i in range(num_examples)
42
+ ]
43
+
44
+
45
+ class TestLLM(llms.Fake):
46
+ """Test language model."""
47
+
48
+ offset: int = 0
49
+
50
+ def _response_from(self, prompt: message_lib.Message) -> message_lib.Message:
51
+ return message_lib.AIMessage(
52
+ str(prompt.metadata.x + prompt.metadata.y + self.offset)
53
+ )
54
+
55
+ @property
56
+ def resource_id(self) -> str:
57
+ return f'test_llm:{self.offset}'
58
+
59
+
60
+ class TestEvaluation(Evaluation):
61
+ """Test evaluation class."""
62
+ inputs = test_inputs()
63
+ metrics = [metrics_lib.Match()]
64
+ lm: language_model.LanguageModel = TestLLM()
65
+
66
+ def process(self, v):
67
+ if v.x == 5:
68
+ raise ValueError('x should not be 5')
69
+ return structured.query(
70
+ '{{x}} + {{y}} = ?', int, lm=self.lm, x=v.x, y=v.y,
71
+ metadata_x=v.x, metadata_y=v.y
72
+ )
73
+
74
+
75
+ def test_experiment():
76
+ """Returns a test experiment."""
77
+ return Suite([
78
+ TestEvaluation(lm=TestLLM(offset=0)),
79
+ TestEvaluation(lm=TestLLM(offset=pg.oneof(range(5)))),
80
+ ])
@@ -17,6 +17,8 @@ import abc
17
17
  import contextlib
18
18
  import dataclasses
19
19
  import enum
20
+ import functools
21
+ import math
20
22
  import threading
21
23
  import time
22
24
  from typing import Annotated, Any, Callable, Iterator, Optional, Sequence, Tuple, Type, Union
@@ -875,7 +877,7 @@ class LanguageModel(component.Component):
875
877
  return DEFAULT_MAX_CONCURRENCY # Default of 1
876
878
 
877
879
 
878
- class UsageSummary(pg.Object):
880
+ class UsageSummary(pg.Object, pg.views.HtmlTreeView.Extension):
879
881
  """Usage sumary."""
880
882
 
881
883
  class AggregatedUsage(pg.Object):
@@ -897,20 +899,131 @@ class UsageSummary(pg.Object):
897
899
  aggregated = self.breakdown.get(model_id, None)
898
900
  with pg.notify_on_change(False):
899
901
  self.breakdown[model_id] = usage + aggregated
900
- self.rebind(total=self.total + usage, skip_notification=True)
902
+ self.rebind(
903
+ total=self.total + usage,
904
+ raise_on_no_change=False
905
+ )
906
+
907
+ def merge(self, other: 'UsageSummary.AggregatedUsage') -> None:
908
+ """Merges the usage summary."""
909
+ with pg.notify_on_change(False):
910
+ for model_id, usage in other.breakdown.items():
911
+ self.add(model_id, usage)
912
+
913
+ def _on_bound(self):
914
+ super()._on_bound()
915
+ self._usage_badge = None
916
+ self._lock = threading.Lock()
901
917
 
902
918
  @property
903
919
  def total(self) -> LMSamplingUsage:
904
920
  return self.cached.total + self.uncached.total
905
921
 
906
- def update(self, model_id: str, usage: LMSamplingUsage, is_cached: bool):
922
+ def add(self, model_id: str, usage: LMSamplingUsage, is_cached: bool):
907
923
  """Updates the usage summary."""
908
- if is_cached:
909
- usage.rebind(estimated_cost=0.0, skip_notification=True)
910
- self.cached.add(model_id, usage)
911
- else:
912
- self.uncached.add(model_id, usage)
924
+ with self._lock:
925
+ if is_cached:
926
+ usage.rebind(estimated_cost=0.0, skip_notification=True)
927
+ self.cached.add(model_id, usage)
928
+ else:
929
+ self.uncached.add(model_id, usage)
930
+ self._update_view()
931
+
932
+ def merge(self, other: 'UsageSummary', as_cached: bool = False) -> None:
933
+ """Aggregates the usage summary.
934
+
935
+ Args:
936
+ other: The usage summary to merge.
937
+ as_cached: Whether to merge the usage summary as cached.
938
+ """
939
+ with self._lock:
940
+ self.cached.merge(other.cached)
941
+ if as_cached:
942
+ self.cached.merge(other.uncached)
943
+ else:
944
+ self.uncached.merge(other.uncached)
945
+ self._update_view()
946
+
947
+ def _sym_nondefault(self) -> dict[str, Any]:
948
+ """Overrides nondefault values so volatile values are not included."""
949
+ return dict()
950
+
951
+ #
952
+ # Html views for the usage summary.
953
+ #
954
+
955
+ def _update_view(self):
956
+ if self._usage_badge is not None:
957
+ self._usage_badge.update(
958
+ self._badge_text(),
959
+ tooltip=pg.format(self.total, verbose=False),
960
+ styles=dict(color=self._badge_color()),
961
+ )
913
962
 
963
+ def _badge_text(self) -> str:
964
+ if self.total.estimated_cost is not None:
965
+ return f'{self.total.estimated_cost:.3f}'
966
+ return '0.000'
967
+
968
+ def _badge_color(self) -> str | None:
969
+ if self.total.estimated_cost is None or self.total.estimated_cost < 1.0:
970
+ return None
971
+
972
+ # Step 1: The normal cost range is around 1e-3 to 1e5.
973
+ # Therefore we normalize the log10 value from [-3, 5] to [0, 1].
974
+ normalized_value = (math.log10(self.total.estimated_cost) + 3) / (5 + 3)
975
+
976
+ # Step 2: Interpolate between green and red
977
+ red = int(255 * normalized_value)
978
+ green = int(255 * (1 - normalized_value))
979
+ return f'rgb({red}, {green}, 0)'
980
+
981
+ def _html_tree_view(
982
+ self,
983
+ *,
984
+ view: pg.views.HtmlTreeView,
985
+ extra_flags: dict[str, Any] | None = None,
986
+ **kwargs
987
+ ) -> pg.Html:
988
+ extra_flags = extra_flags or {}
989
+ as_badge = extra_flags.pop('as_badge', False)
990
+ interactive = extra_flags.get('interactive', True)
991
+ if as_badge:
992
+ usage_badge = self._usage_badge
993
+ if usage_badge is None:
994
+ usage_badge = pg.views.html.controls.Badge(
995
+ self._badge_text(),
996
+ tooltip=pg.format(self.total, verbose=False),
997
+ css_classes=['usage-summary'],
998
+ styles=dict(color=self._badge_color()),
999
+ interactive=True,
1000
+ )
1001
+ if interactive:
1002
+ self._usage_badge = usage_badge
1003
+ return usage_badge.to_html()
1004
+ return super()._html_tree_view(
1005
+ view=view,
1006
+ extra_flags=extra_flags,
1007
+ **kwargs
1008
+ )
1009
+
1010
+ @classmethod
1011
+ @functools.cache
1012
+ def _html_tree_view_css_styles(cls) -> list[str]:
1013
+ return super()._html_tree_view_css_styles() + [
1014
+ """
1015
+ .usage-summary.label {
1016
+ display: inline-flex;
1017
+ border-radius: 5px;
1018
+ padding: 5px;
1019
+ background-color: #f1f1f1;
1020
+ color: #CCC;
1021
+ }
1022
+ .usage-summary.label::before {
1023
+ content: '$';
1024
+ }
1025
+ """
1026
+ ]
914
1027
 
915
1028
  pg.members(
916
1029
  dict(
@@ -938,12 +1051,10 @@ class _UsageTracker:
938
1051
  def __init__(self, model_ids: set[str] | None):
939
1052
  self.model_ids = model_ids
940
1053
  self.usage_summary = UsageSummary()
941
- self._lock = threading.Lock()
942
1054
 
943
1055
  def track(self, model_id: str, usage: LMSamplingUsage, is_cached: bool):
944
1056
  if self.model_ids is None or model_id in self.model_ids:
945
- with self._lock:
946
- self.usage_summary.update(model_id, usage, is_cached)
1057
+ self.usage_summary.add(model_id, usage, is_cached)
947
1058
 
948
1059
 
949
1060
  @contextlib.contextmanager
@@ -685,7 +685,6 @@ class LanguageModelTest(unittest.TestCase):
685
685
  lm2('hi')
686
686
  list(concurrent.concurrent_map(call_lm, ['hi', 'hello']))
687
687
 
688
- print(usages2)
689
688
  self.assertEqual(usages2.uncached.breakdown, {
690
689
  'model2': lm_lib.LMSamplingUsage(100, 100, 200, 1, 1.0),
691
690
  })
@@ -777,7 +776,7 @@ class UsageSummaryTest(unittest.TestCase):
777
776
  self.assertFalse(usage_summary.uncached)
778
777
 
779
778
  # Add uncached.
780
- usage_summary.update(
779
+ usage_summary.add(
781
780
  'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
782
781
  )
783
782
  self.assertEqual(
@@ -788,7 +787,7 @@ class UsageSummaryTest(unittest.TestCase):
788
787
  )
789
788
  # Add cached.
790
789
  self.assertFalse(usage_summary.cached)
791
- usage_summary.update(
790
+ usage_summary.add(
792
791
  'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), True
793
792
  )
794
793
  self.assertEqual(
@@ -798,7 +797,7 @@ class UsageSummaryTest(unittest.TestCase):
798
797
  usage_summary.cached.total, lm_lib.LMSamplingUsage(1, 2, 3, 1, 0.0)
799
798
  )
800
799
  # Add UsageNotAvailable.
801
- usage_summary.update(
800
+ usage_summary.add(
802
801
  'model1', lm_lib.UsageNotAvailable(num_requests=1), False
803
802
  )
804
803
  self.assertEqual(
@@ -808,6 +807,100 @@ class UsageSummaryTest(unittest.TestCase):
808
807
  usage_summary.uncached.total, lm_lib.UsageNotAvailable(num_requests=2)
809
808
  )
810
809
 
810
+ def test_merge(self):
811
+ usage_summary = lm_lib.UsageSummary()
812
+ usage_summary.add(
813
+ 'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
814
+ )
815
+ usage_summary.add(
816
+ 'model2', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
817
+ )
818
+ usage_summary.add(
819
+ 'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
820
+ )
821
+ usage_summary2 = lm_lib.UsageSummary()
822
+ usage_summary2.add(
823
+ 'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
824
+ )
825
+ usage_summary2.add(
826
+ 'model3', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
827
+ )
828
+ usage_summary2.merge(usage_summary)
829
+ self.assertEqual(
830
+ usage_summary2,
831
+ lm_lib.UsageSummary(
832
+ cached=lm_lib.UsageSummary.AggregatedUsage(
833
+ total=lm_lib.LMSamplingUsage(
834
+ prompt_tokens=0,
835
+ completion_tokens=0,
836
+ total_tokens=0,
837
+ num_requests=0,
838
+ estimated_cost=0.0,
839
+ ),
840
+ breakdown={}
841
+ ),
842
+ uncached=lm_lib.UsageSummary.AggregatedUsage(
843
+ total=lm_lib.LMSamplingUsage(
844
+ prompt_tokens=5,
845
+ completion_tokens=10,
846
+ total_tokens=15,
847
+ num_requests=5,
848
+ estimated_cost=25.0
849
+ ),
850
+ breakdown=dict(
851
+ model1=lm_lib.LMSamplingUsage(
852
+ prompt_tokens=3,
853
+ completion_tokens=6,
854
+ total_tokens=9,
855
+ num_requests=3,
856
+ estimated_cost=15.0
857
+ ),
858
+ model3=lm_lib.LMSamplingUsage(
859
+ prompt_tokens=1,
860
+ completion_tokens=2,
861
+ total_tokens=3,
862
+ num_requests=1,
863
+ estimated_cost=5.0
864
+ ),
865
+ model2=lm_lib.LMSamplingUsage(
866
+ prompt_tokens=1,
867
+ completion_tokens=2,
868
+ total_tokens=3,
869
+ num_requests=1,
870
+ estimated_cost=5.0
871
+ )
872
+ )
873
+ )
874
+ )
875
+ )
876
+
877
+ def test_html_view(self):
878
+ usage_summary = lm_lib.UsageSummary()
879
+ usage_summary.add(
880
+ 'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
881
+ )
882
+ self.assertIn(
883
+ '5.000',
884
+ usage_summary.to_html(extra_flags=dict(as_badge=True)).content
885
+ )
886
+ usage_summary.add(
887
+ 'model1', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
888
+ )
889
+ self.assertIn(
890
+ '10.000',
891
+ usage_summary.to_html(
892
+ extra_flags=dict(as_badge=True, interactive=True)
893
+ ).content
894
+ )
895
+ self.assertTrue(
896
+ usage_summary.to_html().content.startswith('<details open')
897
+ )
898
+ with pg.views.html.controls.HtmlControl.track_scripts() as scripts:
899
+ usage_summary.add(
900
+ 'model2', lm_lib.LMSamplingUsage(1, 2, 3, 1, 5.0), False
901
+ )
902
+ self.assertEqual(len(scripts), 4)
903
+
811
904
 
812
905
  if __name__ == '__main__':
813
906
  unittest.main()
@@ -24,6 +24,9 @@ from langfun.core.llms.fake import StaticMapping
24
24
  from langfun.core.llms.fake import StaticResponse
25
25
  from langfun.core.llms.fake import StaticSequence
26
26
 
27
+ # Compositional models.
28
+ from langfun.core.llms.compositional import RandomChoice
29
+
27
30
  # REST-based models.
28
31
  from langfun.core.llms.rest import REST
29
32