langfun 0.0.2.dev20240605__py3-none-any.whl → 0.0.2.dev20240607__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langfun/core/__init__.py CHANGED
@@ -111,12 +111,18 @@ from langfun.core.language_model import RetryableLMError
111
111
  from langfun.core.language_model import RateLimitError
112
112
  from langfun.core.language_model import TemporaryLMError
113
113
 
114
+ # Context manager for tracking usages.
115
+ from langfun.core.language_model import track_usages
116
+
114
117
  # Components for building agents.
115
118
  from langfun.core.memory import Memory
116
119
 
117
120
  # Utility for console output.
118
121
  from langfun.core import console
119
122
 
123
+ # Utility for event logging.
124
+ from langfun.core import logging
125
+
120
126
  # Import internal modules.
121
127
 
122
128
  # pylint: enable=g-import-not-at-top
langfun/core/component.py CHANGED
@@ -210,6 +210,16 @@ def get_contextual_override(var_name: str) -> ContextualOverride | None:
210
210
  return _get_scoped_value(_global_tls, _CONTEXT_OVERRIDES, var_name)
211
211
 
212
212
 
213
+ def context_value(var_name: str, default: Any = RAISE_IF_HAS_ERROR) -> Any:
214
+ """Returns the value of a variable defined in `lf.context`."""
215
+ override = get_contextual_override(var_name)
216
+ if override is None:
217
+ if default == RAISE_IF_HAS_ERROR:
218
+ raise KeyError(f'{var_name!r} does not exist in current context.')
219
+ return default
220
+ return override.value
221
+
222
+
213
223
  def all_contextual_values() -> dict[str, Any]:
214
224
  """Returns all contextual values provided from `lf.context` in scope."""
215
225
  overrides = getattr(_global_tls, _CONTEXT_OVERRIDES, {})
@@ -84,6 +84,11 @@ class ComponentContextTest(unittest.TestCase):
84
84
  lf.get_contextual_override('y'),
85
85
  lf.ContextualOverride(3, cascade=False, override_attrs=False),
86
86
  )
87
+ self.assertEqual(lf.context_value('x'), 3)
88
+ self.assertIsNone(lf.context_value('f', None))
89
+ with self.assertRaisesRegex(KeyError, '.* does not exist'):
90
+ lf.context_value('f')
91
+
87
92
  self.assertEqual(lf.all_contextual_values(), dict(x=3, y=3, z=3))
88
93
 
89
94
  # Member attributes take precedence over `lf.context`.
@@ -14,10 +14,11 @@
14
14
  """Interface for language model."""
15
15
 
16
16
  import abc
17
+ import contextlib
17
18
  import dataclasses
18
19
  import enum
19
20
  import time
20
- from typing import Annotated, Any, Callable, Sequence, Tuple, Type, Union
21
+ from typing import Annotated, Any, Callable, Iterator, Sequence, Tuple, Type, Union
21
22
  from langfun.core import component
22
23
  from langfun.core import concurrent
23
24
  from langfun.core import console
@@ -84,6 +85,13 @@ class LMSamplingUsage(pg.Object):
84
85
  completion_tokens: int
85
86
  total_tokens: int
86
87
 
88
+ def __add__(self, other: 'LMSamplingUsage') -> 'LMSamplingUsage':
89
+ return LMSamplingUsage(
90
+ prompt_tokens=self.prompt_tokens + other.prompt_tokens,
91
+ completion_tokens=self.completion_tokens + other.completion_tokens,
92
+ total_tokens=self.total_tokens + other.total_tokens,
93
+ )
94
+
87
95
 
88
96
  class LMSamplingResult(pg.Object):
89
97
  """Language model response."""
@@ -408,6 +416,16 @@ class LanguageModel(component.Component):
408
416
  total_tokens=usage.total_tokens // n,
409
417
  )
410
418
 
419
+ # Track usage.
420
+ if usage:
421
+ tracked_usages = component.context_value('__tracked_usages__', [])
422
+ model_id = self.model_id
423
+ for usage_dict in tracked_usages:
424
+ if model_id in usage_dict:
425
+ usage_dict[model_id] += usage
426
+ else:
427
+ usage_dict[model_id] = usage
428
+
411
429
  # Track the prompt for corresponding response.
412
430
  response.source = prompt
413
431
 
@@ -692,3 +710,15 @@ class LanguageModel(component.Component):
692
710
  return max(int(requests_per_min / 60), 1) # Max concurrency can't be zero
693
711
  else:
694
712
  return DEFAULT_MAX_CONCURRENCY # Default of 1
713
+
714
+
715
+ @contextlib.contextmanager
716
+ def track_usages() -> Iterator[dict[str, LMSamplingUsage]]:
717
+ """Context manager to track the usage of language models by model ID."""
718
+ tracked_usages = component.context_value('__tracked_usages__', [])
719
+ current_usage = dict()
720
+ with component.context(__tracked_usages__=tracked_usages + [current_usage]):
721
+ try:
722
+ yield current_usage
723
+ finally:
724
+ pass
@@ -580,6 +580,18 @@ class LanguageModelTest(unittest.TestCase):
580
580
  self.assertEqual(lm.rate_to_max_concurrency(requests_per_min=1), 1)
581
581
  self.assertEqual(lm.rate_to_max_concurrency(tokens_per_min=1), 1)
582
582
 
583
+ def test_track_usages(self):
584
+ with lm_lib.track_usages() as usages1:
585
+ lm = MockModel()
586
+ _ = lm('hi')
587
+ with lm_lib.track_usages() as usages2:
588
+ _ = lm('hi')
589
+ self.assertEqual(usages2, {
590
+ 'MockModel': lm_lib.LMSamplingUsage(100, 100, 200),
591
+ })
592
+ self.assertEqual(usages1, {
593
+ 'MockModel': lm_lib.LMSamplingUsage(200, 200, 400),
594
+ })
583
595
 
584
596
  if __name__ == '__main__':
585
597
  unittest.main()
@@ -0,0 +1,168 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Langfun event logging."""
15
+
16
+ import datetime
17
+ import io
18
+ import typing
19
+ from typing import Any, Literal, ContextManager
20
+
21
+ from langfun.core import console
22
+ import pyglove as pg
23
+
24
+
25
+ LogLevel = Literal['debug', 'info', 'error', 'warning', 'fatal']
26
+ _LOG_LEVELS = list(typing.get_args(LogLevel))
27
+ _TLS_KEY_MIN_LOG_LEVEL = '_event_log_level'
28
+
29
+
30
+ def use_log_level(log_level: LogLevel | None = 'info') -> ContextManager[None]:
31
+ """Contextmanager to enable logging at a given level."""
32
+ return pg.object_utils.thread_local_value_scope(
33
+ _TLS_KEY_MIN_LOG_LEVEL, log_level, 'info')
34
+
35
+
36
+ def get_log_level() -> LogLevel | None:
37
+ """Gets the current minimum log level."""
38
+ return pg.object_utils.thread_local_get(_TLS_KEY_MIN_LOG_LEVEL, 'info')
39
+
40
+
41
+ class LogEntry(pg.Object):
42
+ """Event log entry."""
43
+ time: datetime.datetime
44
+ level: LogLevel
45
+ message: str
46
+ metadata: dict[str, Any] = pg.Dict()
47
+ indent: int = 0
48
+
49
+ def should_output(self, min_log_level: LogLevel) -> bool:
50
+ return _LOG_LEVELS.index(self.level) >= _LOG_LEVELS.index(min_log_level)
51
+
52
+ def _repr_html_(self) -> str:
53
+ s = io.StringIO()
54
+ padding_left = 50 * self.indent
55
+ s.write(f'<div style="padding-left: {padding_left}px;">')
56
+ s.write(self._message_display)
57
+ if self.metadata:
58
+ s.write('<div style="padding-left: 20px; margin-top: 10px">')
59
+ s.write('<table style="border-top: 1px solid #EEEEEE;">')
60
+ for k, v in self.metadata.items():
61
+ if hasattr(v, '_repr_html_'):
62
+ cs = v._repr_html_() # pylint: disable=protected-access
63
+ else:
64
+ cs = f'<span style="white-space: pre-wrap">{str(v)}</span>'
65
+ key_span = self._round_text(k, color='#F1C40F', margin_bottom='0px')
66
+ s.write(
67
+ '<tr>'
68
+ '<td style="padding: 5px; vertical-align: top; '
69
+ f'border-bottom: 1px solid #EEEEEE">{key_span}</td>'
70
+ '<td style="padding: 5px; vertical-align: top; '
71
+ f'border-bottom: 1px solid #EEEEEE">{cs}</td></tr>'
72
+ )
73
+ s.write('</table></div>')
74
+ return s.getvalue()
75
+
76
+ @property
77
+ def _message_text_color(self) -> str:
78
+ match self.level:
79
+ case 'debug':
80
+ return '#EEEEEE'
81
+ case 'info':
82
+ return '#A3E4D7'
83
+ case 'error':
84
+ return '#F5C6CB'
85
+ case 'fatal':
86
+ return '#F19CBB'
87
+ case _:
88
+ raise ValueError(f'Unknown log level: {self.level}')
89
+
90
+ @property
91
+ def _time_display(self) -> str:
92
+ display_text = self.time.strftime('%H:%M:%S')
93
+ alt_text = self.time.strftime('%Y-%m-%d %H:%M:%S.%f')
94
+ return (
95
+ '<span style="background-color: #BBBBBB; color: white; '
96
+ 'border-radius:5px; padding:0px 5px 0px 5px;" '
97
+ f'title="{alt_text}">{display_text}</span>'
98
+ )
99
+
100
+ @property
101
+ def _message_display(self) -> str:
102
+ return self._round_text(
103
+ self._time_display + '&nbsp;' + self.message,
104
+ color=self._message_text_color,
105
+ )
106
+
107
+ def _round_text(
108
+ self,
109
+ text: str,
110
+ *,
111
+ color: str = '#EEEEEE',
112
+ display: str = 'inline-block',
113
+ margin_top: str = '5px',
114
+ margin_bottom: str = '5px',
115
+ whitespace: str = 'pre-wrap') -> str:
116
+ return (
117
+ f'<span style="background:{color}; display:{display};'
118
+ f'border-radius:10px; padding:5px; '
119
+ f'margin-top: {margin_top}; margin-bottom: {margin_bottom}; '
120
+ f'white-space: {whitespace}">{text}</span>'
121
+ )
122
+
123
+
124
+ def log(level: LogLevel,
125
+ message: str,
126
+ *,
127
+ indent: int = 0,
128
+ **kwargs) -> LogEntry:
129
+ """Logs a message."""
130
+ entry = LogEntry(
131
+ indent=indent,
132
+ level=level,
133
+ time=datetime.datetime.now(),
134
+ message=message,
135
+ metadata=kwargs,
136
+ )
137
+ if entry.should_output(get_log_level()):
138
+ if console.under_notebook():
139
+ console.display(entry)
140
+ else:
141
+ # TODO(daiyip): Improve the console output formatting.
142
+ console.write(entry)
143
+ return entry
144
+
145
+
146
+ def debug(message: str, *, indent: int = 0, **kwargs) -> LogEntry:
147
+ """Logs a debug message to the session."""
148
+ return log('debug', message, indent=indent, **kwargs)
149
+
150
+
151
+ def info(message: str, *, indent: int = 0, **kwargs) -> LogEntry:
152
+ """Logs an info message to the session."""
153
+ return log('info', message, indent=indent, **kwargs)
154
+
155
+
156
+ def warning(message: str, *, indent: int = 0, **kwargs) -> LogEntry:
157
+ """Logs an info message to the session."""
158
+ return log('warning', message, indent=indent, **kwargs)
159
+
160
+
161
+ def error(message: str, *, indent: int = 0, **kwargs) -> LogEntry:
162
+ """Logs an error message to the session."""
163
+ return log('error', message, indent=indent, **kwargs)
164
+
165
+
166
+ def fatal(message: str, *, indent: int = 0, **kwargs) -> LogEntry:
167
+ """Logs a fatal message to the session."""
168
+ return log('fatal', message, indent=indent, **kwargs)
@@ -0,0 +1,51 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for langfun.core.logging."""
15
+
16
+ import unittest
17
+
18
+ from langfun.core import logging
19
+
20
+
21
+ class LoggingTest(unittest.TestCase):
22
+
23
+ def test_use_log_level(self):
24
+ self.assertEqual(logging.get_log_level(), 'info')
25
+ with logging.use_log_level('debug'):
26
+ self.assertEqual(logging.get_log_level(), 'debug')
27
+ with logging.use_log_level(None):
28
+ self.assertIsNone(logging.get_log_level(), None)
29
+ self.assertEqual(logging.get_log_level(), 'debug')
30
+ self.assertEqual(logging.get_log_level(), 'info')
31
+
32
+ def test_log(self):
33
+ entry = logging.log('info', 'hi', indent=1, x=1, y=2)
34
+ self.assertEqual(entry.level, 'info')
35
+ self.assertEqual(entry.message, 'hi')
36
+ self.assertEqual(entry.indent, 1)
37
+ self.assertEqual(entry.metadata, {'x': 1, 'y': 2})
38
+
39
+ self.assertEqual(logging.debug('hi').level, 'debug')
40
+ self.assertEqual(logging.info('hi').level, 'info')
41
+ self.assertEqual(logging.warning('hi').level, 'warning')
42
+ self.assertEqual(logging.error('hi').level, 'error')
43
+ self.assertEqual(logging.fatal('hi').level, 'fatal')
44
+
45
+ def test_repr_html(self):
46
+ entry = logging.log('info', 'hi', indent=1, x=1, y=2)
47
+ self.assertIn('<div', entry._repr_html_())
48
+
49
+
50
+ if __name__ == '__main__':
51
+ unittest.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.0.2.dev20240605
3
+ Version: 0.0.2.dev20240607
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -1,15 +1,17 @@
1
1
  langfun/__init__.py,sha256=P62MnqA6-f0h8iYfQ3MT6Yg7a4qRnQeb4GrIn6dcSnY,2274
2
- langfun/core/__init__.py,sha256=ZheiCpop_GAZbVpnSS-uPBJaEEM15Td5xFGGizSGqko,4514
3
- langfun/core/component.py,sha256=oxesbC0BoE_TbtxwW5x-BAZWxZyyJbuPiX5S38RqCv0,9909
4
- langfun/core/component_test.py,sha256=uR-_Sz_42Jxc5qzLIB-f5_pXmNwnC01Xlbv5NOQSeSU,8021
2
+ langfun/core/__init__.py,sha256=F3WGAww--u0CdQkg4ENBudsd0ZeGdecscF0R3YXSWmE,4670
3
+ langfun/core/component.py,sha256=Icyoj9ICoJoK2r2PHbrFXbxnseOr9QZZOvKWklLWNo8,10276
4
+ langfun/core/component_test.py,sha256=q15Xn51cVTu2RKxZ9U5VQgT3bm6RQ4638bKhWBtvW5o,8220
5
5
  langfun/core/concurrent.py,sha256=TRc49pJ3HQro2kb5FtcWkHjhBm8UcgE8RJybU5cU3-0,24537
6
6
  langfun/core/concurrent_test.py,sha256=mwFMZhDUdppnDr7vDSTwcbMHwrdsIoKJwRYNtl4ZWL4,15185
7
7
  langfun/core/console.py,sha256=bk5rNPNm9rMGW5YT2HixxU04p2umnoabn5SDz6Dqe88,2317
8
8
  langfun/core/console_test.py,sha256=5SYJdxpJGLgdSSQqqMPoA1X6jpsLD8rgcyk-EgI65oE,1077
9
9
  langfun/core/langfunc.py,sha256=RvIcRjIq0jWYRu1xim-FYe4HSrt97r3GMBO_PuagUmw,11060
10
10
  langfun/core/langfunc_test.py,sha256=_mfARnakX3oji5HDigFSLMd6yQ2wma-2Mgbztwqn73g,8501
11
- langfun/core/language_model.py,sha256=PocBg1t3uB0a_bJntLW5aagHhNbZsVdp2iduSBEW6ro,21240
12
- langfun/core/language_model_test.py,sha256=NZaSUls6cZdtxiqkqumWbtkx9zgNiJlsviYZOWkuHig,20137
11
+ langfun/core/language_model.py,sha256=geHX9bHLjWGlubgGvSLTGluFyB7x_NIk09p5KPkp1Wg,22328
12
+ langfun/core/language_model_test.py,sha256=UfAg7ExCWSVqFpFGFOvQWSwO343FcvzVImb_F3hTfa0,20517
13
+ langfun/core/logging.py,sha256=FyZRxUy2TTF6tWLhQCRpCvfH55WGUdNgQjUTK_SQLnY,5320
14
+ langfun/core/logging_test.py,sha256=qvm3RObYP3knO2PnXR9evBRl4gH621GnjnwywbGbRfg,1833
13
15
  langfun/core/memory.py,sha256=f-asN1F7Vehgdn_fK84v73GrEUOxRtaW934keutTKjk,2416
14
16
  langfun/core/message.py,sha256=Rw3yC9HyGRjMhfDgyNjGlSCALEyDDbJ0_o6qTXeeDiQ,15738
15
17
  langfun/core/message_test.py,sha256=b6DDRoQ5j3uK-dc0QPSLelNTKaXX10MxJrRiI61iGX4,9574
@@ -113,8 +115,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
113
115
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
114
116
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
115
117
  langfun/core/templates/selfplay_test.py,sha256=DYVrkk7uNKCqJGEHH31HssU2BPuMItU1vJLzfcXIlYg,2156
116
- langfun-0.0.2.dev20240605.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
117
- langfun-0.0.2.dev20240605.dist-info/METADATA,sha256=NMWv4oYMcRuZXl22coMTShqGC8hj_Y2PGfTk7n-Alt0,3550
118
- langfun-0.0.2.dev20240605.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
119
- langfun-0.0.2.dev20240605.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
120
- langfun-0.0.2.dev20240605.dist-info/RECORD,,
118
+ langfun-0.0.2.dev20240607.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
119
+ langfun-0.0.2.dev20240607.dist-info/METADATA,sha256=rFmOTbbcV4Jlik6ma8c87fUqXmYpscGfb3_FcMN2Nf0,3550
120
+ langfun-0.0.2.dev20240607.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
121
+ langfun-0.0.2.dev20240607.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
122
+ langfun-0.0.2.dev20240607.dist-info/RECORD,,