langfun 0.0.2.dev20240606__py3-none-any.whl → 0.0.2.dev20240607__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/__init__.py +6 -0
- langfun/core/component.py +10 -0
- langfun/core/component_test.py +5 -0
- langfun/core/language_model.py +31 -1
- langfun/core/language_model_test.py +12 -0
- langfun/core/logging.py +168 -0
- langfun/core/logging_test.py +51 -0
- {langfun-0.0.2.dev20240606.dist-info → langfun-0.0.2.dev20240607.dist-info}/METADATA +1 -1
- {langfun-0.0.2.dev20240606.dist-info → langfun-0.0.2.dev20240607.dist-info}/RECORD +12 -10
- {langfun-0.0.2.dev20240606.dist-info → langfun-0.0.2.dev20240607.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240606.dist-info → langfun-0.0.2.dev20240607.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240606.dist-info → langfun-0.0.2.dev20240607.dist-info}/top_level.txt +0 -0
langfun/core/__init__.py
CHANGED
@@ -111,12 +111,18 @@ from langfun.core.language_model import RetryableLMError
|
|
111
111
|
from langfun.core.language_model import RateLimitError
|
112
112
|
from langfun.core.language_model import TemporaryLMError
|
113
113
|
|
114
|
+
# Context manager for tracking usages.
|
115
|
+
from langfun.core.language_model import track_usages
|
116
|
+
|
114
117
|
# Components for building agents.
|
115
118
|
from langfun.core.memory import Memory
|
116
119
|
|
117
120
|
# Utility for console output.
|
118
121
|
from langfun.core import console
|
119
122
|
|
123
|
+
# Utility for event logging.
|
124
|
+
from langfun.core import logging
|
125
|
+
|
120
126
|
# Import internal modules.
|
121
127
|
|
122
128
|
# pylint: enable=g-import-not-at-top
|
langfun/core/component.py
CHANGED
@@ -210,6 +210,16 @@ def get_contextual_override(var_name: str) -> ContextualOverride | None:
|
|
210
210
|
return _get_scoped_value(_global_tls, _CONTEXT_OVERRIDES, var_name)
|
211
211
|
|
212
212
|
|
213
|
+
def context_value(var_name: str, default: Any = RAISE_IF_HAS_ERROR) -> Any:
|
214
|
+
"""Returns the value of a variable defined in `lf.context`."""
|
215
|
+
override = get_contextual_override(var_name)
|
216
|
+
if override is None:
|
217
|
+
if default == RAISE_IF_HAS_ERROR:
|
218
|
+
raise KeyError(f'{var_name!r} does not exist in current context.')
|
219
|
+
return default
|
220
|
+
return override.value
|
221
|
+
|
222
|
+
|
213
223
|
def all_contextual_values() -> dict[str, Any]:
|
214
224
|
"""Returns all contextual values provided from `lf.context` in scope."""
|
215
225
|
overrides = getattr(_global_tls, _CONTEXT_OVERRIDES, {})
|
langfun/core/component_test.py
CHANGED
@@ -84,6 +84,11 @@ class ComponentContextTest(unittest.TestCase):
|
|
84
84
|
lf.get_contextual_override('y'),
|
85
85
|
lf.ContextualOverride(3, cascade=False, override_attrs=False),
|
86
86
|
)
|
87
|
+
self.assertEqual(lf.context_value('x'), 3)
|
88
|
+
self.assertIsNone(lf.context_value('f', None))
|
89
|
+
with self.assertRaisesRegex(KeyError, '.* does not exist'):
|
90
|
+
lf.context_value('f')
|
91
|
+
|
87
92
|
self.assertEqual(lf.all_contextual_values(), dict(x=3, y=3, z=3))
|
88
93
|
|
89
94
|
# Member attributes take precedence over `lf.context`.
|
langfun/core/language_model.py
CHANGED
@@ -14,10 +14,11 @@
|
|
14
14
|
"""Interface for language model."""
|
15
15
|
|
16
16
|
import abc
|
17
|
+
import contextlib
|
17
18
|
import dataclasses
|
18
19
|
import enum
|
19
20
|
import time
|
20
|
-
from typing import Annotated, Any, Callable, Sequence, Tuple, Type, Union
|
21
|
+
from typing import Annotated, Any, Callable, Iterator, Sequence, Tuple, Type, Union
|
21
22
|
from langfun.core import component
|
22
23
|
from langfun.core import concurrent
|
23
24
|
from langfun.core import console
|
@@ -84,6 +85,13 @@ class LMSamplingUsage(pg.Object):
|
|
84
85
|
completion_tokens: int
|
85
86
|
total_tokens: int
|
86
87
|
|
88
|
+
def __add__(self, other: 'LMSamplingUsage') -> 'LMSamplingUsage':
|
89
|
+
return LMSamplingUsage(
|
90
|
+
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
|
91
|
+
completion_tokens=self.completion_tokens + other.completion_tokens,
|
92
|
+
total_tokens=self.total_tokens + other.total_tokens,
|
93
|
+
)
|
94
|
+
|
87
95
|
|
88
96
|
class LMSamplingResult(pg.Object):
|
89
97
|
"""Language model response."""
|
@@ -408,6 +416,16 @@ class LanguageModel(component.Component):
|
|
408
416
|
total_tokens=usage.total_tokens // n,
|
409
417
|
)
|
410
418
|
|
419
|
+
# Track usage.
|
420
|
+
if usage:
|
421
|
+
tracked_usages = component.context_value('__tracked_usages__', [])
|
422
|
+
model_id = self.model_id
|
423
|
+
for usage_dict in tracked_usages:
|
424
|
+
if model_id in usage_dict:
|
425
|
+
usage_dict[model_id] += usage
|
426
|
+
else:
|
427
|
+
usage_dict[model_id] = usage
|
428
|
+
|
411
429
|
# Track the prompt for corresponding response.
|
412
430
|
response.source = prompt
|
413
431
|
|
@@ -692,3 +710,15 @@ class LanguageModel(component.Component):
|
|
692
710
|
return max(int(requests_per_min / 60), 1) # Max concurrency can't be zero
|
693
711
|
else:
|
694
712
|
return DEFAULT_MAX_CONCURRENCY # Default of 1
|
713
|
+
|
714
|
+
|
715
|
+
@contextlib.contextmanager
|
716
|
+
def track_usages() -> Iterator[dict[str, LMSamplingUsage]]:
|
717
|
+
"""Context manager to track the usage of language models by model ID."""
|
718
|
+
tracked_usages = component.context_value('__tracked_usages__', [])
|
719
|
+
current_usage = dict()
|
720
|
+
with component.context(__tracked_usages__=tracked_usages + [current_usage]):
|
721
|
+
try:
|
722
|
+
yield current_usage
|
723
|
+
finally:
|
724
|
+
pass
|
@@ -580,6 +580,18 @@ class LanguageModelTest(unittest.TestCase):
|
|
580
580
|
self.assertEqual(lm.rate_to_max_concurrency(requests_per_min=1), 1)
|
581
581
|
self.assertEqual(lm.rate_to_max_concurrency(tokens_per_min=1), 1)
|
582
582
|
|
583
|
+
def test_track_usages(self):
|
584
|
+
with lm_lib.track_usages() as usages1:
|
585
|
+
lm = MockModel()
|
586
|
+
_ = lm('hi')
|
587
|
+
with lm_lib.track_usages() as usages2:
|
588
|
+
_ = lm('hi')
|
589
|
+
self.assertEqual(usages2, {
|
590
|
+
'MockModel': lm_lib.LMSamplingUsage(100, 100, 200),
|
591
|
+
})
|
592
|
+
self.assertEqual(usages1, {
|
593
|
+
'MockModel': lm_lib.LMSamplingUsage(200, 200, 400),
|
594
|
+
})
|
583
595
|
|
584
596
|
if __name__ == '__main__':
|
585
597
|
unittest.main()
|
langfun/core/logging.py
ADDED
@@ -0,0 +1,168 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Langfun event logging."""
|
15
|
+
|
16
|
+
import datetime
|
17
|
+
import io
|
18
|
+
import typing
|
19
|
+
from typing import Any, Literal, ContextManager
|
20
|
+
|
21
|
+
from langfun.core import console
|
22
|
+
import pyglove as pg
|
23
|
+
|
24
|
+
|
25
|
+
LogLevel = Literal['debug', 'info', 'error', 'warning', 'fatal']
|
26
|
+
_LOG_LEVELS = list(typing.get_args(LogLevel))
|
27
|
+
_TLS_KEY_MIN_LOG_LEVEL = '_event_log_level'
|
28
|
+
|
29
|
+
|
30
|
+
def use_log_level(log_level: LogLevel | None = 'info') -> ContextManager[None]:
|
31
|
+
"""Contextmanager to enable logging at a given level."""
|
32
|
+
return pg.object_utils.thread_local_value_scope(
|
33
|
+
_TLS_KEY_MIN_LOG_LEVEL, log_level, 'info')
|
34
|
+
|
35
|
+
|
36
|
+
def get_log_level() -> LogLevel | None:
|
37
|
+
"""Gets the current minimum log level."""
|
38
|
+
return pg.object_utils.thread_local_get(_TLS_KEY_MIN_LOG_LEVEL, 'info')
|
39
|
+
|
40
|
+
|
41
|
+
class LogEntry(pg.Object):
|
42
|
+
"""Event log entry."""
|
43
|
+
time: datetime.datetime
|
44
|
+
level: LogLevel
|
45
|
+
message: str
|
46
|
+
metadata: dict[str, Any] = pg.Dict()
|
47
|
+
indent: int = 0
|
48
|
+
|
49
|
+
def should_output(self, min_log_level: LogLevel) -> bool:
|
50
|
+
return _LOG_LEVELS.index(self.level) >= _LOG_LEVELS.index(min_log_level)
|
51
|
+
|
52
|
+
def _repr_html_(self) -> str:
|
53
|
+
s = io.StringIO()
|
54
|
+
padding_left = 50 * self.indent
|
55
|
+
s.write(f'<div style="padding-left: {padding_left}px;">')
|
56
|
+
s.write(self._message_display)
|
57
|
+
if self.metadata:
|
58
|
+
s.write('<div style="padding-left: 20px; margin-top: 10px">')
|
59
|
+
s.write('<table style="border-top: 1px solid #EEEEEE;">')
|
60
|
+
for k, v in self.metadata.items():
|
61
|
+
if hasattr(v, '_repr_html_'):
|
62
|
+
cs = v._repr_html_() # pylint: disable=protected-access
|
63
|
+
else:
|
64
|
+
cs = f'<span style="white-space: pre-wrap">{str(v)}</span>'
|
65
|
+
key_span = self._round_text(k, color='#F1C40F', margin_bottom='0px')
|
66
|
+
s.write(
|
67
|
+
'<tr>'
|
68
|
+
'<td style="padding: 5px; vertical-align: top; '
|
69
|
+
f'border-bottom: 1px solid #EEEEEE">{key_span}</td>'
|
70
|
+
'<td style="padding: 5px; vertical-align: top; '
|
71
|
+
f'border-bottom: 1px solid #EEEEEE">{cs}</td></tr>'
|
72
|
+
)
|
73
|
+
s.write('</table></div>')
|
74
|
+
return s.getvalue()
|
75
|
+
|
76
|
+
@property
|
77
|
+
def _message_text_color(self) -> str:
|
78
|
+
match self.level:
|
79
|
+
case 'debug':
|
80
|
+
return '#EEEEEE'
|
81
|
+
case 'info':
|
82
|
+
return '#A3E4D7'
|
83
|
+
case 'error':
|
84
|
+
return '#F5C6CB'
|
85
|
+
case 'fatal':
|
86
|
+
return '#F19CBB'
|
87
|
+
case _:
|
88
|
+
raise ValueError(f'Unknown log level: {self.level}')
|
89
|
+
|
90
|
+
@property
|
91
|
+
def _time_display(self) -> str:
|
92
|
+
display_text = self.time.strftime('%H:%M:%S')
|
93
|
+
alt_text = self.time.strftime('%Y-%m-%d %H:%M:%S.%f')
|
94
|
+
return (
|
95
|
+
'<span style="background-color: #BBBBBB; color: white; '
|
96
|
+
'border-radius:5px; padding:0px 5px 0px 5px;" '
|
97
|
+
f'title="{alt_text}">{display_text}</span>'
|
98
|
+
)
|
99
|
+
|
100
|
+
@property
|
101
|
+
def _message_display(self) -> str:
|
102
|
+
return self._round_text(
|
103
|
+
self._time_display + ' ' + self.message,
|
104
|
+
color=self._message_text_color,
|
105
|
+
)
|
106
|
+
|
107
|
+
def _round_text(
|
108
|
+
self,
|
109
|
+
text: str,
|
110
|
+
*,
|
111
|
+
color: str = '#EEEEEE',
|
112
|
+
display: str = 'inline-block',
|
113
|
+
margin_top: str = '5px',
|
114
|
+
margin_bottom: str = '5px',
|
115
|
+
whitespace: str = 'pre-wrap') -> str:
|
116
|
+
return (
|
117
|
+
f'<span style="background:{color}; display:{display};'
|
118
|
+
f'border-radius:10px; padding:5px; '
|
119
|
+
f'margin-top: {margin_top}; margin-bottom: {margin_bottom}; '
|
120
|
+
f'white-space: {whitespace}">{text}</span>'
|
121
|
+
)
|
122
|
+
|
123
|
+
|
124
|
+
def log(level: LogLevel,
|
125
|
+
message: str,
|
126
|
+
*,
|
127
|
+
indent: int = 0,
|
128
|
+
**kwargs) -> LogEntry:
|
129
|
+
"""Logs a message."""
|
130
|
+
entry = LogEntry(
|
131
|
+
indent=indent,
|
132
|
+
level=level,
|
133
|
+
time=datetime.datetime.now(),
|
134
|
+
message=message,
|
135
|
+
metadata=kwargs,
|
136
|
+
)
|
137
|
+
if entry.should_output(get_log_level()):
|
138
|
+
if console.under_notebook():
|
139
|
+
console.display(entry)
|
140
|
+
else:
|
141
|
+
# TODO(daiyip): Improve the console output formatting.
|
142
|
+
console.write(entry)
|
143
|
+
return entry
|
144
|
+
|
145
|
+
|
146
|
+
def debug(message: str, *, indent: int = 0, **kwargs) -> LogEntry:
|
147
|
+
"""Logs a debug message to the session."""
|
148
|
+
return log('debug', message, indent=indent, **kwargs)
|
149
|
+
|
150
|
+
|
151
|
+
def info(message: str, *, indent: int = 0, **kwargs) -> LogEntry:
|
152
|
+
"""Logs an info message to the session."""
|
153
|
+
return log('info', message, indent=indent, **kwargs)
|
154
|
+
|
155
|
+
|
156
|
+
def warning(message: str, *, indent: int = 0, **kwargs) -> LogEntry:
|
157
|
+
"""Logs an info message to the session."""
|
158
|
+
return log('warning', message, indent=indent, **kwargs)
|
159
|
+
|
160
|
+
|
161
|
+
def error(message: str, *, indent: int = 0, **kwargs) -> LogEntry:
|
162
|
+
"""Logs an error message to the session."""
|
163
|
+
return log('error', message, indent=indent, **kwargs)
|
164
|
+
|
165
|
+
|
166
|
+
def fatal(message: str, *, indent: int = 0, **kwargs) -> LogEntry:
|
167
|
+
"""Logs a fatal message to the session."""
|
168
|
+
return log('fatal', message, indent=indent, **kwargs)
|
@@ -0,0 +1,51 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tests for langfun.core.logging."""
|
15
|
+
|
16
|
+
import unittest
|
17
|
+
|
18
|
+
from langfun.core import logging
|
19
|
+
|
20
|
+
|
21
|
+
class LoggingTest(unittest.TestCase):
|
22
|
+
|
23
|
+
def test_use_log_level(self):
|
24
|
+
self.assertEqual(logging.get_log_level(), 'info')
|
25
|
+
with logging.use_log_level('debug'):
|
26
|
+
self.assertEqual(logging.get_log_level(), 'debug')
|
27
|
+
with logging.use_log_level(None):
|
28
|
+
self.assertIsNone(logging.get_log_level(), None)
|
29
|
+
self.assertEqual(logging.get_log_level(), 'debug')
|
30
|
+
self.assertEqual(logging.get_log_level(), 'info')
|
31
|
+
|
32
|
+
def test_log(self):
|
33
|
+
entry = logging.log('info', 'hi', indent=1, x=1, y=2)
|
34
|
+
self.assertEqual(entry.level, 'info')
|
35
|
+
self.assertEqual(entry.message, 'hi')
|
36
|
+
self.assertEqual(entry.indent, 1)
|
37
|
+
self.assertEqual(entry.metadata, {'x': 1, 'y': 2})
|
38
|
+
|
39
|
+
self.assertEqual(logging.debug('hi').level, 'debug')
|
40
|
+
self.assertEqual(logging.info('hi').level, 'info')
|
41
|
+
self.assertEqual(logging.warning('hi').level, 'warning')
|
42
|
+
self.assertEqual(logging.error('hi').level, 'error')
|
43
|
+
self.assertEqual(logging.fatal('hi').level, 'fatal')
|
44
|
+
|
45
|
+
def test_repr_html(self):
|
46
|
+
entry = logging.log('info', 'hi', indent=1, x=1, y=2)
|
47
|
+
self.assertIn('<div', entry._repr_html_())
|
48
|
+
|
49
|
+
|
50
|
+
if __name__ == '__main__':
|
51
|
+
unittest.main()
|
@@ -1,15 +1,17 @@
|
|
1
1
|
langfun/__init__.py,sha256=P62MnqA6-f0h8iYfQ3MT6Yg7a4qRnQeb4GrIn6dcSnY,2274
|
2
|
-
langfun/core/__init__.py,sha256=
|
3
|
-
langfun/core/component.py,sha256=
|
4
|
-
langfun/core/component_test.py,sha256=
|
2
|
+
langfun/core/__init__.py,sha256=F3WGAww--u0CdQkg4ENBudsd0ZeGdecscF0R3YXSWmE,4670
|
3
|
+
langfun/core/component.py,sha256=Icyoj9ICoJoK2r2PHbrFXbxnseOr9QZZOvKWklLWNo8,10276
|
4
|
+
langfun/core/component_test.py,sha256=q15Xn51cVTu2RKxZ9U5VQgT3bm6RQ4638bKhWBtvW5o,8220
|
5
5
|
langfun/core/concurrent.py,sha256=TRc49pJ3HQro2kb5FtcWkHjhBm8UcgE8RJybU5cU3-0,24537
|
6
6
|
langfun/core/concurrent_test.py,sha256=mwFMZhDUdppnDr7vDSTwcbMHwrdsIoKJwRYNtl4ZWL4,15185
|
7
7
|
langfun/core/console.py,sha256=bk5rNPNm9rMGW5YT2HixxU04p2umnoabn5SDz6Dqe88,2317
|
8
8
|
langfun/core/console_test.py,sha256=5SYJdxpJGLgdSSQqqMPoA1X6jpsLD8rgcyk-EgI65oE,1077
|
9
9
|
langfun/core/langfunc.py,sha256=RvIcRjIq0jWYRu1xim-FYe4HSrt97r3GMBO_PuagUmw,11060
|
10
10
|
langfun/core/langfunc_test.py,sha256=_mfARnakX3oji5HDigFSLMd6yQ2wma-2Mgbztwqn73g,8501
|
11
|
-
langfun/core/language_model.py,sha256=
|
12
|
-
langfun/core/language_model_test.py,sha256=
|
11
|
+
langfun/core/language_model.py,sha256=geHX9bHLjWGlubgGvSLTGluFyB7x_NIk09p5KPkp1Wg,22328
|
12
|
+
langfun/core/language_model_test.py,sha256=UfAg7ExCWSVqFpFGFOvQWSwO343FcvzVImb_F3hTfa0,20517
|
13
|
+
langfun/core/logging.py,sha256=FyZRxUy2TTF6tWLhQCRpCvfH55WGUdNgQjUTK_SQLnY,5320
|
14
|
+
langfun/core/logging_test.py,sha256=qvm3RObYP3knO2PnXR9evBRl4gH621GnjnwywbGbRfg,1833
|
13
15
|
langfun/core/memory.py,sha256=f-asN1F7Vehgdn_fK84v73GrEUOxRtaW934keutTKjk,2416
|
14
16
|
langfun/core/message.py,sha256=Rw3yC9HyGRjMhfDgyNjGlSCALEyDDbJ0_o6qTXeeDiQ,15738
|
15
17
|
langfun/core/message_test.py,sha256=b6DDRoQ5j3uK-dc0QPSLelNTKaXX10MxJrRiI61iGX4,9574
|
@@ -113,8 +115,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
|
|
113
115
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
114
116
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
115
117
|
langfun/core/templates/selfplay_test.py,sha256=DYVrkk7uNKCqJGEHH31HssU2BPuMItU1vJLzfcXIlYg,2156
|
116
|
-
langfun-0.0.2.
|
117
|
-
langfun-0.0.2.
|
118
|
-
langfun-0.0.2.
|
119
|
-
langfun-0.0.2.
|
120
|
-
langfun-0.0.2.
|
118
|
+
langfun-0.0.2.dev20240607.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
119
|
+
langfun-0.0.2.dev20240607.dist-info/METADATA,sha256=rFmOTbbcV4Jlik6ma8c87fUqXmYpscGfb3_FcMN2Nf0,3550
|
120
|
+
langfun-0.0.2.dev20240607.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
121
|
+
langfun-0.0.2.dev20240607.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
122
|
+
langfun-0.0.2.dev20240607.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|